From e3b39f670ffd63815a22ec8e0521776925daf5f1 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 6 Feb 2023 22:38:22 -0500 Subject: [PATCH] WIP tier CLI --- cmd/access.go | 6 +- cmd/access_test.go | 8 +- cmd/tier.go | 289 +++++++++++++++++++++++++++++++++ server/server.go | 1 - server/server_account_test.go | 12 +- server/server_payments_test.go | 24 +-- server/server_test.go | 12 +- user/manager.go | 30 +++- user/manager_test.go | 6 +- util/util.go | 14 ++ 10 files changed, 367 insertions(+), 35 deletions(-) create mode 100644 cmd/tier.go diff --git a/cmd/access.go b/cmd/access.go index 40c84f2b..87f01d11 100644 --- a/cmd/access.go +++ b/cmd/access.go @@ -189,7 +189,11 @@ func showUsers(c *cli.Context, manager *user.Manager, users []*user.User) error if err != nil { return err } - fmt.Fprintf(c.App.ErrWriter, "user %s (%s)\n", u.Name, u.Role) + tier := "none" + if u.Tier != nil { + tier = u.Tier.Name + } + fmt.Fprintf(c.App.ErrWriter, "user %s (role: %s, tier: %s)\n", u.Name, u.Role, tier) if u.Role == user.RoleAdmin { fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n") } else if len(grants) > 0 { diff --git a/cmd/access_test.go b/cmd/access_test.go index 6e3c5ba3..6582fab0 100644 --- a/cmd/access_test.go +++ b/cmd/access_test.go @@ -15,7 +15,7 @@ func TestCLI_Access_Show(t *testing.T) { app, _, _, stderr := newTestApp() require.Nil(t, runAccessCommand(app, conf)) - require.Contains(t, stderr.String(), "user * (anonymous)\n- no topic-specific permissions\n- no access to any (other) topics (server config)") + require.Contains(t, stderr.String(), "user * (role: anonymous, tier: none)\n- no topic-specific permissions\n- no access to any (other) topics (server config)") } func TestCLI_Access_Grant_And_Publish(t *testing.T) { @@ -32,12 +32,12 @@ func TestCLI_Access_Grant_And_Publish(t *testing.T) { app, _, _, stderr := newTestApp() require.Nil(t, runAccessCommand(app, conf)) - expected := `user phil (admin) + expected := `user phil (role: admin, tier: none) - read-write access to all topics (admin role) -user ben (user) +user ben (role: user, tier: none) - read-write access to topic announcements - read-only access to topic sometopic -user * (anonymous) +user * (role: anonymous, tier: none) - read-only access to topic announcements - no access to any (other) topics (server config) ` diff --git a/cmd/tier.go b/cmd/tier.go new file mode 100644 index 00000000..4bddebf6 --- /dev/null +++ b/cmd/tier.go @@ -0,0 +1,289 @@ +//go:build !noserver + +package cmd + +import ( + "errors" + "fmt" + "github.com/urfave/cli/v2" + "heckel.io/ntfy/user" + "heckel.io/ntfy/util" + "time" +) + +func init() { + commands = append(commands, cmdTier) +} + +const ( + defaultMessageLimit = 5000 + defaultMessageExpiryDuration = 12 * time.Hour + defaultEmailLimit = 20 + defaultReservationLimit = 3 + defaultAttachmentFileSizeLimit = "15M" + defaultAttachmentTotalSizeLimit = "100M" + defaultAttachmentExpiryDuration = 6 * time.Hour + defaultAttachmentBandwidthLimit = "1G" +) + +var ( + flagsTier = append([]cli.Flag{}, flagsUser...) +) + +var cmdTier = &cli.Command{ + Name: "tier", + Usage: "Manage/show tiers", + UsageText: "ntfy tier [list|add|remove] ...", + Flags: flagsTier, + Before: initConfigFileInputSourceFunc("config", flagsUser, initLogFunc), + Category: categoryServer, + Subcommands: []*cli.Command{ + { + Name: "add", + Aliases: []string{"a"}, + Usage: "Adds a new tier", + UsageText: "ntfy tier add [OPTIONS] CODE", + Action: execTierAdd, + Flags: []cli.Flag{ + &cli.StringFlag{Name: "name", Usage: "tier name"}, + &cli.Int64Flag{Name: "message-limit", Value: defaultMessageLimit, Usage: "daily message limit"}, + &cli.DurationFlag{Name: "message-expiry-duration", Value: defaultMessageExpiryDuration, Usage: "duration after which messages are deleted"}, + &cli.Int64Flag{Name: "email-limit", Value: defaultEmailLimit, Usage: "daily email limit"}, + &cli.Int64Flag{Name: "reservation-limit", Value: defaultReservationLimit, Usage: "topic reservation limit"}, + &cli.StringFlag{Name: "attachment-file-size-limit", Value: defaultAttachmentFileSizeLimit, Usage: "per-attachment file size limit"}, + &cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"}, + &cli.DurationFlag{Name: "attachment-expiry-duration", Value: defaultAttachmentExpiryDuration, Usage: "duration after which attachments are deleted"}, + &cli.StringFlag{Name: "attachment-bandwidth-limit", Value: defaultAttachmentBandwidthLimit, Usage: "daily bandwidth limit for attachment uploads/downloads"}, + &cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"}, + }, + Description: ` +FIXME +`, + }, + { + Name: "change", + Aliases: []string{"ch"}, + Usage: "Change a tier", + UsageText: "ntfy tier change [OPTIONS] CODE", + Action: execTierChange, + Flags: []cli.Flag{ + &cli.StringFlag{Name: "name", Usage: "tier name"}, + &cli.Int64Flag{Name: "message-limit", Usage: "daily message limit"}, + &cli.DurationFlag{Name: "message-expiry-duration", Usage: "duration after which messages are deleted"}, + &cli.Int64Flag{Name: "email-limit", Usage: "daily email limit"}, + &cli.Int64Flag{Name: "reservation-limit", Usage: "topic reservation limit"}, + &cli.StringFlag{Name: "attachment-file-size-limit", Usage: "per-attachment file size limit"}, + &cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"}, + &cli.DurationFlag{Name: "attachment-expiry-duration", Usage: "duration after which attachments are deleted"}, + &cli.StringFlag{Name: "attachment-bandwidth-limit", Usage: "daily bandwidth limit for attachment uploads/downloads"}, + &cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"}, + }, + Description: ` +FIXME +`, + }, + { + Name: "remove", + Aliases: []string{"del", "rm"}, + Usage: "Removes a tier", + UsageText: "ntfy tier remove CODE", + Action: execTierDel, + Description: ` +FIXME +`, + }, + { + Name: "list", + Aliases: []string{"l"}, + Usage: "Shows a list of tiers", + Action: execTierList, + Description: ` +FIXME +`, + }, + }, + Description: `Manage tier of the ntfy server. + +The command allows you to add/remove/change tier in the ntfy user database. Tiers are used +to grant users higher limits based on their tier. + +This is a server-only command. It directly manages the user.db as defined in the server config +file server.yml. The command only works if 'auth-file' is properly defined. Please also refer +to the related command 'ntfy access'. + +FIXME + +`, +} + +func execTierAdd(c *cli.Context) error { + code := c.Args().Get(0) + if code == "" { + return errors.New("tier code expected, type 'ntfy tier add --help' for help") + } else if !user.AllowedTier(code) { + return errors.New("tier code must consist only of numbers and letters") + } + manager, err := createUserManager(c) + if err != nil { + return err + } + if tier, _ := manager.Tier(code); tier != nil { + return fmt.Errorf("tier %s already exists", code) + } + name := c.String("name") + if name == "" { + name = code + } + attachmentFileSizeLimit, err := util.ParseSize(c.String("attachment-file-size-limit")) + if err != nil { + return err + } + attachmentTotalSizeLimit, err := util.ParseSize(c.String("attachment-total-size-limit")) + if err != nil { + return err + } + attachmentBandwidthLimit, err := util.ParseSize(c.String("attachment-bandwidth-limit")) + if err != nil { + return err + } + tier := &user.Tier{ + ID: "", // Generated + Code: code, + Name: name, + MessageLimit: c.Int64("message-limit"), + MessageExpiryDuration: c.Duration("message-expiry-duration"), + EmailLimit: c.Int64("email-limit"), + ReservationLimit: c.Int64("reservation-limit"), + AttachmentFileSizeLimit: attachmentFileSizeLimit, + AttachmentTotalSizeLimit: attachmentTotalSizeLimit, + AttachmentExpiryDuration: c.Duration("attachment-expiry-duration"), + AttachmentBandwidthLimit: attachmentBandwidthLimit, + StripePriceID: c.String("stripe-price-id"), + } + if err := manager.AddTier(tier); err != nil { + return err + } + tier, err = manager.Tier(code) + if err != nil { + return err + } + fmt.Fprintf(c.App.ErrWriter, "tier added\n\n") + printTier(c, tier) + return nil +} + +func execTierChange(c *cli.Context) error { + code := c.Args().Get(0) + if code == "" { + return errors.New("tier code expected, type 'ntfy tier change --help' for help") + } else if !user.AllowedTier(code) { + return errors.New("tier code must consist only of numbers and letters") + } + manager, err := createUserManager(c) + if err != nil { + return err + } + tier, err := manager.Tier(code) + if err == user.ErrTierNotFound { + return fmt.Errorf("tier %s does not exist", code) + } else if err != nil { + return err + } + if c.IsSet("name") { + tier.Name = c.String("name") + } + if c.IsSet("message-limit") { + tier.MessageLimit = c.Int64("message-limit") + } + if c.IsSet("message-expiry-duration") { + tier.MessageExpiryDuration = c.Duration("message-expiry-duration") + } + if c.IsSet("email-limit") { + tier.EmailLimit = c.Int64("email-limit") + } + if c.IsSet("reservation-limit") { + tier.ReservationLimit = c.Int64("reservation-limit") + } + if c.IsSet("attachment-file-size-limit") { + tier.AttachmentFileSizeLimit, err = util.ParseSize(c.String("attachment-file-size-limit")) + if err != nil { + return err + } + } + if c.IsSet("attachment-total-size-limit") { + tier.AttachmentTotalSizeLimit, err = util.ParseSize(c.String("attachment-total-size-limit")) + if err != nil { + return err + } + } + if c.IsSet("attachment-expiry-duration") { + tier.AttachmentExpiryDuration = c.Duration("attachment-expiry-duration") + } + if c.IsSet("attachment-bandwidth-limit") { + tier.AttachmentBandwidthLimit, err = util.ParseSize(c.String("attachment-bandwidth-limit")) + if err != nil { + return err + } + } + if c.IsSet("stripe-price-id") { + tier.StripePriceID = c.String("stripe-price-id") + } + if err := manager.UpdateTier(tier); err != nil { + return err + } + fmt.Fprintf(c.App.ErrWriter, "tier updated\n\n") + printTier(c, tier) + return nil +} + +func execTierDel(c *cli.Context) error { + code := c.Args().Get(0) + if code == "" { + return errors.New("tier code expected, type 'ntfy tier del --help' for help") + } + manager, err := createUserManager(c) + if err != nil { + return err + } + if _, err := manager.Tier(code); err == user.ErrTierNotFound { + return fmt.Errorf("tier %s does not exist", code) + } + if err := manager.RemoveTier(code); err != nil { + return err + } + fmt.Fprintf(c.App.ErrWriter, "tier %s removed\n", code) + return nil +} + +func execTierList(c *cli.Context) error { + manager, err := createUserManager(c) + if err != nil { + return err + } + tiers, err := manager.Tiers() + if err != nil { + return err + } + for _, tier := range tiers { + printTier(c, tier) + } + return nil +} + +func printTier(c *cli.Context, tier *user.Tier) { + stripePriceID := tier.StripePriceID + if stripePriceID == "" { + stripePriceID = "(none)" + } + fmt.Fprintf(c.App.ErrWriter, "tier %s (id: %s)\n", tier.Code, tier.ID) + fmt.Fprintf(c.App.ErrWriter, "- Name: %s\n", tier.Name) + fmt.Fprintf(c.App.ErrWriter, "- Message limit: %d\n", tier.MessageLimit) + fmt.Fprintf(c.App.ErrWriter, "- Message expiry duration: %s (%d seconds)\n", tier.MessageExpiryDuration.String(), int64(tier.MessageExpiryDuration.Seconds())) + fmt.Fprintf(c.App.ErrWriter, "- Email limit: %d\n", tier.EmailLimit) + fmt.Fprintf(c.App.ErrWriter, "- Reservation limit: %d\n", tier.ReservationLimit) + fmt.Fprintf(c.App.ErrWriter, "- Attachment file size limit: %s\n", util.FormatSize(tier.AttachmentFileSizeLimit)) + fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit)) + fmt.Fprintf(c.App.ErrWriter, "- Attachment expiry duration: %s (%d seconds)\n", tier.AttachmentExpiryDuration.String(), int64(tier.AttachmentExpiryDuration.Seconds())) + fmt.Fprintf(c.App.ErrWriter, "- Attachment daily bandwidth limit: %s\n", util.FormatSize(tier.AttachmentBandwidthLimit)) + fmt.Fprintf(c.App.ErrWriter, "- Stripe price: %s\n", stripePriceID) +} diff --git a/server/server.go b/server/server.go index 3a5b5ffb..fe7abe0a 100644 --- a/server/server.go +++ b/server/server.go @@ -38,7 +38,6 @@ import ( - HIGH Account limit creation triggers when account is taken! - HIGH Docs - HIGH CLI "ntfy tier [add|list|delete]" -- HIGH CLI "ntfy user" should show tier - HIGH Self-review - MEDIUM: Test for expiring messages after reservation removal - MEDIUM: Test new token endpoints & never-expiring token diff --git a/server/server_account_test.go b/server/server_account_test.go index 2febd64e..dabe76b2 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -437,7 +437,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) { s := newTestServer(t, conf) // A user, an admin, and a reservation walk into a bar - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "pro", ReservationLimit: 2, })) @@ -493,7 +493,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) { require.Equal(t, 200, rr.Code) // Create a tier - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "pro", MessageLimit: 123, MessageExpiryDuration: 86400 * time.Second, @@ -575,7 +575,7 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) { rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) require.Equal(t, 200, rr.Code) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "pro", MessageLimit: 20, ReservationLimit: 2, @@ -610,7 +610,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) require.Equal(t, 200, rr.Code) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "pro", MessageLimit: 20, ReservationLimit: 2, @@ -689,11 +689,11 @@ func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) { // Create user with tier require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "starter", MessageLimit: 10, })) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "pro", MessageLimit: 20, })) diff --git a/server/server_payments_test.go b/server/server_payments_test.go index c1903812..7e2f0054 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -42,12 +42,12 @@ func TestPayments_Tiers(t *testing.T) { }, nil) // Create tiers - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_1", Code: "admin", Name: "Admin", })) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_123", Code: "pro", Name: "Pro", @@ -60,7 +60,7 @@ func TestPayments_Tiers(t *testing.T) { AttachmentExpiryDuration: time.Minute, StripePriceID: "price_123", })) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_444", Code: "business", Name: "Business", @@ -135,7 +135,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) { Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil) // Create tier and user - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_123", Code: "pro", StripePriceID: "price_123", @@ -171,7 +171,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil) // Create tier and user - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_123", Code: "pro", StripePriceID: "price_123", @@ -213,7 +213,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { Return(&stripe.Subscription{}, nil) // Create tier and user - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_123", Code: "pro", StripePriceID: "price_123", @@ -264,7 +264,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes s.stripe = stripeMock // Create a user with a Stripe subscription and 3 reservations - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_123", Code: "starter", StripePriceID: "price_1234", @@ -420,7 +420,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil) // Create a user with a Stripe subscription and 3 reservations - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_1", Code: "starter", StripePriceID: "price_1234", // ! @@ -432,7 +432,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( AttachmentTotalSizeLimit: 1000000, AttachmentBandwidthLimit: 1000000, })) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_2", Code: "pro", StripePriceID: "price_1111", // ! @@ -545,7 +545,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) { Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil) // Create a user with a Stripe subscription and 3 reservations - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_1", Code: "pro", StripePriceID: "price_1234", @@ -626,12 +626,12 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) { Return(&stripe.Subscription{}, nil) // Create tier and user - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_123", Code: "pro", StripePriceID: "price_123", })) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_456", Code: "business", StripePriceID: "price_456", diff --git a/server/server_test.go b/server/server_test.go index 4119e483..711d45d8 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -761,7 +761,7 @@ func TestServer_StatsResetter(t *testing.T) { go s.runStatsResetter() // Create user with tier (tieruser) and user without tier (phil) - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "test", MessageLimit: 5, MessageExpiryDuration: -5 * time.Second, // Second, what a hack! @@ -898,7 +898,7 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) { s := newTestServer(t, c) // Create user, and update it with some message and email stats - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "test", })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) @@ -1275,7 +1275,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) { s := newTestServer(t, c) // Create tier with certain limits - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "test", MessageLimit: 5, MessageExpiryDuration: -5 * time.Second, // Second, what a hack! @@ -1504,7 +1504,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) { // Create tier with certain limits sevenDays := time.Duration(604800) * time.Second - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "test", MessageLimit: 10, MessageExpiryDuration: sevenDays, @@ -1549,7 +1549,7 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) { s := newTestServer(t, c) // Create tier with certain limits - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "test", MessageLimit: 10, MessageExpiryDuration: time.Hour, @@ -1588,7 +1588,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) { s := newTestServer(t, c) // Create tier with certain limits - require.Nil(t, s.userManager.CreateTier(&user.Tier{ + require.Nil(t, s.userManager.AddTier(&user.Tier{ Code: "test", MessageLimit: 100, AttachmentFileSizeLimit: 50_000, diff --git a/user/manager.go b/user/manager.go index a83974e6..c9883774 100644 --- a/user/manager.go +++ b/user/manager.go @@ -248,6 +248,11 @@ const ( INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` + updateTierQuery = ` + UPDATE tier + SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ? + WHERE code = ? + ` selectTiersQuery = ` SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id FROM tier @@ -264,6 +269,7 @@ const ( ` updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` + deleteTierQuery = `DELETE FROM tier WHERE code = ?` updateBillingQuery = ` UPDATE user @@ -1116,8 +1122,8 @@ func (a *Manager) DefaultAccess() Permission { return a.defaultAccess } -// CreateTier creates a new tier in the database -func (a *Manager) CreateTier(tier *Tier) error { +// AddTier creates a new tier in the database +func (a *Manager) AddTier(tier *Tier) error { if tier.ID == "" { tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) } @@ -1127,6 +1133,26 @@ func (a *Manager) CreateTier(tier *Tier) error { return nil } +// UpdateTier updates a tier's properties in the database +func (a *Manager) UpdateTier(tier *Tier) error { + if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID), tier.Code); err != nil { + return err + } + return nil +} + +// RemoveTier deletes the tier with the given code +func (a *Manager) RemoveTier(code string) error { + if !AllowedTier(code) { + return ErrInvalidArgument + } + // This fails if any user has this tier + if _, err := a.db.Exec(deleteTierQuery, code); err != nil { + return err + } + return nil +} + // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information func (a *Manager) ChangeBilling(username string, billing *Billing) error { if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { diff --git a/user/manager_test.go b/user/manager_test.go index c3bdb145..8e22ac23 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -333,7 +333,7 @@ func TestManager_Reservations(t *testing.T) { func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.CreateTier(&Tier{ + require.Nil(t, a.AddTier(&Tier{ Code: "pro", Name: "ntfy Pro", StripePriceID: "price123", @@ -629,7 +629,7 @@ func TestManager_Tier_Create(t *testing.T) { a := newTestManager(t, PermissionDenyAll) // Create tier and user - require.Nil(t, a.CreateTier(&Tier{ + require.Nil(t, a.AddTier(&Tier{ Code: "pro", Name: "Pro", MessageLimit: 123, @@ -670,7 +670,7 @@ func TestManager_Tier_Create(t *testing.T) { func TestAccount_Tier_Create_With_ID(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.CreateTier(&Tier{ + require.Nil(t, a.AddTier(&Tier{ ID: "ti_123", Code: "pro", })) diff --git a/util/util.go b/util/util.go index 20baed56..3e0c1064 100644 --- a/util/util.go +++ b/util/util.go @@ -222,6 +222,20 @@ func ParseSize(s string) (int64, error) { } } +// FormatSize formats bytes into a human-readable notation, e.g. 2.1 MB +func FormatSize(b int64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d bytes", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp]) +} + // ReadPassword will read a password from STDIN. If the terminal supports it, it will not print the // input characters to the screen. If not, it'll just read using normal readline semantics (useful for testing). func ReadPassword(in io.Reader) ([]byte, error) {