Limit number of webpush subscriptions per subscriber IP

pull/751/head
binwiederhier 2023-06-16 21:59:07 -04:00
parent c43a1166e2
commit 341e84f643
5 changed files with 95 additions and 16 deletions

View File

@ -35,8 +35,7 @@ func generateWebPushKeys(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
_, err = fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file:
fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file:
web-push-public-key: %s web-push-public-key: %s
web-push-private-key: %s web-push-private-key: %s
@ -45,6 +44,5 @@ web-push-email-address: <email address>
See https://ntfy.sh/docs/config/#web-push for details. See https://ntfy.sh/docs/config/#web-push for details.
`, publicKey, privateKey) `, publicKey, privateKey)
return err
return nil
} }

View File

@ -59,7 +59,7 @@ func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *
} }
} }
} }
if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), req.Topics); err != nil { if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil {
return err return err
} }
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())

View File

@ -9,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"strings" "strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -225,7 +226,7 @@ func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
} }
func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) { func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", topics)) // Test auth and p256dh require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
} }
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) { func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {

View File

@ -2,15 +2,23 @@ package server
import ( import (
"database/sql" "database/sql"
"errors"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"net/netip"
"time" "time"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
) )
const ( const (
subscriptionIDPrefix = "wps_" subscriptionIDPrefix = "wps_"
subscriptionIDLength = 10 subscriptionIDLength = 10
subscriptionLimitPerSubscriberIP = 10
)
var (
errWebPushNoRows = errors.New("no rows found")
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
) )
const ( const (
@ -22,10 +30,12 @@ const (
key_auth TEXT NOT NULL, key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL, key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL, updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0 warned_at INT NOT NULL DEFAULT 0
); );
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint); CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic ( CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL, subscription_id TEXT NOT NULL,
topic TEXT NOT NULL, topic TEXT NOT NULL,
@ -43,8 +53,9 @@ const (
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
` `
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?` selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
selectWebPushSubscriptionsForTopicQuery = ` selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
selectWebPushSubscriptionsForTopicQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id JOIN subscription s ON s.id = st.subscription_id
@ -52,10 +63,10 @@ const (
` `
selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?` selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?`
insertWebPushSubscriptionQuery = ` insertWebPushSubscriptionQuery = `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, updated_at, warned_at) INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint) ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, updated_at = excluded.updated_at, warned_at = excluded.warned_at DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
` `
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?` updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?` deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
@ -119,12 +130,28 @@ func runWebPushStartupQueries(db *sql.DB) error {
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
// existing entries for a given endpoint. // existing entries for a given endpoint.
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, topics []string) error { func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
if err != nil {
return err
}
defer rowsCount.Close()
var subscriptionCount int
if !rowsCount.Next() {
return errWebPushNoRows
}
if err := rowsCount.Scan(&subscriptionCount); err != nil {
return err
}
if err := rowsCount.Close(); err != nil {
return err
}
// Read existing subscription ID for endpoint (or create new ID) // Read existing subscription ID for endpoint (or create new ID)
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint) rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
if err != nil { if err != nil {
@ -137,6 +164,9 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
return err return err
} }
} else { } else {
if subscriptionCount >= subscriptionLimitPerSubscriberIP {
return errWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} }
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
@ -144,7 +174,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
} }
// Insert or update subscription // Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0 updatedAt, warnedAt := time.Now().Unix(), 0
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, updatedAt, warnedAt); err != nil { if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err return err
} }
// Replace all subscription topics // Replace all subscription topics
@ -159,6 +189,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
return tx.Commit() return tx.Commit()
} }
// SubscriptionsForTopic returns all subscriptions for the given topic
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) { func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic) rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
if err != nil { if err != nil {
@ -168,6 +199,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti
return c.subscriptionsFromRows(rows) return c.subscriptionsFromRows(rows)
} }
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) { func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
if err != nil { if err != nil {
@ -177,6 +209,7 @@ func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPus
return c.subscriptionsFromRows(rows) return c.subscriptionsFromRows(rows)
} }
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error { func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
@ -209,21 +242,25 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip
return subscriptions, nil return subscriptions, nil
} }
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error { func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint) _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
return err return err
} }
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error { func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID) _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
return err return err
} }
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
return err return err
} }
// Close closes the underlying database connection
func (c *webPushStore) Close() error { func (c *webPushStore) Close() error {
return c.db.Close() return c.db.Close()
} }

View File

@ -1,7 +1,10 @@
package server package server
import ( import (
"fmt"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net/netip"
"path/filepath"
"testing" "testing"
) )
@ -10,3 +13,43 @@ func newTestWebPushStore(t *testing.T, filename string) *webPushStore {
require.Nil(t, err) require.Nil(t, err)
return webPush return webPush
} }
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db"))
defer webPush.Close()
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "u_1234")
subs2, err := webPush.SubscriptionsForTopic("mytopic")
require.Nil(t, err)
require.Len(t, subs2, 1)
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
}
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db"))
defer webPush.Close()
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
}
// Another one for the same endpoint should be fine
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different endpoint it should fail
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different IP address it should be fine again
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
}