From 00af52411c6ab936c29f6e9c554621ac298307bf Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 29 Jan 2023 16:15:08 -0500 Subject: [PATCH] More billing unit tests --- docs/config.md | 8 +- server/server.go | 8 +- server/server_account_test.go | 20 ++++- server/server_payments.go | 2 + server/server_payments_test.go | 142 ++++++++++++++++++++++++++++++++- 5 files changed, 168 insertions(+), 12 deletions(-) diff --git a/docs/config.md b/docs/config.md index 33434769..fd5c6154 100644 --- a/docs/config.md +++ b/docs/config.md @@ -504,7 +504,7 @@ or the root domain: proxy_send_timeout 3m; proxy_read_timeout 3m; - client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml + client_max_body_size 0; # Stream request body to backend } } @@ -540,7 +540,7 @@ or the root domain: proxy_send_timeout 3m; proxy_read_timeout 3m; - client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml + client_max_body_size 0; # Stream request body to backend } } ``` @@ -571,7 +571,7 @@ or the root domain: proxy_send_timeout 3m; proxy_read_timeout 3m; - client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml + client_max_body_size 0; # Stream request body to backend } } @@ -603,7 +603,7 @@ or the root domain: proxy_send_timeout 3m; proxy_read_timeout 3m; - client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml + client_max_body_size 0; # Stream request body to backend } } ``` diff --git a/server/server.go b/server/server.go index d4c1573a..17362b1e 100644 --- a/server/server.go +++ b/server/server.go @@ -38,7 +38,6 @@ import ( - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) - HIGH Docs -- Large uploads for higher tiers (nginx config!) - MEDIUM: Test new token endpoints & never-expiring token - MEDIUM: Make sure account endpoints make sense for admins - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) @@ -1641,7 +1640,7 @@ func (s *Server) authenticate(r *http.Request) (user *user.User, err error) { return nil, errHTTPUnauthorized } if strings.HasPrefix(value, "Bearer") { - return s.authenticateBearerAuth(r, value) + return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))) } return s.authenticateBasicAuth(r, value) } @@ -1651,12 +1650,13 @@ func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *use username, password, ok := r.BasicAuth() if !ok { return nil, errors.New("invalid basic auth") + } else if username == "" { + return s.authenticateBearerAuth(r, password) // Treat password as token } return s.userManager.Authenticate(username, password) } -func (s *Server) authenticateBearerAuth(r *http.Request, value string) (*user.User, error) { - token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer")) +func (s *Server) authenticateBearerAuth(r *http.Request, token string) (*user.User, error) { u, err := s.userManager.AuthenticateToken(token) if err != nil { return nil, err diff --git a/server/server_account_test.go b/server/server_account_test.go index fe5fe0fc..62df104b 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -41,6 +41,13 @@ func TestAccount_Signup_Success(t *testing.T) { account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) require.Equal(t, "phil", account.Username) require.Equal(t, "user", account.Role) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("", token.Token), // We allow a fake basic auth to make curl-ing easier (curl -u :) + }) + require.Equal(t, 200, rr.Code) + account, _ = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, "phil", account.Username) } func TestAccount_Signup_UserExists(t *testing.T) { @@ -247,7 +254,18 @@ func TestAccount_ChangePassword(t *testing.T) { require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) - rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{ + rr := request(t, s, "POST", "/v1/account/password", `{"password": "WRONG", "new_password": ""}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 400, rr.Code) + + rr = request(t, s, "POST", "/v1/account/password", `{"password": "WRONG", "new_password": "new password"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 400, rr.Code) + require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code) + + rr = request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) require.Equal(t, 200, rr.Code) diff --git a/server/server_payments.go b/server/server_payments.go index bc616e7c..9af7a42e 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -229,6 +229,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID) if err != nil { return err + } else if sub.Items == nil || len(sub.Items.Data) != 1 { + return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "no items, or more than one item") } params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(false), diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 7206a651..d1c8de4c 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -304,7 +304,14 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes }, }, nil) stripeMock. - On("UpdateCustomer", mock.Anything). + On("UpdateCustomer", "acct_5555", &stripe.CustomerParams{ + Params: stripe.Params{ + Metadata: map[string]string{ + "user_id": u.ID, + "user_name": u.Name, + }, + }, + }). Return(&stripe.Customer{}, nil) // Send messages until rate limit of free tier is hit @@ -517,6 +524,135 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID)) } +func TestPayments_Subscription_Update_Different_Tier(t *testing.T) { + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("GetSubscription", "sub_123"). + Return(&stripe.Subscription{ + ID: "sub_123", + Items: &stripe.SubscriptionItemList{ + Data: []*stripe.SubscriptionItem{ + { + ID: "someid_123", + Price: &stripe.Price{ID: "price_123"}, + }, + }, + }, + }, nil) + stripeMock. + On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{ + CancelAtPeriodEnd: stripe.Bool(false), + ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), + Items: []*stripe.SubscriptionItemsParams{ + { + ID: stripe.String("someid_123"), + Price: stripe.String("price_456"), + }, + }, + }). + Return(&stripe.Subscription{}, nil) + + // Create tier and user + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_123", + Code: "pro", + StripePriceID: "price_123", + })) + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_456", + Code: "business", + StripePriceID: "price_456", + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("phil", "pro")) + require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{ + StripeCustomerID: "acct_123", + StripeSubscriptionID: "sub_123", + })) + + // Call endpoint to change subscription + rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) +} + +func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) { + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("UpdateSubscription", "sub_123", mock.MatchedBy(func(s *stripe.SubscriptionParams) bool { + return *s.CancelAtPeriodEnd // Is true + })). + Return(&stripe.Subscription{}, nil) + + // Create user + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{ + StripeCustomerID: "acct_123", + StripeSubscriptionID: "sub_123", + })) + + // Delete subscription + rr := request(t, s, "DELETE", "/v1/account/billing/subscription", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) +} + +func TestPayments_CreatePortalSession(t *testing.T) { + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("NewPortalSession", &stripe.BillingPortalSessionParams{ + Customer: stripe.String("acct_123"), + ReturnURL: stripe.String(s.config.BaseURL), + }). + Return(&stripe.BillingPortalSession{ + URL: "https://billing.stripe.com/blablabla", + }, nil) + + // Create user + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{ + StripeCustomerID: "acct_123", + StripeSubscriptionID: "sub_123", + })) + + // Create portal session + rr := request(t, s, "POST", "/v1/account/billing/portal", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + ps, _ := util.UnmarshalJSON[apiAccountBillingPortalRedirectResponse](io.NopCloser(rr.Body)) + require.Equal(t, "https://billing.stripe.com/blablabla", ps.RedirectURL) +} + type testStripeAPI struct { mock.Mock } @@ -554,12 +690,12 @@ func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) } func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) { - args := s.Called(id) + args := s.Called(id, params) return args.Get(0).(*stripe.Customer), args.Error(1) } func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) { - args := s.Called(id) + args := s.Called(id, params) return args.Get(0).(*stripe.Subscription), args.Error(1) }