From 2329695a47f7b3b91a7244a52b1440115f4d3918 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 23 Feb 2023 20:46:53 -0500 Subject: [PATCH] Polishing --- server/server.go | 32 +++++++++++++------------------- server/server_manager.go | 20 +++++++++++++------- server/server_test.go | 35 ++++++++++++++++++++++++++++++++++- server/topic.go | 36 +++++++++++++++++++----------------- server/util.go | 9 +++++++++ 5 files changed, 88 insertions(+), 44 deletions(-) diff --git a/server/server.go b/server/server.go index 0466cd23..cdd85c6b 100644 --- a/server/server.go +++ b/server/server.go @@ -570,14 +570,8 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { } func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { - vrate, ok := r.Context().Value(contextRateVisitor).(*visitor) - if !ok { - return nil, errHTTPInternalError - } - t, ok := r.Context().Value(contextTopic).(*topic) - if !ok { - return nil, errHTTPInternalError - } + t := fromContext[topic](r, contextTopic) + vrate := fromContext[visitor](r, contextRateVisitor) if !vrate.MessageAllowed() { return nil, errHTTPTooManyRequestsLimitMessages } @@ -586,10 +580,13 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes return nil, err } m := newDefaultMessage(t.ID, "") - cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vrate, m) + cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, m) if err != nil { return nil, err } + if email != "" && !vrate.EmailAllowed() { + return nil, errHTTPTooManyRequestsLimitEmails + } if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) } @@ -605,13 +602,15 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes m.Message = emptyMessageBody } delayed := m.Time > time.Now().Unix() - ev := logvrm(vrate, r, m). + ev := logvrm(v, r, m). Tag(tagPublish). Fields(log.Context{ "message_delayed": delayed, "message_firebase": firebase, "message_unifiedpush": unifiedpush, "message_email": email, + "rate_visitor_ip": vrate.IP().String(), + "rate_user_id": vrate.MaybeUserID(), }) if ev.IsTrace() { ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message") @@ -623,7 +622,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes return nil, err } if s.firebaseClient != nil && firebase { - go s.sendToFirebase(vrate, m) + go s.sendToFirebase(v, m) } if s.smtpSender != nil && email != "" { go s.sendEmail(v, m, email) @@ -708,7 +707,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { } } -func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { +func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { cache = readBoolParam(r, true, "x-cache", "cache") firebase = readBoolParam(r, true, "x-firebase", "firebase") m.Title = readParam(r, "x-title", "title", "t") @@ -747,11 +746,6 @@ func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) m.Icon = icon } email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") - if email != "" { - if !vrate.EmailAllowed() { - return false, false, "", false, errHTTPTooManyRequestsLimitEmails - } - } if s.smtpSender == nil && email != "" { return false, false, "", false, errHTTPBadRequestEmailDisabled } @@ -993,7 +987,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * defer cancel() subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1126,7 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) } defer func() { for i, subscriberID := range subscriberIDs { diff --git a/server/server_manager.go b/server/server_manager.go index f232ee0f..82367884 100644 --- a/server/server_manager.go +++ b/server/server_manager.go @@ -3,7 +3,6 @@ package server import ( "heckel.io/ntfy/log" "strings" - "time" ) func (s *Server) execManager() { @@ -38,16 +37,23 @@ func (s *Server) execManager() { subs := t.SubscribersCount() ev := log.Tag(tagManager) if ev.IsTrace() { - expiryMessage := "" - if subs == 0 { - expiryTime := time.Until(t.expires) - expiryMessage = ", expires in " + expiryTime.String() + vrate := t.RateVisitor() + if vrate != nil { + ev.Fields(log.Context{ + "rate_visitor_ip": vrate.IP(), + "rate_visitor_user_id": vrate.MaybeUserID(), + }) } - ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) + ev. + Fields(log.Context{ + "message_topic": t.ID, + "message_topic_subscribers": subs, + }). + Trace("- topic %s: %d subscribers", t.ID, subs) } msgs, exists := messageCounts[t.ID] if t.Stale() && (!exists || msgs == 0) { - log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID) + log.Tag(tagManager).Field("message_topic", t.ID).Trace("Deleting empty topic %s", t.ID) emptyTopics++ delete(s.topics, t.ID) continue diff --git a/server/server_test.go b/server/server_test.go index b26cf780..802e8742 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2030,7 +2030,40 @@ func TestServer_Matrix_SubscriberRateLimiting_UP_Only(t *testing.T) { } } -// FIXME add test for rate visitor expiration +func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) { + c := newTestConfig(t) + c.VisitorRequestLimitBurst = 3 + s := newTestServer(t, c) + + // "Register" rate visitor + subscriberFn := func(r *http.Request) { + r.RemoteAddr = "1.2.3.4" + } + rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{ + "rate-topics": "*", + }, subscriberFn) + require.Equal(t, 200, rr.Code) + require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String()) + require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor) + + // Publish message, observe rate visitor tokens being decreased + response := request(t, s, "POST", "/mytopic", "some message", nil) + require.Equal(t, 200, response.Code) + require.Equal(t, int64(0), s.visitors["ip:9.9.9.9"].messagesLimiter.Value()) + require.Equal(t, int64(1), s.topics["mytopic"].rateVisitor.messagesLimiter.Value()) + require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor) + + // Expire visitor + s.visitors["ip:1.2.3.4"].seen = time.Now().Add(-1 * 25 * time.Hour) + s.pruneVisitors() + + // Publish message again, observe that rateVisitor is not used anymore and is reset + response = request(t, s, "POST", "/mytopic", "some message", nil) + require.Equal(t, 200, response.Code) + require.Equal(t, int64(1), s.visitors["ip:9.9.9.9"].messagesLimiter.Value()) + require.Nil(t, s.topics["mytopic"].rateVisitor) + require.Nil(t, s.visitors["ip:1.2.3.4"]) +} func newTestConfig(t *testing.T) *Config { conf := NewConfig() diff --git a/server/topic.go b/server/topic.go index 2bd5a472..e6d4687e 100644 --- a/server/topic.go +++ b/server/topic.go @@ -4,11 +4,6 @@ import ( "heckel.io/ntfy/log" "math/rand" "sync" - "time" -) - -const ( - topicExpiryDuration = 6 * time.Hour ) // topic represents a channel to which subscribers can subscribe, and publishers @@ -17,13 +12,12 @@ type topic struct { ID string subscribers map[int]*topicSubscriber rateVisitor *visitor - expires time.Time mu sync.RWMutex } type topicSubscriber struct { + userID string // User ID associated with this subscription, may be empty subscriber subscriber - visitor *visitor // User ID associated with this subscription, may be empty cancel func() } @@ -39,12 +33,12 @@ func newTopic(id string) *topic { } // Subscribe subscribes to this topic -func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { +func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int { t.mu.Lock() defer t.mu.Unlock() subscriberID := rand.Int() t.subscribers[subscriberID] = &topicSubscriber{ - visitor: visitor, // May be empty + userID: userID, // May be empty subscriber: s, cancel: cancel, } @@ -54,7 +48,10 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { func (t *topic) Stale() bool { t.mu.Lock() defer t.mu.Unlock() - return len(t.subscribers) == 0 && t.expires.Before(time.Now()) + if t.rateVisitor != nil && !t.rateVisitor.Stale() { + return false + } + return len(t.subscribers) == 0 } func (t *topic) SetRateVisitor(v *visitor) { @@ -66,6 +63,9 @@ func (t *topic) SetRateVisitor(v *visitor) { func (t *topic) RateVisitor() *visitor { t.mu.Lock() defer t.mu.Unlock() + if t.rateVisitor != nil && t.rateVisitor.Stale() { + t.rateVisitor = nil + } return t.rateVisitor } @@ -74,9 +74,6 @@ func (t *topic) Unsubscribe(id int) { t.mu.Lock() defer t.mu.Unlock() delete(t.subscribers, id) - if len(t.subscribers) == 0 { - t.expires = time.Now().Add(topicExpiryDuration) - } } // Publish asynchronously publishes to all subscribers @@ -115,9 +112,14 @@ func (t *topic) CancelSubscribers(exceptUserID string) { t.mu.Lock() defer t.mu.Unlock() for _, s := range t.subscribers { - if s.visitor.MaybeUserID() != exceptUserID { - // TODO: Shouldn't this log the IP for anonymous visitors? It was s.userID before my change. - log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.visitor.MaybeUserID()) + if s.userID != exceptUserID { + log. + Tag(tagSubscribe). + Fields(log.Context{ + "message_topic": t.ID, + "user_id": s.userID, + }). + Debug("Canceling subscriber %s", s.userID) s.cancel() } } @@ -130,7 +132,7 @@ func (t *topic) subscribersCopy() map[int]*topicSubscriber { subscribers := make(map[int]*topicSubscriber) for k, sub := range t.subscribers { subscribers[k] = &topicSubscriber{ - visitor: sub.visitor, + userID: sub.userID, subscriber: sub.subscriber, cancel: sub.cancel, } diff --git a/server/util.go b/server/util.go index c6f78102..75810b59 100644 --- a/server/util.go +++ b/server/util.go @@ -2,6 +2,7 @@ package server import ( "context" + "fmt" "heckel.io/ntfy/util" "io" "net/http" @@ -105,3 +106,11 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request { } return r.WithContext(c) } + +func fromContext[T any](r *http.Request, key contextKey) *T { + t, ok := r.Context().Value(key).(*T) + if !ok { + panic(fmt.Sprintf("cannot find key %v in request context", key)) + } + return t +}