Add bandwidth limit to tier; fix display name sync issues

This commit is contained in:
binwiederhier 2023-01-25 10:05:54 -05:00
parent 1771cb3fdb
commit 236254d907
13 changed files with 119 additions and 51 deletions

View file

@ -285,7 +285,7 @@ func execServe(c *cli.Context) error {
conf.TotalTopicLimit = totalTopicLimit conf.TotalTopicLimit = totalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorAttachmentTotalSizeLimit = visitorAttachmentTotalSizeLimit conf.VisitorAttachmentTotalSizeLimit = visitorAttachmentTotalSizeLimit
conf.VisitorAttachmentDailyBandwidthLimit = int(visitorAttachmentDailyBandwidthLimit) conf.VisitorAttachmentDailyBandwidthLimit = visitorAttachmentDailyBandwidthLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs

View file

@ -101,7 +101,7 @@ type Config struct {
TotalAttachmentSizeLimit int64 TotalAttachmentSizeLimit int64
VisitorSubscriptionLimit int VisitorSubscriptionLimit int
VisitorAttachmentTotalSizeLimit int64 VisitorAttachmentTotalSizeLimit int64
VisitorAttachmentDailyBandwidthLimit int VisitorAttachmentDailyBandwidthLimit int64
VisitorRequestLimitBurst int VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration VisitorRequestLimitReplenish time.Duration
VisitorRequestExemptIPAddrs []netip.Prefix VisitorRequestExemptIPAddrs []netip.Prefix

View file

@ -40,7 +40,6 @@ TODO
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Rate limiting: Bandwidth limit must be in tier + TESTS
- MEDIUM: Races with v.user (see publishSyncEventAsync test) - MEDIUM: Races with v.user (see publishSyncEventAsync test)
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@ -866,7 +865,6 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit), util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining), util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
} }
fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...) m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
if err == util.ErrLimitReached { if err == util.ErrLimitReached {
return errHTTPEntityTooLargeAttachment return errHTTPEntityTooLargeAttachment

View file

@ -11,6 +11,7 @@ import (
const ( const (
subscriptionIDLength = 16 subscriptionIDLength = 16
subscriptionIDPrefix = "su_"
syncTopicAccountSyncEvent = "sync" syncTopicAccountSyncEvent = "sync"
) )
@ -55,6 +56,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
AttachmentTotalSize: limits.AttachmentTotalSizeLimit, AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
AttachmentFileSize: limits.AttachmentFileSizeLimit, AttachmentFileSize: limits.AttachmentFileSizeLimit,
AttachmentExpiryDuration: int64(limits.AttachmentExpiryDuration.Seconds()), AttachmentExpiryDuration: int64(limits.AttachmentExpiryDuration.Seconds()),
AttachmentBandwidth: limits.AttachmentBandwidthLimit,
}, },
Stats: &apiAccountStats{ Stats: &apiAccountStats{
Messages: stats.Messages, Messages: stats.Messages,
@ -249,7 +251,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
} }
} }
if newSubscription.ID == "" { if newSubscription.ID == "" {
newSubscription.ID = util.RandomString(subscriptionIDLength) newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription) v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription)
if err := s.userManager.ChangeSettings(v.user); err != nil { if err := s.userManager.ChangeSettings(v.user); err != nil {
return err return err

View file

@ -153,9 +153,9 @@ func TestAccount_ChangeSettings(t *testing.T) {
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
require.Equal(t, "de", account.Language) require.Equal(t, "de", account.Language)
require.Equal(t, 86400, account.Notification.DeleteAfter) require.Equal(t, util.Int(86400), account.Notification.DeleteAfter)
require.Equal(t, "juntos", account.Notification.Sound) require.Equal(t, util.String("juntos"), account.Notification.Sound)
require.Equal(t, 0, account.Notification.MinPriority) // Not set require.Nil(t, account.Notification.MinPriority) // Not set
} }
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
@ -176,7 +176,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
require.NotEmpty(t, account.Subscriptions[0].ID) require.NotEmpty(t, account.Subscriptions[0].ID)
require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL) require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL)
require.Equal(t, "def", account.Subscriptions[0].Topic) require.Equal(t, "def", account.Subscriptions[0].Topic)
require.Equal(t, "", account.Subscriptions[0].DisplayName) require.Nil(t, account.Subscriptions[0].DisplayName)
subscriptionID := account.Subscriptions[0].ID subscriptionID := account.Subscriptions[0].ID
rr = request(t, s, "PATCH", "/v1/account/subscription/"+subscriptionID, `{"display_name": "ding dong"}`, map[string]string{ rr = request(t, s, "PATCH", "/v1/account/subscription/"+subscriptionID, `{"display_name": "ding dong"}`, map[string]string{
@ -193,7 +193,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
require.Equal(t, subscriptionID, account.Subscriptions[0].ID) require.Equal(t, subscriptionID, account.Subscriptions[0].ID)
require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL) require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL)
require.Equal(t, "def", account.Subscriptions[0].Topic) require.Equal(t, "def", account.Subscriptions[0].Topic)
require.Equal(t, "ding dong", account.Subscriptions[0].DisplayName) require.Equal(t, util.String("ding dong"), account.Subscriptions[0].DisplayName)
rr = request(t, s, "DELETE", "/v1/account/subscription/"+subscriptionID, "", map[string]string{ rr = request(t, s, "DELETE", "/v1/account/subscription/"+subscriptionID, "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -402,6 +402,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
AttachmentFileSizeLimit: 1231231, AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123, AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second, AttachmentExpiryDuration: 10800 * time.Second,
AttachmentBandwidthLimit: 21474836480,
})) }))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -442,6 +443,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
require.Equal(t, int64(1231231), account.Limits.AttachmentFileSize) require.Equal(t, int64(1231231), account.Limits.AttachmentFileSize)
require.Equal(t, int64(123123), account.Limits.AttachmentTotalSize) require.Equal(t, int64(123123), account.Limits.AttachmentTotalSize)
require.Equal(t, int64(10800), account.Limits.AttachmentExpiryDuration) require.Equal(t, int64(10800), account.Limits.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), account.Limits.AttachmentBandwidth)
require.Equal(t, 2, len(account.Reservations)) require.Equal(t, 2, len(account.Reservations))
require.Equal(t, "another", account.Reservations[0].Topic) require.Equal(t, "another", account.Reservations[0].Topic)
require.Equal(t, "write-only", account.Reservations[0].Everyone) require.Equal(t, "write-only", account.Reservations[0].Everyone)

View file

@ -265,6 +265,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000, AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000,
AttachmentBandwidthLimit: 1000000,
})) }))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro", Code: "pro",
@ -275,6 +276,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000, AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000,
AttachmentBandwidthLimit: 1000000,
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))

View file

@ -1368,6 +1368,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
AttachmentFileSizeLimit: 50_000, AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000, AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: sevenDays, // 7 days AttachmentExpiryDuration: sevenDays, // 7 days
AttachmentBandwidthLimit: 100000,
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test")) require.Nil(t, s.userManager.ChangeTier("phil", "test"))
@ -1376,6 +1377,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
response := request(t, s, "PUT", "/mytopic", content, map[string]string{ response := request(t, s, "PUT", "/mytopic", content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 200, response.Code)
msg := toMessage(t, response.Body.String()) msg := toMessage(t, response.Body.String())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.True(t, msg.Attachment.Expires > time.Now().Add(sevenDays-30*time.Second).Unix()) require.True(t, msg.Attachment.Expires > time.Now().Add(sevenDays-30*time.Second).Unix())
@ -1396,6 +1398,46 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
} }
func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
content := util.RandomString(5000) // > 4096
c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c)
// Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessagesLimit: 10,
MessagesExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: time.Hour,
AttachmentBandwidthLimit: 14000, // < 3x5000 bytes -> enough for one upload, one download
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish and make sure we can retrieve it
rr := request(t, s, "PUT", "/mytopic", content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
msg := toMessage(t, rr.Body.String())
// Retrieve it (first time succeeds)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
require.Equal(t, content, rr.Body.String())
// Retrieve it AGAIN (fails, due to bandwidth limit)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
}
func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) { func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
smallFile := util.RandomString(20_000) smallFile := util.RandomString(20_000)
largeFile := util.RandomString(50_000) largeFile := util.RandomString(50_000)
@ -1412,6 +1454,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
AttachmentFileSizeLimit: 50_000, AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000, AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: 30 * time.Second, AttachmentExpiryDuration: 30 * time.Second,
AttachmentBandwidthLimit: 1000000,
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test")) require.Nil(t, s.userManager.ChangeTier("phil", "test"))

View file

@ -246,7 +246,7 @@ type apiAccountTier struct {
} }
type apiAccountLimits struct { type apiAccountLimits struct {
Basis string `json:"basis,omitempty"` // "ip", "role" or "tier" Basis string `json:"basis,omitempty"` // "ip" or "tier"
Messages int64 `json:"messages"` Messages int64 `json:"messages"`
MessagesExpiryDuration int64 `json:"messages_expiry_duration"` MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
Emails int64 `json:"emails"` Emails int64 `json:"emails"`
@ -254,6 +254,7 @@ type apiAccountLimits struct {
AttachmentTotalSize int64 `json:"attachment_total_size"` AttachmentTotalSize int64 `json:"attachment_total_size"`
AttachmentFileSize int64 `json:"attachment_file_size"` AttachmentFileSize int64 `json:"attachment_file_size"`
AttachmentExpiryDuration int64 `json:"attachment_expiry_duration"` AttachmentExpiryDuration int64 `json:"attachment_expiry_duration"`
AttachmentBandwidth int64 `json:"attachment_bandwidth"`
} }
type apiAccountStats struct { type apiAccountStats struct {

View file

@ -31,9 +31,9 @@ var (
type visitor struct { type visitor struct {
config *Config config *Config
messageCache *messageCache messageCache *messageCache
userManager *user.Manager // May be nil! userManager *user.Manager // May be nil
ip netip.Addr ip netip.Addr // Visitor IP address
user *user.User user *user.User // Only set if authenticated user, otherwise nil
messages int64 // Number of messages sent, reset every day messages int64 // Number of messages sent, reset every day
emails int64 // Number of emails sent, reset every day emails int64 // Number of emails sent, reset every day
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
@ -61,6 +61,7 @@ type visitorLimits struct {
AttachmentTotalSizeLimit int64 AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64
AttachmentExpiryDuration time.Duration AttachmentExpiryDuration time.Duration
AttachmentBandwidthLimit int64
} }
type visitorStats struct { type visitorStats struct {
@ -84,7 +85,7 @@ const (
) )
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messagesLimiter util.Limiter var messagesLimiter, attachmentBandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
var messages, emails int64 var messages, emails int64
if user != nil { if user != nil {
@ -97,9 +98,11 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst) requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit) messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst) emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
} else { } else {
requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst) requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst) emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
} }
return &visitor{ return &visitor{
config: conf, config: conf,
@ -113,7 +116,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
messagesLimiter: messagesLimiter, // May be nil messagesLimiter: messagesLimiter, // May be nil
emailsLimiter: emailsLimiter, emailsLimiter: emailsLimiter,
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidthLimiter: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour), bandwidthLimiter: attachmentBandwidthLimiter,
accountLimiter: accountLimiter, // May be nil accountLimiter: accountLimiter, // May be nil
firebase: time.Unix(0, 0), firebase: time.Unix(0, 0),
seen: time.Now(), seen: time.Now(),
@ -259,6 +262,7 @@ func (v *visitor) Limits() *visitorLimits {
limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit
limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit
limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration
limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit
} }
return limits return limits
} }
@ -327,5 +331,6 @@ func defaultVisitorLimits(conf *Config) *visitorLimits {
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration, AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
} }
} }

View file

@ -52,6 +52,7 @@ const (
attachment_file_size_limit INT NOT NULL, attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL, attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT stripe_price_id TEXT
); );
CREATE UNIQUE INDEX idx_tier_code ON tier (code); CREATE UNIQUE INDEX idx_tier_code ON tier (code);
@ -109,26 +110,26 @@ const (
` `
selectUserByIDQuery = ` selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ? WHERE u.id = ?
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ? WHERE t.token = ? AND t.expires >= ?
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
@ -232,20 +233,20 @@ const (
` `
insertTierQuery = ` insertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id) INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
selectTiersQuery = ` selectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
FROM tier FROM tier
` `
selectTierByCodeQuery = ` selectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
FROM tier FROM tier
WHERE code = ? WHERE code = ?
` `
selectTierByPriceIDQuery = ` selectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
FROM tier FROM tier
WHERE stripe_price_id = ? WHERE stripe_price_id = ?
` `
@ -670,11 +671,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
var id, username, hash, role, prefs, syncTopic string var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -714,6 +715,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty StripePriceID: stripePriceID.String, // May be empty
} }
} }
@ -994,7 +996,7 @@ func (a *Manager) DefaultAccess() Permission {
// CreateTier creates a new tier in the database // CreateTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error { func (a *Manager) CreateTier(tier *Tier) error {
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength) tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil { if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
return err return err
} }
return nil return nil
@ -1051,11 +1053,11 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string var id, code, name string
var stripePriceID sql.NullString var stripePriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrTierNotFound return nil, ErrTierNotFound
} }
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -1072,6 +1074,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty StripePriceID: stripePriceID.String, // May be empty
}, nil }, nil
} }

View file

@ -3,6 +3,7 @@ package user
import ( import (
"database/sql" "database/sql"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/util"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
@ -583,21 +584,21 @@ func TestManager_ChangeSettings(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, u.Prefs.Subscriptions) require.Nil(t, u.Prefs.Subscriptions)
require.Nil(t, u.Prefs.Notification) require.Nil(t, u.Prefs.Notification)
require.Equal(t, "", u.Prefs.Language) require.Nil(t, u.Prefs.Language)
// Save with new settings // Save with new settings
u.Prefs = &Prefs{ u.Prefs = &Prefs{
Language: "de", Language: util.String("de"),
Notification: &NotificationPrefs{ Notification: &NotificationPrefs{
Sound: "ding", Sound: util.String("ding"),
MinPriority: 2, MinPriority: util.Int(2),
}, },
Subscriptions: []*Subscription{ Subscriptions: []*Subscription{
{ {
ID: "someID", ID: "someID",
BaseURL: "https://ntfy.sh", BaseURL: "https://ntfy.sh",
Topic: "mytopic", Topic: "mytopic",
DisplayName: "My Topic", DisplayName: util.String("My Topic"),
}, },
}, },
} }
@ -606,14 +607,14 @@ func TestManager_ChangeSettings(t *testing.T) {
// Read again // Read again
u, err = a.User("ben") u, err = a.User("ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "de", u.Prefs.Language) require.Equal(t, util.String("de"), u.Prefs.Language)
require.Equal(t, "ding", u.Prefs.Notification.Sound) require.Equal(t, util.String("ding"), u.Prefs.Notification.Sound)
require.Equal(t, 2, u.Prefs.Notification.MinPriority) require.Equal(t, util.Int(2), u.Prefs.Notification.MinPriority)
require.Equal(t, 0, u.Prefs.Notification.DeleteAfter) require.Nil(t, u.Prefs.Notification.DeleteAfter)
require.Equal(t, "someID", u.Prefs.Subscriptions[0].ID) require.Equal(t, "someID", u.Prefs.Subscriptions[0].ID)
require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL) require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL)
require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic) require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic)
require.Equal(t, "My Topic", u.Prefs.Subscriptions[0].DisplayName) require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
} }
func TestSqliteCache_Migration_From1(t *testing.T) { func TestSqliteCache_Migration_From1(t *testing.T) {

View file

@ -50,17 +50,18 @@ type Prefs struct {
// Tier represents a user's account type, including its account limits // Tier represents a user's account type, including its account limits
type Tier struct { type Tier struct {
ID string ID string // Tier identifier (ti_...)
Code string Code string // Code of the tier
Name string Name string // Name of the tier
MessagesLimit int64 MessagesLimit int64 // Daily message limit
MessagesExpiryDuration time.Duration MessagesExpiryDuration time.Duration // Cache duration for messages
EmailsLimit int64 EmailsLimit int64 // Daily email limit
ReservationsLimit int64 ReservationsLimit int64 // Number of topic reservations allowed by user
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64 // Max file size per file (bytes)
AttachmentTotalSizeLimit int64 AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
AttachmentExpiryDuration time.Duration AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
StripePriceID string AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user
StripePriceID string // Price ID for paid tiers (price_...)
} }
// Subscription represents a user's topic subscription // Subscription represents a user's topic subscription

View file

@ -336,3 +336,13 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
} }
return nil, err return nil, err
} }
// String turns a string into a pointer of a string
func String(v string) *string {
return &v
}
// Int turns a string into a pointer of an int
func Int(v int) *int {
return &v
}