diff --git a/cmd/serve.go b/cmd/serve.go index aff7c7c8..ecc4d4a1 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -44,6 +44,8 @@ var flagsServe = append( altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-batch-timeout", Aliases: []string{"cache_batch_timeout"}, EnvVars: []string{"NTFY_CACHE_BATCH_TIMEOUT"}, Usage: "timeout for batched async writes to the message cache (if zero, writes are synchronous)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-startup-queries", Aliases: []string{"cache_startup_queries"}, EnvVars: []string{"NTFY_CACHE_STARTUP_QUERIES"}, Usage: "queries run when the cache database is initialized"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}), @@ -110,6 +112,8 @@ func execServe(c *cli.Context) error { cacheFile := c.String("cache-file") cacheDuration := c.Duration("cache-duration") cacheStartupQueries := c.String("cache-startup-queries") + cacheBatchSize := c.Int("cache-batch-size") + cacheBatchTimeout := c.Duration("cache-batch-timeout") authFile := c.String("auth-file") authDefaultAccess := c.String("auth-default-access") attachmentCacheDir := c.String("attachment-cache-dir") @@ -233,6 +237,8 @@ func execServe(c *cli.Context) error { conf.CacheFile = cacheFile conf.CacheDuration = cacheDuration conf.CacheStartupQueries = cacheStartupQueries + conf.CacheBatchSize = cacheBatchSize + conf.CacheBatchTimeout = cacheBatchTimeout conf.AuthFile = authFile conf.AuthDefaultRead = authDefaultRead conf.AuthDefaultWrite = authDefaultWrite diff --git a/server/config.go b/server/config.go index d8fd429e..1e2b517c 100644 --- a/server/config.go +++ b/server/config.go @@ -61,6 +61,8 @@ type Config struct { CacheFile string CacheDuration time.Duration CacheStartupQueries string + CacheBatchSize int + CacheBatchTimeout time.Duration AuthFile string AuthDefaultRead bool AuthDefaultWrite bool @@ -114,6 +116,8 @@ func NewConfig() *Config { FirebaseKeyFile: "", CacheFile: "", CacheDuration: DefaultCacheDuration, + CacheBatchSize: 0, + CacheBatchTimeout: 0, AuthFile: "", AuthDefaultRead: true, AuthDefaultWrite: true, diff --git a/server/message_cache.go b/server/message_cache.go index ec710e4f..7eb37cf9 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -44,6 +44,7 @@ const ( published INT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); COMMIT; ` @@ -92,7 +93,7 @@ const ( // Schema management queries const ( - currentSchemaVersion = 8 + currentSchemaVersion = 9 createSchemaVersionTableQuery = ` CREATE TABLE IF NOT EXISTS schemaVersion ( id INT PRIMARY KEY, @@ -185,6 +186,11 @@ const ( migrate7To8AlterMessagesTableQuery = ` ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT(''); ` + + // 8 -> 9 + migrate8To9AlterMessagesTableQuery = ` + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + ` ) type messageCache struct { @@ -194,7 +200,7 @@ type messageCache struct { } // newSqliteCache creates a SQLite file-backed cache -func newSqliteCache(filename, startupQueries string, nop bool) (*messageCache, error) { +func newSqliteCache(filename, startupQueries string, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -202,32 +208,28 @@ func newSqliteCache(filename, startupQueries string, nop bool) (*messageCache, e if err := setupCacheDB(db, startupQueries); err != nil { return nil, err } - queue := util.NewBatchingQueue[*message](20, 500*time.Millisecond) + var queue *util.BatchingQueue[*message] + if batchSize > 0 || batchTimeout > 0 { + queue = util.NewBatchingQueue[*message](batchSize, batchTimeout) + } cache := &messageCache{ db: db, queue: queue, nop: nop, } - go func() { - for messages := range queue.Pop() { - log.Debug("Adding %d messages to cache", len(messages)) - if err := cache.addMessages(messages); err != nil { - log.Error("error: %s", err.Error()) - } - } - }() + go cache.processMessageBatches() return cache, nil } // newMemCache creates an in-memory cache func newMemCache() (*messageCache, error) { - return newSqliteCache(createMemoryFilename(), "", false) + return newSqliteCache(createMemoryFilename(), "", 0, 0, false) } // newNopCache creates an in-memory cache that discards all messages; // it is always empty and can be used if caching is entirely disabled func newNopCache() (*messageCache, error) { - return newSqliteCache(createMemoryFilename(), "", true) + return newSqliteCache(createMemoryFilename(), "", 0, 0, true) } // createMemoryFilename creates a unique memory filename to use for the SQLite backend. @@ -240,18 +242,23 @@ func createMemoryFilename() string { return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10)) } +// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously. +// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor. func (c *messageCache) AddMessage(m *message) error { + if c.queue != nil { + c.queue.Enqueue(m) + return nil + } return c.addMessages([]*message{m}) } -func (c *messageCache) QueueMessage(m *message) { - c.queue.Push(m) -} - +// addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until +// SQLite's busy_timeout is exceeded before erroring out. func (c *messageCache) addMessages(ms []*message) error { if c.nop { return nil } + start := time.Now() tx, err := c.db.Begin() if err != nil { return err @@ -305,7 +312,12 @@ func (c *messageCache) addMessages(ms []*message) error { return err } } - return tx.Commit() + if err := tx.Commit(); err != nil { + log.Warn("Cache: Writing %d message(s) failed (took %v)", len(ms), time.Since(start)) + return err + } + log.Debug("Cache: Wrote %d message(s) in %v", len(ms), time.Since(start)) + return nil } func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) { @@ -411,8 +423,12 @@ func (c *messageCache) Topics() (map[string]*topic, error) { } func (c *messageCache) Prune(olderThan time.Time) error { - _, err := c.db.Exec(pruneMessagesQuery, olderThan.Unix()) - return err + start := time.Now() + if _, err := c.db.Exec(pruneMessagesQuery, olderThan.Unix()); err != nil { + log.Warn("Cache: Pruning failed (after %v): %s", time.Since(start), err.Error()) + } + log.Debug("Cache: Pruning successful (took %v)", time.Since(start)) + return nil } func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) { @@ -433,6 +449,17 @@ func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) { return size, nil } +func (c *messageCache) processMessageBatches() { + if c.queue == nil { + return + } + for messages := range c.queue.Dequeue() { + if err := c.addMessages(messages); err != nil { + log.Error("Cache: %s", err.Error()) + } + } +} + func readMessages(rows *sql.Rows) ([]*message, error) { defer rows.Close() messages := make([]*message, 0) @@ -558,6 +585,8 @@ func setupCacheDB(db *sql.DB, startupQueries string) error { return migrateFrom6(db) } else if schemaVersion == 7 { return migrateFrom7(db) + } else if schemaVersion == 8 { + return migrateFrom8(db) } return fmt.Errorf("unexpected schema version found: %d", schemaVersion) } @@ -663,5 +692,16 @@ func migrateFrom7(db *sql.DB) error { if _, err := db.Exec(updateSchemaVersion, 8); err != nil { return err } + return migrateFrom8(db) +} + +func migrateFrom8(db *sql.DB) error { + log.Info("Migrating cache database schema: from 8 to 9") + if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 9); err != nil { + return err + } return nil // Update this when a new version is added } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index c72debca..c3b7305e 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -450,7 +450,7 @@ func TestSqliteCache_StartupQueries_WAL(t *testing.T) { startupQueries := `pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;` - db, err := newSqliteCache(filename, startupQueries, false) + db, err := newSqliteCache(filename, startupQueries, 0, 0, false) require.Nil(t, err) require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) require.FileExists(t, filename) @@ -461,7 +461,7 @@ pragma temp_store = memory;` func TestSqliteCache_StartupQueries_None(t *testing.T) { filename := newSqliteTestCacheFile(t) startupQueries := "" - db, err := newSqliteCache(filename, startupQueries, false) + db, err := newSqliteCache(filename, startupQueries, 0, 0, false) require.Nil(t, err) require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) require.FileExists(t, filename) @@ -472,7 +472,7 @@ func TestSqliteCache_StartupQueries_None(t *testing.T) { func TestSqliteCache_StartupQueries_Fail(t *testing.T) { filename := newSqliteTestCacheFile(t) startupQueries := `xx error` - _, err := newSqliteCache(filename, startupQueries, false) + _, err := newSqliteCache(filename, startupQueries, 0, 0, false) require.Error(t, err) } @@ -501,7 +501,7 @@ func TestMemCache_NopCache(t *testing.T) { } func newSqliteTestCache(t *testing.T) *messageCache { - c, err := newSqliteCache(newSqliteTestCacheFile(t), "", false) + c, err := newSqliteCache(newSqliteTestCacheFile(t), "", 0, 0, false) if err != nil { t.Fatal(err) } @@ -513,7 +513,7 @@ func newSqliteTestCacheFile(t *testing.T) string { } func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache { - c, err := newSqliteCache(filename, startupQueries, false) + c, err := newSqliteCache(filename, startupQueries, 0, 0, false) if err != nil { t.Fatal(err) } diff --git a/server/server.go b/server/server.go index b90b7630..fe729b1b 100644 --- a/server/server.go +++ b/server/server.go @@ -159,7 +159,7 @@ func createMessageCache(conf *Config) (*messageCache, error) { if conf.CacheDuration == 0 { return newNopCache() } else if conf.CacheFile != "" { - return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, false) + return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheBatchSize, conf.CacheBatchTimeout, false) } return newMemCache() } @@ -491,11 +491,10 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m)) } if cache { - log.Trace("%s Queuing for cache", logMessagePrefix(v, m)) - s.messageCache.QueueMessage(m) - /*if err := s.messageCache.AddMessage(m); err != nil { + log.Debug("%s Adding message to cache", logMessagePrefix(v, m)) + if err := s.messageCache.AddMessage(m); err != nil { return nil, err - }*/ + } } s.mu.Lock() s.messages++ diff --git a/server/server.yml b/server/server.yml index 9476478f..4b08129b 100644 --- a/server/server.yml +++ b/server/server.yml @@ -65,6 +65,8 @@ # cache-file: # cache-duration: "12h" # cache-startup-queries: +# cache-batch-size: 0 +# cache-batch-timeout: "0ms" # If set, access to the ntfy server and API can be controlled on a granular level using # the 'ntfy user' and 'ntfy access' commands. See the --help pages for details, or check the docs. diff --git a/util/batching_queue.go b/util/batching_queue.go index 78116470..86901bcd 100644 --- a/util/batching_queue.go +++ b/util/batching_queue.go @@ -5,6 +5,24 @@ import ( "time" ) +// BatchingQueue is a queue that creates batches of the enqueued elements based on a +// max batch size and a batch timeout. +// +// Example: +// +// q := NewBatchingQueue[int](2, 500 * time.Millisecond) +// go func() { +// for batch := range q.Dequeue() { +// fmt.Println(batch) +// } +// }() +// q.Enqueue(1) +// q.Enqueue(2) +// q.Enqueue(3) +// time.Sleep(time.Second) +// +// This example will emit batch [1, 2] immediately (because the batch size is 2), and +// a batch [3] after 500ms. type BatchingQueue[T any] struct { batchSize int timeout time.Duration @@ -13,6 +31,7 @@ type BatchingQueue[T any] struct { mu sync.Mutex } +// NewBatchingQueue creates a new BatchingQueue func NewBatchingQueue[T any](batchSize int, timeout time.Duration) *BatchingQueue[T] { q := &BatchingQueue[T]{ batchSize: batchSize, @@ -20,37 +39,45 @@ func NewBatchingQueue[T any](batchSize int, timeout time.Duration) *BatchingQueu in: make([]T, 0), out: make(chan []T), } - ticker := time.NewTicker(timeout) - go func() { - for range ticker.C { - elements := q.popAll() - if len(elements) > 0 { - q.out <- elements - } - } - }() + go q.timeoutTicker() return q } -func (c *BatchingQueue[T]) Push(element T) { - c.mu.Lock() - c.in = append(c.in, element) - limitReached := len(c.in) == c.batchSize - c.mu.Unlock() +// Enqueue enqueues an element to the queue. If the configured batch size is reached, +// the batch will be emitted immediately. +func (q *BatchingQueue[T]) Enqueue(element T) { + q.mu.Lock() + q.in = append(q.in, element) + limitReached := len(q.in) == q.batchSize + q.mu.Unlock() if limitReached { - c.out <- c.popAll() + q.out <- q.dequeueAll() } } -func (c *BatchingQueue[T]) Pop() <-chan []T { - return c.out +// Dequeue returns a channel emitting batches of elements +func (q *BatchingQueue[T]) Dequeue() <-chan []T { + return q.out } -func (c *BatchingQueue[T]) popAll() []T { - c.mu.Lock() - defer c.mu.Unlock() - elements := make([]T, len(c.in)) - copy(elements, c.in) - c.in = c.in[:0] +func (q *BatchingQueue[T]) dequeueAll() []T { + q.mu.Lock() + defer q.mu.Unlock() + elements := make([]T, len(q.in)) + copy(elements, q.in) + q.in = q.in[:0] return elements } + +func (q *BatchingQueue[T]) timeoutTicker() { + if q.timeout == 0 { + return + } + ticker := time.NewTicker(q.timeout) + for range ticker.C { + elements := q.dequeueAll() + if len(elements) > 0 { + q.out <- elements + } + } +} diff --git a/util/batching_queue_test.go b/util/batching_queue_test.go index 46bc06b8..28764f18 100644 --- a/util/batching_queue_test.go +++ b/util/batching_queue_test.go @@ -2,24 +2,51 @@ package util_test import ( "fmt" + "github.com/stretchr/testify/require" "heckel.io/ntfy/util" "math/rand" "testing" "time" ) -func TestConcurrentQueue_Next(t *testing.T) { - q := util.NewBatchingQueue[int](25, 200*time.Millisecond) +func TestBatchingQueue_InfTimeout(t *testing.T) { + q := util.NewBatchingQueue[int](25, 1*time.Hour) + batches := make([][]int, 0) + total := 0 go func() { - for batch := range q.Pop() { - fmt.Printf("Batch of %d items\n", len(batch)) + for batch := range q.Dequeue() { + batches = append(batches, batch) + total += len(batch) } }() - for i := 0; i < 1000; i++ { + for i := 0; i < 101; i++ { + go q.Enqueue(i) + } + time.Sleep(500 * time.Millisecond) + require.Equal(t, 100, total) // One is missing, stuck in the last batch! + require.Equal(t, 4, len(batches)) +} + +func TestBatchingQueue_WithTimeout(t *testing.T) { + q := util.NewBatchingQueue[int](25, 100*time.Millisecond) + batches := make([][]int, 0) + total := 0 + go func() { + for batch := range q.Dequeue() { + batches = append(batches, batch) + total += len(batch) + } + }() + for i := 0; i < 101; i++ { go func(i int) { - time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) - q.Push(i) + time.Sleep(time.Duration(rand.Intn(700)) * time.Millisecond) + q.Enqueue(i) }(i) } - time.Sleep(2 * time.Second) + time.Sleep(time.Second) + fmt.Println(len(batches)) + fmt.Println(batches) + require.Equal(t, 101, total) + require.True(t, len(batches) > 4) // 101/25 + require.True(t, len(batches) < 21) }