diff --git a/server/server.go b/server/server.go index 656902f4..f5dbcd97 100644 --- a/server/server.go +++ b/server/server.go @@ -35,27 +35,19 @@ import ( ) /* -TODO --- - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) -- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)? -- MEDIUM Rate limiting: Test daily message quota read from database initially - MEDIUM: Races with v.user (see publishSyncEventAsync test) +- MEDIUM: Test that anonymous user and user without tier are the same visitor +- MEDIUM: Make sure account endpoints make sense for admins - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) - MEDIUM: Reservation table delete button: dialog "keep or delete messages?" +- MEDIUM: Tests for remaining payment endpoints - LOW: UI: Flickering upgrade banner when logging in - LOW: JS constants - LOW: Payments reconciliation process -Limits & rate limiting: - users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address! - -Make sure account endpoints make sense for admins - -Tests: -- Payment endpoints (make mocks) */ // Server is the main server, providing the UI and API for ntfy @@ -513,7 +505,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) return errHTTPNotFound } if r.Method == http.MethodGet { - if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil { + if !v.BandwidthAllowed(stat.Size()) { return errHTTPTooManyRequestsLimitAttachmentBandwidth } } @@ -543,7 +535,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if err != nil { return nil, err } - if err := v.MessageAllowed(); err != nil { + if !v.MessageAllowed() { return nil, errHTTPTooManyRequestsLimitMessages } body, err := util.Peek(r.Body, s.config.MessageLimit) @@ -558,9 +550,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) } - if v.user != nil { - m.User = v.user.ID - } + m.User = v.MaybeUserID() m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix() if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err @@ -582,7 +572,6 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes go s.sendToFirebase(v, m) } if s.smtpSender != nil && email != "" { - v.IncrementEmails() go s.sendEmail(v, m, email) } if s.config.UpstreamBaseURL != "" { @@ -597,8 +586,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes return nil, err } } - if s.userManager != nil && v.user != nil { - s.userManager.EnqueueStats(v.user.ID, v.Stats()) // FIXME this makes no sense for tier-less users + u := v.User() + if s.userManager != nil && u != nil && u.Tier != nil { + s.userManager.EnqueueStats(u.ID, v.Stats()) } s.mu.Lock() s.messages++ @@ -704,7 +694,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca } email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") if email != "" { - if err := v.EmailAllowed(); err != nil { + if !v.EmailAllowed() { return false, false, "", false, errHTTPTooManyRequestsLimitEmails } } @@ -909,7 +899,7 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error { log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r)) defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r)) - if err := v.SubscriptionAllowed(); err != nil { + if !v.SubscriptionAllowed() { return errHTTPTooManyRequestsLimitSubscriptions } defer v.RemoveSubscription() @@ -989,7 +979,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { return errHTTPBadRequestWebSocketsUpgradeHeaderMissing } - if err := v.SubscriptionAllowed(); err != nil { + if !v.SubscriptionAllowed() { return errHTTPTooManyRequestsLimitSubscriptions } defer v.RemoveSubscription() diff --git a/server/server_account.go b/server/server_account.go index 718f0225..5f1c82aa 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -23,7 +23,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * } else if v.user != nil { return errHTTPUnauthorized // Cannot create account from user context } - if err := v.AccountCreationAllowed(); err != nil { + if !v.AccountCreationAllowed() { return errHTTPTooManyRequestsLimitAccountCreation } } @@ -428,11 +428,12 @@ func (s *Server) publishSyncEvent(v *visitor) error { func (s *Server) publishSyncEventAsync(v *visitor) { go func() { - if v.user == nil || v.user.SyncTopic == "" { + u := v.User() + if u == nil || u.SyncTopic == "" { return } if err := s.publishSyncEvent(v); err != nil { - log.Trace("Error publishing to user %s's sync topic %s: %s", v.user.Name, v.user.SyncTopic, err.Error()) + log.Trace("Error publishing to user %s's sync topic %s: %s", u.Name, u.SyncTopic, err.Error()) } }() } diff --git a/server/server_test.go b/server/server_test.go index f3789e0f..410473fb 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -841,23 +841,35 @@ func TestServer_StatsResetter(t *testing.T) { require.Equal(t, int64(0), account.Stats.Messages) } -func TestServer_StatsResetter_MessageLimiter(t *testing.T) { - // This tests that the messageLimiter (the only fixed limiter) is reset by the stats resetter +func TestServer_StatsResetter_MessageLimiter_EmailsLimiter(t *testing.T) { + // This tests that the messageLimiter (the only fixed limiter) and the emailsLimiter (token bucket) + // is reset by the stats resetter c := newTestConfigWithAuthFile(t) s := newTestServer(t, c) + s.smtpSender = &testMailer{} // Publish some messages, and check stats for i := 0; i < 3; i++ { response := request(t, s, "PUT", "/mytopic", "test", nil) require.Equal(t, 200, response.Code) } + response := request(t, s, "PUT", "/mytopic", "test", map[string]string{ + "Email": "test@email.com", + }) + require.Equal(t, 200, response.Code) + rr := request(t, s, "GET", "/v1/account", "", nil) require.Equal(t, 200, rr.Code) account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) require.Nil(t, err) - require.Equal(t, int64(3), account.Stats.Messages) - require.Equal(t, int64(3), s.visitor(netip.MustParseAddr("9.9.9.9"), nil).messagesLimiter.Value()) + require.Equal(t, int64(4), account.Stats.Messages) + require.Equal(t, int64(1), account.Stats.Emails) + v := s.visitor(netip.MustParseAddr("9.9.9.9"), nil) + require.Equal(t, int64(4), v.Stats().Messages) + require.Equal(t, int64(4), v.messagesLimiter.Value()) + require.Equal(t, int64(1), v.Stats().Emails) + require.Equal(t, int64(1), v.emailsLimiter.Value()) // Reset stats and check again s.resetStats() @@ -866,7 +878,53 @@ func TestServer_StatsResetter_MessageLimiter(t *testing.T) { account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) require.Nil(t, err) require.Equal(t, int64(0), account.Stats.Messages) - require.Equal(t, int64(0), s.visitor(netip.MustParseAddr("9.9.9.9"), nil).messagesLimiter.Value()) + require.Equal(t, int64(0), account.Stats.Emails) + v = s.visitor(netip.MustParseAddr("9.9.9.9"), nil) + require.Equal(t, int64(0), v.Stats().Messages) + require.Equal(t, int64(0), v.messagesLimiter.Value()) + require.Equal(t, int64(0), v.Stats().Emails) + require.Equal(t, int64(0), v.emailsLimiter.Value()) +} + +func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) { + // This tests that the daily message quota is prefilled originally from the database, + // if the visitor is unknown + + c := newTestConfigWithAuthFile(t) + s := newTestServer(t, c) + var err error + s.userManager, err = user.NewManagerWithStatsInterval(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, 100*time.Millisecond) + require.Nil(t, err) + + // Create user, and update it with some message and email stats + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "test", + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("phil", "test")) + + u, err := s.userManager.User("phil") + require.Nil(t, err) + s.userManager.EnqueueStats(u.ID, &user.Stats{ + Messages: 123456, + Emails: 999, + }) + time.Sleep(400 * time.Millisecond) + + // Get account and verify stats are read from the DB, and that the visitor also has these stats + rr := request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Nil(t, err) + require.Equal(t, int64(123456), account.Stats.Messages) + require.Equal(t, int64(999), account.Stats.Emails) + v := s.visitor(netip.MustParseAddr("9.9.9.9"), u) + require.Equal(t, int64(123456), v.Stats().Messages) + require.Equal(t, int64(123456), v.messagesLimiter.Value()) + require.Equal(t, int64(999), v.Stats().Emails) + require.Equal(t, int64(999), v.emailsLimiter.Value()) } type testMailer struct { diff --git a/server/visitor.go b/server/visitor.go index b3a8bbe3..0fdd98d6 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -58,12 +58,11 @@ type visitor struct { userManager *user.Manager // May be nil ip netip.Addr // Visitor IP address user *user.User // Only set if authenticated user, otherwise nil - emails int64 // Number of emails sent, reset every day requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) messagesLimiter *util.FixedLimiter // Rate limiter for messages - emailsLimiter *rate.Limiter // Rate limiter for emails - subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections) - bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads + emailsLimiter *util.RateLimiter // Rate limiter for emails + subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections) + bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil firebase time.Time // Next allowed Firebase message seen time.Time // Last seen time of this visitor (needed for removal of stale visitors) @@ -123,7 +122,6 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana userManager: userManager, // May be nil ip: ip, user: user, - emails: emails, firebase: time.Unix(0, 0), seen: time.Now(), subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), @@ -133,7 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana bandwidthLimiter: nil, // Set in resetLimiters accountLimiter: nil, // Set in resetLimiters, may be nil } - v.resetLimiters(messages) + v.resetLimitersNoLock(messages, emails) return v } @@ -153,6 +151,8 @@ func (v *visitor) stringNoLock() string { } func (v *visitor) RequestAllowed() error { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() if !v.requestLimiter.Allow() { return errVisitorLimitReached } @@ -174,40 +174,43 @@ func (v *visitor) FirebaseTemporarilyDeny() { v.firebase = time.Now().Add(v.config.FirebaseQuotaExceededPenaltyDuration) } -func (v *visitor) MessageAllowed() error { - if v.messagesLimiter.Allow(1) != nil { - return errVisitorLimitReached - } - return nil -} - -func (v *visitor) EmailAllowed() error { - if !v.emailsLimiter.Allow() { - return errVisitorLimitReached - } - return nil -} - -func (v *visitor) SubscriptionAllowed() error { - v.mu.Lock() +func (v *visitor) MessageAllowed() bool { + v.mu.Lock() // limiters could be replaced! defer v.mu.Unlock() - if err := v.subscriptionLimiter.Allow(1); err != nil { - return errVisitorLimitReached - } - return nil + return v.messagesLimiter.Allow() } -func (v *visitor) AccountCreationAllowed() error { +func (v *visitor) EmailAllowed() bool { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + return v.emailsLimiter.Allow() +} + +func (v *visitor) SubscriptionAllowed() bool { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + return v.subscriptionLimiter.Allow() +} + +func (v *visitor) AccountCreationAllowed() bool { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) { - return errVisitorLimitReached + return false } - return nil + return true +} + +func (v *visitor) BandwidthAllowed(bytes int64) bool { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + return v.bandwidthLimiter.AllowN(bytes) } func (v *visitor) RemoveSubscription() { v.mu.Lock() defer v.mu.Unlock() - v.subscriptionLimiter.Allow(-1) + v.subscriptionLimiter.AllowN(-1) } func (v *visitor) Keepalive() { @@ -217,6 +220,8 @@ func (v *visitor) Keepalive() { } func (v *visitor) BandwidthLimiter() util.Limiter { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() return v.bandwidthLimiter } @@ -226,26 +231,27 @@ func (v *visitor) Stale() bool { return time.Since(v.seen) > visitorExpungeAfter } -func (v *visitor) IncrementEmails() { - v.mu.Lock() - defer v.mu.Unlock() - v.emails++ -} - func (v *visitor) Stats() *user.Stats { - v.mu.Lock() + v.mu.Lock() // limiters could be replaced! defer v.mu.Unlock() return &user.Stats{ Messages: v.messagesLimiter.Value(), - Emails: v.emails, + Emails: v.emailsLimiter.Value(), } } func (v *visitor) ResetStats() { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + v.emailsLimiter.Reset() + v.messagesLimiter.Reset() +} + +// User returns the visitor user, or nil if there is none +func (v *visitor) User() *user.User { v.mu.Lock() defer v.mu.Unlock() - v.emails = 0 - v.messagesLimiter.Reset() + return v.user // May be nil } // SetUser sets the visitors user to the given value @@ -255,7 +261,7 @@ func (v *visitor) SetUser(u *user.User) { shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver v.user = u if shouldResetLimiters { - v.resetLimiters(0) + v.resetLimitersNoLock(0, 0) } } @@ -270,12 +276,12 @@ func (v *visitor) MaybeUserID() string { return "" } -func (v *visitor) resetLimiters(messages int64) { +func (v *visitor) resetLimitersNoLock(messages, emails int64) { log.Debug("%s Resetting limiters for visitor", v.stringNoLock()) limits := v.limitsNoLock() v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst) v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages) - v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst) + v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails) v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay) if v.user == nil { v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) @@ -340,12 +346,13 @@ func configBasedVisitorLimits(conf *Config) *visitorLimits { func (v *visitor) Info() (*visitorInfo, error) { v.mu.Lock() messages := v.messagesLimiter.Value() - emails := v.emails + emails := v.emailsLimiter.Value() v.mu.Unlock() var attachmentsBytesUsed int64 var err error - if v.user != nil { - attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.ID) + u := v.User() + if u != nil { + attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(u.ID) } else { attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String()) } @@ -353,8 +360,8 @@ func (v *visitor) Info() (*visitorInfo, error) { return nil, err } var reservations int64 - if v.user != nil && v.userManager != nil { - reservations, err = v.userManager.ReservationsCount(v.user.Name) + if v.userManager != nil && u != nil { + reservations, err = v.userManager.ReservationsCount(u.Name) if err != nil { return nil, err } diff --git a/user/manager.go b/user/manager.go index 600bcedc..5f147a78 100644 --- a/user/manager.go +++ b/user/manager.go @@ -301,11 +301,11 @@ var _ Auther = (*Manager)(nil) // NewManager creates a new Manager instance func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) { - return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval) + return NewManagerWithStatsInterval(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval) } -// NewManager creates a new Manager instance -func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) { +// NewManagerWithStatsInterval creates a new Manager instance +func NewManagerWithStatsInterval(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err diff --git a/user/manager_test.go b/user/manager_test.go index 335cdf87..860799ea 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -545,7 +545,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { } func TestManager_EnqueueStats(t *testing.T) { - a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) + a, err := NewManagerWithStatsInterval(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) require.Nil(t, err) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) @@ -575,7 +575,7 @@ func TestManager_EnqueueStats(t *testing.T) { } func TestManager_ChangeSettings(t *testing.T) { - a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) + a, err := NewManagerWithStatsInterval(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) require.Nil(t, err) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) @@ -718,7 +718,7 @@ func newTestManager(t *testing.T, defaultAccess Permission) *Manager { } func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager { - a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval) + a, err := NewManagerWithStatsInterval(filename, startupQueries, defaultAccess, statsWriterInterval) require.Nil(t, err) return a } diff --git a/util/limit.go b/util/limit.go index dd3b56fb..ad2118c7 100644 --- a/util/limit.go +++ b/util/limit.go @@ -13,8 +13,17 @@ var ErrLimitReached = errors.New("limit reached") // Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value type Limiter interface { - // Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached - Allow(n int64) error + // Allow adds one to the limiters value, or returns false if the limit has been reached + Allow() bool + + // AllowN adds n to the limiters value, or returns false if the limit has been reached + AllowN(n int64) bool + + // Value returns the current internal limiter value + Value() int64 + + // Reset resets the state of the limiter + Reset() } // FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached @@ -25,6 +34,8 @@ type FixedLimiter struct { mu sync.Mutex } +var _ Limiter = (*FixedLimiter)(nil) + // NewFixedLimiter creates a new Limiter func NewFixedLimiter(limit int64) *FixedLimiter { return NewFixedLimiterWithValue(limit, 0) @@ -38,16 +49,22 @@ func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter { } } -// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was -// exceeded after adding n, ErrLimitReached is returned. -func (l *FixedLimiter) Allow(n int64) error { +// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was +// exceeded, false is returned. +func (l *FixedLimiter) Allow() bool { + return l.AllowN(1) +} + +// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was +// exceeded after adding n, false is returned. +func (l *FixedLimiter) AllowN(n int64) bool { l.mu.Lock() defer l.mu.Unlock() if l.value+n > l.limit { - return ErrLimitReached + return false } l.value += n - return nil + return true } // Value returns the current limiter value @@ -66,12 +83,29 @@ func (l *FixedLimiter) Reset() { // RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit. type RateLimiter struct { + r rate.Limit + b int + value int64 limiter *rate.Limiter + mu sync.Mutex } +var _ Limiter = (*RateLimiter)(nil) + // NewRateLimiter creates a new RateLimiter func NewRateLimiter(r rate.Limit, b int) *RateLimiter { + return NewRateLimiterWithValue(r, b, 0) +} + +// NewRateLimiterWithValue creates a new RateLimiter with the given starting value. +// +// Note that the starting value only has informational value. It does not impact the underlying +// value of the rate.Limiter. +func NewRateLimiterWithValue(r rate.Limit, b int, value int64) *RateLimiter { return &RateLimiter{ + r: r, + b: b, + value: value, limiter: rate.NewLimiter(r, b), } } @@ -82,16 +116,40 @@ func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter { return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes) } -// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was -// exceeded after adding n, ErrLimitReached is returned. -func (l *RateLimiter) Allow(n int64) error { +// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was +// exceeded, false is returned. +func (l *RateLimiter) Allow() bool { + return l.AllowN(1) +} + +// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was +// exceeded after adding n, false is returned. +func (l *RateLimiter) AllowN(n int64) bool { if n <= 0 { - return nil // No-op. Can't take back bytes you're written! + return false // No-op. Can't take back bytes you're written! } + l.mu.Lock() + defer l.mu.Unlock() if !l.limiter.AllowN(time.Now(), int(n)) { - return ErrLimitReached + return false } - return nil + l.value += n + return true +} + +// Value returns the current limiter value +func (l *RateLimiter) Value() int64 { + l.mu.Lock() + defer l.mu.Unlock() + return l.value +} + +// Reset sets the limiter's value back to zero, and resets the underlying rate.Limiter +func (l *RateLimiter) Reset() { + l.mu.Lock() + defer l.mu.Unlock() + l.limiter = rate.NewLimiter(l.r, l.b) + l.value = 0 } // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying @@ -117,9 +175,9 @@ func (w *LimitWriter) Write(p []byte) (n int, err error) { w.mu.Lock() defer w.mu.Unlock() for i := 0; i < len(w.limiters); i++ { - if err := w.limiters[i].Allow(int64(len(p))); err != nil { + if !w.limiters[i].AllowN(int64(len(p))) { for j := i - 1; j >= 0; j-- { - w.limiters[j].Allow(-int64(len(p))) // Revert limiters limits if allowed + w.limiters[j].AllowN(-int64(len(p))) // Revert limiters limits if not allowed } return 0, ErrLimitReached } diff --git a/util/limit_test.go b/util/limit_test.go index 53e10b78..51595351 100644 --- a/util/limit_test.go +++ b/util/limit_test.go @@ -7,26 +7,31 @@ import ( "time" ) -func TestFixedLimiter_Add(t *testing.T) { +func TestFixedLimiter_AllowValueReset(t *testing.T) { l := NewFixedLimiter(10) - if err := l.Allow(5); err != nil { - t.Fatal(err) - } - if err := l.Allow(5); err != nil { - t.Fatal(err) - } - if err := l.Allow(5); err != ErrLimitReached { - t.Fatalf("expected ErrLimitReached, got %#v", err) - } + require.True(t, l.AllowN(5)) + require.Equal(t, int64(5), l.Value()) + + require.True(t, l.AllowN(5)) + require.Equal(t, int64(10), l.Value()) + + require.False(t, l.Allow()) + require.Equal(t, int64(10), l.Value()) + + l.Reset() + require.Equal(t, int64(0), l.Value()) + require.True(t, l.Allow()) + require.True(t, l.AllowN(9)) + require.False(t, l.Allow()) } func TestFixedLimiter_AddSub(t *testing.T) { l := NewFixedLimiter(10) - l.Allow(5) + l.AllowN(5) if l.value != 5 { t.Fatalf("expected value to be %d, got %d", 5, l.value) } - l.Allow(-2) + l.AllowN(-2) if l.value != 3 { t.Fatalf("expected value to be %d, got %d", 7, l.value) } @@ -34,17 +39,22 @@ func TestFixedLimiter_AddSub(t *testing.T) { func TestBytesLimiter_Add_Simple(t *testing.T) { l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h - require.Nil(t, l.Allow(100*1024*1024)) - require.Nil(t, l.Allow(100*1024*1024)) - require.Equal(t, ErrLimitReached, l.Allow(300*1024*1024)) + require.True(t, l.AllowN(100*1024*1024)) + require.Equal(t, int64(100*1024*1024), l.Value()) + + require.True(t, l.AllowN(100*1024*1024)) + require.Equal(t, int64(200*1024*1024), l.Value()) + + require.False(t, l.AllowN(300*1024*1024)) + require.Equal(t, int64(200*1024*1024), l.Value()) } func TestBytesLimiter_Add_Wait(t *testing.T) { l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h (~ 303 bytes per 100ms) - require.Nil(t, l.Allow(250*1024*1024)) - require.Equal(t, ErrLimitReached, l.Allow(400)) + require.True(t, l.AllowN(250*1024*1024)) + require.False(t, l.AllowN(400)) time.Sleep(200 * time.Millisecond) - require.Nil(t, l.Allow(400)) + require.True(t, l.AllowN(400)) } func TestLimitWriter_WriteNoLimiter(t *testing.T) {