WIP_ Add 'At:'/'Delay:' headers to support scheduled messages

This commit is contained in:
Philipp Heckel 2021-12-10 11:31:42 -05:00
parent aacdda94e1
commit 196c86d12b
8 changed files with 311 additions and 112 deletions

View file

@ -11,6 +11,7 @@ const (
DefaultCacheDuration = 12 * time.Hour DefaultCacheDuration = 12 * time.Hour
DefaultKeepaliveInterval = 30 * time.Second DefaultKeepaliveInterval = 30 * time.Second
DefaultManagerInterval = time.Minute DefaultManagerInterval = time.Minute
DefaultAtSenderInterval = 10 * time.Second
) )
// Defines all the limits // Defines all the limits
@ -35,6 +36,7 @@ type Config struct {
CacheDuration time.Duration CacheDuration time.Duration
KeepaliveInterval time.Duration KeepaliveInterval time.Duration
ManagerInterval time.Duration ManagerInterval time.Duration
AtSenderInterval time.Duration
GlobalTopicLimit int GlobalTopicLimit int
VisitorRequestLimitBurst int VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration VisitorRequestLimitReplenish time.Duration
@ -54,6 +56,7 @@ func New(listenHTTP string) *Config {
CacheDuration: DefaultCacheDuration, CacheDuration: DefaultCacheDuration,
KeepaliveInterval: DefaultKeepaliveInterval, KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval, ManagerInterval: DefaultManagerInterval,
AtSenderInterval: DefaultAtSenderInterval,
GlobalTopicLimit: DefaultGlobalTopicLimit, GlobalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,

View file

@ -14,8 +14,10 @@ var (
// i.e. message structs with the Event messageEvent. // i.e. message structs with the Event messageEvent.
type cache interface { type cache interface {
AddMessage(m *message) error AddMessage(m *message) error
Messages(topic string, since sinceTime) ([]*message, error) Messages(topic string, since sinceTime, scheduled bool) ([]*message, error)
MessagesDue() ([]*message, error)
MessageCount(topic string) (int, error) MessageCount(topic string) (int, error)
Topics() (map[string]*topic, error) Topics() (map[string]*topic, error)
Prune(olderThan time.Time) error Prune(olderThan time.Time) error
MarkPublished(m *message) error
} }

View file

@ -1,14 +1,16 @@
package server package server
import ( import (
"sort"
"sync" "sync"
"time" "time"
) )
type memCache struct { type memCache struct {
messages map[string][]*message messages map[string][]*message
nop bool scheduled map[string]*message // Message ID -> message
mu sync.Mutex nop bool
mu sync.Mutex
} }
var _ cache = (*memCache)(nil) var _ cache = (*memCache)(nil)
@ -16,8 +18,9 @@ var _ cache = (*memCache)(nil)
// newMemCache creates an in-memory cache // newMemCache creates an in-memory cache
func newMemCache() *memCache { func newMemCache() *memCache {
return &memCache{ return &memCache{
messages: make(map[string][]*message), messages: make(map[string][]*message),
nop: false, scheduled: make(map[string]*message),
nop: false,
} }
} }
@ -25,77 +28,109 @@ func newMemCache() *memCache {
// it is always empty and can be used if caching is entirely disabled // it is always empty and can be used if caching is entirely disabled
func newNopCache() *memCache { func newNopCache() *memCache {
return &memCache{ return &memCache{
messages: make(map[string][]*message), messages: make(map[string][]*message),
nop: true, scheduled: make(map[string]*message),
nop: true,
} }
} }
func (s *memCache) AddMessage(m *message) error { func (c *memCache) AddMessage(m *message) error {
s.mu.Lock() c.mu.Lock()
defer s.mu.Unlock() defer c.mu.Unlock()
if s.nop { if c.nop {
return nil return nil
} }
if m.Event != messageEvent { if m.Event != messageEvent {
return errUnexpectedMessageType return errUnexpectedMessageType
} }
if _, ok := s.messages[m.Topic]; !ok { if _, ok := c.messages[m.Topic]; !ok {
s.messages[m.Topic] = make([]*message, 0) c.messages[m.Topic] = make([]*message, 0)
} }
s.messages[m.Topic] = append(s.messages[m.Topic], m) delayed := m.Time > time.Now().Unix()
if delayed {
c.scheduled[m.ID] = m
}
c.messages[m.Topic] = append(c.messages[m.Topic], m)
return nil return nil
} }
func (s *memCache) Messages(topic string, since sinceTime) ([]*message, error) { func (c *memCache) Messages(topic string, since sinceTime, scheduled bool) ([]*message, error) {
s.mu.Lock() c.mu.Lock()
defer s.mu.Unlock() defer c.mu.Unlock()
if _, ok := s.messages[topic]; !ok || since.IsNone() { if _, ok := c.messages[topic]; !ok || since.IsNone() {
return make([]*message, 0), nil return make([]*message, 0), nil
} }
messages := make([]*message, 0) // copy! messages := make([]*message, 0)
for _, m := range s.messages[topic] { for _, m := range c.messages[topic] {
msgTime := time.Unix(m.Time, 0) _, messageScheduled := c.scheduled[m.ID]
if msgTime == since.Time() || msgTime.After(since.Time()) { include := m.Time >= since.Time().Unix() && (!messageScheduled || scheduled)
if include {
messages = append(messages, m) messages = append(messages, m)
} }
} }
sort.Slice(messages, func(i, j int) bool {
return messages[i].Time < messages[j].Time
})
return messages, nil return messages, nil
} }
func (s *memCache) MessageCount(topic string) (int, error) { func (c *memCache) MessagesDue() ([]*message, error) {
s.mu.Lock() c.mu.Lock()
defer s.mu.Unlock() defer c.mu.Unlock()
if _, ok := s.messages[topic]; !ok { messages := make([]*message, 0)
return 0, nil for _, m := range c.scheduled {
due := time.Now().Unix() >= m.Time
if due {
messages = append(messages, m)
}
} }
return len(s.messages[topic]), nil sort.Slice(messages, func(i, j int) bool {
return messages[i].Time < messages[j].Time
})
return messages, nil
} }
func (s *memCache) Topics() (map[string]*topic, error) { func (c *memCache) MarkPublished(m *message) error {
s.mu.Lock() c.mu.Lock()
defer s.mu.Unlock() delete(c.scheduled, m.ID)
c.mu.Unlock()
return nil
}
func (c *memCache) MessageCount(topic string) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.messages[topic]; !ok {
return 0, nil
}
return len(c.messages[topic]), nil
}
func (c *memCache) Topics() (map[string]*topic, error) {
c.mu.Lock()
defer c.mu.Unlock()
topics := make(map[string]*topic) topics := make(map[string]*topic)
for topic := range s.messages { for topic := range c.messages {
topics[topic] = newTopic(topic) topics[topic] = newTopic(topic)
} }
return topics, nil return topics, nil
} }
func (s *memCache) Prune(olderThan time.Time) error { func (c *memCache) Prune(olderThan time.Time) error {
s.mu.Lock() c.mu.Lock()
defer s.mu.Unlock() defer c.mu.Unlock()
for topic := range s.messages { for topic := range c.messages {
s.pruneTopic(topic, olderThan) c.pruneTopic(topic, olderThan)
} }
return nil return nil
} }
func (s *memCache) pruneTopic(topic string, olderThan time.Time) { func (c *memCache) pruneTopic(topic string, olderThan time.Time) {
messages := make([]*message, 0) messages := make([]*message, 0)
for _, m := range s.messages[topic] { for _, m := range c.messages[topic] {
if m.Time >= olderThan.Unix() { if m.Time >= olderThan.Unix() {
messages = append(messages, m) messages = append(messages, m)
} }
} }
s.messages[topic] = messages c.messages[topic] = messages
} }

View file

@ -9,6 +9,10 @@ func TestMemCache_Messages(t *testing.T) {
testCacheMessages(t, newMemCache()) testCacheMessages(t, newMemCache())
} }
func TestMemCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemCache())
}
func TestMemCache_Topics(t *testing.T) { func TestMemCache_Topics(t *testing.T) {
testCacheTopics(t, newMemCache()) testCacheTopics(t, newMemCache())
} }
@ -25,7 +29,7 @@ func TestMemCache_NopCache(t *testing.T) {
c := newNopCache() c := newNopCache()
assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message"))) assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message")))
messages, err := c.Messages("mytopic", sinceAllMessages) messages, err := c.Messages("mytopic", sinceAllMessages, false)
assert.Nil(t, err) assert.Nil(t, err)
assert.Empty(t, messages) assert.Empty(t, messages)

View file

@ -21,19 +21,32 @@ const (
message VARCHAR(512) NOT NULL, message VARCHAR(512) NOT NULL,
title VARCHAR(256) NOT NULL, title VARCHAR(256) NOT NULL,
priority INT NOT NULL, priority INT NOT NULL,
tags VARCHAR(256) NOT NULL tags VARCHAR(256) NOT NULL,
published INT NOT NULL
); );
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT; COMMIT;
` `
insertMessageQuery = `INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)` insertMessageQuery = `INSERT INTO messages (id, time, topic, message, title, priority, tags, published) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`
pruneMessagesQuery = `DELETE FROM messages WHERE time < ?` pruneMessagesQuery = `DELETE FROM messages WHERE time < ?`
selectMessagesSinceTimeQuery = ` selectMessagesSinceTimeQuery = `
SELECT id, time, message, title, priority, tags SELECT id, time, topic, message, title, priority, tags
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time ASC
`
selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT id, time, topic, message, title, priority, tags
FROM messages FROM messages
WHERE topic = ? AND time >= ? WHERE topic = ? AND time >= ?
ORDER BY time ASC ORDER BY time ASC
` `
selectMessagesDueQuery = `
SELECT id, time, topic, message, title, priority, tags
FROM messages
WHERE time <= ? AND published = 0
`
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE id = ?`
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
@ -41,7 +54,7 @@ const (
// Schema management queries // Schema management queries
const ( const (
currentSchemaVersion = 1 currentSchemaVersion = 2
createSchemaVersionTableQuery = ` createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion ( CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY, id INT PRIMARY KEY,
@ -49,6 +62,7 @@ const (
); );
` `
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// 0 -> 1 // 0 -> 1
@ -59,6 +73,13 @@ const (
ALTER TABLE messages ADD COLUMN tags VARCHAR(256) NOT NULL DEFAULT(''); ALTER TABLE messages ADD COLUMN tags VARCHAR(256) NOT NULL DEFAULT('');
COMMIT; COMMIT;
` `
// 1 -> 2
migrate1To2AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
COMMIT;
`
) )
type sqliteCache struct { type sqliteCache struct {
@ -84,46 +105,39 @@ func (c *sqliteCache) AddMessage(m *message) error {
if m.Event != messageEvent { if m.Event != messageEvent {
return errUnexpectedMessageType return errUnexpectedMessageType
} }
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message, m.Title, m.Priority, strings.Join(m.Tags, ",")) published := m.Time <= time.Now().Unix()
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message, m.Title, m.Priority, strings.Join(m.Tags, ","), published)
return err return err
} }
func (c *sqliteCache) Messages(topic string, since sinceTime) ([]*message, error) { func (c *sqliteCache) Messages(topic string, since sinceTime, scheduled bool) ([]*message, error) {
if since.IsNone() { if since.IsNone() {
return make([]*message, 0), nil return make([]*message, 0), nil
} }
rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) var rows *sql.Rows
var err error
if scheduled {
rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix())
} else {
rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix())
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() return readMessages(rows)
messages := make([]*message, 0) }
for rows.Next() {
var timestamp int64 func (c *sqliteCache) MessagesDue() ([]*message, error) {
var priority int rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix())
var id, msg, title, tagsStr string if err != nil {
if err := rows.Scan(&id, &timestamp, &msg, &title, &priority, &tagsStr); err != nil {
return nil, err
}
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
}
messages = append(messages, &message{
ID: id,
Time: timestamp,
Event: messageEvent,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
})
}
if err := rows.Err(); err != nil {
return nil, err return nil, err
} }
return messages, nil return readMessages(rows)
}
func (c *sqliteCache) MarkPublished(m *message) error {
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
return err
} }
func (c *sqliteCache) MessageCount(topic string) (int, error) { func (c *sqliteCache) MessageCount(topic string) (int, error) {
@ -169,6 +183,37 @@ func (c *sqliteCache) Prune(olderThan time.Time) error {
return err return err
} }
func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close()
messages := make([]*message, 0)
for rows.Next() {
var timestamp int64
var priority int
var id, topic, msg, title, tagsStr string
if err := rows.Scan(&id, &timestamp, &topic, &msg, &title, &priority, &tagsStr); err != nil {
return nil, err
}
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
}
messages = append(messages, &message{
ID: id,
Time: timestamp,
Event: messageEvent,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func setupDB(db *sql.DB) error { func setupDB(db *sql.DB) error {
// If 'messages' table does not exist, this must be a new database // If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(selectMessagesCountQuery) rowsMC, err := db.Query(selectMessagesCountQuery)
@ -194,7 +239,9 @@ func setupDB(db *sql.DB) error {
if schemaVersion == currentSchemaVersion { if schemaVersion == currentSchemaVersion {
return nil return nil
} else if schemaVersion == 0 { } else if schemaVersion == 0 {
return migrateFrom0To1(db) return migrateFrom0(db)
} else if schemaVersion == 1 {
return migrateFrom1(db)
} }
return fmt.Errorf("unexpected schema version found: %d", schemaVersion) return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
} }
@ -212,7 +259,7 @@ func setupNewDB(db *sql.DB) error {
return nil return nil
} }
func migrateFrom0To1(db *sql.DB) error { func migrateFrom0(db *sql.DB) error {
log.Print("Migrating cache database schema: from 0 to 1") log.Print("Migrating cache database schema: from 0 to 1")
if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil { if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil {
return err return err
@ -223,5 +270,16 @@ func migrateFrom0To1(db *sql.DB) error {
if _, err := db.Exec(insertSchemaVersion, 1); err != nil { if _, err := db.Exec(insertSchemaVersion, 1); err != nil {
return err return err
} }
return nil return migrateFrom1(db)
}
func migrateFrom1(db *sql.DB) error {
log.Print("Migrating cache database schema: from 1 to 2")
if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 2); err != nil {
return err
}
return nil // Update this when a new version is added
} }

View file

@ -9,10 +9,14 @@ import (
"time" "time"
) )
func TestSqliteCache_AddMessage(t *testing.T) { func TestSqliteCache_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestCache(t)) testCacheMessages(t, newSqliteTestCache(t))
} }
func TestSqliteCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newSqliteTestCache(t))
}
func TestSqliteCache_Topics(t *testing.T) { func TestSqliteCache_Topics(t *testing.T) {
testCacheTopics(t, newSqliteTestCache(t)) testCacheTopics(t, newSqliteTestCache(t))
} }
@ -25,7 +29,7 @@ func TestSqliteCache_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestCache(t)) testCachePrune(t, newSqliteTestCache(t))
} }
func TestSqliteCache_Migration_0to1(t *testing.T) { func TestSqliteCache_Migration_From0(t *testing.T) {
filename := newSqliteTestCacheFile(t) filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename) db, err := sql.Open("sqlite3", filename)
assert.Nil(t, err) assert.Nil(t, err)
@ -53,7 +57,7 @@ func TestSqliteCache_Migration_0to1(t *testing.T) {
// Create cache to trigger migration // Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename) c := newSqliteTestCacheFromFile(t, filename)
messages, err := c.Messages("mytopic", sinceAllMessages) messages, err := c.Messages("mytopic", sinceAllMessages, false)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 10, len(messages)) assert.Equal(t, 10, len(messages))
assert.Equal(t, "some message 5", messages[5].Message) assert.Equal(t, "some message 5", messages[5].Message)
@ -67,7 +71,7 @@ func TestSqliteCache_Migration_0to1(t *testing.T) {
var schemaVersion int var schemaVersion int
assert.Nil(t, rows.Scan(&schemaVersion)) assert.Nil(t, rows.Scan(&schemaVersion))
assert.Equal(t, 1, schemaVersion) assert.Equal(t, 2, schemaVersion)
} }
func newSqliteTestCache(t *testing.T) *sqliteCache { func newSqliteTestCache(t *testing.T) *sqliteCache {

View file

@ -27,7 +27,7 @@ func testCacheMessages(t *testing.T, c cache) {
assert.Equal(t, 2, count) assert.Equal(t, 2, count)
// mytopic: since all // mytopic: since all
messages, _ := c.Messages("mytopic", sinceAllMessages) messages, _ := c.Messages("mytopic", sinceAllMessages, false)
assert.Equal(t, 2, len(messages)) assert.Equal(t, 2, len(messages))
assert.Equal(t, "my message", messages[0].Message) assert.Equal(t, "my message", messages[0].Message)
assert.Equal(t, "mytopic", messages[0].Topic) assert.Equal(t, "mytopic", messages[0].Topic)
@ -38,11 +38,11 @@ func testCacheMessages(t *testing.T, c cache) {
assert.Equal(t, "my other message", messages[1].Message) assert.Equal(t, "my other message", messages[1].Message)
// mytopic: since none // mytopic: since none
messages, _ = c.Messages("mytopic", sinceNoMessages) messages, _ = c.Messages("mytopic", sinceNoMessages, false)
assert.Empty(t, messages) assert.Empty(t, messages)
// mytopic: since 2 // mytopic: since 2
messages, _ = c.Messages("mytopic", sinceTime(time.Unix(2, 0))) messages, _ = c.Messages("mytopic", sinceTime(time.Unix(2, 0)), false)
assert.Equal(t, 1, len(messages)) assert.Equal(t, 1, len(messages))
assert.Equal(t, "my other message", messages[0].Message) assert.Equal(t, "my other message", messages[0].Message)
@ -52,7 +52,7 @@ func testCacheMessages(t *testing.T, c cache) {
assert.Equal(t, 1, count) assert.Equal(t, 1, count)
// example: since all // example: since all
messages, _ = c.Messages("example", sinceAllMessages) messages, _ = c.Messages("example", sinceAllMessages, false)
assert.Equal(t, "my example message", messages[0].Message) assert.Equal(t, "my example message", messages[0].Message)
// non-existing: count // non-existing: count
@ -61,7 +61,7 @@ func testCacheMessages(t *testing.T, c cache) {
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
// non-existing: since all // non-existing: since all
messages, _ = c.Messages("doesnotexist", sinceAllMessages) messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
assert.Empty(t, messages) assert.Empty(t, messages)
} }
@ -103,7 +103,7 @@ func testCachePrune(t *testing.T, c cache) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
messages, err := c.Messages("mytopic", sinceAllMessages) messages, err := c.Messages("mytopic", sinceAllMessages, false)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, len(messages)) assert.Equal(t, 1, len(messages))
assert.Equal(t, "my other message", messages[0].Message) assert.Equal(t, "my other message", messages[0].Message)
@ -116,8 +116,34 @@ func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
m.Title = "some title" m.Title = "some title"
assert.Nil(t, c.AddMessage(m)) assert.Nil(t, c.AddMessage(m))
messages, _ := c.Messages("mytopic", sinceAllMessages) messages, _ := c.Messages("mytopic", sinceAllMessages, false)
assert.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) assert.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
assert.Equal(t, 5, messages[0].Priority) assert.Equal(t, 5, messages[0].Priority)
assert.Equal(t, "some title", messages[0].Title) assert.Equal(t, "some title", messages[0].Title)
} }
func testCacheMessagesScheduled(t *testing.T, c cache) {
m1 := newDefaultMessage("mytopic", "message 1")
m2 := newDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
m3 := newDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := newDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
assert.Nil(t, c.AddMessage(m1))
assert.Nil(t, c.AddMessage(m2))
assert.Nil(t, c.AddMessage(m3))
messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled
assert.Equal(t, 1, len(messages))
assert.Equal(t, "message 1", messages[0].Message)
messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled
assert.Equal(t, 3, len(messages))
assert.Equal(t, "message 1", messages[0].Message)
assert.Equal(t, "message 3", messages[1].Message) // Order!
assert.Equal(t, "message 2", messages[2].Message)
messages, _ = c.MessagesDue()
assert.Empty(t, messages)
}

View file

@ -73,6 +73,7 @@ var (
const ( const (
messageLimit = 512 messageLimit = 512
minDelay = 10 * time.Second
) )
var ( var (
@ -183,6 +184,15 @@ func (s *Server) Run() error {
s.updateStatsAndExpire() s.updateStatsAndExpire()
} }
}() }()
go func() {
ticker := time.NewTicker(s.config.AtSenderInterval)
for {
<-ticker.C
if err := s.sendDelayedMessages(); err != nil {
log.Printf("error sending scheduled messages: %s", err.Error())
}
}
}()
listenStr := fmt.Sprintf("%s/http", s.config.ListenHTTP) listenStr := fmt.Sprintf("%s/http", s.config.ListenHTTP)
if s.config.ListenHTTPS != "" { if s.config.ListenHTTPS != "" {
listenStr += fmt.Sprintf(" %s/https", s.config.ListenHTTPS) listenStr += fmt.Sprintf(" %s/https", s.config.ListenHTTPS)
@ -279,14 +289,17 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito
if m.Message == "" { if m.Message == "" {
return errHTTPBadRequest return errHTTPBadRequest
} }
title, priority, tags, cache, firebase := parseHeaders(r.Header) cache, firebase, err := parseHeaders(r.Header, m)
m.Title = title if err != nil {
m.Priority = priority
m.Tags = tags
if err := t.Publish(m); err != nil {
return err return err
} }
if s.firebase != nil && firebase { delayed := m.Time > time.Now().Unix()
if !delayed {
if err := t.Publish(m); err != nil {
return err
}
}
if s.firebase != nil && firebase && !delayed {
go func() { go func() {
if err := s.firebase(m); err != nil { if err := s.firebase(m); err != nil {
log.Printf("Unable to publish to Firebase: %v", err.Error()) log.Printf("Unable to publish to Firebase: %v", err.Error())
@ -308,35 +321,62 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito
return nil return nil
} }
func parseHeaders(header http.Header) (title string, priority int, tags []string, cache bool, firebase bool) { func parseHeaders(header http.Header, m *message) (cache bool, firebase bool, err error) {
title = readHeader(header, "x-title", "title", "ti", "t") cache = readHeader(header, "x-cache", "cache") != "no"
firebase = readHeader(header, "x-firebase", "firebase") != "no"
m.Title = readHeader(header, "x-title", "title", "ti", "t")
priorityStr := readHeader(header, "x-priority", "priority", "prio", "p") priorityStr := readHeader(header, "x-priority", "priority", "prio", "p")
if priorityStr != "" { if priorityStr != "" {
switch strings.ToLower(priorityStr) { switch strings.ToLower(priorityStr) {
case "1", "min": case "1", "min":
priority = 1 m.Priority = 1
case "2", "low": case "2", "low":
priority = 2 m.Priority = 2
case "3", "default": case "3", "default":
priority = 3 m.Priority = 3
case "4", "high": case "4", "high":
priority = 4 m.Priority = 4
case "5", "max", "urgent": case "5", "max", "urgent":
priority = 5 m.Priority = 5
default: default:
priority = 0 return false, false, errHTTPBadRequest
} }
} }
tagsStr := readHeader(header, "x-tags", "tag", "tags", "ta") tagsStr := readHeader(header, "x-tags", "tag", "tags", "ta")
if tagsStr != "" { if tagsStr != "" {
tags = make([]string, 0) m.Tags = make([]string, 0)
for _, s := range strings.Split(tagsStr, ",") { for _, s := range strings.Split(tagsStr, ",") {
tags = append(tags, strings.TrimSpace(s)) m.Tags = append(m.Tags, strings.TrimSpace(s))
} }
} }
cache = readHeader(header, "x-cache", "cache") != "no" atStr := readHeader(header, "x-at", "at", "x-schedule", "schedule", "sched")
firebase = readHeader(header, "x-firebase", "firebase") != "no" if atStr != "" {
return title, priority, tags, cache, firebase if !cache {
return false, false, errHTTPBadRequest
}
at, err := strconv.Atoi(atStr)
if err != nil {
return false, false, errHTTPBadRequest
} else if int64(at) < time.Now().Add(minDelay).Unix() {
return false, false, errHTTPBadRequest
}
m.Time = int64(at)
} else {
delayStr := readHeader(header, "x-delay", "delay", "x-in", "in")
if delayStr != "" {
if !cache {
return false, false, errHTTPBadRequest
}
delay, err := time.ParseDuration(delayStr)
if err != nil {
return false, false, errHTTPBadRequest
} else if delay < minDelay {
return false, false, errHTTPBadRequest
}
m.Time = time.Now().Add(delay).Unix()
}
}
return cache, firebase, nil
} }
func readHeader(header http.Header, names ...string) string { func readHeader(header http.Header, names ...string) string {
@ -401,6 +441,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
} }
var wlock sync.Mutex var wlock sync.Mutex
poll := r.URL.Query().Has("poll") poll := r.URL.Query().Has("poll")
scheduled := r.URL.Query().Has("scheduled") || r.URL.Query().Has("sched")
sub := func(msg *message) error { sub := func(msg *message) error {
wlock.Lock() wlock.Lock()
defer wlock.Unlock() defer wlock.Unlock()
@ -419,7 +460,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll { if poll {
return s.sendOldMessages(topics, since, sub) return s.sendOldMessages(topics, since, scheduled, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)
for _, t := range topics { for _, t := range topics {
@ -433,7 +474,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(topics, since, sub); err != nil { if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
return err return err
} }
for { for {
@ -449,12 +490,12 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
} }
} }
func (s *Server) sendOldMessages(topics []*topic, since sinceTime, sub subscriber) error { func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error {
if since.IsNone() { if since.IsNone() {
return nil return nil
} }
for _, t := range topics { for _, t := range topics {
messages, err := s.cache.Messages(t.ID, since) messages, err := s.cache.Messages(t.ID, since, scheduled)
if err != nil { if err != nil {
return err return err
} }
@ -560,6 +601,32 @@ func (s *Server) updateStatsAndExpire() {
s.messages, len(s.topics), subscribers, messages, len(s.visitors)) s.messages, len(s.topics), subscribers, messages, len(s.visitors))
} }
func (s *Server) sendDelayedMessages() error {
s.mu.Lock()
defer s.mu.Unlock()
messages, err := s.cache.MessagesDue()
if err != nil {
return err
}
for _, m := range messages {
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok {
if err := t.Publish(m); err != nil {
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
}
if s.firebase != nil {
if err := s.firebase(m); err != nil {
log.Printf("unable to publish to Firebase: %v", err.Error())
}
}
}
if err := s.cache.MarkPublished(m); err != nil {
return err
}
}
return nil
}
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error { func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
v := s.visitor(r) v := s.visitor(r)
if err := v.RequestAllowed(); err != nil { if err := v.RequestAllowed(); err != nil {