Payment stuff, cont'd

This commit is contained in:
binwiederhier 2023-01-15 23:29:46 -05:00
parent f7f7f469ad
commit c06bfb989e
11 changed files with 457 additions and 309 deletions

View file

@ -37,8 +37,13 @@ import (
/* /*
TODO TODO
payments: payments:
- handle overdue payment (-> downgrade after 7 days) - send dunning emails when overdue
- delete stripe subscription when acocunt is deleted - payment methods
- unmarshal to stripe.Subscription instead of gjson
- Make ResetTier reset the stripe fields
- delete subscription when account deleted
- remove tier.paid
- add tier.visible
Limits & rate limiting: Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful? users without tier: should the stats be persisted? are they meaningful?
@ -97,27 +102,27 @@ var (
authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`) authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`) publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
webConfigPath = "/config.js" webConfigPath = "/config.js"
healthPath = "/v1/health" healthPath = "/v1/health"
accountPath = "/v1/account" accountPath = "/v1/account"
accountTokenPath = "/v1/account/token" accountTokenPath = "/v1/account/token"
accountPasswordPath = "/v1/account/password" accountPasswordPath = "/v1/account/password"
accountSettingsPath = "/v1/account/settings" accountSettingsPath = "/v1/account/settings"
accountSubscriptionPath = "/v1/account/subscription" accountSubscriptionPath = "/v1/account/subscription"
accountReservationPath = "/v1/account/reservation" accountReservationPath = "/v1/account/reservation"
accountBillingPortalPath = "/v1/account/billing/portal" accountBillingPortalPath = "/v1/account/billing/portal"
accountBillingWebhookPath = "/v1/account/billing/webhook" accountBillingWebhookPath = "/v1/account/billing/webhook"
accountCheckoutPath = "/v1/account/checkout" accountBillingSubscriptionPath = "/v1/account/billing/subscription"
accountCheckoutSuccessTemplate = "/v1/account/checkout/success/{CHECKOUT_SESSION_ID}" accountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
accountCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/checkout/success/(.+)$`) accountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
matrixPushPath = "/_matrix/push/v1/notify" matrixPushPath = "/_matrix/push/v1/notify"
staticRegex = regexp.MustCompile(`^/static/.+`) staticRegex = regexp.MustCompile(`^/static/.+`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app
urlRegex = regexp.MustCompile(`^https?://`) urlRegex = regexp.MustCompile(`^https?://`)
//go:embed site //go:embed site
webFs embed.FS webFs embed.FS
@ -372,14 +377,16 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureUser(s.handleAccountReservationAdd)(w, r, v) return s.ensureUser(s.handleAccountReservationAdd)(w, r, v)
} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.handleAccountReservationDelete)(w, r, v) return s.ensureUser(s.handleAccountReservationDelete)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountCheckoutPath { } else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath {
return s.ensureUser(s.handleAccountCheckoutSessionCreate)(w, r, v) return s.ensureUser(s.handleAccountBillingSubscriptionChange)(w, r, v)
} else if r.Method == http.MethodGet && accountCheckoutSuccessRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath {
return s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete)(w, r, v)
} else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context! return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context!
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath { } else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
return s.ensureUser(s.handleAccountBillingPortalSessionCreate)(w, r, v) return s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath { } else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
return s.ensureUserManager(s.handleAccountBillingWebhookTrigger)(w, r, v) return s.ensureUserManager(s.handleAccountBillingWebhook)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w) return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
@ -1493,6 +1500,15 @@ func (s *Server) ensureUser(next handleFunc) handleFunc {
}) })
} }
func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
return next(w, r, v)
})
}
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
// before passing it on to the next handler. This is meant to be used in combination with handlePublish. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
func (s *Server) transformBodyJSON(next handleFunc) handleFunc { func (s *Server) transformBodyJSON(next handleFunc) handleFunc {

View file

@ -2,14 +2,6 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"github.com/stripe/stripe-go/v74"
portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
"github.com/stripe/stripe-go/v74/checkout/session"
"github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"net/http" "net/http"
@ -17,7 +9,6 @@ import (
const ( const (
jsonBodyBytesLimit = 4096 jsonBodyBytesLimit = 4096
stripeBodyBytesLimit = 16384
subscriptionIDLength = 16 subscriptionIDLength = 16
createdByAPI = "api" createdByAPI = "api"
) )
@ -100,6 +91,14 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
Paid: v.user.Tier.Paid, Paid: v.user.Tier.Paid,
} }
} }
if v.user.Billing.StripeCustomerID != "" {
response.Billing = &apiAccountBilling{
Customer: true,
Subscription: v.user.Billing.StripeSubscriptionID != "",
Status: string(v.user.Billing.StripeSubscriptionStatus),
PaidUntil: v.user.Billing.StripeSubscriptionPaidUntil.Unix(),
}
}
reservations, err := s.userManager.Reservations(v.user.Name) reservations, err := s.userManager.Reservations(v.user.Name)
if err != nil { if err != nil {
return err return err
@ -395,226 +394,3 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
return nil return nil
} }
func (s *Server) handleAccountCheckoutSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
return err
}
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
}
if tier.StripePriceID == "" {
log.Info("Checkout: Downgrading to no tier")
return errors.New("not a paid tier")
} else if v.user.Billing != nil && v.user.Billing.StripeSubscriptionID != "" {
log.Info("Checkout: Changing tier and subscription to %s", tier.Code)
// Upgrade/downgrade tier
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
},
},
}
_, err = subscription.Update(sub.ID, params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
} else {
// Checkout flow
log.Info("Checkout: No existing subscription, creating checkout flow")
}
successURL := s.config.BaseURL + accountCheckoutSuccessTemplate
var stripeCustomerID *string
if v.user.Billing != nil {
stripeCustomerID = &v.user.Billing.StripeCustomerID
}
params := &stripe.CheckoutSessionParams{
ClientReferenceID: &v.user.Name, // FIXME Should be user ID
Customer: stripeCustomerID,
SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(tier.StripePriceID),
Quantity: stripe.Int64(1),
},
},
}
sess, err := session.New(params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{
RedirectURL: sess.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := accountCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}
sessionID := matches[1]
// FIXME how do I rate limit this?
sess, err := session.Get(sessionID, nil)
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestInvalidStripeRequest
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
log.Warn("Stripe: Unexpected session, customer or subscription not found")
return errHTTPBadRequestInvalidStripeRequest
}
sub, err := subscription.Get(sess.Subscription.ID, nil)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
log.Error("Stripe: Unexpected subscription, expected exactly one line item")
return errHTTPBadRequestInvalidStripeRequest
}
priceID := sub.Items.Data[0].Price.ID
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
u, err := s.userManager.User(sess.ClientReferenceID)
if err != nil {
return err
}
if u.Billing == nil {
u.Billing = &user.Billing{}
}
u.Billing.StripeCustomerID = sess.Customer.ID
u.Billing.StripeSubscriptionID = sess.Subscription.ID
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
accountURL := s.config.BaseURL + "/account" // FIXME
http.Redirect(w, r, accountURL, http.StatusSeeOther)
return nil
}
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing == nil {
return errHTTPBadRequestNotAPaidUser
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
}
ps, err := portalsession.New(params)
if err != nil {
return err
}
response := &apiAccountBillingPortalRedirectResponse{
RedirectURL: ps.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingWebhookTrigger(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestInvalidStripeRequest
}
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
if err != nil {
return err
} else if body.LimitReached {
return errHTTPEntityTooLargeJSONBody
}
event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil {
log.Warn("Stripe: invalid request: %s", err.Error())
return errHTTPBadRequestInvalidStripeRequest
} else if event.Data == nil || event.Data.Raw == nil {
log.Warn("Stripe: invalid request, data is nil")
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: webhook event %s received", event.Type)
stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
if !stripeCustomerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
switch event.Type {
case "checkout.session.completed":
// Payment is successful and the subscription is created.
// Provision the subscription, save the customer ID.
return s.handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID.String(), event.Data.Raw)
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
case "invoice.paid":
// Continue to provision the subscription as payments continue to be made.
// Store the status in your database and check when a user accesses your service.
// This approach helps you avoid hitting rate limits.
return nil // FIXME
case "invoice.payment_failed":
// The payment failed or the customer does not have a valid payment method.
// The subscription becomes past_due. Notify your customer and send them to the
// customer portal to update their payment information.
return nil // FIXME
default:
log.Warn("Stripe: unhandled webhook %s", event.Type)
return nil
}
}
func (s *Server) handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID string, event json.RawMessage) error {
log.Info("Stripe: checkout completed for customer %s", stripeCustomerID)
return nil
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
status := gjson.GetBytes(event, "status")
priceID := gjson.GetBytes(event, "items.data.0.price.id")
if !status.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
if err != nil {
return err
}
tier, err := s.userManager.TierByStripePrice(priceID.String())
if err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
return nil
}

287
server/server_payments.go Normal file
View file

@ -0,0 +1,287 @@
package server
import (
"encoding/json"
"errors"
"github.com/stripe/stripe-go/v74"
portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
"github.com/stripe/stripe-go/v74/checkout/session"
"github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"net/http"
"time"
)
const (
stripeBodyBytesLimit = 16384
)
// handleAccountBillingSubscriptionChange facilitates all subscription/tier changes, including payment flows.
//
// FIXME this should be two functions!
//
// It handles two cases:
// - Create subscription: Transition from a user without Stripe subscription to a paid subscription (Checkout flow)
// - Change subscription: Switching between Stripe prices (& tiers) by changing the Stripe subscription
func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
return err
}
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
}
if v.user.Billing.StripeSubscriptionID == "" && tier.StripePriceID != "" {
return s.handleAccountBillingSubscriptionAdd(w, v, tier)
} else if v.user.Billing.StripeSubscriptionID != "" {
return s.handleAccountBillingSubscriptionUpdate(w, v, tier)
}
return errors.New("invalid state")
}
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
if v.user.Billing.StripeSubscriptionID != "" {
_, err := subscription.Cancel(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
}
if err := s.userManager.ResetTier(v.user.Name); err != nil {
return err
}
v.user.Billing.StripeSubscriptionID = ""
v.user.Billing.StripeSubscriptionStatus = ""
v.user.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0)
if err := s.userManager.ChangeBilling(v.user); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *visitor, tier *user.Tier) error {
log.Info("Stripe: No existing subscription, creating checkout flow")
var stripeCustomerID *string
if v.user.Billing.StripeCustomerID != "" {
stripeCustomerID = &v.user.Billing.StripeCustomerID
}
successURL := s.config.BaseURL + accountBillingSubscriptionCheckoutSuccessTemplate
params := &stripe.CheckoutSessionParams{
Customer: stripeCustomerID, // A user may have previously deleted their subscription
ClientReferenceID: &v.user.Name, // FIXME Should be user ID
SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(tier.StripePriceID),
Quantity: stripe.Int64(1),
},
},
/*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
Enabled: stripe.Bool(true),
},*/
}
sess, err := session.New(params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{
RedirectURL: sess.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, v *visitor, tier *user.Tier) error {
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
},
},
}
_, err = subscription.Update(sub.ID, params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}
sessionID := matches[1]
// FIXME how do I rate limit this?
sess, err := session.Get(sessionID, nil)
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestInvalidStripeRequest
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
log.Warn("Stripe: Unexpected session, customer or subscription not found")
return errHTTPBadRequestInvalidStripeRequest
}
sub, err := subscription.Get(sess.Subscription.ID, nil)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
log.Error("Stripe: Unexpected subscription, expected exactly one line item")
return errHTTPBadRequestInvalidStripeRequest
}
priceID := sub.Items.Data[0].Price.ID
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
u, err := s.userManager.User(sess.ClientReferenceID)
if err != nil {
return err
}
u.Billing.StripeCustomerID = sess.Customer.ID
u.Billing.StripeSubscriptionID = sub.ID
u.Billing.StripeSubscriptionStatus = sub.Status
u.Billing.StripeSubscriptionPaidUntil = time.Unix(sub.CurrentPeriodEnd, 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
accountURL := s.config.BaseURL + "/account" // FIXME
http.Redirect(w, r, accountURL, http.StatusSeeOther)
return nil
}
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
}
ps, err := portalsession.New(params)
if err != nil {
return err
}
response := &apiAccountBillingPortalRedirectResponse{
RedirectURL: ps.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestInvalidStripeRequest
}
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
if err != nil {
return err
} else if body.LimitReached {
return errHTTPEntityTooLargeJSONBody
}
event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil {
return errHTTPBadRequestInvalidStripeRequest
} else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: webhook event %s received", event.Type)
stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
if !stripeCustomerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
switch event.Type {
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
case "customer.subscription.deleted":
return s.handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID.String(), event.Data.Raw)
default:
return nil
}
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
status := gjson.GetBytes(event, "status")
currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
priceID := gjson.GetBytes(event, "items.data.0.price.id")
if !status.Exists() || !currentPeriodEnd.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
if err != nil {
return err
}
tier, err := s.userManager.TierByStripePrice(priceID.String())
if err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String())
u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID string, event json.RawMessage) error {
status := gjson.GetBytes(event, "status")
if !status.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID)
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
if err != nil {
return err
}
if err := s.userManager.ResetTier(u.Name); err != nil {
return err
}
u.Billing.StripeSubscriptionID = ""
u.Billing.StripeSubscriptionStatus = ""
u.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
return nil
}

View file

@ -268,6 +268,13 @@ type apiAccountReservation struct {
Everyone string `json:"everyone"` Everyone string `json:"everyone"`
} }
type apiAccountBilling struct {
Customer bool `json:"customer"`
Subscription bool `json:"subscription"`
Status string `json:"status,omitempty"`
PaidUntil int64 `json:"paid_until,omitempty"`
}
type apiAccountResponse struct { type apiAccountResponse struct {
Username string `json:"username"` Username string `json:"username"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
@ -279,6 +286,7 @@ type apiAccountResponse struct {
Tier *apiAccountTier `json:"tier,omitempty"` Tier *apiAccountTier `json:"tier,omitempty"`
Limits *apiAccountLimits `json:"limits,omitempty"` Limits *apiAccountLimits `json:"limits,omitempty"`
Stats *apiAccountStats `json:"stats,omitempty"` Stats *apiAccountStats `json:"stats,omitempty"`
Billing *apiAccountBilling `json:"billing,omitempty"`
} }
type apiAccountReservationRequest struct { type apiAccountReservationRequest struct {

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"github.com/stripe/stripe-go/v74"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
@ -60,7 +61,9 @@ const (
stats_messages INT NOT NULL DEFAULT (0), stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0), stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT, stripe_customer_id TEXT,
stripe_subscription_id TEXT, stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
created_by TEXT NOT NULL, created_by TEXT NOT NULL,
created_at INT NOT NULL, created_at INT NOT NULL,
last_seen INT NOT NULL, last_seen INT NOT NULL,
@ -100,20 +103,20 @@ const (
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier p on p.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier p on p.id = u.tier_id
WHERE t.token = ? AND t.expires >= ? WHERE t.token = ? AND t.expires >= ?
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier p on p.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
@ -231,7 +234,11 @@ const (
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?` updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?` updateBillingQuery = `
UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?
WHERE user = ?
`
) )
// Schema management queries // Schema management queries
@ -597,14 +604,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var username, hash, role, prefs, syncTopic string var username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var paid sql.NullBool var paid sql.NullBool
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -619,16 +626,16 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
Messages: messages, Messages: messages,
Emails: emails, Emails: emails,
}, },
Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
},
} }
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err return nil, err
} }
if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
user.Billing = &Billing{
StripeCustomerID: stripeCustomerID.String,
StripeSubscriptionID: stripeSubscriptionID.String,
}
}
if tierCode.Valid { if tierCode.Valid {
// See readTier() when this is changed! // See readTier() when this is changed!
user.Tier = &Tier{ user.Tier = &Tier{
@ -868,7 +875,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
} }
func (a *Manager) ChangeBilling(user *User) error { func (a *Manager) ChangeBilling(user *User) error {
if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil { if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), user.Name); err != nil {
return err return err
} }
return nil return nil
@ -1020,3 +1027,17 @@ func migrateFrom1(db *sql.DB) error {
} }
return nil // Update this when a new version is added return nil // Update this when a new version is added
} }
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}
}
return sql.NullString{String: s, Valid: true}
}
func nullInt64(v int64) sql.NullInt64 {
if v == 0 {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: v, Valid: true}
}

View file

@ -3,6 +3,7 @@ package user
import ( import (
"errors" "errors"
"github.com/stripe/stripe-go/v74"
"regexp" "regexp"
"time" "time"
) )
@ -85,8 +86,10 @@ type Stats struct {
// Billing is a struct holding a user's billing information // Billing is a struct holding a user's billing information
type Billing struct { type Billing struct {
StripeCustomerID string StripeCustomerID string
StripeSubscriptionID string StripeSubscriptionID string
StripeSubscriptionStatus stripe.SubscriptionStatus
StripeSubscriptionPaidUntil time.Time
} }
// Grant is a struct that represents an access control entry to a topic by a user // Grant is a struct that represents an access control entry to a topic by a user
@ -223,3 +226,13 @@ var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
ErrTierNotFound = errors.New("tier not found") ErrTierNotFound = errors.New("tier not found")
) )
// BillingStatus represents the status of a Stripe subscription
type BillingStatus string
// BillingStatus values, subset of https://stripe.com/docs/billing/subscriptions/overview
const (
BillingStatusIncomplete = BillingStatus("incomplete")
BillingStatusActive = BillingStatus("active")
BillingStatusPastDue = BillingStatus("past_due")
)

View file

@ -183,6 +183,8 @@
"account_usage_tier_none": "Basic", "account_usage_tier_none": "Basic",
"account_usage_tier_upgrade_button": "Upgrade to Pro", "account_usage_tier_upgrade_button": "Upgrade to Pro",
"account_usage_tier_change_button": "Change", "account_usage_tier_change_button": "Change",
"account_usage_tier_payment_overdue": "Your payment is overdue. Please update your payment method, or your account will be downgraded soon.",
"account_usage_manage_billing_button": "Manage billing",
"account_usage_messages_title": "Published messages", "account_usage_messages_title": "Published messages",
"account_usage_emails_title": "Emails sent", "account_usage_emails_title": "Emails sent",
"account_usage_reservations_title": "Reserved topics", "account_usage_reservations_title": "Reserved topics",

View file

@ -8,7 +8,7 @@ import {
accountTokenUrl, accountTokenUrl,
accountUrl, maybeWithAuth, topicUrl, accountUrl, maybeWithAuth, topicUrl,
withBasicAuth, withBasicAuth,
withBearerAuth, accountCheckoutUrl, accountBillingPortalUrl withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl
} from "./utils"; } from "./utils";
import session from "./Session"; import session from "./Session";
import subscriptionManager from "./SubscriptionManager"; import subscriptionManager from "./SubscriptionManager";
@ -264,9 +264,9 @@ class AccountApi {
this.triggerChange(); // Dangle! this.triggerChange(); // Dangle!
} }
async createCheckoutSession(tier) { async updateBillingSubscription(tier) {
const url = accountCheckoutUrl(config.base_url); const url = accountBillingSubscriptionUrl(config.base_url);
console.log(`[AccountApi] Creating checkout session`); console.log(`[AccountApi] Requesting tier change to ${tier}`);
const response = await fetch(url, { const response = await fetch(url, {
method: "POST", method: "POST",
headers: withBearerAuth({}, session.token()), headers: withBearerAuth({}, session.token()),
@ -282,6 +282,20 @@ class AccountApi {
return await response.json(); return await response.json();
} }
async deleteBillingSubscription() {
const url = accountBillingSubscriptionUrl(config.base_url);
console.log(`[AccountApi] Cancelling paid subscription`);
const response = await fetch(url, {
method: "DELETE",
headers: withBearerAuth({}, session.token())
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
}
async createBillingPortalSession() { async createBillingPortalSession() {
const url = accountBillingPortalUrl(config.base_url); const url = accountBillingPortalUrl(config.base_url);
console.log(`[AccountApi] Creating billing portal session`); console.log(`[AccountApi] Creating billing portal session`);

View file

@ -26,7 +26,7 @@ export const accountSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/subscr
export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`; export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`;
export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`; export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`;
export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`; export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`;
export const accountCheckoutUrl = (baseUrl) => `${baseUrl}/v1/account/checkout`; export const accountBillingSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/billing/subscription`;
export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`; export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`;
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; export const expandUrl = (url) => [`https://${url}`, `http://${url}`];

View file

@ -1,6 +1,6 @@
import * as React from 'react'; import * as React from 'react';
import {useContext, useState} from 'react'; import {useContext, useState} from 'react';
import {LinearProgress, Stack, useMediaQuery} from "@mui/material"; import {Alert, LinearProgress, Stack, useMediaQuery} from "@mui/material";
import Tooltip from '@mui/material/Tooltip'; import Tooltip from '@mui/material/Tooltip';
import Typography from "@mui/material/Typography"; import Typography from "@mui/material/Typography";
import EditIcon from '@mui/icons-material/Edit'; import EditIcon from '@mui/icons-material/Edit';
@ -18,7 +18,7 @@ import TextField from "@mui/material/TextField";
import DialogActions from "@mui/material/DialogActions"; import DialogActions from "@mui/material/DialogActions";
import routes from "./routes"; import routes from "./routes";
import IconButton from "@mui/material/IconButton"; import IconButton from "@mui/material/IconButton";
import {formatBytes} from "../app/utils"; import {formatBytes, formatShortDateTime} from "../app/utils";
import accountApi, {UnauthorizedError} from "../app/AccountApi"; import accountApi, {UnauthorizedError} from "../app/AccountApi";
import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined'; import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined';
import {Pref, PrefGroup} from "./Pref"; import {Pref, PrefGroup} from "./Pref";
@ -28,6 +28,7 @@ import humanizeDuration from "humanize-duration";
import UpgradeDialog from "./UpgradeDialog"; import UpgradeDialog from "./UpgradeDialog";
import CelebrationIcon from "@mui/icons-material/Celebration"; import CelebrationIcon from "@mui/icons-material/Celebration";
import {AccountContext} from "./App"; import {AccountContext} from "./App";
import {Warning, WarningAmber} from "@mui/icons-material";
const Account = () => { const Account = () => {
if (!session.exists()) { if (!session.exists()) {
@ -183,7 +184,7 @@ const Stats = () => {
const handleManageBilling = async () => { const handleManageBilling = async () => {
try { try {
const response = await accountApi.createBillingPortalSession(); const response = await accountApi.createBillingPortalSession();
window.location.href = response.redirect_url; window.open(response.redirect_url, "billing_portal");
} catch (e) { } catch (e) {
console.log(`[Account] Error changing password`, e); console.log(`[Account] Error changing password`, e);
if ((e instanceof UnauthorizedError)) { if ((e instanceof UnauthorizedError)) {
@ -199,7 +200,10 @@ const Stats = () => {
{t("account_usage_title")} {t("account_usage_title")}
</Typography> </Typography>
<PrefGroup> <PrefGroup>
<Pref title={t("account_usage_tier_title")}> <Pref
alignTop={account.billing?.status === "past_due"}
title={t("account_usage_tier_title")}
>
<div> <div>
{account.role === "admin" && {account.role === "admin" &&
<> <>
@ -219,26 +223,29 @@ const Stats = () => {
>{t("account_usage_tier_upgrade_button")}</Button> >{t("account_usage_tier_upgrade_button")}</Button>
} }
{config.enable_payments && account.role === "user" && account.tier?.paid && {config.enable_payments && account.role === "user" && account.tier?.paid &&
<> <Button
<Button variant="outlined"
variant="outlined" size="small"
size="small" onClick={() => setUpgradeDialogOpen(true)}
onClick={() => setUpgradeDialogOpen(true)} sx={{ml: 1}}
sx={{ml: 1}} >{t("account_usage_tier_change_button")}</Button>
>{t("account_usage_tier_change_button")}</Button> }
<Button {config.enable_payments && account.role === "user" && account.billing?.customer &&
variant="outlined" <Button
size="small" variant="outlined"
onClick={handleManageBilling} size="small"
sx={{ml: 1}} onClick={handleManageBilling}
>Manage billing</Button> sx={{ml: 1}}
</> >{t("account_usage_manage_billing_button")}</Button>
} }
<UpgradeDialog <UpgradeDialog
open={upgradeDialogOpen} open={upgradeDialogOpen}
onCancel={() => setUpgradeDialogOpen(false)} onCancel={() => setUpgradeDialogOpen(false)}
/> />
</div> </div>
{account.billing?.status === "past_due" &&
<Alert severity="error" sx={{mt: 1}}>{t("account_usage_tier_payment_overdue")}</Alert>
}
</Pref> </Pref>
{account.role !== "admin" && {account.role !== "admin" &&
<Pref title={t("account_usage_reservations_title")}> <Pref title={t("account_usage_reservations_title")}>

View file

@ -17,16 +17,20 @@ import {AccountContext} from "./App";
const UpgradeDialog = (props) => { const UpgradeDialog = (props) => {
const { account } = useContext(AccountContext); const { account } = useContext(AccountContext);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [selected, setSelected] = useState(account?.tier?.code || null); const [newTier, setNewTier] = useState(account?.tier?.code || null);
const [errorText, setErrorText] = useState(""); const [errorText, setErrorText] = useState("");
const handleCheckout = async () => { const handleCheckout = async () => {
try { try {
const response = await accountApi.createCheckoutSession(selected); if (newTier == null) {
if (response.redirect_url) { await accountApi.deleteBillingSubscription();
window.location.href = response.redirect_url;
} else { } else {
await accountApi.sync(); const response = await accountApi.updateBillingSubscription(newTier);
if (response.redirect_url) {
window.location.href = response.redirect_url;
} else {
await accountApi.sync();
}
} }
} catch (e) { } catch (e) {
@ -46,10 +50,10 @@ const UpgradeDialog = (props) => {
display: "flex", display: "flex",
flexDirection: "row" flexDirection: "row"
}}> }}>
<TierCard code={null} name={"Free"} selected={selected === null} onClick={() => setSelected(null)}/> <TierCard code={null} name={"Free"} selected={newTier === null} onClick={() => setNewTier(null)}/>
<TierCard code="starter" name={"Starter"} selected={selected === "starter"} onClick={() => setSelected("starter")}/> <TierCard code="starter" name={"Starter"} selected={newTier === "starter"} onClick={() => setNewTier("starter")}/>
<TierCard code="pro" name={"Pro"} selected={selected === "pro"} onClick={() => setSelected("pro")}/> <TierCard code="pro" name={"Pro"} selected={newTier === "pro"} onClick={() => setNewTier("pro")}/>
<TierCard code="business" name={"Business"} selected={selected === "business"} onClick={() => setSelected("business")}/> <TierCard code="business" name={"Business"} selected={newTier === "business"} onClick={() => setNewTier("business")}/>
</div> </div>
</DialogContent> </DialogContent>
<DialogFooter status={errorText}> <DialogFooter status={errorText}>