From b145e693a5f9d750807a6ecb7de02d500b0e71c8 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Fri, 29 Oct 2021 13:58:14 -0400 Subject: [PATCH] Add firebase support --- cmd/app.go | 27 ++++-- config/config.go | 33 ++++---- server/index.html | 84 +++++++++++-------- server/message.go | 33 +++++--- server/server.go | 209 ++++++++++++++++++++++++++++------------------ server/topic.go | 39 +++++++-- util/util.go | 29 +++++++ 7 files changed, 293 insertions(+), 161 deletions(-) create mode 100644 util/util.go diff --git a/cmd/app.go b/cmd/app.go index 357db1e8..3c2153e9 100644 --- a/cmd/app.go +++ b/cmd/app.go @@ -8,8 +8,10 @@ import ( "github.com/urfave/cli/v2/altsrc" "heckel.io/ntfy/config" "heckel.io/ntfy/server" + "heckel.io/ntfy/util" "log" "os" + "time" ) // New creates a new CLI application @@ -18,7 +20,9 @@ func New() *cli.App { &cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: "/etc/ntfy/config.yml", DefaultText: "/etc/ntfy/config.yml", Usage: "config file"}, altsrc.NewStringFlag(&cli.StringFlag{Name: "listen-http", Aliases: []string{"l"}, EnvVars: []string{"NTFY_LISTEN_HTTP"}, Value: config.DefaultListenHTTP, Usage: "ip:port used to as listen address"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "message-buffer-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_MESSAGE_BUFFER_DURATION"}, Value: config.DefaultMessageBufferDuration, Usage: "buffer messages in memory for this time to allow `since` requests"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}), } return &cli.App{ Name: "ntfy", @@ -41,17 +45,27 @@ func execRun(c *cli.Context) error { // Read all the options listenHTTP := c.String("listen-http") firebaseKeyFile := c.String("firebase-key-file") + messageBufferDuration := c.Duration("message-buffer-duration") keepaliveInterval := c.Duration("keepalive-interval") + managerInterval := c.Duration("manager-interval") // Check values - if firebaseKeyFile != "" && !fileExists(firebaseKeyFile) { + if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { return errors.New("if set, FCM key file must exist") + } else if keepaliveInterval < 5*time.Second { + return errors.New("keepalive interval cannot be lower than five seconds") + } else if managerInterval < 5*time.Second { + return errors.New("manager interval cannot be lower than five seconds") + } else if messageBufferDuration < managerInterval { + return errors.New("message buffer duration cannot be lower than manager interval") } - // Run main bot, can be killed by signal + // Run server conf := config.New(listenHTTP) conf.FirebaseKeyFile = firebaseKeyFile + conf.MessageBufferDuration = messageBufferDuration conf.KeepaliveInterval = keepaliveInterval + conf.ManagerInterval = managerInterval s, err := server.New(conf) if err != nil { log.Fatalln(err) @@ -68,9 +82,9 @@ func execRun(c *cli.Context) error { func initConfigFileInputSource(configFlag string, flags []cli.Flag) cli.BeforeFunc { return func(context *cli.Context) error { configFile := context.String(configFlag) - if context.IsSet(configFlag) && !fileExists(configFile) { + if context.IsSet(configFlag) && !util.FileExists(configFile) { return fmt.Errorf("config file %s does not exist", configFile) - } else if !context.IsSet(configFlag) && !fileExists(configFile) { + } else if !context.IsSet(configFlag) && !util.FileExists(configFile) { return nil } inputSource, err := altsrc.NewYamlSourceFromFile(configFile) @@ -80,8 +94,3 @@ func initConfigFileInputSource(configFlag string, flags []cli.Flag) cli.BeforeFu return altsrc.ApplyInputSourceValues(context, inputSource, flags) } } - -func fileExists(filename string) bool { - stat, _ := os.Stat(filename) - return stat != nil -} diff --git a/config/config.go b/config/config.go index 7de949a1..6e3f242a 100644 --- a/config/config.go +++ b/config/config.go @@ -8,9 +8,10 @@ import ( // Defines default config settings const ( - DefaultListenHTTP = ":80" - DefaultKeepaliveInterval = 30 * time.Second - defaultManagerInterval = time.Minute + DefaultListenHTTP = ":80" + DefaultMessageBufferDuration = 12 * time.Hour + DefaultKeepaliveInterval = 30 * time.Second + DefaultManagerInterval = time.Minute ) // Defines the max number of requests, here: @@ -22,22 +23,24 @@ var ( // Config is the main config struct for the application. Use New to instantiate a default config struct. type Config struct { - ListenHTTP string - Limit rate.Limit - LimitBurst int - FirebaseKeyFile string - KeepaliveInterval time.Duration - ManagerInterval time.Duration + ListenHTTP string + FirebaseKeyFile string + MessageBufferDuration time.Duration + KeepaliveInterval time.Duration + ManagerInterval time.Duration + Limit rate.Limit + LimitBurst int } // New instantiates a default new config func New(listenHTTP string) *Config { return &Config{ - ListenHTTP: listenHTTP, - Limit: defaultLimit, - LimitBurst: defaultLimitBurst, - FirebaseKeyFile: "", - KeepaliveInterval: DefaultKeepaliveInterval, - ManagerInterval: defaultManagerInterval, + ListenHTTP: listenHTTP, + FirebaseKeyFile: "", + MessageBufferDuration: DefaultMessageBufferDuration, + KeepaliveInterval: DefaultKeepaliveInterval, + ManagerInterval: DefaultManagerInterval, + Limit: defaultLimit, + LimitBurst: defaultLimitBurst, } } diff --git a/server/index.html b/server/index.html index 0c9f5392..9ad0d701 100644 --- a/server/index.html +++ b/server/index.html @@ -38,14 +38,31 @@

+

Publishing messages

+

+ Publishing messages can be done via PUT or POST using. Topics are created on the fly by subscribing or publishing to them. + Because there is no sign-up, the topic is essentially a password, so pick something that's not easily guessable. +

+

+ Here's an example showing how to publish a message using curl: +

+ + curl -d "long process is done" ntfy.sh/mytopic + +

+ Here's an example in JS with fetch() (see full example): +

+ + fetch('https://ntfy.sh/mytopic', {
+   method: 'POST', // PUT works too
+   body: 'Hello from the other side.'
+ }) +
+

Subscribe to a topic

- Topics are created on the fly by subscribing to them. You can create and subscribe to a topic either in this web UI, or in - your own app by subscribing to an EventSource, - a JSON feed, or raw feed. -

-

- Because there is no sign-up, the topic is essentially a password, so pick something that's not easily guessable. + You can create and subscribe to a topic either in this web UI, or in your own app by subscribing to an + EventSource, a JSON feed, or raw feed.

Subscribe via web

@@ -66,7 +83,7 @@

Subscribe via your app, or via the CLI

- Using EventSource, you can consume + Using EventSource in JS, you can consume notifications like this (see full example):

@@ -76,30 +93,29 @@ };

- Or you can use curl or any other HTTP library. Here's an example for the /json endpoint, - which prints one JSON message per line (keepalive and open messages have an "event" field): -

- - $ curl -s ntfy.sh/mytopic/json
- {"time":1635359841,"event":"open"}
- {"time":1635359844,"message":"This is a notification"}
- {"time":1635359851,"event":"keepalive"} -
-

- Using the /sse endpoint (SSE, server-sent events stream): + You can also use the same /sse endpoint via curl or any other HTTP library:

$ curl -s ntfy.sh/mytopic/sse
event: open
- data: {"time":1635359796,"event":"open"}

+ data: {"id":"weSj9RtNkj","time":1635528898,"event":"open","topic":"mytopic"}

- data: {"time":1635359803,"message":"This is a notification"}

+ data: {"id":"p0M5y6gcCY","time":1635528909,"event":"message","topic":"mytopic","message":"Hi!"}

event: keepalive
- data: {"time":1635359806,"event":"keepalive"} + data: {"id":"VNxNIg5fpt","time":1635528928,"event":"keepalive","topic":"test"}

- Using the /raw endpoint (empty lines are keepalive messages): + To consume JSON instead, use the /json endpoint, which prints one message per line: +

+ + $ curl -s ntfy.sh/mytopic/json
+ {"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"}
+ {"id":"hwQ2YpKdmg","time":1635528741,"event":"message","topic":"mytopic","message":"Hi!"}
+ {"id":"DGUDShMCsc","time":1635528787,"event":"keepalive","topic":"mytopic"} +
+

+ Or use the /raw endpoint if you need something super simple (empty lines are keepalive messages):

$ curl -s ntfy.sh/mytopic/raw
@@ -107,27 +123,25 @@ This is a notification
-

Publishing messages

+

Message buffering and polling

- Publishing messages can be done via PUT or POST using. Here's an example using curl: + Messages are buffered in memory for a few hours to account for network interruptions of subscribers. + You can read back what you missed by using the since=... query parameter. It takes either a + duration (e.g. 10m or 30s) or a Unix timestamp (e.g. 1635528757):

- curl -d "long process is done" ntfy.sh/mytopic + $ curl -s "ntfy.sh/mytopic/json?since=10m"
+ # Same output as above, but includes messages from up to 10 minutes ago

- Here's an example in JS with fetch() (see full example): + You can also just poll for messages if you don't like the long-standing connection using the poll=1 + query parameter. The connection will end after all available messages have been read. This parameter has to be + combined with since=.

- fetch('https://ntfy.sh/mytopic', {
-   method: 'POST', // PUT works too
-   body: 'Hello from the other side.'
- }) + $ curl -s "ntfy.sh/mytopic/json?poll=1&since=10m"
+ # Returns messages from up to 10 minutes ago and ends the connection
-

- Messages published to a non-existing topic or a topic without subscribers will not be delivered later. - There is (currently) no buffering of any kind. If you're not listening, the message won't be delivered. -

-

FAQ

Isn't this like ...?
diff --git a/server/message.go b/server/message.go index eb9a4bfd..5b91a284 100644 --- a/server/message.go +++ b/server/message.go @@ -1,18 +1,27 @@ package server -import "time" +import ( + "heckel.io/ntfy/util" + "time" +) // List of possible events const ( openEvent = "open" keepaliveEvent = "keepalive" - messageEvent = "message" + messageEvent = "message" +) + +const ( + messageIDLength = 10 ) // message represents a message published to a topic type message struct { - Time int64 `json:"time"` // Unix time in seconds - Event string `json:"event,omitempty"` // One of the above + ID string `json:"id"` // Random message ID + Time int64 `json:"time"` // Unix time in seconds + Event string `json:"event"` // One of the above + Topic string `json:"topic"` Message string `json:"message,omitempty"` } @@ -20,25 +29,27 @@ type message struct { type messageEncoder func(msg *message) (string, error) // newMessage creates a new message with the current timestamp -func newMessage(event string, msg string) *message { +func newMessage(event, topic, msg string) *message { return &message{ + ID: util.RandomString(messageIDLength), Time: time.Now().Unix(), Event: event, + Topic: topic, Message: msg, } } // newOpenMessage is a convenience method to create an open message -func newOpenMessage() *message { - return newMessage(openEvent, "") +func newOpenMessage(topic string) *message { + return newMessage(openEvent, topic, "") } // newKeepaliveMessage is a convenience method to create a keepalive message -func newKeepaliveMessage() *message { - return newMessage(keepaliveEvent, "") +func newKeepaliveMessage(topic string) *message { + return newMessage(keepaliveEvent, topic, "") } // newDefaultMessage is a convenience method to create a notification message -func newDefaultMessage(msg string) *message { - return newMessage(messageEvent, msg) +func newDefaultMessage(topic, msg string) *message { + return newMessage(messageEvent, topic, msg) } diff --git a/server/server.go b/server/server.go index caa6534a..fcc9f007 100644 --- a/server/server.go +++ b/server/server.go @@ -17,17 +17,23 @@ import ( "net" "net/http" "regexp" + "strconv" "strings" "sync" "time" ) +// TODO add "max connections open" limit +// TODO add "max messages in a topic" limit +// TODO add "max topics" limit + // Server is the main server type Server struct { config *config.Config topics map[string]*topic visitors map[string]*visitor - firebase *messaging.Client + firebase subscriber + messages int64 mu sync.Mutex } @@ -53,10 +59,11 @@ const ( ) var ( - topicRegex = regexp.MustCompile(`^/[^/]+$`) - jsonRegex = regexp.MustCompile(`^/[^/]+/json$`) - sseRegex = regexp.MustCompile(`^/[^/]+/sse$`) - rawRegex = regexp.MustCompile(`^/[^/]+/raw$`) + topicRegex = regexp.MustCompile(`^/[^/]+$`) + jsonRegex = regexp.MustCompile(`^/[^/]+/json$`) + sseRegex = regexp.MustCompile(`^/[^/]+/sse$`) + rawRegex = regexp.MustCompile(`^/[^/]+/raw$`) + staticRegex = regexp.MustCompile(`^/static/.+`) //go:embed "index.html" @@ -65,30 +72,57 @@ var ( //go:embed static webStaticFs embed.FS + errHTTPBadRequest = &errHTTP{http.StatusBadRequest, http.StatusText(http.StatusBadRequest)} errHTTPNotFound = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)} errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)} ) func New(conf *config.Config) (*Server, error) { - var fcm *messaging.Client + var firebaseSubscriber subscriber if conf.FirebaseKeyFile != "" { - fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(conf.FirebaseKeyFile)) - if err != nil { - return nil, err - } - fcm, err = fb.Messaging(context.Background()) + var err error + firebaseSubscriber, err = createFirebaseSubscriber(conf) if err != nil { return nil, err } } return &Server{ config: conf, - firebase: fcm, + firebase: firebaseSubscriber, topics: make(map[string]*topic), visitors: make(map[string]*visitor), }, nil } +func createFirebaseSubscriber(conf *config.Config) (subscriber, error) { + fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(conf.FirebaseKeyFile)) + if err != nil { + return nil, err + } + msg, err := fb.Messaging(context.Background()) + if err != nil { + return nil, err + } + return func(m *message) error { + _, err := msg.Send(context.Background(), &messaging.Message{ + Data: map[string]string{ + "id": m.ID, + "time": fmt.Sprintf("%d", m.Time), + "event": m.Event, + "topic": m.Topic, + "message": m.Message, + }, + Notification: &messaging.Notification{ + Title: m.Topic, // FIXME convert to ntfy.sh/$topic instead + Body: m.Message, + ImageURL: "", + }, + Topic: m.Topic, + }) + return err + }, nil +} + func (s *Server) Run() error { go func() { ticker := time.NewTicker(s.config.ManagerInterval) @@ -106,28 +140,6 @@ func (s *Server) listenAndServe() error { return http.ListenAndServe(s.config.ListenHTTP, nil) } -func (s *Server) updateStatsAndExpire() { - s.mu.Lock() - defer s.mu.Unlock() - - // Expire visitors from rate visitors map - for ip, v := range s.visitors { - if time.Since(v.seen) > visitorExpungeAfter { - delete(s.visitors, ip) - } - } - - // Print stats - var subscribers, messages int - for _, t := range s.topics { - subs, msgs := t.Stats() - subscribers += subs - messages += msgs - } - log.Printf("Stats: %d topic(s), %d subscriber(s), %d message(s) sent, %d visitor(s)", - len(s.topics), subscribers, messages, len(s.visitors)) -} - func (s *Server) handle(w http.ResponseWriter, r *http.Request) { if err := s.handleInternal(w, r); err != nil { if e, ok := err.(*errHTTP); ok { @@ -147,14 +159,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { return s.handleHome(w, r) } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { return s.handleStatic(w, r) + } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { + return s.handlePublish(w, r) } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { return s.handleSubscribeJSON(w, r) } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) { return s.handleSubscribeSSE(w, r) } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) { return s.handleSubscribeRaw(w, r) - } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { - return s.handlePublishHTTP(w, r) } else if r.Method == http.MethodOptions { return s.handleOptions(w, r) } @@ -166,42 +178,28 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error { return err } -func (s *Server) handlePublishHTTP(w http.ResponseWriter, r *http.Request) error { - t, err := s.topic(r.URL.Path[1:]) - if err != nil { - return err - } +func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { + http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r) + return nil +} + +func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request) error { + t := s.createTopic(r.URL.Path[1:]) reader := io.LimitReader(r.Body, messageLimit) b, err := io.ReadAll(reader) if err != nil { return err } - if err := t.Publish(newDefaultMessage(string(b))); err != nil { - return err - } - if err := s.maybePublishFirebase(t.id, string(b)); err != nil { + if err := t.Publish(newDefaultMessage(t.id, string(b))); err != nil { return err } w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + s.mu.Lock() + s.messages++ + s.mu.Unlock() return nil } -func (s *Server) maybePublishFirebase(topic, message string) error { - _, err := s.firebase.Send(context.Background(), &messaging.Message{ - Data: map[string]string{ - "topic": topic, - "message": message, - }, - Notification: &messaging.Notification{ - Title: "ntfy.sh/" + topic, - Body: message, - ImageURL: "", - }, - Topic: topic, - }) - return err -} - func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) error { encoder := func(msg *message) (string, error) { var buf bytes.Buffer @@ -239,6 +237,11 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request) erro func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format string, contentType string, encoder messageEncoder) error { t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack + since, err := parseSince(r) + if err != nil { + return err + } + poll := r.URL.Query().Has("poll") sub := func(msg *message) error { m, err := encoder(msg) if err != nil { @@ -252,11 +255,17 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format } return nil } - subscriberID := t.Subscribe(sub) - defer s.unsubscribe(t, subscriberID) w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType) - if err := sub(newOpenMessage()); err != nil { // Send out open message + if poll { + return sendOldMessages(t, since, sub) + } + subscriberID := t.Subscribe(sub) + defer t.Unsubscribe(subscriberID) + if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message + return err + } + if err := sendOldMessages(t, since, sub); err != nil { return err } for { @@ -266,49 +275,85 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format case <-r.Context().Done(): return nil case <-time.After(s.config.KeepaliveInterval): - if err := sub(newKeepaliveMessage()); err != nil { // Send keepalive message + if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message return err } } } } +func sendOldMessages(t *topic, since time.Time, sub subscriber) error { + if since.IsZero() { + return nil + } + for _, m := range t.Messages(since) { + if err := sub(m); err != nil { + return err + } + } + return nil +} + +func parseSince(r *http.Request) (time.Time, error) { + if !r.URL.Query().Has("since") { + return time.Time{}, nil + } + if since, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil { + return time.Unix(since, 0), nil + } + if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil { + return time.Now().Add(-1 * d), nil + } + return time.Time{}, errHTTPBadRequest +} + func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST") return nil } -func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { - http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r) - return nil -} - func (s *Server) createTopic(id string) *topic { s.mu.Lock() defer s.mu.Unlock() if _, ok := s.topics[id]; !ok { s.topics[id] = newTopic(id) + if s.firebase != nil { + s.topics[id].Subscribe(s.firebase) + } } return s.topics[id] } -func (s *Server) topic(topicID string) (*topic, error) { +func (s *Server) updateStatsAndExpire() { s.mu.Lock() defer s.mu.Unlock() - c, ok := s.topics[topicID] - if !ok { - return nil, errHTTPNotFound - } - return c, nil -} -func (s *Server) unsubscribe(t *topic, subscriberID int) { - s.mu.Lock() - defer s.mu.Unlock() - if subscribers := t.Unsubscribe(subscriberID); subscribers == 0 { - delete(s.topics, t.id) + // Expire visitors from rate visitors map + for ip, v := range s.visitors { + if time.Since(v.seen) > visitorExpungeAfter { + delete(s.visitors, ip) + } } + + // Prune old messages, remove topics without subscribers + for _, t := range s.topics { + t.Prune(s.config.MessageBufferDuration) + subs, msgs := t.Stats() + if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { + delete(s.topics, t.id) + } + } + + // Print stats + var subscribers, messages int + for _, t := range s.topics { + subs, msgs := t.Stats() + subscribers += subs + messages += msgs + } + log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)", + s.messages, len(s.topics), subscribers, messages, len(s.visitors)) } // visitor creates or retrieves a rate.Limiter for the given visitor. diff --git a/server/topic.go b/server/topic.go index ab9f26b8..037f4857 100644 --- a/server/topic.go +++ b/server/topic.go @@ -2,7 +2,6 @@ package server import ( "context" - "errors" "log" "math/rand" "sync" @@ -14,7 +13,7 @@ import ( type topic struct { id string subscribers map[int]subscriber - messages int + messages []*message last time.Time ctx context.Context cancel context.CancelFunc @@ -45,21 +44,17 @@ func (t *topic) Subscribe(s subscriber) int { return subscriberID } -func (t *topic) Unsubscribe(id int) int { +func (t *topic) Unsubscribe(id int) { t.mu.Lock() defer t.mu.Unlock() delete(t.subscribers, id) - return len(t.subscribers) } func (t *topic) Publish(m *message) error { t.mu.Lock() defer t.mu.Unlock() - if len(t.subscribers) == 0 { - return errors.New("no subscribers") - } t.last = time.Now() - t.messages++ + t.messages = append(t.messages, m) for _, s := range t.subscribers { if err := s(m); err != nil { log.Printf("error publishing message to subscriber") @@ -68,10 +63,36 @@ func (t *topic) Publish(m *message) error { return nil } +func (t *topic) Messages(since time.Time) []*message { + t.mu.Lock() + defer t.mu.Unlock() + messages := make([]*message, 0) // copy! + for _, m := range t.messages { + msgTime := time.Unix(m.Time, 0) + if msgTime == since || msgTime.After(since) { + messages = append(messages, m) + } + } + return messages +} + +func (t *topic) Prune(keep time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + for i, m := range t.messages { + msgTime := time.Unix(m.Time, 0) + if time.Since(msgTime) < keep { + t.messages = t.messages[i:] + return + } + } + t.messages = make([]*message, 0) +} + func (t *topic) Stats() (subscribers int, messages int) { t.mu.Lock() defer t.mu.Unlock() - return len(t.subscribers), t.messages + return len(t.subscribers), len(t.messages) } func (t *topic) Close() { diff --git a/util/util.go b/util/util.go new file mode 100644 index 00000000..73516220 --- /dev/null +++ b/util/util.go @@ -0,0 +1,29 @@ +package util + +import ( + "math/rand" + "os" + "time" +) + +const ( + randomStringCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +var ( + random = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +func FileExists(filename string) bool { + stat, _ := os.Stat(filename) + return stat != nil +} + +// RandomString returns a random string with a given length +func RandomString(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = randomStringCharset[random.Intn(len(randomStringCharset))] + } + return string(b) +}