Limit number of webpush subscriptions per subscriber IP
parent
c43a1166e2
commit
341e84f643
|
@ -35,8 +35,7 @@ func generateWebPushKeys(c *cli.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
_, err = fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file:
|
||||||
fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file:
|
|
||||||
|
|
||||||
web-push-public-key: %s
|
web-push-public-key: %s
|
||||||
web-push-private-key: %s
|
web-push-private-key: %s
|
||||||
|
@ -45,6 +44,5 @@ web-push-email-address: <email address>
|
||||||
|
|
||||||
See https://ntfy.sh/docs/config/#web-push for details.
|
See https://ntfy.sh/docs/config/#web-push for details.
|
||||||
`, publicKey, privateKey)
|
`, publicKey, privateKey)
|
||||||
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), req.Topics); err != nil {
|
if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -225,7 +226,7 @@ func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
|
func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
|
||||||
require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", topics)) // Test auth and p256dh
|
require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
|
||||||
}
|
}
|
||||||
|
|
||||||
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
|
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
|
|
@ -2,15 +2,23 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
subscriptionIDPrefix = "wps_"
|
subscriptionIDPrefix = "wps_"
|
||||||
subscriptionIDLength = 10
|
subscriptionIDLength = 10
|
||||||
|
subscriptionLimitPerSubscriberIP = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errWebPushNoRows = errors.New("no rows found")
|
||||||
|
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -21,11 +29,13 @@ const (
|
||||||
endpoint TEXT NOT NULL,
|
endpoint TEXT NOT NULL,
|
||||||
key_auth TEXT NOT NULL,
|
key_auth TEXT NOT NULL,
|
||||||
key_p256dh TEXT NOT NULL,
|
key_p256dh TEXT NOT NULL,
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
|
subscriber_ip TEXT NOT NULL,
|
||||||
updated_at INT NOT NULL,
|
updated_at INT NOT NULL,
|
||||||
warned_at INT NOT NULL DEFAULT 0
|
warned_at INT NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||||
subscription_id TEXT NOT NULL,
|
subscription_id TEXT NOT NULL,
|
||||||
topic TEXT NOT NULL,
|
topic TEXT NOT NULL,
|
||||||
|
@ -43,8 +53,9 @@ const (
|
||||||
PRAGMA foreign_keys = ON;
|
PRAGMA foreign_keys = ON;
|
||||||
`
|
`
|
||||||
|
|
||||||
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||||
selectWebPushSubscriptionsForTopicQuery = `
|
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||||
|
selectWebPushSubscriptionsForTopicQuery = `
|
||||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||||
FROM subscription_topic st
|
FROM subscription_topic st
|
||||||
JOIN subscription s ON s.id = st.subscription_id
|
JOIN subscription s ON s.id = st.subscription_id
|
||||||
|
@ -52,10 +63,10 @@ const (
|
||||||
`
|
`
|
||||||
selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?`
|
selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?`
|
||||||
insertWebPushSubscriptionQuery = `
|
insertWebPushSubscriptionQuery = `
|
||||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, updated_at, warned_at)
|
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT (endpoint)
|
ON CONFLICT (endpoint)
|
||||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||||
`
|
`
|
||||||
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||||
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||||
|
@ -119,12 +130,28 @@ func runWebPushStartupQueries(db *sql.DB) error {
|
||||||
|
|
||||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
|
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
|
||||||
// existing entries for a given endpoint.
|
// existing entries for a given endpoint.
|
||||||
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, topics []string) error {
|
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
// Read number of subscriptions for subscriber IP address
|
||||||
|
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rowsCount.Close()
|
||||||
|
var subscriptionCount int
|
||||||
|
if !rowsCount.Next() {
|
||||||
|
return errWebPushNoRows
|
||||||
|
}
|
||||||
|
if err := rowsCount.Scan(&subscriptionCount); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := rowsCount.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// Read existing subscription ID for endpoint (or create new ID)
|
// Read existing subscription ID for endpoint (or create new ID)
|
||||||
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
|
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -137,6 +164,9 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if subscriptionCount >= subscriptionLimitPerSubscriberIP {
|
||||||
|
return errWebPushTooManySubscriptions
|
||||||
|
}
|
||||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||||
}
|
}
|
||||||
if err := rows.Close(); err != nil {
|
if err := rows.Close(); err != nil {
|
||||||
|
@ -144,7 +174,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||||
}
|
}
|
||||||
// Insert or update subscription
|
// Insert or update subscription
|
||||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||||
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, updatedAt, warnedAt); err != nil {
|
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Replace all subscription topics
|
// Replace all subscription topics
|
||||||
|
@ -159,6 +189,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SubscriptionsForTopic returns all subscriptions for the given topic
|
||||||
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
|
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
|
||||||
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
|
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -168,6 +199,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti
|
||||||
return c.subscriptionsFromRows(rows)
|
return c.subscriptionsFromRows(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
|
||||||
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
|
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
|
||||||
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
|
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -177,6 +209,7 @@ func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPus
|
||||||
return c.subscriptionsFromRows(rows)
|
return c.subscriptionsFromRows(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
|
||||||
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
|
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -209,21 +242,25 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip
|
||||||
return subscriptions, nil
|
return subscriptions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
|
||||||
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||||
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
|
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
|
||||||
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
|
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
|
||||||
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
|
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
|
||||||
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||||
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
|
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying database connection
|
||||||
func (c *webPushStore) Close() error {
|
func (c *webPushStore) Close() error {
|
||||||
return c.db.Close()
|
return c.db.Close()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"net/netip"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,3 +13,43 @@ func newTestWebPushStore(t *testing.T, filename string) *webPushStore {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return webPush
|
return webPush
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
|
||||||
|
webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db"))
|
||||||
|
defer webPush.Close()
|
||||||
|
|
||||||
|
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||||
|
|
||||||
|
subs, err := webPush.SubscriptionsForTopic("test-topic")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, subs, 1)
|
||||||
|
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||||
|
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||||
|
require.Equal(t, subs[0].Auth, "auth-key")
|
||||||
|
require.Equal(t, subs[0].UserID, "u_1234")
|
||||||
|
|
||||||
|
subs2, err := webPush.SubscriptionsForTopic("mytopic")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, subs2, 1)
|
||||||
|
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
|
||||||
|
webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db"))
|
||||||
|
defer webPush.Close()
|
||||||
|
|
||||||
|
// Insert 10 subscriptions with the same IP address
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||||
|
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Another one for the same endpoint should be fine
|
||||||
|
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||||
|
|
||||||
|
// But with a different endpoint it should fail
|
||||||
|
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||||
|
|
||||||
|
// But with a different IP address it should be fine again
|
||||||
|
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue