Merge branch 'main' into markdown
This commit is contained in:
commit
43981bb675
129 changed files with 6741 additions and 1183 deletions
|
@ -1,10 +1,11 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/user"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/user"
|
||||
)
|
||||
|
||||
// Defines default config settings (excluding limits, see below)
|
||||
|
@ -22,6 +23,12 @@ const (
|
|||
DefaultStripePriceCacheDuration = 3 * time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed
|
||||
)
|
||||
|
||||
// Defines default Web Push settings
|
||||
const (
|
||||
DefaultWebPushExpiryWarningDuration = 7 * 24 * time.Hour
|
||||
DefaultWebPushExpiryDuration = 9 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// Defines all global and per-visitor limits
|
||||
// - message size limit: the max number of bytes for a message
|
||||
// - total topic limit: max number of topics overall
|
||||
|
@ -146,6 +153,13 @@ type Config struct {
|
|||
EnableMetrics bool
|
||||
AccessControlAllowOrigin string // CORS header field to restrict access from web clients
|
||||
Version string // injected by App
|
||||
WebPushPrivateKey string
|
||||
WebPushPublicKey string
|
||||
WebPushFile string
|
||||
WebPushEmailAddress string
|
||||
WebPushStartupQueries string
|
||||
WebPushExpiryDuration time.Duration
|
||||
WebPushExpiryWarningDuration time.Duration
|
||||
}
|
||||
|
||||
// NewConfig instantiates a default new server config
|
||||
|
@ -227,5 +241,11 @@ func NewConfig() *Config {
|
|||
EnableReservations: false,
|
||||
AccessControlAllowOrigin: "*",
|
||||
Version: "",
|
||||
WebPushPrivateKey: "",
|
||||
WebPushPublicKey: "",
|
||||
WebPushFile: "",
|
||||
WebPushEmailAddress: "",
|
||||
WebPushExpiryDuration: DefaultWebPushExpiryDuration,
|
||||
WebPushExpiryWarningDuration: DefaultWebPushExpiryWarningDuration,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,6 +114,9 @@ var (
|
|||
errHTTPBadRequestAnonymousCallsNotAllowed = &errHTTP{40035, http.StatusBadRequest, "invalid request: anonymous phone calls are not allowed", "https://ntfy.sh/docs/publish/#phone-calls", nil}
|
||||
errHTTPBadRequestPhoneNumberVerifyChannelInvalid = &errHTTP{40036, http.StatusBadRequest, "invalid request: verification channel must be 'sms' or 'call'", "https://ntfy.sh/docs/publish/#phone-calls", nil}
|
||||
errHTTPBadRequestDelayNoCall = &errHTTP{40037, http.StatusBadRequest, "delayed call notifications are not supported", "", nil}
|
||||
errHTTPBadRequestWebPushSubscriptionInvalid = &errHTTP{40038, http.StatusBadRequest, "invalid request: web push payload malformed", "", nil}
|
||||
errHTTPBadRequestWebPushEndpointUnknown = &errHTTP{40039, http.StatusBadRequest, "invalid request: web push endpoint unknown", "", nil}
|
||||
errHTTPBadRequestWebPushTopicCountTooHigh = &errHTTP{40040, http.StatusBadRequest, "invalid request: too many web push topic subscriptions", "", nil}
|
||||
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
|
||||
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
|
@ -138,5 +141,6 @@ var (
|
|||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil}
|
||||
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil}
|
||||
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil}
|
||||
errHTTPInternalErrorWebPushUnableToPublish = &errHTTP{50004, http.StatusInternalServerError, "internal server error: unable to publish web push message", "", nil}
|
||||
errHTTPInsufficientStorageUnifiedPush = &errHTTP{50701, http.StatusInsufficientStorage, "cannot publish to UnifiedPush topic without previously active subscriber", "", nil}
|
||||
)
|
||||
|
|
|
@ -29,6 +29,7 @@ const (
|
|||
tagResetter = "resetter"
|
||||
tagWebsocket = "websocket"
|
||||
tagMatrix = "matrix"
|
||||
tagWebPush = "webpush"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -270,7 +270,7 @@ func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupDB(db, startupQueries, cacheDuration); err != nil {
|
||||
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var queue *util.BatchingQueue[*message]
|
||||
|
@ -749,7 +749,7 @@ func (c *messageCache) Close() error {
|
|||
return c.db.Close()
|
||||
}
|
||||
|
||||
func setupDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
// Run startup queries
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
|
|
|
@ -9,13 +9,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -32,6 +25,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
|
@ -52,6 +53,7 @@ type Server struct {
|
|||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache *messageCache // Database that stores the messages
|
||||
webPush *webPushStore // Database that stores web push subscriptions
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
|
@ -76,11 +78,15 @@ var (
|
|||
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
|
||||
|
||||
webConfigPath = "/config.js"
|
||||
webManifestPath = "/manifest.webmanifest"
|
||||
webRootHTMLPath = "/app.html"
|
||||
webServiceWorkerPath = "/sw.js"
|
||||
accountPath = "/account"
|
||||
matrixPushPath = "/_matrix/push/v1/notify"
|
||||
metricsPath = "/metrics"
|
||||
apiHealthPath = "/v1/health"
|
||||
apiStatsPath = "/v1/stats"
|
||||
apiWebPushPath = "/v1/webpush"
|
||||
apiTiersPath = "/v1/tiers"
|
||||
apiUsersPath = "/v1/users"
|
||||
apiUsersAccessPath = "/v1/users/access"
|
||||
|
@ -151,6 +157,13 @@ func New(conf *Config) (*Server, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var webPush *webPushStore
|
||||
if conf.WebPushPublicKey != "" {
|
||||
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
topics, err := messageCache.Topics()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -190,6 +203,7 @@ func New(conf *Config) (*Server, error) {
|
|||
s := &Server{
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
webPush: webPush,
|
||||
fileCache: fileCache,
|
||||
firebaseClient: firebaseClient,
|
||||
smtpSender: mailer,
|
||||
|
@ -342,6 +356,9 @@ func (s *Server) closeDatabases() {
|
|||
s.userManager.Close()
|
||||
}
|
||||
s.messageCache.Close()
|
||||
if s.webPush != nil {
|
||||
s.webPush.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// handle is the main entry point for all HTTP requests
|
||||
|
@ -416,6 +433,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
return s.handleHealth(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
|
||||
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == webManifestPath {
|
||||
return s.ensureWebPushEnabled(s.handleWebManifest)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiUsersPath {
|
||||
return s.ensureAdmin(s.handleUsersGet)(w, r, v)
|
||||
} else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
|
||||
|
@ -470,6 +489,10 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberAdd)))(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPhonePath {
|
||||
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberDelete)))(w, r, v)
|
||||
} else if r.Method == http.MethodPost && apiWebPushPath == r.URL.Path {
|
||||
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushUpdate))(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && apiWebPushPath == r.URL.Path {
|
||||
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushDelete))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath {
|
||||
return s.handleStats(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath {
|
||||
|
@ -478,7 +501,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
return s.handleMatrixDiscovery(w)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == metricsPath && s.metricsHandler != nil {
|
||||
return s.handleMetrics(w, r, v)
|
||||
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
||||
} else if r.Method == http.MethodGet && (staticRegex.MatchString(r.URL.Path) || r.URL.Path == webServiceWorkerPath || r.URL.Path == webRootHTMLPath) {
|
||||
return s.ensureWebEnabled(s.handleStatic)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
||||
return s.ensureWebEnabled(s.handleDocs)(w, r, v)
|
||||
|
@ -552,7 +575,9 @@ func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visi
|
|||
EnableCalls: s.config.TwilioAccount != "",
|
||||
EnableEmails: s.config.SMTPSenderFrom != "",
|
||||
EnableReservations: s.config.EnableReservations,
|
||||
EnableWebPush: s.config.WebPushPublicKey != "",
|
||||
BillingContact: s.config.BillingContact,
|
||||
WebPushPublicKey: s.config.WebPushPublicKey,
|
||||
DisallowedTopics: s.config.DisallowedTopics,
|
||||
}
|
||||
b, err := json.MarshalIndent(response, "", " ")
|
||||
|
@ -564,6 +589,25 @@ func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visi
|
|||
return err
|
||||
}
|
||||
|
||||
// handleWebManifest serves the web app manifest for the progressive web app (PWA)
|
||||
func (s *Server) handleWebManifest(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
response := &webManifestResponse{
|
||||
Name: "ntfy web",
|
||||
Description: "ntfy lets you send push notifications via scripts from any computer or phone",
|
||||
ShortName: "ntfy",
|
||||
Scope: "/",
|
||||
StartURL: s.config.WebRoot,
|
||||
Display: "standalone",
|
||||
BackgroundColor: "#ffffff",
|
||||
ThemeColor: "#317f6f",
|
||||
Icons: []*webManifestIcon{
|
||||
{SRC: "/static/images/pwa-192x192.png", Sizes: "192x192", Type: "image/png"},
|
||||
{SRC: "/static/images/pwa-512x512.png", Sizes: "512x512", Type: "image/png"},
|
||||
},
|
||||
}
|
||||
return s.writeJSONWithContentType(w, response, "application/manifest+json")
|
||||
}
|
||||
|
||||
// handleMetrics returns Prometheus metrics. This endpoint is only called if enable-metrics is set,
|
||||
// and listen-metrics-http is not set.
|
||||
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
|
@ -760,9 +804,12 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
|||
if s.config.TwilioAccount != "" && call != "" {
|
||||
go s.callPhone(v, r, m, call)
|
||||
}
|
||||
if s.config.UpstreamBaseURL != "" {
|
||||
if s.config.UpstreamBaseURL != "" && !unifiedpush { // UP messages are not sent to upstream
|
||||
go s.forwardPollRequest(v, m)
|
||||
}
|
||||
if s.config.WebPushPublicKey != "" {
|
||||
go s.publishToWebPushEndpoints(v, m)
|
||||
}
|
||||
} else {
|
||||
logvrm(v, r, m).Tag(tagPublish).Debug("Message delayed, will process later")
|
||||
}
|
||||
|
@ -1696,6 +1743,9 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
|||
if s.config.UpstreamBaseURL != "" {
|
||||
go s.forwardPollRequest(v, m)
|
||||
}
|
||||
if s.config.WebPushPublicKey != "" {
|
||||
go s.publishToWebPushEndpoints(v, m)
|
||||
}
|
||||
if err := s.messageCache.MarkPublished(m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1916,7 +1966,11 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
|||
}
|
||||
|
||||
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return s.writeJSONWithContentType(w, v, "application/json")
|
||||
}
|
||||
|
||||
func (s *Server) writeJSONWithContentType(w http.ResponseWriter, v any, contentType string) error {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
return err
|
||||
|
|
|
@ -144,6 +144,27 @@
|
|||
# smtp-server-domain:
|
||||
# smtp-server-addr-prefix:
|
||||
|
||||
# Web Push support (background notifications for browsers)
|
||||
#
|
||||
# If enabled, allows ntfy to receive push notifications, even when the ntfy web app is closed. When enabled, users
|
||||
# can enable background notifications in the web app. Once enabled, ntfy will forward published messages to the push
|
||||
# endpoint, which will then forward it to the browser.
|
||||
#
|
||||
# You must configure web-push-public/private key, web-push-file, and web-push-email-address below to enable Web Push.
|
||||
# Run "ntfy webpush keys" to generate the keys.
|
||||
#
|
||||
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
|
||||
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
|
||||
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
|
||||
# - web-push-email-address is the admin email address send to the push provider, e.g. `sysadmin@example.com`
|
||||
# - web-push-startup-queries is an optional list of queries to run on startup`
|
||||
#
|
||||
# web-push-public-key:
|
||||
# web-push-private-key:
|
||||
# web-push-file:
|
||||
# web-push-email-address:
|
||||
# web-push-startup-queries:
|
||||
|
||||
# If enabled, ntfy can perform voice calls via Twilio via the "X-Call" header.
|
||||
#
|
||||
# - twilio-account is the Twilio account SID, e.g. AC12345beefbeef67890beefbeef122586
|
||||
|
|
|
@ -170,6 +170,11 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
|
|||
if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
|
||||
return errHTTPBadRequestIncorrectPasswordConfirmation
|
||||
}
|
||||
if s.webPush != nil && u.ID != "" {
|
||||
if err := s.webPush.RemoveSubscriptionsByUserID(u.ID); err != nil {
|
||||
logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name)
|
||||
}
|
||||
}
|
||||
if u.Billing.StripeSubscriptionID != "" {
|
||||
logvr(v, r).Tag(tagStripe).Info("Canceling billing subscription for user %s", u.Name)
|
||||
if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil {
|
||||
|
|
|
@ -15,6 +15,7 @@ func (s *Server) execManager() {
|
|||
s.pruneTokens()
|
||||
s.pruneAttachments()
|
||||
s.pruneMessages()
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
|
||||
// Message count per topic
|
||||
var messagesCached int
|
||||
|
|
|
@ -58,6 +58,15 @@ func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) ensureWebPushEnabled(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.config.WebRoot == "" || s.config.WebPushPublicKey == "" {
|
||||
return errHTTPNotFound
|
||||
}
|
||||
return next(w, r, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ensureUserManager(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.userManager == nil {
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
|
@ -238,6 +239,12 @@ func TestServer_WebEnabled(t *testing.T) {
|
|||
rr = request(t, s, "GET", "/config.js", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
rr = request(t, s, "GET", "/sw.js", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
rr = request(t, s, "GET", "/app.html", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
rr = request(t, s, "GET", "/static/css/home.css", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
|
@ -250,6 +257,35 @@ func TestServer_WebEnabled(t *testing.T) {
|
|||
|
||||
rr = request(t, s2, "GET", "/config.js", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
rr = request(t, s2, "GET", "/sw.js", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
rr = request(t, s2, "GET", "/app.html", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestServer_WebPushEnabled(t *testing.T) {
|
||||
conf := newTestConfig(t)
|
||||
conf.WebRoot = "" // Disable web app
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf2 := newTestConfig(t)
|
||||
s2 := newTestServer(t, conf2)
|
||||
|
||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf3 := newTestConfigWithWebPush(t)
|
||||
s3 := newTestServer(t, conf3)
|
||||
|
||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
||||
|
||||
}
|
||||
|
||||
func TestServer_PublishLargeMessage(t *testing.T) {
|
||||
|
@ -2559,6 +2595,29 @@ func TestServer_UpstreamBaseURL_With_Access_Token_Success(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestServer_UpstreamBaseURL_DoNotForwardUnifiedPush(t *testing.T) {
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("UnifiedPush messages should not be forwarded")
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.BaseURL = "http://myserver.internal"
|
||||
c.UpstreamBaseURL = upstreamServer.URL
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Send UP message, this should not forward to upstream server
|
||||
response := request(t, s, "PUT", "/mytopic?up=1", `hi there`, nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.NotEmpty(t, m.ID)
|
||||
require.Equal(t, "hi there", m.Message)
|
||||
|
||||
// Forwarding is done asynchronously, so wait a bit.
|
||||
// This ensures that the t.Fatal above is actually not triggered.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *Config {
|
||||
conf := NewConfig()
|
||||
conf.BaseURL = "http://127.0.0.1:12345"
|
||||
|
@ -2568,19 +2627,33 @@ func newTestConfig(t *testing.T) *Config {
|
|||
return conf
|
||||
}
|
||||
|
||||
func newTestConfigWithAuthFile(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
func configureAuth(t *testing.T, conf *Config) *Config {
|
||||
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||
conf.AuthStartupQueries = "pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;"
|
||||
conf.AuthBcryptCost = bcrypt.MinCost // This speeds up tests a lot
|
||||
return conf
|
||||
}
|
||||
|
||||
func newTestConfigWithAuthFile(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
conf = configureAuth(t, conf)
|
||||
return conf
|
||||
}
|
||||
|
||||
func newTestConfigWithWebPush(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
||||
require.Nil(t, err)
|
||||
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
||||
conf.WebPushEmailAddress = "testing@example.com"
|
||||
conf.WebPushPrivateKey = privateKey
|
||||
conf.WebPushPublicKey = publicKey
|
||||
return conf
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, config *Config) *Server {
|
||||
server, err := New(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Nil(t, err)
|
||||
return server
|
||||
}
|
||||
|
||||
|
|
171
server/server_webpush.go
Normal file
171
server/server_webpush.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
)
|
||||
|
||||
const (
|
||||
webPushTopicSubscribeLimit = 50
|
||||
)
|
||||
|
||||
var (
|
||||
webPushAllowedEndpointsPatterns = []string{
|
||||
"https://*.google.com/",
|
||||
"https://*.googleapis.com/",
|
||||
"https://*.mozilla.com/",
|
||||
"https://*.mozaws.net/",
|
||||
"https://*.windows.com/",
|
||||
"https://*.microsoft.com/",
|
||||
"https://*.apple.com/",
|
||||
}
|
||||
webPushAllowedEndpointsRegex *regexp.Regexp
|
||||
)
|
||||
|
||||
func init() {
|
||||
for i, pattern := range webPushAllowedEndpointsPatterns {
|
||||
webPushAllowedEndpointsPatterns[i] = strings.ReplaceAll(strings.ReplaceAll(pattern, ".", "\\."), "*", ".+")
|
||||
}
|
||||
allPatterns := fmt.Sprintf("^(%s)", strings.Join(webPushAllowedEndpointsPatterns, "|"))
|
||||
webPushAllowedEndpointsRegex = regexp.MustCompile(allPatterns)
|
||||
}
|
||||
|
||||
func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil || req.Endpoint == "" || req.P256dh == "" || req.Auth == "" {
|
||||
return errHTTPBadRequestWebPushSubscriptionInvalid
|
||||
} else if !webPushAllowedEndpointsRegex.MatchString(req.Endpoint) {
|
||||
return errHTTPBadRequestWebPushEndpointUnknown
|
||||
} else if len(req.Topics) > webPushTopicSubscribeLimit {
|
||||
return errHTTPBadRequestWebPushTopicCountTooHigh
|
||||
}
|
||||
topics, err := s.topicsFromIDs(req.Topics...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.userManager != nil {
|
||||
u := v.User()
|
||||
for _, t := range topics {
|
||||
if err := s.userManager.Authorize(u, t.ID, user.PermissionRead); err != nil {
|
||||
logvr(v, r).With(t).Err(err).Debug("Access to topic %s not authorized", t.ID)
|
||||
return errHTTPForbidden.With(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil || req.Endpoint == "" {
|
||||
return errHTTPBadRequestWebPushSubscriptionInvalid
|
||||
}
|
||||
if err := s.webPush.RemoveSubscriptionsByEndpoint(req.Endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
||||
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
|
||||
if err != nil {
|
||||
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
|
||||
return
|
||||
}
|
||||
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
|
||||
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m))
|
||||
if err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
|
||||
return
|
||||
}
|
||||
for _, subscription := range subscriptions {
|
||||
if err := s.sendWebPushNotification(subscription, payload, v, m); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(v, m, subscription).Warn("Unable to publish web push message")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) pruneAndNotifyWebPushSubscriptions() {
|
||||
if s.config.WebPushPublicKey == "" {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
if err := s.pruneAndNotifyWebPushSubscriptionsInternal(); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).Warn("Unable to prune or notify web push subscriptions")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
|
||||
// Expire old subscriptions
|
||||
if err := s.webPush.RemoveExpiredSubscriptions(s.config.WebPushExpiryDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify subscriptions that will expire soon
|
||||
subscriptions, err := s.webPush.SubscriptionsExpiring(s.config.WebPushExpiryWarningDuration)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(subscriptions) == 0 {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(newWebPushSubscriptionExpiringPayload())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
warningSent := make([]*webPushSubscription, 0)
|
||||
for _, subscription := range subscriptions {
|
||||
if err := s.sendWebPushNotification(subscription, payload); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning")
|
||||
continue
|
||||
}
|
||||
warningSent = append(warningSent, subscription)
|
||||
}
|
||||
if err := s.webPush.MarkExpiryWarningSent(warningSent); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Tag(tagWebPush).Debug("Expired old subscriptions and published %d expiry imminent warnings", len(subscriptions))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message")
|
||||
payload := &webpush.Subscription{
|
||||
Endpoint: sub.Endpoint,
|
||||
Keys: webpush.Keys{
|
||||
Auth: sub.Auth,
|
||||
P256dh: sub.P256dh,
|
||||
},
|
||||
}
|
||||
resp, err := webpush.SendNotification(message, payload, &webpush.Options{
|
||||
Subscriber: s.config.WebPushEmailAddress,
|
||||
VAPIDPublicKey: s.config.WebPushPublicKey,
|
||||
VAPIDPrivateKey: s.config.WebPushPrivateKey,
|
||||
Urgency: webpush.UrgencyHigh, // iOS requires this to ensure delivery
|
||||
TTL: int(s.config.CacheDuration.Seconds()),
|
||||
})
|
||||
if err != nil {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Err(err).Debug("Unable to publish web push message, removing endpoint")
|
||||
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
if (resp.StatusCode < 200 || resp.StatusCode > 299) && resp.StatusCode != 429 {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Field("response_code", resp.StatusCode).Debug("Unable to publish web push message, unexpected response")
|
||||
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
return errHTTPInternalErrorWebPushUnableToPublish.With(sub).With(contexters...)
|
||||
}
|
||||
return nil
|
||||
}
|
256
server/server_webpush_test.go
Normal file
256
server/server_webpush_test.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
|
||||
)
|
||||
|
||||
func TestServer_WebPush_Disabled(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 404, response.Code)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "")
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
topicList := make([]string, 51)
|
||||
for i := range topicList {
|
||||
topicList[i] = util.RandomString(5)
|
||||
}
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Delete(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 403, response.Code)
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
// should've been deleted with the account
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/push-receive", r.URL.Path)
|
||||
require.Equal(t, "high", r.Header.Get("Urgency"))
|
||||
require.Equal(t, "", r.Header.Get("Topic"))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
||||
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Expiry(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(``))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-7*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-9*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
waitFor(t, func() bool {
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
return len(subs) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
|
||||
topicsJSON, err := json.Marshal(topics)
|
||||
require.Nil(t, err)
|
||||
|
||||
return fmt.Sprintf(`{
|
||||
"topics": %s,
|
||||
"endpoint": "%s",
|
||||
"p256dh": "p256dh-key",
|
||||
"auth": "auth-key"
|
||||
}`, topicsJSON, endpoint)
|
||||
}
|
||||
|
||||
func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
|
||||
require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
|
||||
}
|
||||
|
||||
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
|
||||
subs, err := s.webPush.SubscriptionsForTopic(topic)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, expectedLength)
|
||||
}
|
|
@ -1,12 +1,13 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
|
@ -41,7 +42,7 @@ type message struct {
|
|||
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
||||
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
User string `json:"-"` // Username of the uploader, used to associated attachments
|
||||
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
||||
}
|
||||
|
||||
func (m *message) Context() log.Context {
|
||||
|
@ -398,7 +399,9 @@ type apiConfigResponse struct {
|
|||
EnableCalls bool `json:"enable_calls"`
|
||||
EnableEmails bool `json:"enable_emails"`
|
||||
EnableReservations bool `json:"enable_reservations"`
|
||||
EnableWebPush bool `json:"enable_web_push"`
|
||||
BillingContact string `json:"billing_contact"`
|
||||
WebPushPublicKey string `json:"web_push_public_key"`
|
||||
DisallowedTopics []string `json:"disallowed_topics"`
|
||||
}
|
||||
|
||||
|
@ -463,3 +466,75 @@ type apiStripeSubscriptionDeletedEvent struct {
|
|||
ID string `json:"id"`
|
||||
Customer string `json:"customer"`
|
||||
}
|
||||
|
||||
type apiWebPushUpdateSubscriptionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Auth string `json:"auth"`
|
||||
P256dh string `json:"p256dh"`
|
||||
Topics []string `json:"topics"`
|
||||
}
|
||||
|
||||
// List of possible Web Push events (see sw.js)
|
||||
const (
|
||||
webPushMessageEvent = "message"
|
||||
webPushExpiringEvent = "subscription_expiring"
|
||||
)
|
||||
|
||||
type webPushPayload struct {
|
||||
Event string `json:"event"`
|
||||
SubscriptionID string `json:"subscription_id"`
|
||||
Message *message `json:"message"`
|
||||
}
|
||||
|
||||
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
|
||||
return &webPushPayload{
|
||||
Event: webPushMessageEvent,
|
||||
SubscriptionID: subscriptionID,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
type webPushControlMessagePayload struct {
|
||||
Event string `json:"event"`
|
||||
}
|
||||
|
||||
func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload {
|
||||
return &webPushControlMessagePayload{
|
||||
Event: webPushExpiringEvent,
|
||||
}
|
||||
}
|
||||
|
||||
type webPushSubscription struct {
|
||||
ID string
|
||||
Endpoint string
|
||||
Auth string
|
||||
P256dh string
|
||||
UserID string
|
||||
}
|
||||
|
||||
func (w *webPushSubscription) Context() log.Context {
|
||||
return map[string]any{
|
||||
"web_push_subscription_id": w.ID,
|
||||
"web_push_subscription_user_id": w.UserID,
|
||||
"web_push_subscription_endpoint": w.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// https://developer.mozilla.org/en-US/docs/Web/Manifest
|
||||
type webManifestResponse struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ShortName string `json:"short_name"`
|
||||
Scope string `json:"scope"`
|
||||
StartURL string `json:"start_url"`
|
||||
Display string `json:"display"`
|
||||
BackgroundColor string `json:"background_color"`
|
||||
ThemeColor string `json:"theme_color"`
|
||||
Icons []*webManifestIcon `json:"icons"`
|
||||
}
|
||||
|
||||
type webManifestIcon struct {
|
||||
SRC string `json:"src"`
|
||||
Sizes string `json:"sizes"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
|
280
server/webpush_store.go
Normal file
280
server/webpush_store.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"heckel.io/ntfy/util"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
subscriptionIDPrefix = "wps_"
|
||||
subscriptionIDLength = 10
|
||||
subscriptionEndpointLimitPerSubscriberIP = 10
|
||||
)
|
||||
|
||||
var (
|
||||
errWebPushNoRows = errors.New("no rows found")
|
||||
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
|
||||
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||
)
|
||||
|
||||
const (
|
||||
createWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
builtinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||
selectWebPushSubscriptionsForTopicQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = ?
|
||||
ORDER BY endpoint
|
||||
`
|
||||
selectWebPushSubscriptionsExpiringSoonQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= ?
|
||||
`
|
||||
insertWebPushSubscriptionQuery = `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`
|
||||
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
||||
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
||||
|
||||
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
||||
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentWebPushSchemaVersion = 1
|
||||
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
type webPushStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupWebPushDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &webPushStore{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(selectWebPushSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewWebPushDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
}
|
||||
|
||||
func setupNewWebPushDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(builtinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
|
||||
// existing entries for a given endpoint.
|
||||
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rowsCount.Close()
|
||||
var subscriptionCount int
|
||||
if !rowsCount.Next() {
|
||||
return errWebPushNoRows
|
||||
}
|
||||
if err := rowsCount.Scan(&subscriptionCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rowsCount.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
var subscriptionID string
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return errWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic
|
||||
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
|
||||
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
|
||||
subscriptions := make([]*webPushSubscription, 0)
|
||||
for rows.Next() {
|
||||
var id, endpoint, auth, p256dh, userID string
|
||||
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscriptions = append(subscriptions, &webPushSubscription{
|
||||
ID: id,
|
||||
Endpoint: endpoint,
|
||||
Auth: auth,
|
||||
P256dh: p256dh,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
|
||||
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
|
||||
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
|
||||
if userID == "" {
|
||||
return errWebPushUserIDCannotBeEmpty
|
||||
}
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (c *webPushStore) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
199
server/webpush_store_test.go
Normal file
199
server/webpush_store_test.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "u_1234")
|
||||
|
||||
subs2, err := webPush.SubscriptionsForTopic("mytopic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs2, 1)
|
||||
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert 10 subscriptions with the same IP address
|
||||
for i := 0; i < 10; i++ {
|
||||
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
// Another one for the same endpoint should be fine
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different endpoint it should fail
|
||||
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different IP address it should be fine again
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics, and another with one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
// Update the first subscription to have only one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234"))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID(""))
|
||||
}
|
||||
|
||||
func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Mark them as warning sent
|
||||
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
|
||||
|
||||
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
|
||||
require.Nil(t, err)
|
||||
defer rows.Close()
|
||||
var endpoint string
|
||||
require.True(t, rows.Next())
|
||||
require.Nil(t, rows.Scan(&endpoint))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, testWebPushEndpoint, endpoint)
|
||||
require.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as soon-to-expire
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Should not be cleaned up yet
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// Run expiration
|
||||
subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as expired
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Run expiration
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// List again, should be 0
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func newTestWebPushStore(t *testing.T) *webPushStore {
|
||||
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||
require.Nil(t, err)
|
||||
return webPush
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue