From 01fd4754f9308f70afea37a1d60d9678810f02a9 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 14 Jan 2023 06:43:44 -0500 Subject: [PATCH] WIP: Stripe integration --- cmd/access.go | 4 +- cmd/serve.go | 21 +- cmd/user.go | 8 +- go.mod | 4 + go.sum | 12 + server/config.go | 2 + server/errors.go | 2 + server/server.go | 18 ++ server/server_account.go | 232 ++++++++++++++++++ server/types.go | 12 + user/manager.go | 110 ++++++++- user/manager_test.go | 2 +- user/types.go | 11 +- web/src/app/AccountApi.js | 39 ++- web/src/app/utils.js | 2 + web/src/components/Account.js | 40 ++- web/src/components/Preferences.js | 6 +- web/src/components/SubscribeDialog.js | 2 +- .../components/SubscriptionSettingsDialog.js | 4 +- web/src/components/UpgradeDialog.js | 69 +++++- 20 files changed, 557 insertions(+), 43 deletions(-) diff --git a/cmd/access.go b/cmd/access.go index c304acd5..76c375bd 100644 --- a/cmd/access.go +++ b/cmd/access.go @@ -103,7 +103,7 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms) write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms) u, err := manager.User(username) - if err == user.ErrNotFound { + if err == user.ErrUserNotFound { return fmt.Errorf("user %s does not exist", username) } else if u.Role == user.RoleAdmin { return fmt.Errorf("user %s is an admin user, access control entries have no effect", username) @@ -173,7 +173,7 @@ func showAllAccess(c *cli.Context, manager *user.Manager) error { func showUserAccess(c *cli.Context, manager *user.Manager, username string) error { users, err := manager.User(username) - if err == user.ErrNotFound { + if err == user.ErrUserNotFound { return fmt.Errorf("user %s does not exist", username) } else if err != nil { return err diff --git a/cmd/serve.go b/cmd/serve.go index 566cce8b..8fd97ae2 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,6 +5,7 @@ package cmd import ( "errors" "fmt" + "github.com/stripe/stripe-go/v74" "heckel.io/ntfy/user" "io/fs" "math" @@ -61,7 +62,6 @@ var flagsServe = append( altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-signup", Aliases: []string{"enable_signup"}, EnvVars: []string{"NTFY_ENABLE_SIGNUP"}, Value: false, Usage: "allows users to sign up via the web app, or API"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-login", Aliases: []string{"enable_login"}, EnvVars: []string{"NTFY_ENABLE_LOGIN"}, Value: false, Usage: "allows users to log in via the web app, or API"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-reservations", Aliases: []string{"enable_reservations"}, EnvVars: []string{"NTFY_ENABLE_RESERVATIONS"}, Value: false, Usage: "allows users to reserve topics (if their tier allows it)"}), - altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-payments", Aliases: []string{"enable_payments"}, EnvVars: []string{"NTFY_ENABLE_PAYMENTS"}, Value: false, Usage: "enables payments integration [preliminary option, may change]"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "upstream-base-url", Aliases: []string{"upstream_base_url"}, EnvVars: []string{"NTFY_UPSTREAM_BASE_URL"}, Value: "", Usage: "forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}), @@ -80,6 +80,8 @@ var flagsServe = append( altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}), ) var cmdServe = &cli.Command{ @@ -132,7 +134,6 @@ func execServe(c *cli.Context) error { webRoot := c.String("web-root") enableSignup := c.Bool("enable-signup") enableLogin := c.Bool("enable-login") - enablePayments := c.Bool("enable-payments") enableReservations := c.Bool("enable-reservations") upstreamBaseURL := c.String("upstream-base-url") smtpSenderAddr := c.String("smtp-sender-addr") @@ -152,6 +153,8 @@ func execServe(c *cli.Context) error { visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") behindProxy := c.Bool("behind-proxy") + stripeKey := c.String("stripe-key") + stripeWebhookKey := c.String("stripe-webhook-key") // Check values if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { @@ -188,14 +191,17 @@ func execServe(c *cli.Context) error { return errors.New("if upstream-base-url is set, base-url must also be set") } else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL { return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications") - } else if authFile == "" && (enableSignup || enableLogin || enableReservations || enablePayments) { - return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or enable-payments if auth-file is not set") + } else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") { + return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set") } else if enableSignup && !enableLogin { return errors.New("cannot set enable-signup without also setting enable-login") + } else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") { + return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set") } webRootIsApp := webRoot == "app" enableWeb := webRoot != "disable" + enablePayments := stripeKey != "" // Default auth permissions authDefault, err := user.ParsePermission(authDefaultAccess) @@ -239,6 +245,11 @@ func execServe(c *cli.Context) error { visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...) } + // Stripe things + if stripeKey != "" { + stripe.Key = stripeKey + } + // Run server conf := server.NewConfig() conf.BaseURL = baseURL @@ -282,6 +293,8 @@ func execServe(c *cli.Context) error { conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.BehindProxy = behindProxy + conf.StripeKey = stripeKey + conf.StripeWebhookKey = stripeWebhookKey conf.EnableWeb = enableWeb conf.EnableSignup = enableSignup conf.EnableLogin = enableLogin diff --git a/cmd/user.go b/cmd/user.go index 98a94490..ff870188 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -215,7 +215,7 @@ func execUserDel(c *cli.Context) error { if err != nil { return err } - if _, err := manager.User(username); err == user.ErrNotFound { + if _, err := manager.User(username); err == user.ErrUserNotFound { return fmt.Errorf("user %s does not exist", username) } if err := manager.RemoveUser(username); err != nil { @@ -237,7 +237,7 @@ func execUserChangePass(c *cli.Context) error { if err != nil { return err } - if _, err := manager.User(username); err == user.ErrNotFound { + if _, err := manager.User(username); err == user.ErrUserNotFound { return fmt.Errorf("user %s does not exist", username) } if password == "" { @@ -265,7 +265,7 @@ func execUserChangeRole(c *cli.Context) error { if err != nil { return err } - if _, err := manager.User(username); err == user.ErrNotFound { + if _, err := manager.User(username); err == user.ErrUserNotFound { return fmt.Errorf("user %s does not exist", username) } if err := manager.ChangeRole(username, role); err != nil { @@ -289,7 +289,7 @@ func execUserChangeTier(c *cli.Context) error { if err != nil { return err } - if _, err := manager.User(username); err == user.ErrNotFound { + if _, err := manager.User(username); err == user.ErrUserNotFound { return fmt.Errorf("user %s does not exist", username) } if tier == tierReset { diff --git a/go.mod b/go.mod index dda12d4c..f31bd218 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,10 @@ require ( github.com/googleapis/gax-go/v2 v2.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/stripe/stripe-go/v74 v74.5.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.4.0 // indirect diff --git a/go.sum b/go.sum index 0ad00577..ce4367cb 100644 --- a/go.sum +++ b/go.sum @@ -95,10 +95,20 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4= +github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY= github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= @@ -119,6 +129,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -135,6 +146,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/server/config.go b/server/config.go index 2267ec3d..c63f2e37 100644 --- a/server/config.go +++ b/server/config.go @@ -110,6 +110,8 @@ type Config struct { VisitorAccountCreateLimitReplenish time.Duration VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats BehindProxy bool + StripeKey string + StripeWebhookKey string EnableWeb bool EnableSignup bool // Enable creation of accounts via API and UI EnableLogin bool diff --git a/server/errors.go b/server/errors.go index 0b6645d4..aaac8127 100644 --- a/server/errors.go +++ b/server/errors.go @@ -58,6 +58,8 @@ var ( errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""} errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""} errHTTPBadRequestMakesNoSenseForAdmin = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""} + errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", ""} + errHTTPBadRequestInvalidStripeRequest = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid Stripe request", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} diff --git a/server/server.go b/server/server.go index d5ac543f..0ad06a08 100644 --- a/server/server.go +++ b/server/server.go @@ -36,6 +36,10 @@ import ( /* TODO + payments: + - handle overdue payment (-> downgrade after 7 days) + - delete stripe subscription when acocunt is deleted + Limits & rate limiting: users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address! @@ -43,6 +47,7 @@ import ( update last_seen when API is accessed Make sure account endpoints make sense for admins + triggerChange after publishing a message UI: - flicker of upgrade banner - JS constants @@ -100,6 +105,11 @@ var ( accountSettingsPath = "/v1/account/settings" accountSubscriptionPath = "/v1/account/subscription" accountReservationPath = "/v1/account/reservation" + accountBillingPortalPath = "/v1/account/billing/portal" + accountBillingWebhookPath = "/v1/account/billing/webhook" + accountCheckoutPath = "/v1/account/checkout" + accountCheckoutSuccessTemplate = "/v1/account/checkout/success/{CHECKOUT_SESSION_ID}" + accountCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/checkout/success/(.+)$`) accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) matrixPushPath = "/_matrix/push/v1/notify" @@ -362,6 +372,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.ensureUser(s.handleAccountReservationAdd)(w, r, v) } else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) { return s.ensureUser(s.handleAccountReservationDelete)(w, r, v) + } else if r.Method == http.MethodPost && r.URL.Path == accountCheckoutPath { + return s.ensureUser(s.handleAccountCheckoutSessionCreate)(w, r, v) + } else if r.Method == http.MethodGet && accountCheckoutSuccessRegex.MatchString(r.URL.Path) { + return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context! + } else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath { + return s.ensureUser(s.handleAccountBillingPortalSessionCreate)(w, r, v) + } else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath { + return s.ensureUserManager(s.handleAccountBillingWebhookTrigger)(w, r, v) } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { return s.handleMatrixDiscovery(w) } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { diff --git a/server/server_account.go b/server/server_account.go index 23549b21..27c7f40c 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -2,6 +2,14 @@ package server import ( "encoding/json" + "errors" + "github.com/stripe/stripe-go/v74" + portalsession "github.com/stripe/stripe-go/v74/billingportal/session" + "github.com/stripe/stripe-go/v74/checkout/session" + "github.com/stripe/stripe-go/v74/subscription" + "github.com/stripe/stripe-go/v74/webhook" + "github.com/tidwall/gjson" + "heckel.io/ntfy/log" "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" @@ -9,6 +17,7 @@ import ( const ( jsonBodyBytesLimit = 4096 + stripeBodyBytesLimit = 16384 subscriptionIDLength = 16 createdByAPI = "api" ) @@ -386,3 +395,226 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this return nil } + +func (s *Server) handleAccountCheckoutSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { + req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit) + if err != nil { + return err + } + tier, err := s.userManager.Tier(req.Tier) + if err != nil { + return err + } + if tier.StripePriceID == "" { + log.Info("Checkout: Downgrading to no tier") + return errors.New("not a paid tier") + } else if v.user.Billing != nil && v.user.Billing.StripeSubscriptionID != "" { + log.Info("Checkout: Changing tier and subscription to %s", tier.Code) + + // Upgrade/downgrade tier + sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil) + if err != nil { + return err + } + params := &stripe.SubscriptionParams{ + CancelAtPeriodEnd: stripe.Bool(false), + ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), + Items: []*stripe.SubscriptionItemsParams{ + { + ID: stripe.String(sub.Items.Data[0].ID), + Price: stripe.String(tier.StripePriceID), + }, + }, + } + _, err = subscription.Update(sub.ID, params) + if err != nil { + return err + } + response := &apiAccountCheckoutResponse{} + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil + } else { + // Checkout flow + log.Info("Checkout: No existing subscription, creating checkout flow") + } + + successURL := s.config.BaseURL + accountCheckoutSuccessTemplate + var stripeCustomerID *string + if v.user.Billing != nil { + stripeCustomerID = &v.user.Billing.StripeCustomerID + } + params := &stripe.CheckoutSessionParams{ + ClientReferenceID: &v.user.Name, // FIXME Should be user ID + Customer: stripeCustomerID, + SuccessURL: &successURL, + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + LineItems: []*stripe.CheckoutSessionLineItemParams{ + { + Price: stripe.String(tier.StripePriceID), + Quantity: stripe.Int64(1), + }, + }, + } + sess, err := session.New(params) + if err != nil { + return err + } + response := &apiAccountCheckoutResponse{ + RedirectURL: sess.URL, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error { + // We don't have a v.user in this endpoint, only a userManager! + matches := accountCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) + if len(matches) != 2 { + return errHTTPInternalErrorInvalidPath + } + sessionID := matches[1] + // FIXME how do I rate limit this? + sess, err := session.Get(sessionID, nil) + if err != nil { + log.Warn("Stripe: %s", err) + return errHTTPBadRequestInvalidStripeRequest + } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { + log.Warn("Stripe: Unexpected session, customer or subscription not found") + return errHTTPBadRequestInvalidStripeRequest + } + sub, err := subscription.Get(sess.Subscription.ID, nil) + if err != nil { + return err + } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { + log.Error("Stripe: Unexpected subscription, expected exactly one line item") + return errHTTPBadRequestInvalidStripeRequest + } + priceID := sub.Items.Data[0].Price.ID + tier, err := s.userManager.TierByStripePrice(priceID) + if err != nil { + return err + } + u, err := s.userManager.User(sess.ClientReferenceID) + if err != nil { + return err + } + if u.Billing == nil { + u.Billing = &user.Billing{} + } + u.Billing.StripeCustomerID = sess.Customer.ID + u.Billing.StripeSubscriptionID = sess.Subscription.ID + if err := s.userManager.ChangeBilling(u); err != nil { + return err + } + if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { + return err + } + accountURL := s.config.BaseURL + "/account" // FIXME + http.Redirect(w, r, accountURL, http.StatusSeeOther) + return nil +} + +func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { + if v.user.Billing == nil { + return errHTTPBadRequestNotAPaidUser + } + params := &stripe.BillingPortalSessionParams{ + Customer: stripe.String(v.user.Billing.StripeCustomerID), + ReturnURL: stripe.String(s.config.BaseURL), + } + ps, err := portalsession.New(params) + if err != nil { + return err + } + response := &apiAccountBillingPortalRedirectResponse{ + RedirectURL: ps.URL, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountBillingWebhookTrigger(w http.ResponseWriter, r *http.Request, v *visitor) error { + // We don't have a v.user in this endpoint, only a userManager! + stripeSignature := r.Header.Get("Stripe-Signature") + if stripeSignature == "" { + return errHTTPBadRequestInvalidStripeRequest + } + body, err := util.Peek(r.Body, stripeBodyBytesLimit) + if err != nil { + return err + } else if body.LimitReached { + return errHTTPEntityTooLargeJSONBody + } + event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey) + if err != nil { + log.Warn("Stripe: invalid request: %s", err.Error()) + return errHTTPBadRequestInvalidStripeRequest + } else if event.Data == nil || event.Data.Raw == nil { + log.Warn("Stripe: invalid request, data is nil") + return errHTTPBadRequestInvalidStripeRequest + } + log.Info("Stripe: webhook event %s received", event.Type) + stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer") + if !stripeCustomerID.Exists() { + return errHTTPBadRequestInvalidStripeRequest + } + switch event.Type { + case "checkout.session.completed": + // Payment is successful and the subscription is created. + // Provision the subscription, save the customer ID. + return s.handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID.String(), event.Data.Raw) + case "customer.subscription.updated": + return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw) + case "invoice.paid": + // Continue to provision the subscription as payments continue to be made. + // Store the status in your database and check when a user accesses your service. + // This approach helps you avoid hitting rate limits. + return nil // FIXME + case "invoice.payment_failed": + // The payment failed or the customer does not have a valid payment method. + // The subscription becomes past_due. Notify your customer and send them to the + // customer portal to update their payment information. + return nil // FIXME + default: + log.Warn("Stripe: unhandled webhook %s", event.Type) + return nil + } +} + +func (s *Server) handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID string, event json.RawMessage) error { + log.Info("Stripe: checkout completed for customer %s", stripeCustomerID) + return nil +} + +func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error { + status := gjson.GetBytes(event, "status") + priceID := gjson.GetBytes(event, "items.data.0.price.id") + if !status.Exists() || !priceID.Exists() { + return errHTTPBadRequestInvalidStripeRequest + } + log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID) + u, err := s.userManager.UserByStripeCustomer(stripeCustomerID) + if err != nil { + return err + } + tier, err := s.userManager.TierByStripePrice(priceID.String()) + if err != nil { + return err + } + if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { + return err + } + return nil +} diff --git a/server/types.go b/server/types.go index 9d6bb9e9..e6ebc28f 100644 --- a/server/types.go +++ b/server/types.go @@ -295,3 +295,15 @@ type apiConfigResponse struct { EnableReservations bool `json:"enable_reservations"` DisallowedTopics []string `json:"disallowed_topics"` } + +type apiAccountTierChangeRequest struct { + Tier string `json:"tier"` +} + +type apiAccountCheckoutResponse struct { + RedirectURL string `json:"redirect_url"` +} + +type apiAccountBillingPortalRedirectResponse struct { + RedirectURL string `json:"redirect_url"` +} diff --git a/user/manager.go b/user/manager.go index 955d5e19..f440838a 100644 --- a/user/manager.go +++ b/user/manager.go @@ -44,8 +44,11 @@ const ( reservations_limit INT NOT NULL, attachment_file_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL, - attachment_expiry_duration INT NOT NULL + attachment_expiry_duration INT NOT NULL, + stripe_price_id TEXT ); + CREATE UNIQUE INDEX idx_tier_code ON tier (code); + CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); CREATE TABLE IF NOT EXISTS user ( id INTEGER PRIMARY KEY AUTOINCREMENT, tier_id INT, @@ -56,12 +59,16 @@ const ( sync_topic TEXT NOT NULL, stats_messages INT NOT NULL DEFAULT (0), stats_emails INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, created_by TEXT NOT NULL, created_at INT NOT NULL, last_seen INT NOT NULL, FOREIGN KEY (tier_id) REFERENCES tier (id) ); CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); CREATE TABLE IF NOT EXISTS user_access ( user_id INT NOT NULL, topic TEXT NOT NULL, @@ -93,18 +100,24 @@ const ( ` selectUserByNameQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u LEFT JOIN tier p on p.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id LEFT JOIN tier p on p.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` + selectUserByStripeCustomerIDQuery = ` + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + FROM user u + LEFT JOIN tier p on p.id = u.tier_id + WHERE u.stripe_customer_id = ? + ` selectTopicPermsQuery = ` SELECT read, write FROM user_access a @@ -204,9 +217,21 @@ const ( INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` - selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` + selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` + selectTierByCodeQuery = ` + SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + FROM tier + WHERE code = ? + ` + selectTierByPriceIDQuery = ` + SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + FROM tier + WHERE stripe_price_id = ? + ` updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` + + updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?` ) // Schema management queries @@ -543,7 +568,7 @@ func (a *Manager) Users() ([]*User, error) { return users, nil } -// User returns the user with the given username if it exists, or ErrNotFound otherwise. +// User returns the user with the given username if it exists, or ErrUserNotFound otherwise. // You may also pass Everyone to retrieve the anonymous user and its Grant list. func (a *Manager) User(username string) (*User, error) { rows, err := a.db.Query(selectUserByNameQuery, username) @@ -553,6 +578,14 @@ func (a *Manager) User(username string) (*User, error) { return a.readUser(rows) } +func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) { + rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID) + if err != nil { + return nil, err + } + return a.readUser(rows) +} + func (a *Manager) userByToken(token string) (*User, error) { rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix()) if err != nil { @@ -564,14 +597,14 @@ func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var username, hash, role, prefs, syncTopic string - var tierCode, tierName sql.NullString + var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString var paid sql.NullBool var messages, emails int64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 if !rows.Next() { - return nil, ErrNotFound + return nil, ErrUserNotFound } - if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration); err != nil { + if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -590,7 +623,14 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { return nil, err } + if stripeCustomerID.Valid && stripeSubscriptionID.Valid { + user.Billing = &Billing{ + StripeCustomerID: stripeCustomerID.String, + StripeSubscriptionID: stripeSubscriptionID.String, + } + } if tierCode.Valid { + // See readTier() when this is changed! user.Tier = &Tier{ Code: tierCode.String, Name: tierName.String, @@ -602,6 +642,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + StripePriceID: stripePriceID.String, } } return user, nil @@ -826,6 +867,59 @@ func (a *Manager) CreateTier(tier *Tier) error { return nil } +func (a *Manager) ChangeBilling(user *User) error { + if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil { + return err + } + return nil +} + +func (a *Manager) Tier(code string) (*Tier, error) { + rows, err := a.db.Query(selectTierByCodeQuery, code) + if err != nil { + return nil, err + } + return a.readTier(rows) +} + +func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { + rows, err := a.db.Query(selectTierByPriceIDQuery, priceID) + if err != nil { + return nil, err + } + return a.readTier(rows) +} + +func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { + defer rows.Close() + var code, name string + var stripePriceID sql.NullString + var paid bool + var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 + if !rows.Next() { + return nil, ErrTierNotFound + } + if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + // When changed, note readUser() as well + return &Tier{ + Code: code, + Name: name, + Paid: paid, + MessagesLimit: messagesLimit.Int64, + MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, + EmailsLimit: emailsLimit.Int64, + ReservationsLimit: reservationsLimit.Int64, + AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, + AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, + AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + StripePriceID: stripePriceID.String, // May be empty! + }, nil +} + func toSQLWildcard(s string) string { return strings.ReplaceAll(s, "*", "%") } diff --git a/user/manager_test.go b/user/manager_test.go index 3c7b35a8..3c7b2704 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -208,7 +208,7 @@ func TestManager_UserManagement(t *testing.T) { // Remove user require.Nil(t, a.RemoveUser("ben")) _, err = a.User("ben") - require.Equal(t, ErrNotFound, err) + require.Equal(t, ErrUserNotFound, err) users, err = a.Users() require.Nil(t, err) diff --git a/user/types.go b/user/types.go index f0247d48..77a34749 100644 --- a/user/types.go +++ b/user/types.go @@ -16,6 +16,7 @@ type User struct { Prefs *Prefs Tier *Tier Stats *Stats + Billing *Billing SyncTopic string Created time.Time LastSeen time.Time @@ -58,6 +59,7 @@ type Tier struct { AttachmentFileSizeLimit int64 AttachmentTotalSizeLimit int64 AttachmentExpiryDuration time.Duration + StripePriceID string } // Subscription represents a user's topic subscription @@ -81,6 +83,12 @@ type Stats struct { Emails int64 } +// Billing is a struct holding a user's billing information +type Billing struct { + StripeCustomerID string + StripeSubscriptionID string +} + // Grant is a struct that represents an access control entry to a topic by a user type Grant struct { TopicPattern string // May include wildcard (*) @@ -212,5 +220,6 @@ var ( ErrUnauthenticated = errors.New("unauthenticated") ErrUnauthorized = errors.New("unauthorized") ErrInvalidArgument = errors.New("invalid argument") - ErrNotFound = errors.New("not found") + ErrUserNotFound = errors.New("user not found") + ErrTierNotFound = errors.New("tier not found") ) diff --git a/web/src/app/AccountApi.js b/web/src/app/AccountApi.js index 1ce27663..38adfffb 100644 --- a/web/src/app/AccountApi.js +++ b/web/src/app/AccountApi.js @@ -8,7 +8,7 @@ import { accountTokenUrl, accountUrl, maybeWithAuth, topicUrl, withBasicAuth, - withBearerAuth + withBearerAuth, accountCheckoutUrl, accountBillingPortalUrl } from "./utils"; import session from "./Session"; import subscriptionManager from "./SubscriptionManager"; @@ -228,7 +228,7 @@ class AccountApi { this.triggerChange(); // Dangle! } - async upsertAccess(topic, everyone) { + async upsertReservation(topic, everyone) { const url = accountReservationUrl(config.base_url); console.log(`[AccountApi] Upserting user access to topic ${topic}, everyone=${everyone}`); const response = await fetch(url, { @@ -249,7 +249,7 @@ class AccountApi { this.triggerChange(); // Dangle! } - async deleteAccess(topic) { + async deleteReservation(topic) { const url = accountReservationSingleUrl(config.base_url, topic); console.log(`[AccountApi] Removing topic reservation ${url}`); const response = await fetch(url, { @@ -264,6 +264,39 @@ class AccountApi { this.triggerChange(); // Dangle! } + async createCheckoutSession(tier) { + const url = accountCheckoutUrl(config.base_url); + console.log(`[AccountApi] Creating checkout session`); + const response = await fetch(url, { + method: "POST", + headers: withBearerAuth({}, session.token()), + body: JSON.stringify({ + tier: tier + }) + }); + if (response.status === 401 || response.status === 403) { + throw new UnauthorizedError(); + } else if (response.status !== 200) { + throw new Error(`Unexpected server response ${response.status}`); + } + return await response.json(); + } + + async createBillingPortalSession() { + const url = accountBillingPortalUrl(config.base_url); + console.log(`[AccountApi] Creating billing portal session`); + const response = await fetch(url, { + method: "POST", + headers: withBearerAuth({}, session.token()) + }); + if (response.status === 401 || response.status === 403) { + throw new UnauthorizedError(); + } else if (response.status !== 200) { + throw new Error(`Unexpected server response ${response.status}`); + } + return await response.json(); + } + async sync() { try { if (!session.token()) { diff --git a/web/src/app/utils.js b/web/src/app/utils.js index 4b06da7e..8001933e 100644 --- a/web/src/app/utils.js +++ b/web/src/app/utils.js @@ -26,6 +26,8 @@ export const accountSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/subscr export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`; export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`; export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`; +export const accountCheckoutUrl = (baseUrl) => `${baseUrl}/v1/account/checkout`; +export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`; export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; export const expandSecureUrl = (url) => `https://${url}`; diff --git a/web/src/components/Account.js b/web/src/components/Account.js index 734577f0..9e68ce94 100644 --- a/web/src/components/Account.js +++ b/web/src/components/Account.js @@ -171,10 +171,28 @@ const Stats = () => { const { t } = useTranslation(); const { account } = useContext(AccountContext); const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false); + if (!account) { return <>; } - const normalize = (value, max) => Math.min(value / max * 100, 100); + + const normalize = (value, max) => { + return Math.min(value / max * 100, 100); + }; + + const handleManageBilling = async () => { + try { + const response = await accountApi.createBillingPortalSession(); + window.location.href = response.redirect_url; + } catch (e) { + console.log(`[Account] Error changing password`, e); + if ((e instanceof UnauthorizedError)) { + session.resetAndRedirect(routes.login); + } + // TODO show error + } + }; + return ( @@ -201,12 +219,20 @@ const Stats = () => { >{t("account_usage_tier_upgrade_button")} } {config.enable_payments && account.role === "user" && account.tier?.paid && - + <> + + + } { const handleDialogSubmit = async (reservation) => { setDialogOpen(false); try { - await accountApi.upsertAccess(reservation.topic, reservation.everyone); + await accountApi.upsertReservation(reservation.topic, reservation.everyone); await accountApi.sync(); console.debug(`[Preferences] Added topic reservation`, reservation); } catch (e) { @@ -557,7 +557,7 @@ const ReservationsTable = (props) => { const handleDialogSubmit = async (reservation) => { setDialogOpen(false); try { - await accountApi.upsertAccess(reservation.topic, reservation.everyone); + await accountApi.upsertReservation(reservation.topic, reservation.everyone); await accountApi.sync(); console.debug(`[Preferences] Added topic reservation`, reservation); } catch (e) { @@ -568,7 +568,7 @@ const ReservationsTable = (props) => { const handleDeleteClick = async (reservation) => { try { - await accountApi.deleteAccess(reservation.topic); + await accountApi.deleteReservation(reservation.topic); await accountApi.sync(); console.debug(`[Preferences] Deleted topic reservation`, reservation); } catch (e) { diff --git a/web/src/components/SubscribeDialog.js b/web/src/components/SubscribeDialog.js index 460ee5df..d07ec3b6 100644 --- a/web/src/components/SubscribeDialog.js +++ b/web/src/components/SubscribeDialog.js @@ -110,7 +110,7 @@ const SubscribePage = (props) => { if (session.exists() && baseUrl === config.base_url && reserveTopicVisible) { console.log(`[SubscribeDialog] Reserving topic ${topic} with everyone access ${everyone}`); try { - await accountApi.upsertAccess(topic, everyone); + await accountApi.upsertReservation(topic, everyone); // Account sync later after it was added } catch (e) { console.log(`[SubscribeDialog] Error reserving topic`, e); diff --git a/web/src/components/SubscriptionSettingsDialog.js b/web/src/components/SubscriptionSettingsDialog.js index 85d77c7c..23c5ec05 100644 --- a/web/src/components/SubscriptionSettingsDialog.js +++ b/web/src/components/SubscriptionSettingsDialog.js @@ -37,9 +37,9 @@ const SubscriptionSettingsDialog = (props) => { // Reservation if (reserveTopicVisible) { - await accountApi.upsertAccess(subscription.topic, everyone); + await accountApi.upsertReservation(subscription.topic, everyone); } else if (!reserveTopicVisible && subscription.reservation) { // Was removed - await accountApi.deleteAccess(subscription.topic); + await accountApi.deleteReservation(subscription.topic); } // Sync account diff --git a/web/src/components/UpgradeDialog.js b/web/src/components/UpgradeDialog.js index 1a44c97d..2204c6cf 100644 --- a/web/src/components/UpgradeDialog.js +++ b/web/src/components/UpgradeDialog.js @@ -2,28 +2,83 @@ import * as React from 'react'; import Dialog from '@mui/material/Dialog'; import DialogContent from '@mui/material/DialogContent'; import DialogTitle from '@mui/material/DialogTitle'; -import {useMediaQuery} from "@mui/material"; +import {CardActionArea, CardContent, useMediaQuery} from "@mui/material"; import theme from "./theme"; import DialogFooter from "./DialogFooter"; +import Button from "@mui/material/Button"; +import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi"; +import session from "../app/Session"; +import routes from "./routes"; +import {useContext, useState} from "react"; +import Card from "@mui/material/Card"; +import Typography from "@mui/material/Typography"; +import {AccountContext} from "./App"; const UpgradeDialog = (props) => { + const { account } = useContext(AccountContext); const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); + const [selected, setSelected] = useState(account?.tier?.code || null); + const [errorText, setErrorText] = useState(""); - const handleSuccess = async () => { - // TODO + const handleCheckout = async () => { + try { + const response = await accountApi.createCheckoutSession(selected); + if (response.redirect_url) { + window.location.href = response.redirect_url; + } else { + await accountApi.sync(); + } + + } catch (e) { + console.log(`[UpgradeDialog] Error creating checkout session`, e); + if ((e instanceof UnauthorizedError)) { + session.resetAndRedirect(routes.login); + } + // FIXME show error + } } return ( - + Upgrade to Pro - Content +
+ setSelected(null)}/> + setSelected("starter")}/> + setSelected("pro")}/> + setSelected("business")}/> +
- - Footer + +
); }; +const TierCard = (props) => { + const cardStyle = (props.selected) ? { + border: "1px solid red", + + } : {}; + return ( + + + + + {props.name} + + + Lizards are a widespread group of squamate reptiles, with over 6,000 + species, ranging across all continents except Antarctica + + + + + ); +} + export default UpgradeDialog;