diff --git a/cmd/webpush.go b/cmd/webpush.go index 2aec4d7f..6e74def2 100644 --- a/cmd/webpush.go +++ b/cmd/webpush.go @@ -35,8 +35,7 @@ func generateWebPushKeys(c *cli.Context) error { if err != nil { return err } - - fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file: + _, err = fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file: web-push-public-key: %s web-push-private-key: %s @@ -45,6 +44,5 @@ web-push-email-address: See https://ntfy.sh/docs/config/#web-push for details. `, publicKey, privateKey) - - return nil + return err } diff --git a/server/server_web_push.go b/server/server_webpush.go similarity index 98% rename from server/server_web_push.go rename to server/server_webpush.go index 0875b94f..209cb2d7 100644 --- a/server/server_web_push.go +++ b/server/server_webpush.go @@ -59,7 +59,7 @@ func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v * } } } - if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), req.Topics); err != nil { + if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) diff --git a/server/server_web_push_test.go b/server/server_webpush_test.go similarity index 98% rename from server/server_web_push_test.go rename to server/server_webpush_test.go index 7db82b8e..665156c8 100644 --- a/server/server_web_push_test.go +++ b/server/server_webpush_test.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/netip" "strings" "sync/atomic" "testing" @@ -225,7 +226,7 @@ func payloadForTopics(t *testing.T, topics []string, endpoint string) string { } func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) { - require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", topics)) // Test auth and p256dh + require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh } func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) { diff --git a/server/webpush_store.go b/server/webpush_store.go index 6dc1ddef..d2b7ef27 100644 --- a/server/webpush_store.go +++ b/server/webpush_store.go @@ -2,15 +2,23 @@ package server import ( "database/sql" + "errors" "heckel.io/ntfy/util" + "net/netip" "time" _ "github.com/mattn/go-sqlite3" // SQLite driver ) const ( - subscriptionIDPrefix = "wps_" - subscriptionIDLength = 10 + subscriptionIDPrefix = "wps_" + subscriptionIDLength = 10 + subscriptionLimitPerSubscriberIP = 10 +) + +var ( + errWebPushNoRows = errors.New("no rows found") + errWebPushTooManySubscriptions = errors.New("too many subscriptions") ) const ( @@ -21,11 +29,13 @@ const ( endpoint TEXT NOT NULL, key_auth TEXT NOT NULL, key_p256dh TEXT NOT NULL, - user_id TEXT NOT NULL, + user_id TEXT NOT NULL, + subscriber_ip TEXT NOT NULL, updated_at INT NOT NULL, warned_at INT NOT NULL DEFAULT 0 ); CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint); + CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip); CREATE TABLE IF NOT EXISTS subscription_topic ( subscription_id TEXT NOT NULL, topic TEXT NOT NULL, @@ -43,8 +53,9 @@ const ( PRAGMA foreign_keys = ON; ` - selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?` - selectWebPushSubscriptionsForTopicQuery = ` + selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?` + selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?` + selectWebPushSubscriptionsForTopicQuery = ` SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription_topic st JOIN subscription s ON s.id = st.subscription_id @@ -52,10 +63,10 @@ const ( ` selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?` insertWebPushSubscriptionQuery = ` - INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, updated_at, warned_at) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT (endpoint) - DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, updated_at = excluded.updated_at, warned_at = excluded.warned_at + DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at ` updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?` deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?` @@ -119,12 +130,28 @@ func runWebPushStartupQueries(db *sql.DB) error { // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all // existing entries for a given endpoint. -func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, topics []string) error { +func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() + // Read number of subscriptions for subscriber IP address + rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String()) + if err != nil { + return err + } + defer rowsCount.Close() + var subscriptionCount int + if !rowsCount.Next() { + return errWebPushNoRows + } + if err := rowsCount.Scan(&subscriptionCount); err != nil { + return err + } + if err := rowsCount.Close(); err != nil { + return err + } // Read existing subscription ID for endpoint (or create new ID) rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint) if err != nil { @@ -137,6 +164,9 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID return err } } else { + if subscriptionCount >= subscriptionLimitPerSubscriberIP { + return errWebPushTooManySubscriptions + } subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) } if err := rows.Close(); err != nil { @@ -144,7 +174,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID } // Insert or update subscription updatedAt, warnedAt := time.Now().Unix(), 0 - if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, updatedAt, warnedAt); err != nil { + if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { return err } // Replace all subscription topics @@ -159,6 +189,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID return tx.Commit() } +// SubscriptionsForTopic returns all subscriptions for the given topic func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) { rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic) if err != nil { @@ -168,6 +199,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti return c.subscriptionsFromRows(rows) } +// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) { rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) if err != nil { @@ -177,6 +209,7 @@ func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPus return c.subscriptionsFromRows(rows) } +// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error { tx, err := c.db.Begin() if err != nil { @@ -209,21 +242,25 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip return subscriptions, nil } +// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error { _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint) return err } +// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error { _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID) return err } +// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) return err } +// Close closes the underlying database connection func (c *webPushStore) Close() error { return c.db.Close() } diff --git a/server/webpush_store_test.go b/server/webpush_store_test.go index 28068ebe..03951a07 100644 --- a/server/webpush_store_test.go +++ b/server/webpush_store_test.go @@ -1,7 +1,10 @@ package server import ( + "fmt" "github.com/stretchr/testify/require" + "net/netip" + "path/filepath" "testing" ) @@ -10,3 +13,43 @@ func newTestWebPushStore(t *testing.T, filename string) *webPushStore { require.Nil(t, err) return webPush } + +func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) { + webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db")) + defer webPush.Close() + + require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + subs, err := webPush.SubscriptionsForTopic("test-topic") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, subs[0].Endpoint, testWebPushEndpoint) + require.Equal(t, subs[0].P256dh, "p256dh-key") + require.Equal(t, subs[0].Auth, "auth-key") + require.Equal(t, subs[0].UserID, "u_1234") + + subs2, err := webPush.SubscriptionsForTopic("mytopic") + require.Nil(t, err) + require.Len(t, subs2, 1) + require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint) +} + +func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) { + webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db")) + defer webPush.Close() + + // Insert 10 subscriptions with the same IP address + for i := 0; i < 10; i++ { + endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) + require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + } + + // Another one for the same endpoint should be fine + require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + // But with a different endpoint it should fail + require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + // But with a different IP address it should be fine again + require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"})) +}