diff --git a/go.mod b/go.mod index b668bd0a..a34978ca 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( github.com/googleapis/gax-go/v2 v2.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.4.0 // indirect diff --git a/go.sum b/go.sum index 98096334..787eaf11 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,7 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/server/server.go b/server/server.go index 16323b2d..66a46d8e 100644 --- a/server/server.go +++ b/server/server.go @@ -37,8 +37,6 @@ import ( /* TODO payments: - - send dunning emails when overdue - - payment methods - delete subscription when account deleted - delete messages + reserved topics on ResetTier @@ -76,9 +74,10 @@ type Server struct { visitors map[string]*visitor // ip: or user: firebaseClient *firebaseClient messages int64 - userManager *user.Manager // Might be nil! - messageCache *messageCache - fileCache *fileCache + userManager *user.Manager // Might be nil! + messageCache *messageCache // Database that stores the messages + fileCache *fileCache // File system based cache that stores attachments + stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price closeChan chan bool mu sync.Mutex @@ -160,6 +159,10 @@ func New(conf *Config) (*Server, error) { if conf.SMTPSenderAddr != "" { mailer = &smtpSender{config: conf} } + var stripe stripeAPI + if conf.StripeSecretKey != "" { + stripe = newStripeAPI() + } messageCache, err := createMessageCache(conf) if err != nil { return nil, err @@ -190,7 +193,7 @@ func New(conf *Config) (*Server, error) { } firebaseClient = newFirebaseClient(sender, userManager) } - return &Server{ + s := &Server{ config: conf, messageCache: messageCache, fileCache: fileCache, @@ -199,8 +202,10 @@ func New(conf *Config) (*Server, error) { topics: topics, userManager: userManager, visitors: make(map[string]*visitor), - priceCache: util.NewLookupCache(fetchStripePrices, conf.StripePriceCacheDuration), - }, nil + stripe: stripe, + } + s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration) + return s, nil } func createMessageCache(conf *Config) (*messageCache, error) { diff --git a/server/server_middleware.go b/server/server_middleware.go index 163384fa..544bb590 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -33,7 +33,7 @@ func (s *Server) ensureUser(next handleFunc) handleFunc { func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if s.config.StripeSecretKey == "" { + if s.config.StripeSecretKey == "" || s.stripe == nil { return errHTTPNotFound } return next(w, r, v) diff --git a/server/server_payments.go b/server/server_payments.go index 5a3440be..ee1bb1a2 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -96,7 +96,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r var stripeCustomerID *string if v.user.Billing.StripeCustomerID != "" { stripeCustomerID = &v.user.Billing.StripeCustomerID - stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil) + stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID) if err != nil { return err } else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 { @@ -120,7 +120,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r Enabled: stripe.Bool(true), },*/ } - sess, err := session.New(params) + sess, err := s.stripe.NewCheckoutSession(params) if err != nil { return err } @@ -137,14 +137,14 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr return errHTTPInternalErrorInvalidPath } sessionID := matches[1] - sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this? + sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this? if err != nil { log.Warn("Stripe: %s", err) return errHTTPBadRequestBillingRequestInvalid } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found") } - sub, err := subscription.Get(sess.Subscription.ID, nil) + sub, err := s.stripe.GetSubscription(sess.Subscription.ID) if err != nil { return err } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { @@ -180,7 +180,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r return err } log.Info("Stripe: Changing tier and subscription to %s", tier.Code) - sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil) + sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID) if err != nil { return err } @@ -194,7 +194,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r }, }, } - _, err = subscription.Update(sub.ID, params) + _, err = s.stripe.UpdateSubscription(sub.ID, params) if err != nil { return err } @@ -208,7 +208,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(true), } - _, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params) + _, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params) if err != nil { return err } @@ -224,7 +224,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, Customer: stripe.String(v.user.Billing.StripeCustomerID), ReturnURL: stripe.String(s.config.BaseURL), } - ps, err := portalsession.New(params) + ps, err := s.stripe.NewPortalSession(params) if err != nil { return err } @@ -248,7 +248,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ } else if body.LimitReached { return errHTTPEntityTooLargeJSONBody } - event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey) + event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey) if err != nil { return errHTTPBadRequestBillingRequestInvalid } else if event.Data == nil || event.Data.Raw == nil { @@ -331,24 +331,82 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio // fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices // in memory, and ultimately for the web app to display the price table. -func fetchStripePrices() (map[string]string, error) { +func (s *Server) fetchStripePrices() (map[string]string, error) { log.Debug("Caching prices from Stripe API") - prices := make(map[string]string) - iter := price.List(&stripe.PriceListParams{ - Active: stripe.Bool(true), - }) - for iter.Next() { - p := iter.Price() + priceMap := make(map[string]string) + prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)}) + if err != nil { + log.Warn("Fetching Stripe prices failed: %s", err.Error()) + return nil, err + } + for _, p := range prices { if p.UnitAmount%100 == 0 { - prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100) + priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100) } else { - prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100) + priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100) } - log.Trace("- Caching price %s = %v", p.ID, prices[p.ID]) + log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID]) + } + return priceMap, nil +} + +// stripeAPI is a small interface to facilitate mocking of the Stripe API +type stripeAPI interface { + NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) + NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) + ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) + GetCustomer(id string) (*stripe.Customer, error) + GetSession(id string) (*stripe.CheckoutSession, error) + GetSubscription(id string) (*stripe.Subscription, error) + UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) + ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) +} + +// realStripeAPI is a thin shim around the Stripe functions to facilitate mocking +type realStripeAPI struct{} + +var _ stripeAPI = (*realStripeAPI)(nil) + +func newStripeAPI() stripeAPI { + return &realStripeAPI{} +} + +func (s *realStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + return session.New(params) +} + +func (s *realStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) { + return portalsession.New(params) +} + +func (s *realStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) { + prices := make([]*stripe.Price, 0) + iter := price.List(params) + for iter.Next() { + prices = append(prices, iter.Price()) } if iter.Err() != nil { - log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error()) return nil, iter.Err() } return prices, nil } + +func (s *realStripeAPI) GetCustomer(id string) (*stripe.Customer, error) { + return customer.Get(id, nil) +} + +func (s *realStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) { + return session.Get(id, nil) +} + +func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) { + return subscription.Get(id, nil) +} + +func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) { + return subscription.Update(id, params) +} + +func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) { + return webhook.ConstructEvent(payload, header, secret) +} diff --git a/server/server_payments_test.go b/server/server_payments_test.go new file mode 100644 index 00000000..43375d62 --- /dev/null +++ b/server/server_payments_test.go @@ -0,0 +1,130 @@ +package server + +import ( + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stripe/stripe-go/v74" + "heckel.io/ntfy/user" + "heckel.io/ntfy/util" + "io" + "testing" +) + +func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) { + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("NewCheckoutSession", mock.Anything). + Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil) + + // Create tier and user + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "pro", + StripePriceID: "price_123", + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + + // Create subscription + response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, response.Code) + redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body)) + require.Nil(t, err) + require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL) +} + +func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("GetCustomer", "acct_123"). + Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil) + stripeMock. + On("NewCheckoutSession", mock.Anything). + Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil) + + // Create tier and user + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "pro", + StripePriceID: "price_123", + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + + u, err := s.userManager.User("phil") + require.Nil(t, err) + + u.Billing.StripeCustomerID = "acct_123" + require.Nil(t, s.userManager.ChangeBilling(u)) + + // Create subscription + response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, response.Code) + redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body)) + require.Nil(t, err) + require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL) +} + +type testStripeAPI struct { + mock.Mock +} + +func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + args := s.Called(params) + return args.Get(0).(*stripe.CheckoutSession), args.Error(1) +} + +func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) { + args := s.Called(params) + return args.Get(0).(*stripe.BillingPortalSession), args.Error(1) +} + +func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) { + args := s.Called(params) + return args.Get(0).([]*stripe.Price), args.Error(1) +} + +func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) { + args := s.Called(id) + return args.Get(0).(*stripe.Customer), args.Error(1) +} + +func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) { + args := s.Called(id) + return args.Get(0).(*stripe.CheckoutSession), args.Error(1) +} + +func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) { + args := s.Called(id) + return args.Get(0).(*stripe.Subscription), args.Error(1) +} + +func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) { + args := s.Called(id) + return args.Get(0).(*stripe.Subscription), args.Error(1) +} + +func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) { + args := s.Called(payload, header, secret) + return args.Get(0).(stripe.Event), args.Error(1) +} + +var _ stripeAPI = (*testStripeAPI)(nil)