From 83de87989407b7ec63b33d7e253e4ed6ff2d6650 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 16 Jan 2023 16:35:37 -0500 Subject: [PATCH] publishSyncEvent, Stripe endpoint changes --- cmd/serve.go | 22 ++-- server/config.go | 2 +- server/server.go | 72 +++-------- server/server.yml | 12 +- server/server_account.go | 43 ++++++- server/server_middleware.go | 63 ++++++++++ server/server_payments.go | 188 ++++++++++++++-------------- server/types.go | 6 +- user/manager.go | 27 ++-- web/public/static/langs/en.json | 6 +- web/src/app/AccountApi.js | 16 ++- web/src/components/Account.js | 143 ++++++++++++--------- web/src/components/Navigation.js | 14 ++- web/src/components/UpgradeDialog.js | 72 ++++++++--- 14 files changed, 424 insertions(+), 262 deletions(-) create mode 100644 server/server_middleware.go diff --git a/cmd/serve.go b/cmd/serve.go index 8fd97ae2..631ef38b 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -80,8 +80,8 @@ var flagsServe = append( altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), - altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}), - altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-secret-key", Aliases: []string{"stripe_secret_key"}, EnvVars: []string{"NTFY_STRIPE_SECRET_KEY"}, Value: "", Usage: "key used for the Stripe API communication, this enables payments"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "key required to validate the authenticity of incoming webhooks from Stripe"}), ) var cmdServe = &cli.Command{ @@ -153,7 +153,7 @@ func execServe(c *cli.Context) error { visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") behindProxy := c.Bool("behind-proxy") - stripeKey := c.String("stripe-key") + stripeSecretKey := c.String("stripe-secret-key") stripeWebhookKey := c.String("stripe-webhook-key") // Check values @@ -191,17 +191,17 @@ func execServe(c *cli.Context) error { return errors.New("if upstream-base-url is set, base-url must also be set") } else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL { return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications") - } else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") { - return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set") + } else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeSecretKey != "") { + return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-secret-key if auth-file is not set") } else if enableSignup && !enableLogin { return errors.New("cannot set enable-signup without also setting enable-login") - } else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") { - return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set") + } else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") { + return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set") } webRootIsApp := webRoot == "app" enableWeb := webRoot != "disable" - enablePayments := stripeKey != "" + enablePayments := stripeSecretKey != "" // Default auth permissions authDefault, err := user.ParsePermission(authDefaultAccess) @@ -246,8 +246,8 @@ func execServe(c *cli.Context) error { } // Stripe things - if stripeKey != "" { - stripe.Key = stripeKey + if stripeSecretKey != "" { + stripe.Key = stripeSecretKey } // Run server @@ -293,7 +293,7 @@ func execServe(c *cli.Context) error { conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.BehindProxy = behindProxy - conf.StripeKey = stripeKey + conf.StripeSecretKey = stripeSecretKey conf.StripeWebhookKey = stripeWebhookKey conf.EnableWeb = enableWeb conf.EnableSignup = enableSignup diff --git a/server/config.go b/server/config.go index c63f2e37..c1e67f62 100644 --- a/server/config.go +++ b/server/config.go @@ -110,7 +110,7 @@ type Config struct { VisitorAccountCreateLimitReplenish time.Duration VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats BehindProxy bool - StripeKey string + StripeSecretKey string StripeWebhookKey string EnableWeb bool EnableSignup bool // Enable creation of accounts via API and UI diff --git a/server/server.go b/server/server.go index c8d94400..8f4370c5 100644 --- a/server/server.go +++ b/server/server.go @@ -40,12 +40,10 @@ import ( - send dunning emails when overdue - payment methods - unmarshal to stripe.Subscription instead of gjson - - Make ResetTier reset the stripe fields - delete subscription when account deleted - - remove tier.paid - add tier.visible - fix tier selection boxes - - account sync after switching tiers + - delete messages + reserved topics on ResetTier Limits & rate limiting: users without tier: should the stats be persisted? are they meaningful? @@ -360,7 +358,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit } else if r.Method == http.MethodGet && r.URL.Path == accountPath { return s.handleAccountGet(w, r, v) // Allowed by anonymous } else if r.Method == http.MethodDelete && r.URL.Path == accountPath { - return s.ensureUser(s.handleAccountDelete)(w, r, v) + return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath { return s.ensureUser(s.handleAccountPasswordChange)(w, r, v) } else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath { @@ -368,27 +366,29 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit } else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath { return s.ensureUser(s.handleAccountTokenDelete)(w, r, v) } else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath { - return s.ensureUser(s.handleAccountSettingsChange)(w, r, v) + return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath { - return s.ensureUser(s.handleAccountSubscriptionAdd)(w, r, v) + return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v) } else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) { - return s.ensureUser(s.handleAccountSubscriptionChange)(w, r, v) + return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v) } else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) { - return s.ensureUser(s.handleAccountSubscriptionDelete)(w, r, v) + return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountReservationPath { - return s.ensureUser(s.handleAccountReservationAdd)(w, r, v) + return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v) } else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) { - return s.ensureUser(s.handleAccountReservationDelete)(w, r, v) + return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath { - return s.ensureUser(s.handleAccountBillingSubscriptionChange)(w, r, v) - } else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath { - return s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete)(w, r, v) + return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook } else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) { - return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context! + return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context! + } else if r.Method == http.MethodPut && r.URL.Path == accountBillingSubscriptionPath { + return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook + } else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath { + return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook } else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath { - return s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate)(w, r, v) + return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath { - return s.ensureUserManager(s.handleAccountBillingWebhook)(w, r, v) + return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { return s.handleMatrixDiscovery(w) } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { @@ -1423,12 +1423,12 @@ func (s *Server) sendDelayedMessages() error { for _, m := range messages { var v *visitor if s.userManager != nil && m.User != "" { - user, err := s.userManager.User(m.User) + u, err := s.userManager.User(m.User) if err != nil { log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) continue } - v = s.visitorFromUser(user, m.Sender) + v = s.visitorFromUser(u, m.Sender) } else { v = s.visitorFromIP(m.Sender) } @@ -1475,42 +1475,6 @@ func (s *Server) limitRequests(next handleFunc) handleFunc { } } -func (s *Server) ensureWebEnabled(next handleFunc) handleFunc { - return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if !s.config.EnableWeb { - return errHTTPNotFound - } - return next(w, r, v) - } -} - -func (s *Server) ensureUserManager(next handleFunc) handleFunc { - return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if s.userManager == nil { - return errHTTPNotFound - } - return next(w, r, v) - } -} - -func (s *Server) ensureUser(next handleFunc) handleFunc { - return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user == nil { - return errHTTPUnauthorized - } - return next(w, r, v) - }) -} - -func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { - return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing.StripeCustomerID == "" { - return errHTTPBadRequestNotAPaidUser - } - return next(w, r, v) - }) -} - // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers // before passing it on to the next handler. This is meant to be used in combination with handlePublish. func (s *Server) transformBodyJSON(next handleFunc) handleFunc { diff --git a/server/server.yml b/server/server.yml index 37256231..51e29c41 100644 --- a/server/server.yml +++ b/server/server.yml @@ -164,12 +164,10 @@ # - enable-signup allows users to sign up via the web app, or API # - enable-login allows users to log in via the web app, or API # - enable-reservations allows users to reserve topics (if their tier allows it) -# - enable-payments enables payments integration [preliminary option, may change] # # enable-signup: false # enable-login: false # enable-reservations: false -# enable-payments: false # Server URL of a Firebase/APNS-connected ntfy server (likely "https://ntfy.sh"). # @@ -216,6 +214,16 @@ # visitor-attachment-total-size-limit: "100M" # visitor-attachment-daily-bandwidth-limit: "500M" +# Payments integration via Stripe +# +# - stripe-secret-key is the key used for the Stripe API communication. Setting this values +# enables payments in the ntfy web app (e.g. Upgrade dialog). See https://dashboard.stripe.com/apikeys. +# - stripe-webhook-key is the key required to validate the authenticity of incoming webhooks from Stripe. +# Webhooks are essential up keep the local database in sync with the payment provider. See https://dashboard.stripe.com/webhooks. +# +# stripe-secret-key: +# stripe-webhook-key: + # Log level, can be TRACE, DEBUG, INFO, WARN or ERROR # This option can be hot-reloaded by calling "kill -HUP $pid" or "systemctl reload ntfy". # diff --git a/server/server_account.go b/server/server_account.go index 66250db1..9159ea47 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -2,15 +2,18 @@ package server import ( "encoding/json" + "errors" + "heckel.io/ntfy/log" "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" ) const ( - jsonBodyBytesLimit = 4096 - subscriptionIDLength = 16 - createdByAPI = "api" + jsonBodyBytesLimit = 4096 + subscriptionIDLength = 16 + createdByAPI = "api" + syncTopicAccountSyncEvent = "sync" ) func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -395,3 +398,37 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this return nil } + +func (s *Server) publishSyncEvent(v *visitor) error { + if v.user == nil || v.user.SyncTopic == "" { + return nil + } + log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic) + topics, err := s.topicsFromIDs(v.user.SyncTopic) + if err != nil { + return err + } else if len(topics) == 0 { + return errors.New("cannot retrieve sync topic") + } + syncTopic := topics[0] + messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent}) + if err != nil { + return err + } + m := newDefaultMessage(syncTopic.ID, string(messageBytes)) + if err := syncTopic.Publish(v, m); err != nil { + return err + } + return nil +} + +func (s *Server) publishSyncEventAsync(v *visitor) { + go func() { + if v.user == nil || v.user.SyncTopic == "" { + return + } + if err := s.publishSyncEvent(v); err != nil { + log.Trace("Error publishing to user %s's sync topic %s: %s", v.user.Name, v.user.SyncTopic, err.Error()) + } + }() +} diff --git a/server/server_middleware.go b/server/server_middleware.go new file mode 100644 index 00000000..33f784cb --- /dev/null +++ b/server/server_middleware.go @@ -0,0 +1,63 @@ +package server + +import ( + "net/http" +) + +func (s *Server) ensureWebEnabled(next handleFunc) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if !s.config.EnableWeb { + return errHTTPNotFound + } + return next(w, r, v) + } +} + +func (s *Server) ensureUserManager(next handleFunc) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if s.userManager == nil { + return errHTTPNotFound + } + return next(w, r, v) + } +} + +func (s *Server) ensureUser(next handleFunc) handleFunc { + return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user == nil { + return errHTTPUnauthorized + } + return next(w, r, v) + }) +} + +func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if !s.config.EnablePayments { + return errHTTPNotFound + } + return next(w, r, v) + } +} + +func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { + return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing.StripeCustomerID == "" { + return errHTTPBadRequestNotAPaidUser + } + return next(w, r, v) + }) +} + +func (s *Server) withAccountSync(next handleFunc) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user == nil { + return next(w, r, v) + } + err := next(w, r, v) + if err == nil { + s.publishSyncEventAsync(v) + } + return err + } +} diff --git a/server/server_payments.go b/server/server_payments.go index b78a94a7..298945af 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -6,13 +6,14 @@ import ( "github.com/stripe/stripe-go/v74" portalsession "github.com/stripe/stripe-go/v74/billingportal/session" "github.com/stripe/stripe-go/v74/checkout/session" + "github.com/stripe/stripe-go/v74/customer" "github.com/stripe/stripe-go/v74/subscription" "github.com/stripe/stripe-go/v74/webhook" "github.com/tidwall/gjson" "heckel.io/ntfy/log" - "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" + "net/netip" "time" ) @@ -20,15 +21,13 @@ const ( stripeBodyBytesLimit = 16384 ) -// handleAccountBillingSubscriptionChange facilitates all subscription/tier changes, including payment flows. -// -// FIXME this should be two functions! -// -// It handles two cases: -// - Create subscription: Transition from a user without Stripe subscription to a paid subscription (Checkout flow) -// - Change subscription: Switching between Stripe prices (& tiers) by changing the Stripe subscription -func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error { - req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit) +// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier +// will be updated by a subsequent webhook from Stripe, once the subscription becomes active. +func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing.StripeSubscriptionID != "" { + return errors.New("subscription already exists") //FIXME + } + req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit) if err != nil { return err } @@ -36,46 +35,21 @@ func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r if err != nil { return err } - if v.user.Billing.StripeSubscriptionID == "" && tier.StripePriceID != "" { - return s.handleAccountBillingSubscriptionAdd(w, v, tier) - } else if v.user.Billing.StripeSubscriptionID != "" { - return s.handleAccountBillingSubscriptionUpdate(w, v, tier) + if tier.StripePriceID == "" { + return errors.New("invalid tier") //FIXME } - return errors.New("invalid state") -} - -// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, -// and cancelling the Stripe subscription entirely -func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing.StripeCustomerID == "" { - return errHTTPBadRequestNotAPaidUser - } - if v.user.Billing.StripeSubscriptionID != "" { - _, err := subscription.Cancel(v.user.Billing.StripeSubscriptionID, nil) - if err != nil { - return err - } - } - if err := s.userManager.ResetTier(v.user.Name); err != nil { - return err - } - v.user.Billing.StripeSubscriptionID = "" - v.user.Billing.StripeSubscriptionStatus = "" - v.user.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0) - v.user.Billing.StripeSubscriptionCancelAt = time.Unix(0, 0) - if err := s.userManager.ChangeBilling(v.user); err != nil { - return err - } - return nil -} - -func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *visitor, tier *user.Tier) error { log.Info("Stripe: No existing subscription, creating checkout flow") var stripeCustomerID *string if v.user.Billing.StripeCustomerID != "" { stripeCustomerID = &v.user.Billing.StripeCustomerID + stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil) + if err != nil { + return err + } else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 { + return errors.New("customer cannot have more than one subscription") //FIXME + } } - successURL := s.config.BaseURL + accountBillingSubscriptionCheckoutSuccessTemplate + successURL := s.config.BaseURL + "/account" //+ accountBillingSubscriptionCheckoutSuccessTemplate params := &stripe.CheckoutSessionParams{ Customer: stripeCustomerID, // A user may have previously deleted their subscription ClientReferenceID: &v.user.Name, // FIXME Should be user ID @@ -106,36 +80,7 @@ func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *v return nil } -func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, v *visitor, tier *user.Tier) error { - log.Info("Stripe: Changing tier and subscription to %s", tier.Code) - sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil) - if err != nil { - return err - } - params := &stripe.SubscriptionParams{ - CancelAtPeriodEnd: stripe.Bool(false), - ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), - Items: []*stripe.SubscriptionItemsParams{ - { - ID: stripe.String(sub.Items.Data[0].ID), - Price: stripe.String(tier.StripePriceID), - }, - }, - } - _, err = subscription.Update(sub.ID, params) - if err != nil { - return err - } - response := &apiAccountCheckoutResponse{} - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil -} - -func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error { +func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { // We don't have a v.user in this endpoint, only a userManager! matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { @@ -183,6 +128,66 @@ func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r return nil } +// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates +// a user's tier accordingly. This endpoint only works if there is an existing subscription. +func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing.StripeSubscriptionID != "" { + return errors.New("no existing subscription for user") + } + req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit) + if err != nil { + return err + } + tier, err := s.userManager.Tier(req.Tier) + if err != nil { + return err + } + log.Info("Stripe: Changing tier and subscription to %s", tier.Code) + sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil) + if err != nil { + return err + } + params := &stripe.SubscriptionParams{ + CancelAtPeriodEnd: stripe.Bool(false), + ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), + Items: []*stripe.SubscriptionItemsParams{ + { + ID: stripe.String(sub.Items.Data[0].ID), + Price: stripe.String(tier.StripePriceID), + }, + }, + } + _, err = subscription.Update(sub.ID, params) + if err != nil { + return err + } + response := &apiAccountCheckoutResponse{} // FIXME + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + +// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, +// and cancelling the Stripe subscription entirely +func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing.StripeCustomerID == "" { + return errHTTPBadRequestNotAPaidUser + } + if v.user.Billing.StripeSubscriptionID != "" { + params := &stripe.SubscriptionParams{ + CancelAtPeriodEnd: stripe.Bool(true), + } + _, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params) + if err != nil { + return err + } + } + return nil +} + func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { if v.user.Billing.StripeCustomerID == "" { return errHTTPBadRequestNotAPaidUser @@ -206,8 +211,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, return nil } -func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, v *visitor) error { - // We don't have a v.user in this endpoint, only a userManager! +func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error { + // Note that the visitor (v) in this endpoint is the Stripe API, so we don't have v.user available stripeSignature := r.Header.Get("Stripe-Signature") if stripeSignature == "" { return errHTTPBadRequestInvalidStripeRequest @@ -225,30 +230,27 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ return errHTTPBadRequestInvalidStripeRequest } log.Info("Stripe: webhook event %s received", event.Type) - stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer") - if !stripeCustomerID.Exists() { - return errHTTPBadRequestInvalidStripeRequest - } switch event.Type { case "customer.subscription.updated": - return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw) + return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw) case "customer.subscription.deleted": - return s.handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID.String(), event.Data.Raw) + return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw) default: return nil } } -func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error { +func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error { + subscriptionID := gjson.GetBytes(event, "id") + customerID := gjson.GetBytes(event, "customer") status := gjson.GetBytes(event, "status") currentPeriodEnd := gjson.GetBytes(event, "current_period_end") cancelAt := gjson.GetBytes(event, "cancel_at") priceID := gjson.GetBytes(event, "items.data.0.price.id") - if !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() { + if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() { return errHTTPBadRequestInvalidStripeRequest } - log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID) - u, err := s.userManager.UserByStripeCustomer(stripeCustomerID) + u, err := s.userManager.UserByStripeCustomer(customerID.String()) if err != nil { return err } @@ -259,22 +261,25 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { return err } + u.Billing.StripeSubscriptionID = subscriptionID.String() u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String()) u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0) u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt.Int(), 0) if err := s.userManager.ChangeBilling(u); err != nil { return err } + log.Info("Stripe: customer %s: subscription updated to %s, with price %s", customerID.String(), status, priceID) + s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) return nil } -func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID string, event json.RawMessage) error { - status := gjson.GetBytes(event, "status") - if !status.Exists() { +func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error { + stripeCustomerID := gjson.GetBytes(event, "customer") + if !stripeCustomerID.Exists() { return errHTTPBadRequestInvalidStripeRequest } - log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID) - u, err := s.userManager.UserByStripeCustomer(stripeCustomerID) + log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID.String()) + u, err := s.userManager.UserByStripeCustomer(stripeCustomerID.String()) if err != nil { return err } @@ -288,5 +293,6 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID if err := s.userManager.ChangeBilling(u); err != nil { return err } + s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) return nil } diff --git a/server/types.go b/server/types.go index fc81a2a6..0e37f553 100644 --- a/server/types.go +++ b/server/types.go @@ -305,7 +305,7 @@ type apiConfigResponse struct { DisallowedTopics []string `json:"disallowed_topics"` } -type apiAccountTierChangeRequest struct { +type apiAccountBillingSubscriptionChangeRequest struct { Tier string `json:"tier"` } @@ -316,3 +316,7 @@ type apiAccountCheckoutResponse struct { type apiAccountBillingPortalRedirectResponse struct { RedirectURL string `json:"redirect_url"` } + +type apiAccountSyncTopicResponse struct { + Event string `json:"event"` +} diff --git a/user/manager.go b/user/manager.go index 7b37b8f8..de4fc747 100644 --- a/user/manager.go +++ b/user/manager.go @@ -38,7 +38,6 @@ const ( id INTEGER PRIMARY KEY AUTOINCREMENT, code TEXT NOT NULL, name TEXT NOT NULL, - paid INT NOT NULL, messages_limit INT NOT NULL, messages_expiry_duration INT NOT NULL, emails_limit INT NOT NULL, @@ -104,20 +103,20 @@ const ( ` selectUserByNameQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u LEFT JOIN tier p on p.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at , p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id LEFT JOIN tier p on p.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` selectUserByStripeCustomerIDQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at , p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u LEFT JOIN tier p on p.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -218,17 +217,17 @@ const ( ` insertTierQuery = ` - INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ` selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` selectTierByCodeQuery = ` - SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id FROM tier WHERE code = ? ` selectTierByPriceIDQuery = ` - SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id FROM tier WHERE stripe_price_id = ? ` @@ -606,13 +605,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var username, hash, role, prefs, syncTopic string var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString - var paid sql.NullBool var messages, emails int64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -643,7 +641,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { user.Tier = &Tier{ Code: tierCode.String, Name: tierName.String, - Paid: paid.Bool, + Paid: stripePriceID.Valid, // If there is a price, it's a paid tier MessagesLimit: messagesLimit.Int64, MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, EmailsLimit: emailsLimit.Int64, @@ -870,7 +868,7 @@ func (a *Manager) DefaultAccess() Permission { // CreateTier creates a new tier in the database func (a *Manager) CreateTier(tier *Tier) error { - if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.Paid, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil { + if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil { return err } return nil @@ -903,12 +901,11 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { defer rows.Close() var code, name string var stripePriceID sql.NullString - var paid bool var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 if !rows.Next() { return nil, ErrTierNotFound } - if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -917,7 +914,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { return &Tier{ Code: code, Name: name, - Paid: paid, + Paid: stripePriceID.Valid, // If there is a price, it's a paid tier MessagesLimit: messagesLimit.Int64, MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, EmailsLimit: emailsLimit.Int64, diff --git a/web/public/static/langs/en.json b/web/public/static/langs/en.json index 56e6366b..0ae7e453 100644 --- a/web/public/static/langs/en.json +++ b/web/public/static/langs/en.json @@ -179,8 +179,10 @@ "account_usage_unlimited": "Unlimited", "account_usage_limits_reset_daily": "Usage limits are reset daily at midnight (UTC)", "account_usage_tier_title": "Account type", + "account_usage_tier_description": "Your account's power level", "account_usage_tier_admin": "Admin", - "account_usage_tier_none": "Basic", + "account_usage_tier_basic": "Basic", + "account_usage_tier_free": "Free", "account_usage_tier_upgrade_button": "Upgrade to Pro", "account_usage_tier_change_button": "Change", "account_usage_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew", @@ -199,6 +201,8 @@ "account_delete_dialog_label": "Type '{{username}}' to delete account", "account_delete_dialog_button_cancel": "Cancel", "account_delete_dialog_button_submit": "Permanently delete account", + "account_upgrade_dialog_title": "Change billing plan", + "account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.", "prefs_notifications_title": "Notifications", "prefs_notifications_sound_title": "Notification sound", "prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive", diff --git a/web/src/app/AccountApi.js b/web/src/app/AccountApi.js index fe918b45..ef9f13a9 100644 --- a/web/src/app/AccountApi.js +++ b/web/src/app/AccountApi.js @@ -264,11 +264,20 @@ class AccountApi { this.triggerChange(); // Dangle! } + async createBillingSubscription(tier) { + console.log(`[AccountApi] Creating billing subscription with ${tier}`); + return await this.upsertBillingSubscription("POST", tier) + } + async updateBillingSubscription(tier) { + console.log(`[AccountApi] Updating billing subscription with ${tier}`); + return await this.upsertBillingSubscription("PUT", tier) + } + + async upsertBillingSubscription(method, tier) { const url = accountBillingSubscriptionUrl(config.base_url); - console.log(`[AccountApi] Requesting tier change to ${tier}`); const response = await fetch(url, { - method: "POST", + method: method, headers: withBearerAuth({}, session.token()), body: JSON.stringify({ tier: tier @@ -284,7 +293,7 @@ class AccountApi { async deleteBillingSubscription() { const url = accountBillingSubscriptionUrl(config.base_url); - console.log(`[AccountApi] Cancelling paid subscription`); + console.log(`[AccountApi] Cancelling billing subscription`); const response = await fetch(url, { method: "DELETE", headers: withBearerAuth({}, session.token()) @@ -345,6 +354,7 @@ class AccountApi { } async triggerChange() { + return null; const account = await this.get(); if (!account.sync_topic) { return; diff --git a/web/src/components/Account.js b/web/src/components/Account.js index 622ca42d..f451e2be 100644 --- a/web/src/components/Account.js +++ b/web/src/components/Account.js @@ -56,6 +56,7 @@ const Basics = () => { + ); @@ -168,18 +169,20 @@ const ChangePasswordDialog = (props) => { ); }; -const Stats = () => { +const AccountType = () => { const { t } = useTranslation(); const { account } = useContext(AccountContext); + const [upgradeDialogKey, setUpgradeDialogKey] = useState(0); const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false); if (!account) { return <>; } - const normalize = (value, max) => { - return Math.min(value / max * 100, 100); - }; + const handleUpgradeClick = () => { + setUpgradeDialogKey(k => k + 1); + setUpgradeDialogOpen(true); + } const handleManageBilling = async () => { try { @@ -194,67 +197,89 @@ const Stats = () => { } }; + let accountType; + if (account.role === "admin") { + const tierSuffix = (account.tier) ? `(with ${account.tier.name} tier)` : `(no tier)`; + accountType = `${t("account_usage_tier_admin")} ${tierSuffix}`; + } else if (!account.tier) { + accountType = (config.enable_payments) ? t("account_usage_tier_free") : t("account_usage_tier_basic"); + } else { + accountType = account.tier.name; + } + + return ( + 0} + title={t("account_usage_tier_title")} + description={t("account_usage_tier_description")} + > +
+ {accountType} + {account.billing?.paid_until && !account.billing?.cancel_at && + + + + } + {config.enable_payments && account.role === "user" && !account.billing?.subscription && + + } + {config.enable_payments && account.role === "user" && account.billing?.subscription && + + } + {config.enable_payments && account.role === "user" && account.billing?.customer && + + } + setUpgradeDialogOpen(false)} + /> +
+ {account.billing?.status === "past_due" && + {t("account_usage_tier_payment_overdue")} + } + {account.billing?.cancel_at > 0 && + {t("account_usage_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })} + } +
+ ) +}; + +const Stats = () => { + const { t } = useTranslation(); + const { account } = useContext(AccountContext); + const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false); + + if (!account) { + return <>; + } + + const normalize = (value, max) => { + return Math.min(value / max * 100, 100); + }; + return ( {t("account_usage_title")} - 0} - title={t("account_usage_tier_title")} - > -
- {account.role === "admin" && - <> - {t("account_usage_tier_admin")} - {" "}{account.tier ? `(with ${account.tier.name} tier)` : `(no tier)`} - - } - {account.role === "user" && account.tier && account.tier.name} - {account.role === "user" && !account.tier && t("account_usage_tier_none")} - {account.billing?.paid_until && - - - - } - {config.enable_payments && account.role === "user" && (!account.tier || !account.tier.paid) && - - } - {config.enable_payments && account.role === "user" && account.tier?.paid && - - } - {config.enable_payments && account.role === "user" && account.billing?.customer && - - } - setUpgradeDialogOpen(false)} - /> -
- {account.billing?.status === "past_due" && - {t("account_usage_tier_payment_overdue")} - } - {account.billing?.cancel_at > 0 && - {t("account_usage_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })} - } -
{account.role !== "admin" && {account.limits.reservations > 0 && diff --git a/web/src/components/Navigation.js b/web/src/components/Navigation.js index 3a54e913..81dbb476 100644 --- a/web/src/components/Navigation.js +++ b/web/src/components/Navigation.js @@ -103,8 +103,8 @@ const NavList = (props) => { }; const isAdmin = account?.role === "admin"; - const isPaid = account?.tier?.paid; - const showUpgradeBanner = config.enable_payments && !isAdmin && !isPaid;// && (!props.account || !props.account.tier || !props.account.tier.paid || props.account); + const isPaid = account?.billing?.subscription; + const showUpgradeBanner = config.enable_payments && !isAdmin && !isPaid; const showSubscriptionsList = props.subscriptions?.length > 0; const showNotificationBrowserNotSupportedBox = !notifier.browserSupported(); const showNotificationContextNotSupportedBox = notifier.browserSupported() && !notifier.contextSupported(); // Only show if notifications are generally supported in the browser @@ -174,7 +174,14 @@ const NavList = (props) => { }; const UpgradeBanner = () => { + const [dialogKey, setDialogKey] = useState(0); const [dialogOpen, setDialogOpen] = useState(false); + + const handleClick = () => { + setDialogKey(k => k + 1); + setDialogOpen(true); + }; + return ( { background: "linear-gradient(150deg, rgba(196, 228, 221, 0.46) 0%, rgb(255, 255, 255) 100%)", }}> - setDialogOpen(true)} sx={{pt: 2, pb: 2}}> + { /> setDialogOpen(false)} /> diff --git a/web/src/components/UpgradeDialog.js b/web/src/components/UpgradeDialog.js index a50fdb82..5fb175fd 100644 --- a/web/src/components/UpgradeDialog.js +++ b/web/src/components/UpgradeDialog.js @@ -2,7 +2,7 @@ import * as React from 'react'; import Dialog from '@mui/material/Dialog'; import DialogContent from '@mui/material/DialogContent'; import DialogTitle from '@mui/material/DialogTitle'; -import {CardActionArea, CardContent, useMediaQuery} from "@mui/material"; +import {Alert, CardActionArea, CardContent, useMediaQuery} from "@mui/material"; import theme from "./theme"; import DialogFooter from "./DialogFooter"; import Button from "@mui/material/Button"; @@ -13,28 +13,53 @@ import {useContext, useState} from "react"; import Card from "@mui/material/Card"; import Typography from "@mui/material/Typography"; import {AccountContext} from "./App"; +import {formatShortDate} from "../app/utils"; +import {useTranslation} from "react-i18next"; const UpgradeDialog = (props) => { + const { t } = useTranslation(); const { account } = useContext(AccountContext); const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const [newTier, setNewTier] = useState(account?.tier?.code || null); const [errorText, setErrorText] = useState(""); - const handleCheckout = async () => { - try { - if (newTier == null) { - await accountApi.deleteBillingSubscription(); - } else { - const response = await accountApi.updateBillingSubscription(newTier); - if (response.redirect_url) { - window.location.href = response.redirect_url; - } else { - await accountApi.sync(); - } - } + if (!account) { + return <>; + } + const currentTier = account.tier?.code || null; + let action, submitButtonLabel, submitButtonEnabled; + if (currentTier === newTier) { + submitButtonLabel = "Update subscription"; + submitButtonEnabled = false; + action = null; + } else if (currentTier === null) { + submitButtonLabel = "Pay $5 now and subscribe"; + submitButtonEnabled = true; + action = Action.CREATE; + } else if (newTier === null) { + submitButtonLabel = "Cancel subscription"; + submitButtonEnabled = true; + action = Action.CANCEL; + } else { + submitButtonLabel = "Update subscription"; + submitButtonEnabled = true; + action = Action.UPDATE; + } + + const handleSubmit = async () => { + try { + if (action === Action.CREATE) { + const response = await accountApi.createBillingSubscription(newTier); + window.location.href = response.redirect_url; + } else if (action === Action.UPDATE) { + await accountApi.updateBillingSubscription(newTier); + } else if (action === Action.CANCEL) { + await accountApi.deleteBillingSubscription(); + } + props.onCancel(); } catch (e) { - console.log(`[UpgradeDialog] Error creating checkout session`, e); + console.log(`[UpgradeDialog] Error changing billing subscription`, e); if ((e instanceof UnauthorizedError)) { session.resetAndRedirect(routes.login); } @@ -44,7 +69,7 @@ const UpgradeDialog = (props) => { return ( - Upgrade to Pro + Change billing plan
{ setNewTier("pro")}/> setNewTier("business")}/>
+ {action === Action.CANCEL && + + {t("account_upgrade_dialog_cancel_warning", { date: formatShortDate(account.billing.paid_until) })} + + }
- + +
); @@ -65,8 +96,7 @@ const UpgradeDialog = (props) => { const TierCard = (props) => { const cardStyle = (props.selected) ? { - border: "1px solid red", - + background: "#eee" } : {}; return ( @@ -85,4 +115,10 @@ const TierCard = (props) => { ); } +const Action = { + CREATE: 1, + UPDATE: 2, + CANCEL: 3 +}; + export default UpgradeDialog;