diff --git a/go.mod b/go.mod index dda58c9b..96ffff98 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require github.com/pkg/errors v0.9.1 // indirect require ( firebase.google.com/go/v4 v4.11.0 + github.com/SherClockHolmes/webpush-go v1.2.0 github.com/prometheus/client_golang v1.15.1 github.com/stripe/stripe-go/v74 v74.21.0 ) @@ -39,7 +40,6 @@ require ( cloud.google.com/go/longrunning v0.5.0 // indirect github.com/AlekSi/pointer v1.2.0 // indirect github.com/MicahParks/keyfunc v1.9.0 // indirect - github.com/SherClockHolmes/webpush-go v1.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/server/server_account.go b/server/server_account.go index 0336f816..b42496db 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -171,9 +171,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * return errHTTPBadRequestIncorrectPasswordConfirmation } if s.webPush != nil { - err := s.webPush.ExpireWebPushForUser(u.Name) - - if err != nil { + if err := s.webPush.RemoveByUserID(u.ID); err != nil { logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name) } } diff --git a/server/server_test.go b/server/server_test.go index f264d096..76f83eea 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2620,12 +2620,8 @@ func newTestConfigWithAuthFile(t *testing.T) *Config { func newTestConfigWithWebPush(t *testing.T) *Config { conf := newTestConfig(t) - privateKey, publicKey, err := webpush.GenerateVAPIDKeys() - if err != nil { - t.Fatal(err) - } - + require.Nil(t, err) conf.WebPushEnabled = true conf.WebPushSubscriptionsFile = filepath.Join(t.TempDir(), "subscriptions.db") conf.WebPushEmailAddress = "testing@example.com" @@ -2636,9 +2632,7 @@ func newTestConfigWithWebPush(t *testing.T) *Config { func newTestServer(t *testing.T, config *Config) *Server { server, err := New(config) - if err != nil { - t.Fatal(err) - } + require.Nil(t, err) return server } diff --git a/server/server_web_push.go b/server/server_web_push.go index d3f669cf..d8a25e61 100644 --- a/server/server_web_push.go +++ b/server/server_web_push.go @@ -10,15 +10,8 @@ import ( ) func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Request, v *visitor) error { - var username string - u := v.User() - if u != nil { - username = u.Name - } - var sub webPushSubscribePayload err := json.NewDecoder(r.Body).Decode(&sub) - if err != nil || sub.BrowserSubscription.Endpoint == "" || sub.BrowserSubscription.Keys.P256dh == "" || sub.BrowserSubscription.Keys.Auth == "" { return errHTTPBadRequestWebPushSubscriptionInvalid } @@ -27,12 +20,9 @@ func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Requ if err != nil { return err } - - err = s.webPush.AddSubscription(topic.ID, username, sub) - if err != nil { + if err = s.webPush.AddSubscription(topic.ID, v.MaybeUserID(), sub); err != nil { return err } - return s.writeJSON(w, newSuccessResponse()) } @@ -59,7 +49,7 @@ func (s *Server) handleTopicWebPushUnsubscribe(w http.ResponseWriter, r *http.Re } func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) { - subscriptions, err := s.webPush.GetSubscriptionsForTopic(m.Topic) + subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic) if err != nil { logvm(v, m).Err(err).Warn("Unable to publish web push messages") return @@ -69,21 +59,17 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) { // Importing the emojis in the service worker would add unnecessary complexity, // simply do it here for web push notifications instead - var titleWithDefault string - var formattedTitle string - + var titleWithDefault, formattedTitle string emojis, _, err := toEmojis(m.Tags) if err != nil { logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message") return } - if m.Title == "" { titleWithDefault = m.Topic } else { titleWithDefault = m.Title } - if len(emojis) > 0 { formattedTitle = fmt.Sprintf("%s %s", strings.Join(emojis[:], " "), titleWithDefault) } else { @@ -92,7 +78,7 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) { for i, xi := range subscriptions { go func(i int, sub webPushSubscription) { - ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.Username, "topic": m.Topic, "message_id": m.ID} + ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.UserID, "topic": m.Topic, "message_id": m.ID} payload := &webPushPayload{ SubscriptionID: fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), @@ -110,31 +96,25 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) { Subscriber: s.config.WebPushEmailAddress, VAPIDPublicKey: s.config.WebPushPublicKey, VAPIDPrivateKey: s.config.WebPushPrivateKey, - // deliverability on iOS isn't great with lower urgency values, + // Deliverability on iOS isn't great with lower urgency values, // and thus we can't really map lower ntfy priorities to lower urgency values Urgency: webpush.UrgencyHigh, }) if err != nil { logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message") - - err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint) - if err != nil { + if err := s.webPush.RemoveByEndpoint(sub.BrowserSubscription.Endpoint); err != nil { logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription") } - return } // May want to handle at least 429 differently, but for now treat all errors the same if !(200 <= resp.StatusCode && resp.StatusCode <= 299) { logvm(v, m).Fields(ctx).Field("response", resp).Debug("Unable to publish web push message") - - err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint) - if err != nil { + if err := s.webPush.RemoveByEndpoint(sub.BrowserSubscription.Endpoint); err != nil { logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription") } - return } }(i, xi) diff --git a/server/server_web_push_test.go b/server/server_web_push_test.go index 3b8863d4..56936529 100644 --- a/server/server_web_push_test.go +++ b/server/server_web_push_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "sync/atomic" "testing" @@ -41,7 +42,7 @@ func TestServer_WebPush_TopicSubscribe(t *testing.T) { require.Equal(t, 200, response.Code) require.Equal(t, `{"success":true}`+"\n", response.Body.String()) - subs, err := s.webPush.GetSubscriptionsForTopic("test-topic") + subs, err := s.webPush.SubscriptionsForTopic("test-topic") if err != nil { t.Fatal(err) } @@ -50,7 +51,7 @@ func TestServer_WebPush_TopicSubscribe(t *testing.T) { require.Equal(t, subs[0].BrowserSubscription.Endpoint, "https://example.com/webpush") require.Equal(t, subs[0].BrowserSubscription.Keys.P256dh, "p256dh-key") require.Equal(t, subs[0].BrowserSubscription.Keys.Auth, "auth-key") - require.Equal(t, subs[0].Username, "") + require.Equal(t, subs[0].UserID, "") } func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) { @@ -64,17 +65,13 @@ func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) { response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, map[string]string{ "Authorization": util.BasicAuth("ben", "ben"), }) - require.Equal(t, 200, response.Code) require.Equal(t, `{"success":true}`+"\n", response.Body.String()) - subs, err := s.webPush.GetSubscriptionsForTopic("test-topic") - if err != nil { - t.Fatal(err) - } - + subs, err := s.webPush.SubscriptionsForTopic("test-topic") + require.Nil(t, err) require.Len(t, subs, 1) - require.Equal(t, subs[0].Username, "ben") + require.True(t, strings.HasPrefix(subs[0].UserID, "u_")) } func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) { @@ -203,7 +200,7 @@ func addSubscription(t *testing.T, s *Server, topic string, url string) { } func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) { - subs, err := s.webPush.GetSubscriptionsForTopic("test-topic") + subs, err := s.webPush.SubscriptionsForTopic("test-topic") if err != nil { t.Fatal(err) } diff --git a/server/types.go b/server/types.go index 6eed5eef..bac4a478 100644 --- a/server/types.go +++ b/server/types.go @@ -41,7 +41,7 @@ type message struct { PollID string `json:"poll_id,omitempty"` Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting - User string `json:"-"` // Username of the uploader, used to associated attachments + User string `json:"-"` // UserID of the uploader, used to associated attachments } func (m *message) Context() log.Context { @@ -476,7 +476,7 @@ type webPushPayload struct { type webPushSubscription struct { BrowserSubscription webpush.Subscription - Username string + UserID string } type webPushSubscribePayload struct { diff --git a/server/web_push.go b/server/web_push.go index 2fafb2a8..8969af68 100644 --- a/server/web_push.go +++ b/server/web_push.go @@ -12,7 +12,7 @@ const ( CREATE TABLE IF NOT EXISTS subscriptions ( id INTEGER PRIMARY KEY AUTOINCREMENT, topic TEXT NOT NULL, - username TEXT, + user_id TEXT, endpoint TEXT NOT NULL, key_auth TEXT NOT NULL, key_p256dh TEXT NOT NULL, @@ -24,14 +24,14 @@ const ( COMMIT; ` insertWebPushSubscriptionQuery = ` - INSERT OR REPLACE INTO subscriptions (topic, username, endpoint, key_auth, key_p256dh) + INSERT OR REPLACE INTO subscriptions (topic, user_id, endpoint, key_auth, key_p256dh) VALUES (?, ?, ?, ?, ?) ` deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscriptions WHERE endpoint = ?` - deleteWebPushSubscriptionByUsernameQuery = `DELETE FROM subscriptions WHERE username = ?` + deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscriptions WHERE user_id = ?` deleteWebPushSubscriptionByTopicAndEndpointQuery = `DELETE FROM subscriptions WHERE topic = ? AND endpoint = ?` - selectWebPushSubscriptionsForTopicQuery = `SELECT endpoint, key_auth, key_p256dh, username FROM subscriptions WHERE topic = ?` + selectWebPushSubscriptionsForTopicQuery = `SELECT endpoint, key_auth, key_p256dh, user_id FROM subscriptions WHERE topic = ?` selectWebPushSubscriptionsCountQuery = `SELECT COUNT(*) FROM subscriptions` ) @@ -69,11 +69,11 @@ func setupNewSubscriptionsDB(db *sql.DB) error { return nil } -func (c *webPushStore) AddSubscription(topic string, username string, subscription webPushSubscribePayload) error { +func (c *webPushStore) AddSubscription(topic string, userID string, subscription webPushSubscribePayload) error { _, err := c.db.Exec( insertWebPushSubscriptionQuery, topic, - username, + userID, subscription.BrowserSubscription.Endpoint, subscription.BrowserSubscription.Keys.Auth, subscription.BrowserSubscription.Keys.P256dh, @@ -90,7 +90,7 @@ func (c *webPushStore) RemoveSubscription(topic string, endpoint string) error { return err } -func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []webPushSubscription, err error) { +func (c *webPushStore) SubscriptionsForTopic(topic string) (subscriptions []webPushSubscription, err error) { rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic) if err != nil { return nil, err @@ -100,7 +100,7 @@ func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []w var data []webPushSubscription for rows.Next() { i := webPushSubscription{} - err = rows.Scan(&i.BrowserSubscription.Endpoint, &i.BrowserSubscription.Keys.Auth, &i.BrowserSubscription.Keys.P256dh, &i.Username) + err = rows.Scan(&i.BrowserSubscription.Endpoint, &i.BrowserSubscription.Keys.Auth, &i.BrowserSubscription.Keys.P256dh, &i.UserID) if err != nil { return nil, err } @@ -109,7 +109,7 @@ func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []w return data, nil } -func (c *webPushStore) ExpireWebPushEndpoint(endpoint string) error { +func (c *webPushStore) RemoveByEndpoint(endpoint string) error { _, err := c.db.Exec( deleteWebPushSubscriptionByEndpointQuery, endpoint, @@ -117,10 +117,10 @@ func (c *webPushStore) ExpireWebPushEndpoint(endpoint string) error { return err } -func (c *webPushStore) ExpireWebPushForUser(username string) error { +func (c *webPushStore) RemoveByUserID(userID string) error { _, err := c.db.Exec( - deleteWebPushSubscriptionByUsernameQuery, - username, + deleteWebPushSubscriptionByUserIDQuery, + userID, ) return err }