Add limiters and database changes

This commit is contained in:
binwiederhier 2023-05-07 11:59:15 -04:00
parent 113b7c8a08
commit f9e2d6ddcb
11 changed files with 173 additions and 32 deletions

View file

@ -83,6 +83,8 @@ var flagsServe = append(
altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-sms-daily-limit", Aliases: []string{"visitor_sms_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_SMS_DAILY_LIMIT"}, Value: server.DefaultVisitorSMSDailyLimit, Usage: "max number of SMS messages per visitor per day"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-call-daily-limit", Aliases: []string{"visitor_call_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_CALL_DAILY_LIMIT"}, Value: server.DefaultVisitorCallDailyLimit, Usage: "max number of phone calls per visitor per day"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
@ -168,6 +170,8 @@ func execServe(c *cli.Context) error {
visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
visitorSMSDailyLimit := c.Int("visitor-sms-daily-limit")
visitorCallDailyLimit := c.Int("visitor-call-daily-limit")
behindProxy := c.Bool("behind-proxy")
stripeSecretKey := c.String("stripe-secret-key")
stripeWebhookKey := c.String("stripe-webhook-key")
@ -329,6 +333,8 @@ func execServe(c *cli.Context) error {
conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.VisitorSMSDailyLimit = visitorSMSDailyLimit
conf.VisitorCallDailyLimit = visitorCallDailyLimit
conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting
conf.BehindProxy = behindProxy
conf.StripeSecretKey = stripeSecretKey

View file

@ -18,6 +18,8 @@ const (
defaultMessageLimit = 5000
defaultMessageExpiryDuration = "12h"
defaultEmailLimit = 20
defaultSMSLimit = 10
defaultCallLimit = 10
defaultReservationLimit = 3
defaultAttachmentFileSizeLimit = "15M"
defaultAttachmentTotalSizeLimit = "100M"
@ -48,6 +50,8 @@ var cmdTier = &cli.Command{
&cli.Int64Flag{Name: "message-limit", Value: defaultMessageLimit, Usage: "daily message limit"},
&cli.StringFlag{Name: "message-expiry-duration", Value: defaultMessageExpiryDuration, Usage: "duration after which messages are deleted"},
&cli.Int64Flag{Name: "email-limit", Value: defaultEmailLimit, Usage: "daily email limit"},
&cli.Int64Flag{Name: "sms-limit", Value: defaultSMSLimit, Usage: "daily SMS limit"},
&cli.Int64Flag{Name: "call-limit", Value: defaultCallLimit, Usage: "daily phone call limit"},
&cli.Int64Flag{Name: "reservation-limit", Value: defaultReservationLimit, Usage: "topic reservation limit"},
&cli.StringFlag{Name: "attachment-file-size-limit", Value: defaultAttachmentFileSizeLimit, Usage: "per-attachment file size limit"},
&cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"},
@ -91,6 +95,8 @@ Examples:
&cli.Int64Flag{Name: "message-limit", Usage: "daily message limit"},
&cli.StringFlag{Name: "message-expiry-duration", Usage: "duration after which messages are deleted"},
&cli.Int64Flag{Name: "email-limit", Usage: "daily email limit"},
&cli.Int64Flag{Name: "sms-limit", Usage: "daily SMS limit"},
&cli.Int64Flag{Name: "call-limit", Usage: "daily phone call limit"},
&cli.Int64Flag{Name: "reservation-limit", Usage: "topic reservation limit"},
&cli.StringFlag{Name: "attachment-file-size-limit", Usage: "per-attachment file size limit"},
&cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"},
@ -215,6 +221,8 @@ func execTierAdd(c *cli.Context) error {
MessageLimit: c.Int64("message-limit"),
MessageExpiryDuration: messageExpiryDuration,
EmailLimit: c.Int64("email-limit"),
SMSLimit: c.Int64("sms-limit"),
CallLimit: c.Int64("call-limit"),
ReservationLimit: c.Int64("reservation-limit"),
AttachmentFileSizeLimit: attachmentFileSizeLimit,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit,
@ -267,6 +275,12 @@ func execTierChange(c *cli.Context) error {
if c.IsSet("email-limit") {
tier.EmailLimit = c.Int64("email-limit")
}
if c.IsSet("sms-limit") {
tier.SMSLimit = c.Int64("sms-limit")
}
if c.IsSet("call-limit") {
tier.CallLimit = c.Int64("call-limit")
}
if c.IsSet("reservation-limit") {
tier.ReservationLimit = c.Int64("reservation-limit")
}
@ -357,6 +371,8 @@ func printTier(c *cli.Context, tier *user.Tier) {
fmt.Fprintf(c.App.ErrWriter, "- Message limit: %d\n", tier.MessageLimit)
fmt.Fprintf(c.App.ErrWriter, "- Message expiry duration: %s (%d seconds)\n", tier.MessageExpiryDuration.String(), int64(tier.MessageExpiryDuration.Seconds()))
fmt.Fprintf(c.App.ErrWriter, "- Email limit: %d\n", tier.EmailLimit)
fmt.Fprintf(c.App.ErrWriter, "- SMS limit: %d\n", tier.SMSLimit)
fmt.Fprintf(c.App.ErrWriter, "- Phone call limit: %d\n", tier.CallLimit)
fmt.Fprintf(c.App.ErrWriter, "- Reservation limit: %d\n", tier.ReservationLimit)
fmt.Fprintf(c.App.ErrWriter, "- Attachment file size limit: %s\n", util.FormatSize(tier.AttachmentFileSizeLimit))
fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit))

View file

@ -47,6 +47,8 @@ const (
DefaultVisitorMessageDailyLimit = 0
DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorSMSDailyLimit = 10
DefaultVisitorCallDailyLimit = 10
DefaultVisitorAccountCreationLimitBurst = 3
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
DefaultVisitorAuthFailureLimitBurst = 30
@ -126,6 +128,8 @@ type Config struct {
VisitorMessageDailyLimit int
VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration
VisitorSMSDailyLimit int
VisitorCallDailyLimit int
VisitorAccountCreationLimitBurst int
VisitorAccountCreationLimitReplenish time.Duration
VisitorAuthFailureLimitBurst int

View file

@ -126,6 +126,8 @@ var (
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", "", nil}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations", nil} // FIXME document limit
errHTTPTooManyRequestsLimitSMS = &errHTTP{42910, http.StatusTooManyRequests, "limit reached: daily SMS quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPTooManyRequestsLimitCalls = &errHTTP{42911, http.StatusTooManyRequests, "limit reached: daily phone call quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil}

View file

@ -683,6 +683,10 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
return nil, errHTTPTooManyRequestsLimitMessages.With(t)
} else if email != "" && !vrate.EmailAllowed() {
return nil, errHTTPTooManyRequestsLimitEmails.With(t)
} else if sms != "" && !vrate.SMSAllowed() {
return nil, errHTTPTooManyRequestsLimitSMS.With(t)
} else if call != "" && !vrate.CallAllowed() {
return nil, errHTTPTooManyRequestsLimitCalls.With(t)
}
if m.PollID != "" {
m = newPollRequestMessage(t.ID, m.PollID)
@ -726,7 +730,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
if s.config.TwilioAccount != "" && sms != "" {
go s.sendSMS(v, r, m, sms)
}
if call != "" {
if s.config.TwilioAccount != "" && call != "" {
go s.callPhone(v, r, m, call)
}
if s.config.UpstreamBaseURL != "" {

View file

@ -224,11 +224,17 @@
# visitor-request-limit-exempt-hosts: ""
# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset
# every day at midnight UTC. If the limit is not set (or set to zero), the request
# limit (see above) governs the upper limit.
# every day at midnight UTC. If the limit is not set (or set to zero), the request limit (see above)
# governs the upper limit. SMS and calls are only supported if the twilio-settings are properly configured.
#
# visitor-message-daily-limit: 0
# Rate limiting: Daily limit of SMS and calls per visitor and day. The limit is reset every day
# at midnight UTC. SMS and calls are only supported if the twilio-settings are properly configured.
#
# visitor-sms-daily-limit: 10
# visitor-call-daily-limit: 10
# Rate limiting: Allowed emails per visitor:
# - visitor-email-limit-burst is the initial bucket of emails each visitor has
# - visitor-email-limit-replenish is the rate at which the bucket is refilled

View file

@ -56,6 +56,8 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Messages: limits.MessageLimit,
MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
Emails: limits.EmailLimit,
SMS: limits.SMSLimit,
Calls: limits.CallLimit,
Reservations: limits.ReservationsLimit,
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
AttachmentFileSize: limits.AttachmentFileSizeLimit,
@ -67,6 +69,10 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
MessagesRemaining: stats.MessagesRemaining,
Emails: stats.Emails,
EmailsRemaining: stats.EmailsRemaining,
SMS: stats.SMS,
SMSRemaining: stats.SMSRemaining,
Calls: stats.Calls,
CallsRemaining: stats.CallsRemaining,
Reservations: stats.Reservations,
ReservationsRemaining: stats.ReservationsRemaining,
AttachmentTotalSize: stats.AttachmentTotalSize,

View file

@ -287,6 +287,8 @@ type apiAccountLimits struct {
Messages int64 `json:"messages"`
MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
Emails int64 `json:"emails"`
SMS int64 `json:"sms"`
Calls int64 `json:"calls"`
Reservations int64 `json:"reservations"`
AttachmentTotalSize int64 `json:"attachment_total_size"`
AttachmentFileSize int64 `json:"attachment_file_size"`
@ -299,6 +301,10 @@ type apiAccountStats struct {
MessagesRemaining int64 `json:"messages_remaining"`
Emails int64 `json:"emails"`
EmailsRemaining int64 `json:"emails_remaining"`
SMS int64 `json:"sms"`
SMSRemaining int64 `json:"sms_remaining"`
Calls int64 `json:"calls"`
CallsRemaining int64 `json:"calls_remaining"`
Reservations int64 `json:"reservations"`
ReservationsRemaining int64 `json:"reservations_remaining"`
AttachmentTotalSize int64 `json:"attachment_total_size"`

View file

@ -56,6 +56,8 @@ type visitor struct {
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
messagesLimiter *util.FixedLimiter // Rate limiter for messages
emailsLimiter *util.RateLimiter // Rate limiter for emails
smsLimiter *util.FixedLimiter // Rate limiter for SMS
callsLimiter *util.FixedLimiter // Rate limiter for calls
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
@ -79,6 +81,8 @@ type visitorLimits struct {
EmailLimit int64
EmailLimitBurst int
EmailLimitReplenish rate.Limit
SMSLimit int64
CallLimit int64
ReservationsLimit int64
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
@ -91,6 +95,10 @@ type visitorStats struct {
MessagesRemaining int64
Emails int64
EmailsRemaining int64
SMS int64
SMSRemaining int64
Calls int64
CallsRemaining int64
Reservations int64
ReservationsRemaining int64
AttachmentTotalSize int64
@ -107,10 +115,12 @@ const (
)
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails int64
var messages, emails, sms, calls int64
if user != nil {
messages = user.Stats.Messages
emails = user.Stats.Emails
sms = user.Stats.SMS
calls = user.Stats.Calls
}
v := &visitor{
config: conf,
@ -124,11 +134,13 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
requestLimiter: nil, // Set in resetLimiters
messagesLimiter: nil, // Set in resetLimiters, may be nil
emailsLimiter: nil, // Set in resetLimiters
smsLimiter: nil, // Set in resetLimiters, may be nil
callsLimiter: nil, // Set in resetLimiters, may be nil
bandwidthLimiter: nil, // Set in resetLimiters
accountLimiter: nil, // Set in resetLimiters, may be nil
authLimiter: nil, // Set in resetLimiters, may be nil
}
v.resetLimitersNoLock(messages, emails, false)
v.resetLimitersNoLock(messages, emails, sms, calls, false)
return v
}
@ -147,12 +159,22 @@ func (v *visitor) contextNoLock() log.Context {
"visitor_messages": info.Stats.Messages,
"visitor_messages_limit": info.Limits.MessageLimit,
"visitor_messages_remaining": info.Stats.MessagesRemaining,
"visitor_emails": info.Stats.Emails,
"visitor_emails_limit": info.Limits.EmailLimit,
"visitor_emails_remaining": info.Stats.EmailsRemaining,
"visitor_request_limiter_limit": v.requestLimiter.Limit(),
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
}
if v.config.SMTPSenderFrom != "" {
fields["visitor_emails"] = info.Stats.Emails
fields["visitor_emails_limit"] = info.Limits.EmailLimit
fields["visitor_emails_remaining"] = info.Stats.EmailsRemaining
}
if v.config.TwilioAccount != "" {
fields["visitor_sms"] = info.Stats.SMS
fields["visitor_sms_limit"] = info.Limits.SMSLimit
fields["visitor_sms_remaining"] = info.Stats.SMSRemaining
fields["visitor_calls"] = info.Stats.Calls
fields["visitor_calls_limit"] = info.Limits.CallLimit
fields["visitor_calls_remaining"] = info.Stats.CallsRemaining
}
if v.authLimiter != nil {
fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
@ -216,6 +238,18 @@ func (v *visitor) EmailAllowed() bool {
return v.emailsLimiter.Allow()
}
func (v *visitor) SMSAllowed() bool {
v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock()
return v.smsLimiter.Allow()
}
func (v *visitor) CallAllowed() bool {
v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock()
return v.callsLimiter.Allow()
}
func (v *visitor) SubscriptionAllowed() bool {
v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock()
@ -296,6 +330,8 @@ func (v *visitor) Stats() *user.Stats {
return &user.Stats{
Messages: v.messagesLimiter.Value(),
Emails: v.emailsLimiter.Value(),
SMS: v.smsLimiter.Value(),
Calls: v.callsLimiter.Value(),
}
}
@ -304,6 +340,8 @@ func (v *visitor) ResetStats() {
defer v.mu.RUnlock()
v.emailsLimiter.Reset()
v.messagesLimiter.Reset()
v.smsLimiter.Reset()
v.callsLimiter.Reset()
}
// User returns the visitor user, or nil if there is none
@ -334,11 +372,11 @@ func (v *visitor) SetUser(u *user.User) {
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
v.user = u // u may be nil!
if shouldResetLimiters {
var messages, emails int64
var messages, emails, sms, calls int64
if u != nil {
messages, emails = u.Stats.Messages, u.Stats.Emails
messages, emails, sms, calls = u.Stats.Messages, u.Stats.Emails, u.Stats.SMS, u.Stats.Calls
}
v.resetLimitersNoLock(messages, emails, true)
v.resetLimitersNoLock(messages, emails, sms, calls, true)
}
}
@ -353,11 +391,13 @@ func (v *visitor) MaybeUserID() string {
return ""
}
func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) {
func (v *visitor) resetLimitersNoLock(messages, emails, sms, calls int64, enqueueUpdate bool) {
limits := v.limitsNoLock()
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages)
v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails)
v.smsLimiter = util.NewFixedLimiterWithValue(limits.SMSLimit, sms)
v.callsLimiter = util.NewFixedLimiterWithValue(limits.CallLimit, calls)
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
if v.user == nil {
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
@ -370,6 +410,8 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
go v.userManager.EnqueueUserStats(v.user.ID, &user.Stats{
Messages: messages,
Emails: emails,
SMS: sms,
Calls: calls,
})
}
log.Fields(v.contextNoLock()).Debug("Rate limiters reset for visitor") // Must be after function, because contextNoLock() describes rate limiters
@ -398,6 +440,8 @@ func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
EmailLimit: tier.EmailLimit,
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
SMSLimit: tier.SMSLimit,
CallLimit: tier.CallLimit,
ReservationsLimit: tier.ReservationLimit,
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
@ -420,6 +464,8 @@ func configBasedVisitorLimits(conf *Config) *visitorLimits {
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
EmailLimitBurst: conf.VisitorEmailLimitBurst,
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
SMSLimit: int64(conf.VisitorSMSDailyLimit),
CallLimit: int64(conf.VisitorCallDailyLimit),
ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
@ -465,12 +511,18 @@ func (v *visitor) Info() (*visitorInfo, error) {
func (v *visitor) infoLightNoLock() *visitorInfo {
messages := v.messagesLimiter.Value()
emails := v.emailsLimiter.Value()
sms := v.smsLimiter.Value()
calls := v.callsLimiter.Value()
limits := v.limitsNoLock()
stats := &visitorStats{
Messages: messages,
MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
Emails: emails,
EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
SMS: sms,
SMSRemaining: zeroIfNegative(limits.SMSLimit - sms),
Calls: calls,
CallsRemaining: zeroIfNegative(limits.CallLimit - calls),
}
return &visitorInfo{
Limits: limits,

View file

@ -55,6 +55,8 @@ const (
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
sms_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
@ -76,6 +78,8 @@ const (
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_sms INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
@ -123,26 +127,26 @@ const (
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -173,8 +177,8 @@ const (
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_sms = ?, stats_calls = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_sms = 0, stats_calls = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?`
@ -258,25 +262,25 @@ const (
`
insertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
updateTierQuery = `
UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, sms_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
WHERE code = ?
`
selectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
`
selectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
`
@ -293,7 +297,7 @@ const (
// Schema management queries
const (
currentSchemaVersion = 3
currentSchemaVersion = 4
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
@ -391,12 +395,21 @@ const (
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
// 3 -> 4
migrate3To4UpdateQueries = `
ALTER TABLE tier ADD COLUMN sms_limit INT NOT NULL DEFAULT (0);
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_sms INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
`
)
var (
migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
}
)
@ -700,9 +713,11 @@ func (a *Manager) writeUserStatsQueue() error {
"user_id": userID,
"messages_count": update.Messages,
"emails_count": update.Emails,
"sms_count": update.SMS,
"calls_count": update.Calls,
}).
Trace("Updating stats for user %s", userID)
if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil {
if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, update.SMS, update.Calls, userID); err != nil {
return err
}
}
@ -911,12 +926,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails int64
var messages, emails, sms, calls int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &sms, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -931,6 +946,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
Stats: &Stats{
Messages: messages,
Emails: emails,
SMS: sms,
Calls: calls,
},
Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, // May be empty
@ -1259,7 +1276,7 @@ func (a *Manager) AddTier(tier *Tier) error {
if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
}
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.SMSLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
return err
}
return nil
@ -1267,7 +1284,7 @@ func (a *Manager) AddTier(tier *Tier) error {
// UpdateTier updates a tier's properties in the database
func (a *Manager) UpdateTier(tier *Tier) error {
if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.SMSLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
return err
}
return nil
@ -1336,11 +1353,11 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
var messagesLimit, messagesExpiryDuration, emailsLimit, smsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &smsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -1353,6 +1370,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64,
SMSLimit: smsLimit.Int64,
CallLimit: callsLimit.Int64,
ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
@ -1495,6 +1514,22 @@ func migrateFrom2(db *sql.DB) error {
return tx.Commit()
}
func migrateFrom3(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}

View file

@ -86,6 +86,8 @@ type Tier struct {
MessageLimit int64 // Daily message limit
MessageExpiryDuration time.Duration // Cache duration for messages
EmailLimit int64 // Daily email limit
SMSLimit int64 // Daily SMS limit
CallLimit int64 // Daily phone call limit
ReservationLimit int64 // Number of topic reservations allowed by user
AttachmentFileSizeLimit int64 // Max file size per file (bytes)
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
@ -131,6 +133,8 @@ type NotificationPrefs struct {
type Stats struct {
Messages int64
Emails int64
SMS int64
Calls int64
}
// Billing is a struct holding a user's billing information