From c06bfb989ef7c0a8ebf1ec23a1cb35e0e52c5eaf Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 15 Jan 2023 23:29:46 -0500 Subject: [PATCH] Payment stuff, cont'd --- server/server.go | 72 ++++--- server/server_account.go | 240 +---------------------- server/server_payments.go | 287 ++++++++++++++++++++++++++++ server/types.go | 8 + user/manager.go | 51 +++-- user/types.go | 17 +- web/public/static/langs/en.json | 2 + web/src/app/AccountApi.js | 22 ++- web/src/app/utils.js | 2 +- web/src/components/Account.js | 43 +++-- web/src/components/UpgradeDialog.js | 22 ++- 11 files changed, 457 insertions(+), 309 deletions(-) create mode 100644 server/server_payments.go diff --git a/server/server.go b/server/server.go index 0ad06a08..64891313 100644 --- a/server/server.go +++ b/server/server.go @@ -37,8 +37,13 @@ import ( /* TODO payments: - - handle overdue payment (-> downgrade after 7 days) - - delete stripe subscription when acocunt is deleted + - send dunning emails when overdue + - payment methods + - unmarshal to stripe.Subscription instead of gjson + - Make ResetTier reset the stripe fields + - delete subscription when account deleted + - remove tier.paid + - add tier.visible Limits & rate limiting: users without tier: should the stats be persisted? are they meaningful? @@ -97,27 +102,27 @@ var ( authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`) publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`) - webConfigPath = "/config.js" - healthPath = "/v1/health" - accountPath = "/v1/account" - accountTokenPath = "/v1/account/token" - accountPasswordPath = "/v1/account/password" - accountSettingsPath = "/v1/account/settings" - accountSubscriptionPath = "/v1/account/subscription" - accountReservationPath = "/v1/account/reservation" - accountBillingPortalPath = "/v1/account/billing/portal" - accountBillingWebhookPath = "/v1/account/billing/webhook" - accountCheckoutPath = "/v1/account/checkout" - accountCheckoutSuccessTemplate = "/v1/account/checkout/success/{CHECKOUT_SESSION_ID}" - accountCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/checkout/success/(.+)$`) - accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) - accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) - matrixPushPath = "/_matrix/push/v1/notify" - staticRegex = regexp.MustCompile(`^/static/.+`) - docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) - fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) - disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app - urlRegex = regexp.MustCompile(`^https?://`) + webConfigPath = "/config.js" + healthPath = "/v1/health" + accountPath = "/v1/account" + accountTokenPath = "/v1/account/token" + accountPasswordPath = "/v1/account/password" + accountSettingsPath = "/v1/account/settings" + accountSubscriptionPath = "/v1/account/subscription" + accountReservationPath = "/v1/account/reservation" + accountBillingPortalPath = "/v1/account/billing/portal" + accountBillingWebhookPath = "/v1/account/billing/webhook" + accountBillingSubscriptionPath = "/v1/account/billing/subscription" + accountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}" + accountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`) + accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) + accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) + matrixPushPath = "/_matrix/push/v1/notify" + staticRegex = regexp.MustCompile(`^/static/.+`) + docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) + fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) + disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app + urlRegex = regexp.MustCompile(`^https?://`) //go:embed site webFs embed.FS @@ -372,14 +377,16 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.ensureUser(s.handleAccountReservationAdd)(w, r, v) } else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) { return s.ensureUser(s.handleAccountReservationDelete)(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountCheckoutPath { - return s.ensureUser(s.handleAccountCheckoutSessionCreate)(w, r, v) - } else if r.Method == http.MethodGet && accountCheckoutSuccessRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath { + return s.ensureUser(s.handleAccountBillingSubscriptionChange)(w, r, v) + } else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath { + return s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete)(w, r, v) + } else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) { return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context! } else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath { - return s.ensureUser(s.handleAccountBillingPortalSessionCreate)(w, r, v) + return s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate)(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath { - return s.ensureUserManager(s.handleAccountBillingWebhookTrigger)(w, r, v) + return s.ensureUserManager(s.handleAccountBillingWebhook)(w, r, v) } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { return s.handleMatrixDiscovery(w) } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { @@ -1493,6 +1500,15 @@ func (s *Server) ensureUser(next handleFunc) handleFunc { }) } +func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { + return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing.StripeCustomerID == "" { + return errHTTPBadRequestNotAPaidUser + } + return next(w, r, v) + }) +} + // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers // before passing it on to the next handler. This is meant to be used in combination with handlePublish. func (s *Server) transformBodyJSON(next handleFunc) handleFunc { diff --git a/server/server_account.go b/server/server_account.go index 27c7f40c..fe7d4c11 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -2,14 +2,6 @@ package server import ( "encoding/json" - "errors" - "github.com/stripe/stripe-go/v74" - portalsession "github.com/stripe/stripe-go/v74/billingportal/session" - "github.com/stripe/stripe-go/v74/checkout/session" - "github.com/stripe/stripe-go/v74/subscription" - "github.com/stripe/stripe-go/v74/webhook" - "github.com/tidwall/gjson" - "heckel.io/ntfy/log" "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" @@ -17,7 +9,6 @@ import ( const ( jsonBodyBytesLimit = 4096 - stripeBodyBytesLimit = 16384 subscriptionIDLength = 16 createdByAPI = "api" ) @@ -100,6 +91,14 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis Paid: v.user.Tier.Paid, } } + if v.user.Billing.StripeCustomerID != "" { + response.Billing = &apiAccountBilling{ + Customer: true, + Subscription: v.user.Billing.StripeSubscriptionID != "", + Status: string(v.user.Billing.StripeSubscriptionStatus), + PaidUntil: v.user.Billing.StripeSubscriptionPaidUntil.Unix(), + } + } reservations, err := s.userManager.Reservations(v.user.Name) if err != nil { return err @@ -395,226 +394,3 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this return nil } - -func (s *Server) handleAccountCheckoutSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { - req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit) - if err != nil { - return err - } - tier, err := s.userManager.Tier(req.Tier) - if err != nil { - return err - } - if tier.StripePriceID == "" { - log.Info("Checkout: Downgrading to no tier") - return errors.New("not a paid tier") - } else if v.user.Billing != nil && v.user.Billing.StripeSubscriptionID != "" { - log.Info("Checkout: Changing tier and subscription to %s", tier.Code) - - // Upgrade/downgrade tier - sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil) - if err != nil { - return err - } - params := &stripe.SubscriptionParams{ - CancelAtPeriodEnd: stripe.Bool(false), - ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), - Items: []*stripe.SubscriptionItemsParams{ - { - ID: stripe.String(sub.Items.Data[0].ID), - Price: stripe.String(tier.StripePriceID), - }, - }, - } - _, err = subscription.Update(sub.ID, params) - if err != nil { - return err - } - response := &apiAccountCheckoutResponse{} - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil - } else { - // Checkout flow - log.Info("Checkout: No existing subscription, creating checkout flow") - } - - successURL := s.config.BaseURL + accountCheckoutSuccessTemplate - var stripeCustomerID *string - if v.user.Billing != nil { - stripeCustomerID = &v.user.Billing.StripeCustomerID - } - params := &stripe.CheckoutSessionParams{ - ClientReferenceID: &v.user.Name, // FIXME Should be user ID - Customer: stripeCustomerID, - SuccessURL: &successURL, - Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), - LineItems: []*stripe.CheckoutSessionLineItemParams{ - { - Price: stripe.String(tier.StripePriceID), - Quantity: stripe.Int64(1), - }, - }, - } - sess, err := session.New(params) - if err != nil { - return err - } - response := &apiAccountCheckoutResponse{ - RedirectURL: sess.URL, - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil -} - -func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error { - // We don't have a v.user in this endpoint, only a userManager! - matches := accountCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) - if len(matches) != 2 { - return errHTTPInternalErrorInvalidPath - } - sessionID := matches[1] - // FIXME how do I rate limit this? - sess, err := session.Get(sessionID, nil) - if err != nil { - log.Warn("Stripe: %s", err) - return errHTTPBadRequestInvalidStripeRequest - } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { - log.Warn("Stripe: Unexpected session, customer or subscription not found") - return errHTTPBadRequestInvalidStripeRequest - } - sub, err := subscription.Get(sess.Subscription.ID, nil) - if err != nil { - return err - } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { - log.Error("Stripe: Unexpected subscription, expected exactly one line item") - return errHTTPBadRequestInvalidStripeRequest - } - priceID := sub.Items.Data[0].Price.ID - tier, err := s.userManager.TierByStripePrice(priceID) - if err != nil { - return err - } - u, err := s.userManager.User(sess.ClientReferenceID) - if err != nil { - return err - } - if u.Billing == nil { - u.Billing = &user.Billing{} - } - u.Billing.StripeCustomerID = sess.Customer.ID - u.Billing.StripeSubscriptionID = sess.Subscription.ID - if err := s.userManager.ChangeBilling(u); err != nil { - return err - } - if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { - return err - } - accountURL := s.config.BaseURL + "/account" // FIXME - http.Redirect(w, r, accountURL, http.StatusSeeOther) - return nil -} - -func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing == nil { - return errHTTPBadRequestNotAPaidUser - } - params := &stripe.BillingPortalSessionParams{ - Customer: stripe.String(v.user.Billing.StripeCustomerID), - ReturnURL: stripe.String(s.config.BaseURL), - } - ps, err := portalsession.New(params) - if err != nil { - return err - } - response := &apiAccountBillingPortalRedirectResponse{ - RedirectURL: ps.URL, - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil -} - -func (s *Server) handleAccountBillingWebhookTrigger(w http.ResponseWriter, r *http.Request, v *visitor) error { - // We don't have a v.user in this endpoint, only a userManager! - stripeSignature := r.Header.Get("Stripe-Signature") - if stripeSignature == "" { - return errHTTPBadRequestInvalidStripeRequest - } - body, err := util.Peek(r.Body, stripeBodyBytesLimit) - if err != nil { - return err - } else if body.LimitReached { - return errHTTPEntityTooLargeJSONBody - } - event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey) - if err != nil { - log.Warn("Stripe: invalid request: %s", err.Error()) - return errHTTPBadRequestInvalidStripeRequest - } else if event.Data == nil || event.Data.Raw == nil { - log.Warn("Stripe: invalid request, data is nil") - return errHTTPBadRequestInvalidStripeRequest - } - log.Info("Stripe: webhook event %s received", event.Type) - stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer") - if !stripeCustomerID.Exists() { - return errHTTPBadRequestInvalidStripeRequest - } - switch event.Type { - case "checkout.session.completed": - // Payment is successful and the subscription is created. - // Provision the subscription, save the customer ID. - return s.handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID.String(), event.Data.Raw) - case "customer.subscription.updated": - return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw) - case "invoice.paid": - // Continue to provision the subscription as payments continue to be made. - // Store the status in your database and check when a user accesses your service. - // This approach helps you avoid hitting rate limits. - return nil // FIXME - case "invoice.payment_failed": - // The payment failed or the customer does not have a valid payment method. - // The subscription becomes past_due. Notify your customer and send them to the - // customer portal to update their payment information. - return nil // FIXME - default: - log.Warn("Stripe: unhandled webhook %s", event.Type) - return nil - } -} - -func (s *Server) handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID string, event json.RawMessage) error { - log.Info("Stripe: checkout completed for customer %s", stripeCustomerID) - return nil -} - -func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error { - status := gjson.GetBytes(event, "status") - priceID := gjson.GetBytes(event, "items.data.0.price.id") - if !status.Exists() || !priceID.Exists() { - return errHTTPBadRequestInvalidStripeRequest - } - log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID) - u, err := s.userManager.UserByStripeCustomer(stripeCustomerID) - if err != nil { - return err - } - tier, err := s.userManager.TierByStripePrice(priceID.String()) - if err != nil { - return err - } - if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { - return err - } - return nil -} diff --git a/server/server_payments.go b/server/server_payments.go new file mode 100644 index 00000000..81e52217 --- /dev/null +++ b/server/server_payments.go @@ -0,0 +1,287 @@ +package server + +import ( + "encoding/json" + "errors" + "github.com/stripe/stripe-go/v74" + portalsession "github.com/stripe/stripe-go/v74/billingportal/session" + "github.com/stripe/stripe-go/v74/checkout/session" + "github.com/stripe/stripe-go/v74/subscription" + "github.com/stripe/stripe-go/v74/webhook" + "github.com/tidwall/gjson" + "heckel.io/ntfy/log" + "heckel.io/ntfy/user" + "heckel.io/ntfy/util" + "net/http" + "time" +) + +const ( + stripeBodyBytesLimit = 16384 +) + +// handleAccountBillingSubscriptionChange facilitates all subscription/tier changes, including payment flows. +// +// FIXME this should be two functions! +// +// It handles two cases: +// - Create subscription: Transition from a user without Stripe subscription to a paid subscription (Checkout flow) +// - Change subscription: Switching between Stripe prices (& tiers) by changing the Stripe subscription +func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error { + req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit) + if err != nil { + return err + } + tier, err := s.userManager.Tier(req.Tier) + if err != nil { + return err + } + if v.user.Billing.StripeSubscriptionID == "" && tier.StripePriceID != "" { + return s.handleAccountBillingSubscriptionAdd(w, v, tier) + } else if v.user.Billing.StripeSubscriptionID != "" { + return s.handleAccountBillingSubscriptionUpdate(w, v, tier) + } + return errors.New("invalid state") +} + +// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, +// and cancelling the Stripe subscription entirely +func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing.StripeCustomerID == "" { + return errHTTPBadRequestNotAPaidUser + } + if v.user.Billing.StripeSubscriptionID != "" { + _, err := subscription.Cancel(v.user.Billing.StripeSubscriptionID, nil) + if err != nil { + return err + } + } + if err := s.userManager.ResetTier(v.user.Name); err != nil { + return err + } + v.user.Billing.StripeSubscriptionID = "" + v.user.Billing.StripeSubscriptionStatus = "" + v.user.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0) + if err := s.userManager.ChangeBilling(v.user); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *visitor, tier *user.Tier) error { + log.Info("Stripe: No existing subscription, creating checkout flow") + var stripeCustomerID *string + if v.user.Billing.StripeCustomerID != "" { + stripeCustomerID = &v.user.Billing.StripeCustomerID + } + successURL := s.config.BaseURL + accountBillingSubscriptionCheckoutSuccessTemplate + params := &stripe.CheckoutSessionParams{ + Customer: stripeCustomerID, // A user may have previously deleted their subscription + ClientReferenceID: &v.user.Name, // FIXME Should be user ID + SuccessURL: &successURL, + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + LineItems: []*stripe.CheckoutSessionLineItemParams{ + { + Price: stripe.String(tier.StripePriceID), + Quantity: stripe.Int64(1), + }, + }, + /*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{ + Enabled: stripe.Bool(true), + },*/ + } + sess, err := session.New(params) + if err != nil { + return err + } + response := &apiAccountCheckoutResponse{ + RedirectURL: sess.URL, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, v *visitor, tier *user.Tier) error { + log.Info("Stripe: Changing tier and subscription to %s", tier.Code) + sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil) + if err != nil { + return err + } + params := &stripe.SubscriptionParams{ + CancelAtPeriodEnd: stripe.Bool(false), + ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), + Items: []*stripe.SubscriptionItemsParams{ + { + ID: stripe.String(sub.Items.Data[0].ID), + Price: stripe.String(tier.StripePriceID), + }, + }, + } + _, err = subscription.Update(sub.ID, params) + if err != nil { + return err + } + response := &apiAccountCheckoutResponse{} + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error { + // We don't have a v.user in this endpoint, only a userManager! + matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) + if len(matches) != 2 { + return errHTTPInternalErrorInvalidPath + } + sessionID := matches[1] + // FIXME how do I rate limit this? + sess, err := session.Get(sessionID, nil) + if err != nil { + log.Warn("Stripe: %s", err) + return errHTTPBadRequestInvalidStripeRequest + } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { + log.Warn("Stripe: Unexpected session, customer or subscription not found") + return errHTTPBadRequestInvalidStripeRequest + } + sub, err := subscription.Get(sess.Subscription.ID, nil) + if err != nil { + return err + } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { + log.Error("Stripe: Unexpected subscription, expected exactly one line item") + return errHTTPBadRequestInvalidStripeRequest + } + priceID := sub.Items.Data[0].Price.ID + tier, err := s.userManager.TierByStripePrice(priceID) + if err != nil { + return err + } + u, err := s.userManager.User(sess.ClientReferenceID) + if err != nil { + return err + } + u.Billing.StripeCustomerID = sess.Customer.ID + u.Billing.StripeSubscriptionID = sub.ID + u.Billing.StripeSubscriptionStatus = sub.Status + u.Billing.StripeSubscriptionPaidUntil = time.Unix(sub.CurrentPeriodEnd, 0) + if err := s.userManager.ChangeBilling(u); err != nil { + return err + } + if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { + return err + } + accountURL := s.config.BaseURL + "/account" // FIXME + http.Redirect(w, r, accountURL, http.StatusSeeOther) + return nil +} + +func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing.StripeCustomerID == "" { + return errHTTPBadRequestNotAPaidUser + } + params := &stripe.BillingPortalSessionParams{ + Customer: stripe.String(v.user.Billing.StripeCustomerID), + ReturnURL: stripe.String(s.config.BaseURL), + } + ps, err := portalsession.New(params) + if err != nil { + return err + } + response := &apiAccountBillingPortalRedirectResponse{ + RedirectURL: ps.URL, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, v *visitor) error { + // We don't have a v.user in this endpoint, only a userManager! + stripeSignature := r.Header.Get("Stripe-Signature") + if stripeSignature == "" { + return errHTTPBadRequestInvalidStripeRequest + } + body, err := util.Peek(r.Body, stripeBodyBytesLimit) + if err != nil { + return err + } else if body.LimitReached { + return errHTTPEntityTooLargeJSONBody + } + event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey) + if err != nil { + return errHTTPBadRequestInvalidStripeRequest + } else if event.Data == nil || event.Data.Raw == nil { + return errHTTPBadRequestInvalidStripeRequest + } + log.Info("Stripe: webhook event %s received", event.Type) + stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer") + if !stripeCustomerID.Exists() { + return errHTTPBadRequestInvalidStripeRequest + } + switch event.Type { + case "customer.subscription.updated": + return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw) + case "customer.subscription.deleted": + return s.handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID.String(), event.Data.Raw) + default: + return nil + } +} + +func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error { + status := gjson.GetBytes(event, "status") + currentPeriodEnd := gjson.GetBytes(event, "current_period_end") + priceID := gjson.GetBytes(event, "items.data.0.price.id") + if !status.Exists() || !currentPeriodEnd.Exists() || !priceID.Exists() { + return errHTTPBadRequestInvalidStripeRequest + } + log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID) + u, err := s.userManager.UserByStripeCustomer(stripeCustomerID) + if err != nil { + return err + } + tier, err := s.userManager.TierByStripePrice(priceID.String()) + if err != nil { + return err + } + if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { + return err + } + u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String()) + u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0) + if err := s.userManager.ChangeBilling(u); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID string, event json.RawMessage) error { + status := gjson.GetBytes(event, "status") + if !status.Exists() { + return errHTTPBadRequestInvalidStripeRequest + } + log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID) + u, err := s.userManager.UserByStripeCustomer(stripeCustomerID) + if err != nil { + return err + } + if err := s.userManager.ResetTier(u.Name); err != nil { + return err + } + u.Billing.StripeSubscriptionID = "" + u.Billing.StripeSubscriptionStatus = "" + u.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0) + if err := s.userManager.ChangeBilling(u); err != nil { + return err + } + return nil +} diff --git a/server/types.go b/server/types.go index e6ebc28f..cee114dc 100644 --- a/server/types.go +++ b/server/types.go @@ -268,6 +268,13 @@ type apiAccountReservation struct { Everyone string `json:"everyone"` } +type apiAccountBilling struct { + Customer bool `json:"customer"` + Subscription bool `json:"subscription"` + Status string `json:"status,omitempty"` + PaidUntil int64 `json:"paid_until,omitempty"` +} + type apiAccountResponse struct { Username string `json:"username"` Role string `json:"role,omitempty"` @@ -279,6 +286,7 @@ type apiAccountResponse struct { Tier *apiAccountTier `json:"tier,omitempty"` Limits *apiAccountLimits `json:"limits,omitempty"` Stats *apiAccountStats `json:"stats,omitempty"` + Billing *apiAccountBilling `json:"billing,omitempty"` } type apiAccountReservationRequest struct { diff --git a/user/manager.go b/user/manager.go index f440838a..7e50b4a1 100644 --- a/user/manager.go +++ b/user/manager.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" _ "github.com/mattn/go-sqlite3" // SQLite driver + "github.com/stripe/stripe-go/v74" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/log" "heckel.io/ntfy/util" @@ -60,7 +61,9 @@ const ( stats_messages INT NOT NULL DEFAULT (0), stats_emails INT NOT NULL DEFAULT (0), stripe_customer_id TEXT, - stripe_subscription_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_paid_until INT, created_by TEXT NOT NULL, created_at INT NOT NULL, last_seen INT NOT NULL, @@ -100,20 +103,20 @@ const ( ` selectUserByNameQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u LEFT JOIN tier p on p.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id LEFT JOIN tier p on p.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` selectUserByStripeCustomerIDQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u LEFT JOIN tier p on p.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -231,7 +234,11 @@ const ( updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` - updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?` + updateBillingQuery = ` + UPDATE user + SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ? + WHERE user = ? + ` ) // Schema management queries @@ -597,14 +604,14 @@ func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var username, hash, role, prefs, syncTopic string - var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString + var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var paid sql.NullBool var messages, emails int64 - var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 + var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -619,16 +626,16 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { Messages: messages, Emails: emails, }, + Billing: &Billing{ + StripeCustomerID: stripeCustomerID.String, // May be empty + StripeSubscriptionID: stripeSubscriptionID.String, // May be empty + StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty + StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero + }, } if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { return nil, err } - if stripeCustomerID.Valid && stripeSubscriptionID.Valid { - user.Billing = &Billing{ - StripeCustomerID: stripeCustomerID.String, - StripeSubscriptionID: stripeSubscriptionID.String, - } - } if tierCode.Valid { // See readTier() when this is changed! user.Tier = &Tier{ @@ -868,7 +875,7 @@ func (a *Manager) CreateTier(tier *Tier) error { } func (a *Manager) ChangeBilling(user *User) error { - if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil { + if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), user.Name); err != nil { return err } return nil @@ -1020,3 +1027,17 @@ func migrateFrom1(db *sql.DB) error { } return nil // Update this when a new version is added } + +func nullString(s string) sql.NullString { + if s == "" { + return sql.NullString{} + } + return sql.NullString{String: s, Valid: true} +} + +func nullInt64(v int64) sql.NullInt64 { + if v == 0 { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: v, Valid: true} +} diff --git a/user/types.go b/user/types.go index 77a34749..e9a689fa 100644 --- a/user/types.go +++ b/user/types.go @@ -3,6 +3,7 @@ package user import ( "errors" + "github.com/stripe/stripe-go/v74" "regexp" "time" ) @@ -85,8 +86,10 @@ type Stats struct { // Billing is a struct holding a user's billing information type Billing struct { - StripeCustomerID string - StripeSubscriptionID string + StripeCustomerID string + StripeSubscriptionID string + StripeSubscriptionStatus stripe.SubscriptionStatus + StripeSubscriptionPaidUntil time.Time } // Grant is a struct that represents an access control entry to a topic by a user @@ -223,3 +226,13 @@ var ( ErrUserNotFound = errors.New("user not found") ErrTierNotFound = errors.New("tier not found") ) + +// BillingStatus represents the status of a Stripe subscription +type BillingStatus string + +// BillingStatus values, subset of https://stripe.com/docs/billing/subscriptions/overview +const ( + BillingStatusIncomplete = BillingStatus("incomplete") + BillingStatusActive = BillingStatus("active") + BillingStatusPastDue = BillingStatus("past_due") +) diff --git a/web/public/static/langs/en.json b/web/public/static/langs/en.json index 0efc0a1c..f18aedfb 100644 --- a/web/public/static/langs/en.json +++ b/web/public/static/langs/en.json @@ -183,6 +183,8 @@ "account_usage_tier_none": "Basic", "account_usage_tier_upgrade_button": "Upgrade to Pro", "account_usage_tier_change_button": "Change", + "account_usage_tier_payment_overdue": "Your payment is overdue. Please update your payment method, or your account will be downgraded soon.", + "account_usage_manage_billing_button": "Manage billing", "account_usage_messages_title": "Published messages", "account_usage_emails_title": "Emails sent", "account_usage_reservations_title": "Reserved topics", diff --git a/web/src/app/AccountApi.js b/web/src/app/AccountApi.js index 38adfffb..fe918b45 100644 --- a/web/src/app/AccountApi.js +++ b/web/src/app/AccountApi.js @@ -8,7 +8,7 @@ import { accountTokenUrl, accountUrl, maybeWithAuth, topicUrl, withBasicAuth, - withBearerAuth, accountCheckoutUrl, accountBillingPortalUrl + withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl } from "./utils"; import session from "./Session"; import subscriptionManager from "./SubscriptionManager"; @@ -264,9 +264,9 @@ class AccountApi { this.triggerChange(); // Dangle! } - async createCheckoutSession(tier) { - const url = accountCheckoutUrl(config.base_url); - console.log(`[AccountApi] Creating checkout session`); + async updateBillingSubscription(tier) { + const url = accountBillingSubscriptionUrl(config.base_url); + console.log(`[AccountApi] Requesting tier change to ${tier}`); const response = await fetch(url, { method: "POST", headers: withBearerAuth({}, session.token()), @@ -282,6 +282,20 @@ class AccountApi { return await response.json(); } + async deleteBillingSubscription() { + const url = accountBillingSubscriptionUrl(config.base_url); + console.log(`[AccountApi] Cancelling paid subscription`); + const response = await fetch(url, { + method: "DELETE", + headers: withBearerAuth({}, session.token()) + }); + if (response.status === 401 || response.status === 403) { + throw new UnauthorizedError(); + } else if (response.status !== 200) { + throw new Error(`Unexpected server response ${response.status}`); + } + } + async createBillingPortalSession() { const url = accountBillingPortalUrl(config.base_url); console.log(`[AccountApi] Creating billing portal session`); diff --git a/web/src/app/utils.js b/web/src/app/utils.js index 8001933e..8603ec55 100644 --- a/web/src/app/utils.js +++ b/web/src/app/utils.js @@ -26,7 +26,7 @@ export const accountSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/subscr export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`; export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`; export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`; -export const accountCheckoutUrl = (baseUrl) => `${baseUrl}/v1/account/checkout`; +export const accountBillingSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/billing/subscription`; export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`; export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; diff --git a/web/src/components/Account.js b/web/src/components/Account.js index 9e68ce94..8744dbfc 100644 --- a/web/src/components/Account.js +++ b/web/src/components/Account.js @@ -1,6 +1,6 @@ import * as React from 'react'; import {useContext, useState} from 'react'; -import {LinearProgress, Stack, useMediaQuery} from "@mui/material"; +import {Alert, LinearProgress, Stack, useMediaQuery} from "@mui/material"; import Tooltip from '@mui/material/Tooltip'; import Typography from "@mui/material/Typography"; import EditIcon from '@mui/icons-material/Edit'; @@ -18,7 +18,7 @@ import TextField from "@mui/material/TextField"; import DialogActions from "@mui/material/DialogActions"; import routes from "./routes"; import IconButton from "@mui/material/IconButton"; -import {formatBytes} from "../app/utils"; +import {formatBytes, formatShortDateTime} from "../app/utils"; import accountApi, {UnauthorizedError} from "../app/AccountApi"; import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined'; import {Pref, PrefGroup} from "./Pref"; @@ -28,6 +28,7 @@ import humanizeDuration from "humanize-duration"; import UpgradeDialog from "./UpgradeDialog"; import CelebrationIcon from "@mui/icons-material/Celebration"; import {AccountContext} from "./App"; +import {Warning, WarningAmber} from "@mui/icons-material"; const Account = () => { if (!session.exists()) { @@ -183,7 +184,7 @@ const Stats = () => { const handleManageBilling = async () => { try { const response = await accountApi.createBillingPortalSession(); - window.location.href = response.redirect_url; + window.open(response.redirect_url, "billing_portal"); } catch (e) { console.log(`[Account] Error changing password`, e); if ((e instanceof UnauthorizedError)) { @@ -199,7 +200,10 @@ const Stats = () => { {t("account_usage_title")} - +
{account.role === "admin" && <> @@ -219,26 +223,29 @@ const Stats = () => { >{t("account_usage_tier_upgrade_button")} } {config.enable_payments && account.role === "user" && account.tier?.paid && - <> - - - + + } + {config.enable_payments && account.role === "user" && account.billing?.customer && + } setUpgradeDialogOpen(false)} />
+ {account.billing?.status === "past_due" && + {t("account_usage_tier_payment_overdue")} + }
{account.role !== "admin" && diff --git a/web/src/components/UpgradeDialog.js b/web/src/components/UpgradeDialog.js index 2204c6cf..a50fdb82 100644 --- a/web/src/components/UpgradeDialog.js +++ b/web/src/components/UpgradeDialog.js @@ -17,16 +17,20 @@ import {AccountContext} from "./App"; const UpgradeDialog = (props) => { const { account } = useContext(AccountContext); const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); - const [selected, setSelected] = useState(account?.tier?.code || null); + const [newTier, setNewTier] = useState(account?.tier?.code || null); const [errorText, setErrorText] = useState(""); const handleCheckout = async () => { try { - const response = await accountApi.createCheckoutSession(selected); - if (response.redirect_url) { - window.location.href = response.redirect_url; + if (newTier == null) { + await accountApi.deleteBillingSubscription(); } else { - await accountApi.sync(); + const response = await accountApi.updateBillingSubscription(newTier); + if (response.redirect_url) { + window.location.href = response.redirect_url; + } else { + await accountApi.sync(); + } } } catch (e) { @@ -46,10 +50,10 @@ const UpgradeDialog = (props) => { display: "flex", flexDirection: "row" }}> - setSelected(null)}/> - setSelected("starter")}/> - setSelected("pro")}/> - setSelected("business")}/> + setNewTier(null)}/> + setNewTier("starter")}/> + setNewTier("pro")}/> + setNewTier("business")}/>