Merge branch 'main' into markdown

This commit is contained in:
binwiederhier 2023-07-04 21:15:08 -04:00
commit 43981bb675
129 changed files with 6741 additions and 1183 deletions

View file

@ -1,10 +1,11 @@
package server
import (
"heckel.io/ntfy/user"
"io/fs"
"net/netip"
"time"
"heckel.io/ntfy/user"
)
// Defines default config settings (excluding limits, see below)
@ -22,6 +23,12 @@ const (
DefaultStripePriceCacheDuration = 3 * time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed
)
// Defines default Web Push settings
const (
DefaultWebPushExpiryWarningDuration = 7 * 24 * time.Hour
DefaultWebPushExpiryDuration = 9 * 24 * time.Hour
)
// Defines all global and per-visitor limits
// - message size limit: the max number of bytes for a message
// - total topic limit: max number of topics overall
@ -146,6 +153,13 @@ type Config struct {
EnableMetrics bool
AccessControlAllowOrigin string // CORS header field to restrict access from web clients
Version string // injected by App
WebPushPrivateKey string
WebPushPublicKey string
WebPushFile string
WebPushEmailAddress string
WebPushStartupQueries string
WebPushExpiryDuration time.Duration
WebPushExpiryWarningDuration time.Duration
}
// NewConfig instantiates a default new server config
@ -227,5 +241,11 @@ func NewConfig() *Config {
EnableReservations: false,
AccessControlAllowOrigin: "*",
Version: "",
WebPushPrivateKey: "",
WebPushPublicKey: "",
WebPushFile: "",
WebPushEmailAddress: "",
WebPushExpiryDuration: DefaultWebPushExpiryDuration,
WebPushExpiryWarningDuration: DefaultWebPushExpiryWarningDuration,
}
}

View file

@ -114,6 +114,9 @@ var (
errHTTPBadRequestAnonymousCallsNotAllowed = &errHTTP{40035, http.StatusBadRequest, "invalid request: anonymous phone calls are not allowed", "https://ntfy.sh/docs/publish/#phone-calls", nil}
errHTTPBadRequestPhoneNumberVerifyChannelInvalid = &errHTTP{40036, http.StatusBadRequest, "invalid request: verification channel must be 'sms' or 'call'", "https://ntfy.sh/docs/publish/#phone-calls", nil}
errHTTPBadRequestDelayNoCall = &errHTTP{40037, http.StatusBadRequest, "delayed call notifications are not supported", "", nil}
errHTTPBadRequestWebPushSubscriptionInvalid = &errHTTP{40038, http.StatusBadRequest, "invalid request: web push payload malformed", "", nil}
errHTTPBadRequestWebPushEndpointUnknown = &errHTTP{40039, http.StatusBadRequest, "invalid request: web push endpoint unknown", "", nil}
errHTTPBadRequestWebPushTopicCountTooHigh = &errHTTP{40040, http.StatusBadRequest, "invalid request: too many web push topic subscriptions", "", nil}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}
@ -138,5 +141,6 @@ var (
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil}
errHTTPInternalErrorWebPushUnableToPublish = &errHTTP{50004, http.StatusInternalServerError, "internal server error: unable to publish web push message", "", nil}
errHTTPInsufficientStorageUnifiedPush = &errHTTP{50701, http.StatusInsufficientStorage, "cannot publish to UnifiedPush topic without previously active subscriber", "", nil}
)

View file

@ -29,6 +29,7 @@ const (
tagResetter = "resetter"
tagWebsocket = "websocket"
tagMatrix = "matrix"
tagWebPush = "webpush"
)
var (

View file

@ -270,7 +270,7 @@ func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration
if err != nil {
return nil, err
}
if err := setupDB(db, startupQueries, cacheDuration); err != nil {
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
@ -749,7 +749,7 @@ func (c *messageCache) Close() error {
return c.db.Close()
}
func setupDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {

View file

@ -9,13 +9,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net"
"net/http"
@ -32,6 +25,14 @@ import (
"sync"
"time"
"unicode/utf8"
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
)
// Server is the main server, providing the UI and API for ntfy
@ -52,6 +53,7 @@ type Server struct {
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
webPush *webPushStore // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
@ -76,11 +78,15 @@ var (
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
webConfigPath = "/config.js"
webManifestPath = "/manifest.webmanifest"
webRootHTMLPath = "/app.html"
webServiceWorkerPath = "/sw.js"
accountPath = "/account"
matrixPushPath = "/_matrix/push/v1/notify"
metricsPath = "/metrics"
apiHealthPath = "/v1/health"
apiStatsPath = "/v1/stats"
apiWebPushPath = "/v1/webpush"
apiTiersPath = "/v1/tiers"
apiUsersPath = "/v1/users"
apiUsersAccessPath = "/v1/users/access"
@ -151,6 +157,13 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
var webPush *webPushStore
if conf.WebPushPublicKey != "" {
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
if err != nil {
return nil, err
}
}
topics, err := messageCache.Topics()
if err != nil {
return nil, err
@ -190,6 +203,7 @@ func New(conf *Config) (*Server, error) {
s := &Server{
config: conf,
messageCache: messageCache,
webPush: webPush,
fileCache: fileCache,
firebaseClient: firebaseClient,
smtpSender: mailer,
@ -342,6 +356,9 @@ func (s *Server) closeDatabases() {
s.userManager.Close()
}
s.messageCache.Close()
if s.webPush != nil {
s.webPush.Close()
}
}
// handle is the main entry point for all HTTP requests
@ -416,6 +433,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.handleHealth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webManifestPath {
return s.ensureWebPushEnabled(s.handleWebManifest)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersGet)(w, r, v)
} else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
@ -470,6 +489,10 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberAdd)))(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPhonePath {
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberDelete)))(w, r, v)
} else if r.Method == http.MethodPost && apiWebPushPath == r.URL.Path {
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushUpdate))(w, r, v)
} else if r.Method == http.MethodDelete && apiWebPushPath == r.URL.Path {
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushDelete))(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath {
return s.handleStats(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath {
@ -478,7 +501,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && r.URL.Path == metricsPath && s.metricsHandler != nil {
return s.handleMetrics(w, r, v)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
} else if r.Method == http.MethodGet && (staticRegex.MatchString(r.URL.Path) || r.URL.Path == webServiceWorkerPath || r.URL.Path == webRootHTMLPath) {
return s.ensureWebEnabled(s.handleStatic)(w, r, v)
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
return s.ensureWebEnabled(s.handleDocs)(w, r, v)
@ -552,7 +575,9 @@ func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visi
EnableCalls: s.config.TwilioAccount != "",
EnableEmails: s.config.SMTPSenderFrom != "",
EnableReservations: s.config.EnableReservations,
EnableWebPush: s.config.WebPushPublicKey != "",
BillingContact: s.config.BillingContact,
WebPushPublicKey: s.config.WebPushPublicKey,
DisallowedTopics: s.config.DisallowedTopics,
}
b, err := json.MarshalIndent(response, "", " ")
@ -564,6 +589,25 @@ func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visi
return err
}
// handleWebManifest serves the web app manifest for the progressive web app (PWA)
func (s *Server) handleWebManifest(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
response := &webManifestResponse{
Name: "ntfy web",
Description: "ntfy lets you send push notifications via scripts from any computer or phone",
ShortName: "ntfy",
Scope: "/",
StartURL: s.config.WebRoot,
Display: "standalone",
BackgroundColor: "#ffffff",
ThemeColor: "#317f6f",
Icons: []*webManifestIcon{
{SRC: "/static/images/pwa-192x192.png", Sizes: "192x192", Type: "image/png"},
{SRC: "/static/images/pwa-512x512.png", Sizes: "512x512", Type: "image/png"},
},
}
return s.writeJSONWithContentType(w, response, "application/manifest+json")
}
// handleMetrics returns Prometheus metrics. This endpoint is only called if enable-metrics is set,
// and listen-metrics-http is not set.
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visitor) error {
@ -760,9 +804,12 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
if s.config.TwilioAccount != "" && call != "" {
go s.callPhone(v, r, m, call)
}
if s.config.UpstreamBaseURL != "" {
if s.config.UpstreamBaseURL != "" && !unifiedpush { // UP messages are not sent to upstream
go s.forwardPollRequest(v, m)
}
if s.config.WebPushPublicKey != "" {
go s.publishToWebPushEndpoints(v, m)
}
} else {
logvrm(v, r, m).Tag(tagPublish).Debug("Message delayed, will process later")
}
@ -1696,6 +1743,9 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
if s.config.UpstreamBaseURL != "" {
go s.forwardPollRequest(v, m)
}
if s.config.WebPushPublicKey != "" {
go s.publishToWebPushEndpoints(v, m)
}
if err := s.messageCache.MarkPublished(m); err != nil {
return err
}
@ -1916,7 +1966,11 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
}
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
w.Header().Set("Content-Type", "application/json")
return s.writeJSONWithContentType(w, v, "application/json")
}
func (s *Server) writeJSONWithContentType(w http.ResponseWriter, v any, contentType string) error {
w.Header().Set("Content-Type", contentType)
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if err := json.NewEncoder(w).Encode(v); err != nil {
return err

View file

@ -144,6 +144,27 @@
# smtp-server-domain:
# smtp-server-addr-prefix:
# Web Push support (background notifications for browsers)
#
# If enabled, allows ntfy to receive push notifications, even when the ntfy web app is closed. When enabled, users
# can enable background notifications in the web app. Once enabled, ntfy will forward published messages to the push
# endpoint, which will then forward it to the browser.
#
# You must configure web-push-public/private key, web-push-file, and web-push-email-address below to enable Web Push.
# Run "ntfy webpush keys" to generate the keys.
#
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
# - web-push-email-address is the admin email address send to the push provider, e.g. `sysadmin@example.com`
# - web-push-startup-queries is an optional list of queries to run on startup`
#
# web-push-public-key:
# web-push-private-key:
# web-push-file:
# web-push-email-address:
# web-push-startup-queries:
# If enabled, ntfy can perform voice calls via Twilio via the "X-Call" header.
#
# - twilio-account is the Twilio account SID, e.g. AC12345beefbeef67890beefbeef122586

View file

@ -170,6 +170,11 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
return errHTTPBadRequestIncorrectPasswordConfirmation
}
if s.webPush != nil && u.ID != "" {
if err := s.webPush.RemoveSubscriptionsByUserID(u.ID); err != nil {
logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name)
}
}
if u.Billing.StripeSubscriptionID != "" {
logvr(v, r).Tag(tagStripe).Info("Canceling billing subscription for user %s", u.Name)
if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil {

View file

@ -15,6 +15,7 @@ func (s *Server) execManager() {
s.pruneTokens()
s.pruneAttachments()
s.pruneMessages()
s.pruneAndNotifyWebPushSubscriptions()
// Message count per topic
var messagesCached int

View file

@ -58,6 +58,15 @@ func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
}
}
func (s *Server) ensureWebPushEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.WebRoot == "" || s.config.WebPushPublicKey == "" {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureUserManager(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.userManager == nil {

View file

@ -22,6 +22,7 @@ import (
"testing"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
@ -238,6 +239,12 @@ func TestServer_WebEnabled(t *testing.T) {
rr = request(t, s, "GET", "/config.js", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s, "GET", "/sw.js", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s, "GET", "/app.html", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s, "GET", "/static/css/home.css", "", nil)
require.Equal(t, 404, rr.Code)
@ -250,6 +257,35 @@ func TestServer_WebEnabled(t *testing.T) {
rr = request(t, s2, "GET", "/config.js", "", nil)
require.Equal(t, 200, rr.Code)
rr = request(t, s2, "GET", "/sw.js", "", nil)
require.Equal(t, 200, rr.Code)
rr = request(t, s2, "GET", "/app.html", "", nil)
require.Equal(t, 200, rr.Code)
}
func TestServer_WebPushEnabled(t *testing.T) {
conf := newTestConfig(t)
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf2 := newTestConfig(t)
s2 := newTestServer(t, conf2)
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf3 := newTestConfigWithWebPush(t)
s3 := newTestServer(t, conf3)
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 200, rr.Code)
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
}
func TestServer_PublishLargeMessage(t *testing.T) {
@ -2559,6 +2595,29 @@ func TestServer_UpstreamBaseURL_With_Access_Token_Success(t *testing.T) {
})
}
func TestServer_UpstreamBaseURL_DoNotForwardUnifiedPush(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("UnifiedPush messages should not be forwarded")
}))
defer upstreamServer.Close()
c := newTestConfigWithAuthFile(t)
c.BaseURL = "http://myserver.internal"
c.UpstreamBaseURL = upstreamServer.URL
s := newTestServer(t, c)
// Send UP message, this should not forward to upstream server
response := request(t, s, "PUT", "/mytopic?up=1", `hi there`, nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.NotEmpty(t, m.ID)
require.Equal(t, "hi there", m.Message)
// Forwarding is done asynchronously, so wait a bit.
// This ensures that the t.Fatal above is actually not triggered.
time.Sleep(500 * time.Millisecond)
}
func newTestConfig(t *testing.T) *Config {
conf := NewConfig()
conf.BaseURL = "http://127.0.0.1:12345"
@ -2568,19 +2627,33 @@ func newTestConfig(t *testing.T) *Config {
return conf
}
func newTestConfigWithAuthFile(t *testing.T) *Config {
conf := newTestConfig(t)
func configureAuth(t *testing.T, conf *Config) *Config {
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
conf.AuthStartupQueries = "pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;"
conf.AuthBcryptCost = bcrypt.MinCost // This speeds up tests a lot
return conf
}
func newTestConfigWithAuthFile(t *testing.T) *Config {
conf := newTestConfig(t)
conf = configureAuth(t, conf)
return conf
}
func newTestConfigWithWebPush(t *testing.T) *Config {
conf := newTestConfig(t)
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
require.Nil(t, err)
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
conf.WebPushEmailAddress = "testing@example.com"
conf.WebPushPrivateKey = privateKey
conf.WebPushPublicKey = publicKey
return conf
}
func newTestServer(t *testing.T, config *Config) *Server {
server, err := New(config)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
return server
}

171
server/server_webpush.go Normal file
View file

@ -0,0 +1,171 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/SherClockHolmes/webpush-go"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
)
const (
webPushTopicSubscribeLimit = 50
)
var (
webPushAllowedEndpointsPatterns = []string{
"https://*.google.com/",
"https://*.googleapis.com/",
"https://*.mozilla.com/",
"https://*.mozaws.net/",
"https://*.windows.com/",
"https://*.microsoft.com/",
"https://*.apple.com/",
}
webPushAllowedEndpointsRegex *regexp.Regexp
)
func init() {
for i, pattern := range webPushAllowedEndpointsPatterns {
webPushAllowedEndpointsPatterns[i] = strings.ReplaceAll(strings.ReplaceAll(pattern, ".", "\\."), "*", ".+")
}
allPatterns := fmt.Sprintf("^(%s)", strings.Join(webPushAllowedEndpointsPatterns, "|"))
webPushAllowedEndpointsRegex = regexp.MustCompile(allPatterns)
}
func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil || req.Endpoint == "" || req.P256dh == "" || req.Auth == "" {
return errHTTPBadRequestWebPushSubscriptionInvalid
} else if !webPushAllowedEndpointsRegex.MatchString(req.Endpoint) {
return errHTTPBadRequestWebPushEndpointUnknown
} else if len(req.Topics) > webPushTopicSubscribeLimit {
return errHTTPBadRequestWebPushTopicCountTooHigh
}
topics, err := s.topicsFromIDs(req.Topics...)
if err != nil {
return err
}
if s.userManager != nil {
u := v.User()
for _, t := range topics {
if err := s.userManager.Authorize(u, t.ID, user.PermissionRead); err != nil {
logvr(v, r).With(t).Err(err).Debug("Access to topic %s not authorized", t.ID)
return errHTTPForbidden.With(t)
}
}
}
if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *visitor) error {
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil || req.Endpoint == "" {
return errHTTPBadRequestWebPushSubscriptionInvalid
}
if err := s.webPush.RemoveSubscriptionsByEndpoint(req.Endpoint); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
if err != nil {
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
return
}
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m))
if err != nil {
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
return
}
for _, subscription := range subscriptions {
if err := s.sendWebPushNotification(subscription, payload, v, m); err != nil {
log.Tag(tagWebPush).Err(err).With(v, m, subscription).Warn("Unable to publish web push message")
}
}
}
func (s *Server) pruneAndNotifyWebPushSubscriptions() {
if s.config.WebPushPublicKey == "" {
return
}
go func() {
if err := s.pruneAndNotifyWebPushSubscriptionsInternal(); err != nil {
log.Tag(tagWebPush).Err(err).Warn("Unable to prune or notify web push subscriptions")
}
}()
}
func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
// Expire old subscriptions
if err := s.webPush.RemoveExpiredSubscriptions(s.config.WebPushExpiryDuration); err != nil {
return err
}
// Notify subscriptions that will expire soon
subscriptions, err := s.webPush.SubscriptionsExpiring(s.config.WebPushExpiryWarningDuration)
if err != nil {
return err
} else if len(subscriptions) == 0 {
return nil
}
payload, err := json.Marshal(newWebPushSubscriptionExpiringPayload())
if err != nil {
return err
}
warningSent := make([]*webPushSubscription, 0)
for _, subscription := range subscriptions {
if err := s.sendWebPushNotification(subscription, payload); err != nil {
log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning")
continue
}
warningSent = append(warningSent, subscription)
}
if err := s.webPush.MarkExpiryWarningSent(warningSent); err != nil {
return err
}
log.Tag(tagWebPush).Debug("Expired old subscriptions and published %d expiry imminent warnings", len(subscriptions))
return nil
}
func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error {
log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message")
payload := &webpush.Subscription{
Endpoint: sub.Endpoint,
Keys: webpush.Keys{
Auth: sub.Auth,
P256dh: sub.P256dh,
},
}
resp, err := webpush.SendNotification(message, payload, &webpush.Options{
Subscriber: s.config.WebPushEmailAddress,
VAPIDPublicKey: s.config.WebPushPublicKey,
VAPIDPrivateKey: s.config.WebPushPrivateKey,
Urgency: webpush.UrgencyHigh, // iOS requires this to ensure delivery
TTL: int(s.config.CacheDuration.Seconds()),
})
if err != nil {
log.Tag(tagWebPush).With(sub).With(contexters...).Err(err).Debug("Unable to publish web push message, removing endpoint")
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
return err
}
return err
}
if (resp.StatusCode < 200 || resp.StatusCode > 299) && resp.StatusCode != 429 {
log.Tag(tagWebPush).With(sub).With(contexters...).Field("response_code", resp.StatusCode).Debug("Unable to publish web push message, unexpected response")
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
return err
}
return errHTTPInternalErrorWebPushUnableToPublish.With(sub).With(contexters...)
}
return nil
}

View file

@ -0,0 +1,256 @@
package server
import (
"encoding/json"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
)
const (
testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
)
func TestServer_WebPush_Disabled(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 404, response.Code)
}
func TestServer_WebPush_TopicAdd(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "")
}
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
}
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
topicList := make([]string, 51)
for i := range topicList {
topicList[i] = util.RandomString(5)
}
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
}
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_Delete(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
}
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 403, response.Code)
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
// should've been deleted with the account
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_Publish(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/push-receive", r.URL.Path)
require.Equal(t, "high", r.Header.Get("Urgency"))
require.Equal(t, "", r.Header.Get("Topic"))
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
})
}
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(http.StatusGone)
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
requireSubscriptionCount(t, s, "test-topic", 1)
requireSubscriptionCount(t, s, "test-topic-abc", 1)
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
})
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic-abc", 0)
}
func TestServer_WebPush_Expiry(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(200)
w.Write([]byte(``))
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-7*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()
requireSubscriptionCount(t, s, "test-topic", 1)
waitFor(t, func() bool {
return received.Load()
})
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-9*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()
waitFor(t, func() bool {
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
return len(subs) == 0
})
}
func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
topicsJSON, err := json.Marshal(topics)
require.Nil(t, err)
return fmt.Sprintf(`{
"topics": %s,
"endpoint": "%s",
"p256dh": "p256dh-key",
"auth": "auth-key"
}`, topicsJSON, endpoint)
}
func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
}
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
subs, err := s.webPush.SubscriptionsForTopic(topic)
require.Nil(t, err)
require.Len(t, subs, expectedLength)
}

View file

@ -1,12 +1,13 @@
package server
import (
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"net/http"
"net/netip"
"time"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
)
@ -41,7 +42,7 @@ type message struct {
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // Username of the uploader, used to associated attachments
User string `json:"-"` // UserID of the uploader, used to associated attachments
}
func (m *message) Context() log.Context {
@ -398,7 +399,9 @@ type apiConfigResponse struct {
EnableCalls bool `json:"enable_calls"`
EnableEmails bool `json:"enable_emails"`
EnableReservations bool `json:"enable_reservations"`
EnableWebPush bool `json:"enable_web_push"`
BillingContact string `json:"billing_contact"`
WebPushPublicKey string `json:"web_push_public_key"`
DisallowedTopics []string `json:"disallowed_topics"`
}
@ -463,3 +466,75 @@ type apiStripeSubscriptionDeletedEvent struct {
ID string `json:"id"`
Customer string `json:"customer"`
}
type apiWebPushUpdateSubscriptionRequest struct {
Endpoint string `json:"endpoint"`
Auth string `json:"auth"`
P256dh string `json:"p256dh"`
Topics []string `json:"topics"`
}
// List of possible Web Push events (see sw.js)
const (
webPushMessageEvent = "message"
webPushExpiringEvent = "subscription_expiring"
)
type webPushPayload struct {
Event string `json:"event"`
SubscriptionID string `json:"subscription_id"`
Message *message `json:"message"`
}
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
return &webPushPayload{
Event: webPushMessageEvent,
SubscriptionID: subscriptionID,
Message: message,
}
}
type webPushControlMessagePayload struct {
Event string `json:"event"`
}
func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload {
return &webPushControlMessagePayload{
Event: webPushExpiringEvent,
}
}
type webPushSubscription struct {
ID string
Endpoint string
Auth string
P256dh string
UserID string
}
func (w *webPushSubscription) Context() log.Context {
return map[string]any{
"web_push_subscription_id": w.ID,
"web_push_subscription_user_id": w.UserID,
"web_push_subscription_endpoint": w.Endpoint,
}
}
// https://developer.mozilla.org/en-US/docs/Web/Manifest
type webManifestResponse struct {
Name string `json:"name"`
Description string `json:"description"`
ShortName string `json:"short_name"`
Scope string `json:"scope"`
StartURL string `json:"start_url"`
Display string `json:"display"`
BackgroundColor string `json:"background_color"`
ThemeColor string `json:"theme_color"`
Icons []*webManifestIcon `json:"icons"`
}
type webManifestIcon struct {
SRC string `json:"src"`
Sizes string `json:"sizes"`
Type string `json:"type"`
}

280
server/webpush_store.go Normal file
View file

@ -0,0 +1,280 @@
package server
import (
"database/sql"
"errors"
"heckel.io/ntfy/util"
"net/netip"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
const (
subscriptionIDPrefix = "wps_"
subscriptionIDLength = 10
subscriptionEndpointLimitPerSubscriberIP = 10
)
var (
errWebPushNoRows = errors.New("no rows found")
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
)
const (
createWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
selectWebPushSubscriptionsForTopicQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = ?
ORDER BY endpoint
`
selectWebPushSubscriptionsExpiringSoonQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= ?
`
insertWebPushSubscriptionQuery = `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
)
// Schema management queries
const (
currentWebPushSchemaVersion = 1
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
type webPushStore struct {
db *sql.DB
}
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupWebPushDB(db); err != nil {
return nil, err
}
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
}, nil
}
func setupWebPushDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rows, err := db.Query(selectWebPushSchemaVersionQuery)
if err != nil {
return setupNewWebPushDB(db)
}
return rows.Close()
}
func setupNewWebPushDB(db *sql.DB) error {
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
// existing entries for a given endpoint.
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
if err != nil {
return err
}
defer rowsCount.Close()
var subscriptionCount int
if !rowsCount.Next() {
return errWebPushNoRows
}
if err := rowsCount.Scan(&subscriptionCount); err != nil {
return err
}
if err := rowsCount.Close(); err != nil {
return err
}
// Read existing subscription ID for endpoint (or create new ID)
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
if err != nil {
return err
}
defer rows.Close()
var subscriptionID string
if rows.Next() {
if err := rows.Scan(&subscriptionID); err != nil {
return err
}
} else {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return errWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
}
if err := rows.Close(); err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
return err
}
}
return tx.Commit()
}
// SubscriptionsForTopic returns all subscriptions for the given topic
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
if err != nil {
return nil, err
}
defer rows.Close()
return c.subscriptionsFromRows(rows)
}
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
if err != nil {
return nil, err
}
defer rows.Close()
return c.subscriptionsFromRows(rows)
}
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, subscription := range subscriptions {
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
return err
}
}
return tx.Commit()
}
func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
subscriptions := make([]*webPushSubscription, 0)
for rows.Next() {
var id, endpoint, auth, p256dh, userID string
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
return nil, err
}
subscriptions = append(subscriptions, &webPushSubscription{
ID: id,
Endpoint: endpoint,
Auth: auth,
P256dh: p256dh,
UserID: userID,
})
}
return subscriptions, nil
}
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
return err
}
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
if userID == "" {
return errWebPushUserIDCannotBeEmpty
}
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
return err
}
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
return err
}
// Close closes the underlying database connection
func (c *webPushStore) Close() error {
return c.db.Close()
}

View file

@ -0,0 +1,199 @@
package server
import (
"fmt"
"github.com/stretchr/testify/require"
"net/netip"
"path/filepath"
"testing"
"time"
)
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "u_1234")
subs2, err := webPush.SubscriptionsForTopic("mytopic")
require.Nil(t, err)
require.Len(t, subs2, 1)
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
}
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
}
// Another one for the same endpoint should be fine
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different endpoint it should fail
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different IP address it should be fine again
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
}
func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics, and another with one topic
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
subs, err = webPush.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
// Update the first subscription to have only one topic
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
subs, err = webPush.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234"))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID(""))
}
func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Mark them as warning sent
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
require.Nil(t, err)
defer rows.Close()
var endpoint string
require.True(t, rows.Next())
require.Nil(t, rows.Scan(&endpoint))
require.Nil(t, err)
require.Equal(t, testWebPushEndpoint, endpoint)
require.False(t, rows.Next())
}
func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as soon-to-expire
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Should not be cleaned up yet
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
// Run expiration
subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
}
func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as expired
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Run expiration
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
// List again, should be 0
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func newTestWebPushStore(t *testing.T) *webPushStore {
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
return webPush
}