diff --git a/server/cache_mem.go b/server/cache_mem.go index 6922f2dc..cc63ff64 100644 --- a/server/cache_mem.go +++ b/server/cache_mem.go @@ -15,25 +15,6 @@ type memCache struct { var _ cache = (*memCache)(nil) -// newMemCache creates an in-memory cache -func newMemCache() *memCache { - return &memCache{ - messages: make(map[string][]*message), - scheduled: make(map[string]*message), - nop: false, - } -} - -// newNopCache creates an in-memory cache that discards all messages; -// it is always empty and can be used if caching is entirely disabled -func newNopCache() *memCache { - return &memCache{ - messages: make(map[string][]*message), - scheduled: make(map[string]*message), - nop: true, - } -} - func (c *memCache) AddMessage(m *message) error { c.mu.Lock() defer c.mu.Unlock() diff --git a/server/cache_mem_test.go b/server/cache_mem_test.go index 6d8a17dd..561f461e 100644 --- a/server/cache_mem_test.go +++ b/server/cache_mem_test.go @@ -6,35 +6,35 @@ import ( ) func TestMemCache_Messages(t *testing.T) { - testCacheMessages(t, newMemCache()) + testCacheMessages(t, newMemTestCache(t)) } func TestMemCache_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newMemCache()) + testCacheMessagesScheduled(t, newMemTestCache(t)) } func TestMemCache_Topics(t *testing.T) { - testCacheTopics(t, newMemCache()) + testCacheTopics(t, newMemTestCache(t)) } func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newMemCache()) + testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t)) } func TestMemCache_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newMemCache()) + testCacheMessagesSinceID(t, newMemTestCache(t)) } func TestMemCache_Prune(t *testing.T) { - testCachePrune(t, newMemCache()) + testCachePrune(t, newMemTestCache(t)) } func TestMemCache_Attachments(t *testing.T) { - testCacheAttachments(t, newMemCache()) + testCacheAttachments(t, newMemTestCache(t)) } func TestMemCache_NopCache(t *testing.T) { - c := newNopCache() + c, _ := newNopCache() assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message"))) messages, err := c.Messages("mytopic", sinceAllMessages, false) @@ -45,3 +45,11 @@ func TestMemCache_NopCache(t *testing.T) { assert.Nil(t, err) assert.Empty(t, topics) } + +func newMemTestCache(t *testing.T) cache { + c, err := newMemCache() + if err != nil { + t.Fatal(err) + } + return c +} diff --git a/server/cache_sqlite.go b/server/cache_sqlite.go index a2d9636a..beade6c8 100644 --- a/server/cache_sqlite.go +++ b/server/cache_sqlite.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" _ "github.com/mattn/go-sqlite3" // SQLite driver + "heckel.io/ntfy/util" "log" "strings" "time" @@ -42,6 +43,7 @@ const ( VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1` + selectRowIDFromMessageID = `SELECT id FROM messages WHERE topic = ? AND mid = ?` selectMessagesSinceTimeQuery = ` SELECT mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding FROM messages @@ -57,16 +59,13 @@ const ( selectMessagesSinceIDQuery = ` SELECT mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding FROM messages - WHERE topic = ? - AND published = 1 - AND id > (SELECT IFNULL(id,0) FROM messages WHERE mid = ?) + WHERE topic = ? AND id > ? AND published = 1 ORDER BY time, id ` selectMessagesSinceIDIncludeScheduledQuery = ` SELECT mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding FROM messages - WHERE topic = ? - AND id > (SELECT IFNULL(id,0) FROM messages WHERE mid = ?) + WHERE topic = ? AND id > ? ORDER BY time, id ` selectMessagesDueQuery = ` @@ -166,12 +165,13 @@ const ( ) type sqliteCache struct { - db *sql.DB + db *sql.DB + nop bool } var _ cache = (*sqliteCache)(nil) -func newSqliteCache(filename string) (*sqliteCache, error) { +func newSqliteCache(filename string, nop bool) (*sqliteCache, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -180,14 +180,39 @@ func newSqliteCache(filename string) (*sqliteCache, error) { return nil, err } return &sqliteCache{ - db: db, + db: db, + nop: nop, }, nil } +// newMemCache creates an in-memory cache +func newMemCache() (*sqliteCache, error) { + return newSqliteCache(createMemoryFilename(), false) +} + +// newNopCache creates an in-memory cache that discards all messages; +// it is always empty and can be used if caching is entirely disabled +func newNopCache() (*sqliteCache, error) { + return newSqliteCache(createMemoryFilename(), true) +} + +// createMemoryFilename creates a unique filename to use for the SQLite backend. +// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory +// sql database, so if the stdlib's sql engine happens to open another connection and +// you've only specified ":memory:", that connection will see a brand new database. +// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared"). +// Every connection to this string will point to the same in-memory database." +func createMemoryFilename() string { + return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10)) +} + func (c *sqliteCache) AddMessage(m *message) error { if m.Event != messageEvent { return errUnexpectedMessageType } + if c.nop { + return nil + } published := m.Time <= time.Now().Unix() tags := strings.Join(m.Tags, ",") var attachmentName, attachmentType, attachmentURL, attachmentOwner string @@ -225,21 +250,44 @@ func (c *sqliteCache) AddMessage(m *message) error { func (c *sqliteCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) { if since.IsNone() { return make([]*message, 0), nil + } else if since.IsID() { + return c.messagesSinceID(topic, since, scheduled) } + return c.messagesSinceTime(topic, since, scheduled) +} + +func (c *sqliteCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) { var rows *sql.Rows var err error - if since.IsID() { - if scheduled { - rows, err = c.db.Query(selectMessagesSinceIDIncludeScheduledQuery, topic, since.ID()) - } else { - rows, err = c.db.Query(selectMessagesSinceIDQuery, topic, since.ID()) - } + if scheduled { + rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix()) } else { - if scheduled { - rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix()) - } else { - rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) - } + rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) + } + if err != nil { + return nil, err + } + return readMessages(rows) +} + +func (c *sqliteCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) { + idrows, err := c.db.Query(selectRowIDFromMessageID, topic, since.ID()) + if err != nil { + return nil, err + } + defer idrows.Close() + if !idrows.Next() { + return c.messagesSinceTime(topic, sinceAllMessages, scheduled) + } + var rowID int64 + if err := idrows.Scan(&rowID); err != nil { + return nil, err + } + var rows *sql.Rows + if scheduled { + rows, err = c.db.Query(selectMessagesSinceIDIncludeScheduledQuery, topic, rowID) + } else { + rows, err = c.db.Query(selectMessagesSinceIDQuery, topic, rowID) } if err != nil { return nil, err diff --git a/server/cache_sqlite_test.go b/server/cache_sqlite_test.go index 11c29bf2..9c99c6f4 100644 --- a/server/cache_sqlite_test.go +++ b/server/cache_sqlite_test.go @@ -142,7 +142,7 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) { } func newSqliteTestCache(t *testing.T) *sqliteCache { - c, err := newSqliteCache(newSqliteTestCacheFile(t)) + c, err := newSqliteCache(newSqliteTestCacheFile(t), false) if err != nil { t.Fatal(err) } @@ -154,7 +154,7 @@ func newSqliteTestCacheFile(t *testing.T) string { } func newSqliteTestCacheFromFile(t *testing.T, filename string) *sqliteCache { - c, err := newSqliteCache(filename) + c, err := newSqliteCache(filename, false) if err != nil { t.Fatal(err) } diff --git a/server/cache_test.go b/server/cache_test.go index ab727114..a80c0552 100644 --- a/server/cache_test.go +++ b/server/cache_test.go @@ -212,7 +212,7 @@ func testCacheMessagesSinceID(t *testing.T, c cache) { require.Equal(t, 0, len(messages)) // Case 5: Since ID exists and is last message (-> Return no messages), include scheduled - messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false) + messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true) require.Equal(t, 1, len(messages)) require.Equal(t, "message 5", messages[0].Message) diff --git a/server/server.go b/server/server.go index 49eb681c..4ced5bfb 100644 --- a/server/server.go +++ b/server/server.go @@ -162,11 +162,11 @@ func New(conf *Config) (*Server, error) { func createCache(conf *Config) (cache, error) { if conf.CacheDuration == 0 { - return newNopCache(), nil + return newNopCache() } else if conf.CacheFile != "" { - return newSqliteCache(conf.CacheFile) + return newSqliteCache(conf.CacheFile, false) } - return newMemCache(), nil + return newMemCache() } // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts