From bdeec4d2977c650a04ac694b45d93fc7e7a243c1 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Wed, 22 Feb 2023 22:26:43 -0500 Subject: [PATCH] Polish a little --- server/server.go | 67 +++++++++++++++++++++---------------- server/server_middleware.go | 14 ++++++-- server/server_test.go | 2 +- server/topic.go | 4 +++ server/util.go | 16 +++++---- 5 files changed, 63 insertions(+), 40 deletions(-) diff --git a/server/server.go b/server/server.go index ec4ca670..d1b0122d 100644 --- a/server/server.go +++ b/server/server.go @@ -104,15 +104,15 @@ var ( ) const ( - firebaseControlTopic = "~control" // See Android if changed - firebasePollTopic = "~poll" // See iOS if changed - emptyMessageBody = "triggered" // Used if message body is empty - newMessageBody = "New message" // Used in poll requests as generic message - defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment - encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages - jsonBodyBytesLimit = 16384 - unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber - rateVisitorExpiryDuration = 12 * time.Hour + firebaseControlTopic = "~control" // See Android if changed + firebasePollTopic = "~poll" // See iOS if changed + emptyMessageBody = "triggered" // Used if message body is empty + newMessageBody = "New message" // Used in poll requests as generic message + defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment + encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages + jsonBodyBytesLimit = 16384 + unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber + rateTopicsWildcard = "*" // Allows defining all topics in the request subscriber-rate-limited topics ) // WebSocket constants @@ -571,11 +571,11 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { } func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { - vrate, ok := r.Context().Value("vRate").(*visitor) + vrate, ok := r.Context().Value(contextRateVisitor).(*visitor) if !ok { return nil, errHTTPInternalError } - t, ok := r.Context().Value("topic").(*topic) + t, ok := r.Context().Value(contextTopic).(*topic) if !ok { return nil, errHTTPInternalError } @@ -709,7 +709,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { } } -func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { +func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { cache = readBoolParam(r, true, "x-cache", "cache") firebase = readBoolParam(r, true, "x-firebase", "firebase") m.Title = readParam(r, "x-title", "title", "t") @@ -749,7 +749,7 @@ func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) } email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") if email != "" { - if !vRate.EmailAllowed() { + if !vrate.EmailAllowed() { return false, false, "", false, errHTTPTooManyRequestsLimitEmails } } @@ -954,7 +954,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * if err != nil { return err } - poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r) + poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r) if err != nil { return err } @@ -984,12 +984,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * } return nil } - for _, t := range topics { - subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well - if subscriberRateLimited { - t.SetRateVisitor(v) - } - } + registerRateVisitors(topics, rateTopics, v) w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { @@ -1042,7 +1037,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi if err != nil { return err } - poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r) + poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r) if err != nil { return err } @@ -1125,12 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } return conn.WriteJSON(msg) } - for _, t := range topics { - subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well - if subscriberRateLimited { - t.SetRateVisitor(v) - } - } + registerRateVisitors(topics, rateTopics, v) w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests if poll { return s.sendOldMessages(topics, since, scheduled, v, sub) @@ -1158,7 +1148,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return err } -func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) { +func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, rateTopics []string, err error) { poll = readBoolParam(r, false, "x-poll", "poll", "po") scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched") since, err = parseSince(r, poll) @@ -1169,10 +1159,29 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu if err != nil { return } - subscriberTopics = readCommaSeparatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt") + rateTopics = readCommaSeparatedParam(r, "x-rate-topics", "rate-topics") return } +// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic +// will be rate limited against the rate visitor instead of the publishing visitor. +// +// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition +// until the Android app will send the "Rate-Topics" header. +func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) { + if len(rateTopics) > 0 && rateTopics[0] == rateTopicsWildcard { + for _, t := range topics { + t.SetRateVisitor(v) + } + } else { + for _, t := range topics { + if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) { + t.SetRateVisitor(v) + } + } + } +} + // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the // marker, returning only messages that are newer than the marker. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error { diff --git a/server/server_middleware.go b/server/server_middleware.go index 712223c4..750a02e0 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -1,12 +1,18 @@ package server import ( - "context" "net/http" "heckel.io/ntfy/util" ) +type contextKey int + +const ( + contextRateVisitor contextKey = iota + 2586 + contextTopic +) + func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { @@ -29,8 +35,10 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc { if rateVisitor := t.RateVisitor(); rateVisitor != nil { vrate = rateVisitor } - r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t)) - + r = withContext(r, map[contextKey]any{ + contextRateVisitor: vrate, + contextTopic: t, + }) if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if !vrate.RequestAllowed() { diff --git a/server/server_test.go b/server/server_test.go index 7f2665a0..5e2a30a7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1899,7 +1899,7 @@ func TestServer_SubscriberRateLimiting(t *testing.T) { r.RemoteAddr = "1.2.3.4" } rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{ - "Subscriber-Rate-Limit-Topics": "subscriber1topic", + "Rate-Topics": "subscriber1topic", }, subscriber1Fn) require.Equal(t, 200, rr.Code) require.Equal(t, "", rr.Body.String()) diff --git a/server/topic.go b/server/topic.go index 3b0cb542..9ba60e2d 100644 --- a/server/topic.go +++ b/server/topic.go @@ -8,6 +8,10 @@ import ( "heckel.io/ntfy/log" ) +const ( + rateVisitorExpiryDuration = 12 * time.Hour +) + // topic represents a channel to which subscribers can subscribe, and publishers // can publish a message type topic struct { diff --git a/server/util.go b/server/util.go index 8ec258fc..c6f78102 100644 --- a/server/util.go +++ b/server/util.go @@ -1,6 +1,7 @@ package server import ( + "context" "heckel.io/ntfy/util" "io" "net/http" @@ -45,13 +46,6 @@ func readHeaderParam(r *http.Request, names ...string) string { return "" } -func readHeaderParamValues(r *http.Request, names ...string) (values []string) { - for _, name := range names { - values = append(values, r.Header.Values(name)...) - } - return -} - func readQueryParam(r *http.Request, names ...string) string { for _, name := range names { value := r.URL.Query().Get(strings.ToLower(name)) @@ -103,3 +97,11 @@ func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, } return obj, nil } + +func withContext(r *http.Request, ctx map[contextKey]any) *http.Request { + c := r.Context() + for k, v := range ctx { + c = context.WithValue(c, k, v) + } + return r.WithContext(c) +}