Limit number of webpush subscriptions per subscriber IP
This commit is contained in:
		
							parent
							
								
									c43a1166e2
								
							
						
					
					
						commit
						341e84f643
					
				
					 5 changed files with 95 additions and 16 deletions
				
			
		|  | @ -35,8 +35,7 @@ func generateWebPushKeys(c *cli.Context) error { | |||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file: | ||||
| 	_, err = fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file: | ||||
| 
 | ||||
| web-push-public-key: %s | ||||
| web-push-private-key: %s | ||||
|  | @ -45,6 +44,5 @@ web-push-email-address: <email address> | |||
| 
 | ||||
| See https://ntfy.sh/docs/config/#web-push for details. | ||||
| `, publicKey, privateKey) | ||||
| 
 | ||||
| 	return nil | ||||
| 	return err | ||||
| } | ||||
|  |  | |||
|  | @ -59,7 +59,7 @@ func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v * | |||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), req.Topics); err != nil { | ||||
| 	if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return s.writeJSON(w, newSuccessResponse()) | ||||
|  | @ -9,6 +9,7 @@ import ( | |||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
|  | @ -225,7 +226,7 @@ func payloadForTopics(t *testing.T, topics []string, endpoint string) string { | |||
| } | ||||
| 
 | ||||
| func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) { | ||||
| 	require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", topics)) // Test auth and p256dh | ||||
| 	require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh | ||||
| } | ||||
| 
 | ||||
| func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) { | ||||
|  | @ -2,15 +2,23 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| 
 | ||||
| 	_ "github.com/mattn/go-sqlite3" // SQLite driver | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	subscriptionIDPrefix = "wps_" | ||||
| 	subscriptionIDLength = 10 | ||||
| 	subscriptionIDPrefix             = "wps_" | ||||
| 	subscriptionIDLength             = 10 | ||||
| 	subscriptionLimitPerSubscriberIP = 10 | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	errWebPushNoRows               = errors.New("no rows found") | ||||
| 	errWebPushTooManySubscriptions = errors.New("too many subscriptions") | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -21,11 +29,13 @@ const ( | |||
| 			endpoint TEXT NOT NULL, | ||||
| 			key_auth TEXT NOT NULL, | ||||
| 			key_p256dh TEXT NOT NULL, | ||||
| 			user_id TEXT NOT NULL, | ||||
| 			user_id TEXT NOT NULL,		 | ||||
| 			subscriber_ip TEXT NOT NULL, | ||||
| 			updated_at INT NOT NULL, | ||||
| 			warned_at INT NOT NULL DEFAULT 0 | ||||
| 		); | ||||
| 		CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint); | ||||
| 		CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip); | ||||
| 		CREATE TABLE IF NOT EXISTS subscription_topic ( | ||||
| 			subscription_id TEXT NOT NULL, | ||||
| 			topic TEXT NOT NULL, | ||||
|  | @ -43,8 +53,9 @@ const ( | |||
| 		PRAGMA foreign_keys = ON; | ||||
| 	` | ||||
| 
 | ||||
| 	selectWebPushSubscriptionIDByEndpoint   = `SELECT id FROM subscription WHERE endpoint = ?` | ||||
| 	selectWebPushSubscriptionsForTopicQuery = ` | ||||
| 	selectWebPushSubscriptionIDByEndpoint        = `SELECT id FROM subscription WHERE endpoint = ?` | ||||
| 	selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?` | ||||
| 	selectWebPushSubscriptionsForTopicQuery      = ` | ||||
| 		SELECT id, endpoint, key_auth, key_p256dh, user_id | ||||
| 		FROM subscription_topic st | ||||
| 		JOIN subscription s ON s.id = st.subscription_id | ||||
|  | @ -52,10 +63,10 @@ const ( | |||
| 	` | ||||
| 	selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?` | ||||
| 	insertWebPushSubscriptionQuery              = ` | ||||
| 		INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, updated_at, warned_at) | ||||
| 		VALUES (?, ?, ?, ?, ?, ?, ?) | ||||
| 		INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) | ||||
| 		VALUES (?, ?, ?, ?, ?, ?, ?, ?) | ||||
| 		ON CONFLICT (endpoint)  | ||||
| 		DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, updated_at = excluded.updated_at, warned_at = excluded.warned_at | ||||
| 		DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at | ||||
| 	` | ||||
| 	updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?` | ||||
| 	deleteWebPushSubscriptionByEndpointQuery  = `DELETE FROM subscription WHERE endpoint = ?` | ||||
|  | @ -119,12 +130,28 @@ func runWebPushStartupQueries(db *sql.DB) error { | |||
| 
 | ||||
| // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all | ||||
| // existing entries for a given endpoint. | ||||
| func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, topics []string) error { | ||||
| func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { | ||||
| 	tx, err := c.db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer tx.Rollback() | ||||
| 	// Read number of subscriptions for subscriber IP address | ||||
| 	rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer rowsCount.Close() | ||||
| 	var subscriptionCount int | ||||
| 	if !rowsCount.Next() { | ||||
| 		return errWebPushNoRows | ||||
| 	} | ||||
| 	if err := rowsCount.Scan(&subscriptionCount); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err := rowsCount.Close(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// Read existing subscription ID for endpoint (or create new ID) | ||||
| 	rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint) | ||||
| 	if err != nil { | ||||
|  | @ -137,6 +164,9 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID | |||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		if subscriptionCount >= subscriptionLimitPerSubscriberIP { | ||||
| 			return errWebPushTooManySubscriptions | ||||
| 		} | ||||
| 		subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) | ||||
| 	} | ||||
| 	if err := rows.Close(); err != nil { | ||||
|  | @ -144,7 +174,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID | |||
| 	} | ||||
| 	// Insert or update subscription | ||||
| 	updatedAt, warnedAt := time.Now().Unix(), 0 | ||||
| 	if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, updatedAt, warnedAt); err != nil { | ||||
| 	if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// Replace all subscription topics | ||||
|  | @ -159,6 +189,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID | |||
| 	return tx.Commit() | ||||
| } | ||||
| 
 | ||||
| // SubscriptionsForTopic returns all subscriptions for the given topic | ||||
| func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) { | ||||
| 	rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic) | ||||
| 	if err != nil { | ||||
|  | @ -168,6 +199,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti | |||
| 	return c.subscriptionsFromRows(rows) | ||||
| } | ||||
| 
 | ||||
| // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period | ||||
| func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) { | ||||
| 	rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) | ||||
| 	if err != nil { | ||||
|  | @ -177,6 +209,7 @@ func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPus | |||
| 	return c.subscriptionsFromRows(rows) | ||||
| } | ||||
| 
 | ||||
| // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon | ||||
| func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error { | ||||
| 	tx, err := c.db.Begin() | ||||
| 	if err != nil { | ||||
|  | @ -209,21 +242,25 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip | |||
| 	return subscriptions, nil | ||||
| } | ||||
| 
 | ||||
| // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint | ||||
| func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error { | ||||
| 	_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // RemoveSubscriptionsByUserID removes all subscriptions for the given user ID | ||||
| func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error { | ||||
| 	_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period | ||||
| func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { | ||||
| 	_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // Close closes the underlying database connection | ||||
| func (c *webPushStore) Close() error { | ||||
| 	return c.db.Close() | ||||
| } | ||||
|  |  | |||
|  | @ -1,7 +1,10 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"net/netip" | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
|  | @ -10,3 +13,43 @@ func newTestWebPushStore(t *testing.T, filename string) *webPushStore { | |||
| 	require.Nil(t, err) | ||||
| 	return webPush | ||||
| } | ||||
| 
 | ||||
| func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) { | ||||
| 	webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db")) | ||||
| 	defer webPush.Close() | ||||
| 
 | ||||
| 	require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) | ||||
| 
 | ||||
| 	subs, err := webPush.SubscriptionsForTopic("test-topic") | ||||
| 	require.Nil(t, err) | ||||
| 	require.Len(t, subs, 1) | ||||
| 	require.Equal(t, subs[0].Endpoint, testWebPushEndpoint) | ||||
| 	require.Equal(t, subs[0].P256dh, "p256dh-key") | ||||
| 	require.Equal(t, subs[0].Auth, "auth-key") | ||||
| 	require.Equal(t, subs[0].UserID, "u_1234") | ||||
| 
 | ||||
| 	subs2, err := webPush.SubscriptionsForTopic("mytopic") | ||||
| 	require.Nil(t, err) | ||||
| 	require.Len(t, subs2, 1) | ||||
| 	require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint) | ||||
| } | ||||
| 
 | ||||
| func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) { | ||||
| 	webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db")) | ||||
| 	defer webPush.Close() | ||||
| 
 | ||||
| 	// Insert 10 subscriptions with the same IP address | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) | ||||
| 		require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) | ||||
| 	} | ||||
| 
 | ||||
| 	// Another one for the same endpoint should be fine | ||||
| 	require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) | ||||
| 
 | ||||
| 	// But with a different endpoint it should fail | ||||
| 	require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) | ||||
| 
 | ||||
| 	// But with a different IP address it should be fine again | ||||
| 	require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"})) | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue