Polish a little
parent
21b27b5dbe
commit
bdeec4d297
|
@ -104,15 +104,15 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
firebaseControlTopic = "~control" // See Android if changed
|
firebaseControlTopic = "~control" // See Android if changed
|
||||||
firebasePollTopic = "~poll" // See iOS if changed
|
firebasePollTopic = "~poll" // See iOS if changed
|
||||||
emptyMessageBody = "triggered" // Used if message body is empty
|
emptyMessageBody = "triggered" // Used if message body is empty
|
||||||
newMessageBody = "New message" // Used in poll requests as generic message
|
newMessageBody = "New message" // Used in poll requests as generic message
|
||||||
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
||||||
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
||||||
jsonBodyBytesLimit = 16384
|
jsonBodyBytesLimit = 16384
|
||||||
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
||||||
rateVisitorExpiryDuration = 12 * time.Hour
|
rateTopicsWildcard = "*" // Allows defining all topics in the request subscriber-rate-limited topics
|
||||||
)
|
)
|
||||||
|
|
||||||
// WebSocket constants
|
// WebSocket constants
|
||||||
|
@ -571,11 +571,11 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
|
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
|
||||||
vrate, ok := r.Context().Value("vRate").(*visitor)
|
vrate, ok := r.Context().Value(contextRateVisitor).(*visitor)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errHTTPInternalError
|
return nil, errHTTPInternalError
|
||||||
}
|
}
|
||||||
t, ok := r.Context().Value("topic").(*topic)
|
t, ok := r.Context().Value(contextTopic).(*topic)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errHTTPInternalError
|
return nil, errHTTPInternalError
|
||||||
}
|
}
|
||||||
|
@ -709,7 +709,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
||||||
cache = readBoolParam(r, true, "x-cache", "cache")
|
cache = readBoolParam(r, true, "x-cache", "cache")
|
||||||
firebase = readBoolParam(r, true, "x-firebase", "firebase")
|
firebase = readBoolParam(r, true, "x-firebase", "firebase")
|
||||||
m.Title = readParam(r, "x-title", "title", "t")
|
m.Title = readParam(r, "x-title", "title", "t")
|
||||||
|
@ -749,7 +749,7 @@ func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message)
|
||||||
}
|
}
|
||||||
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
||||||
if email != "" {
|
if email != "" {
|
||||||
if !vRate.EmailAllowed() {
|
if !vrate.EmailAllowed() {
|
||||||
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
|
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -954,7 +954,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
|
poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -984,12 +984,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, t := range topics {
|
registerRateVisitors(topics, rateTopics, v)
|
||||||
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
|
||||||
if subscriberRateLimited {
|
|
||||||
t.SetRateVisitor(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||||
if poll {
|
if poll {
|
||||||
|
@ -1042,7 +1037,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
|
poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1125,12 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
return conn.WriteJSON(msg)
|
return conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
for _, t := range topics {
|
registerRateVisitors(topics, rateTopics, v)
|
||||||
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
|
||||||
if subscriberRateLimited {
|
|
||||||
t.SetRateVisitor(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
if poll {
|
if poll {
|
||||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||||
|
@ -1158,7 +1148,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) {
|
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, rateTopics []string, err error) {
|
||||||
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
||||||
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
||||||
since, err = parseSince(r, poll)
|
since, err = parseSince(r, poll)
|
||||||
|
@ -1169,10 +1159,29 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
subscriberTopics = readCommaSeparatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
|
rateTopics = readCommaSeparatedParam(r, "x-rate-topics", "rate-topics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic
|
||||||
|
// will be rate limited against the rate visitor instead of the publishing visitor.
|
||||||
|
//
|
||||||
|
// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
|
||||||
|
// until the Android app will send the "Rate-Topics" header.
|
||||||
|
func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) {
|
||||||
|
if len(rateTopics) > 0 && rateTopics[0] == rateTopicsWildcard {
|
||||||
|
for _, t := range topics {
|
||||||
|
t.SetRateVisitor(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, t := range topics {
|
||||||
|
if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) {
|
||||||
|
t.SetRateVisitor(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
||||||
// marker, returning only messages that are newer than the marker.
|
// marker, returning only messages that are newer than the marker.
|
||||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
contextRateVisitor contextKey = iota + 2586
|
||||||
|
contextTopic
|
||||||
|
)
|
||||||
|
|
||||||
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||||
|
@ -29,8 +35,10 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
|
||||||
if rateVisitor := t.RateVisitor(); rateVisitor != nil {
|
if rateVisitor := t.RateVisitor(); rateVisitor != nil {
|
||||||
vrate = rateVisitor
|
vrate = rateVisitor
|
||||||
}
|
}
|
||||||
r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t))
|
r = withContext(r, map[contextKey]any{
|
||||||
|
contextRateVisitor: vrate,
|
||||||
|
contextTopic: t,
|
||||||
|
})
|
||||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
} else if !vrate.RequestAllowed() {
|
} else if !vrate.RequestAllowed() {
|
||||||
|
|
|
@ -1899,7 +1899,7 @@ func TestServer_SubscriberRateLimiting(t *testing.T) {
|
||||||
r.RemoteAddr = "1.2.3.4"
|
r.RemoteAddr = "1.2.3.4"
|
||||||
}
|
}
|
||||||
rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
|
rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
|
||||||
"Subscriber-Rate-Limit-Topics": "subscriber1topic",
|
"Rate-Topics": "subscriber1topic",
|
||||||
}, subscriber1Fn)
|
}, subscriber1Fn)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
require.Equal(t, "", rr.Body.String())
|
require.Equal(t, "", rr.Body.String())
|
||||||
|
|
|
@ -8,6 +8,10 @@ import (
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
rateVisitorExpiryDuration = 12 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
// topic represents a channel to which subscribers can subscribe, and publishers
|
// topic represents a channel to which subscribers can subscribe, and publishers
|
||||||
// can publish a message
|
// can publish a message
|
||||||
type topic struct {
|
type topic struct {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -45,13 +46,6 @@ func readHeaderParam(r *http.Request, names ...string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func readHeaderParamValues(r *http.Request, names ...string) (values []string) {
|
|
||||||
for _, name := range names {
|
|
||||||
values = append(values, r.Header.Values(name)...)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func readQueryParam(r *http.Request, names ...string) string {
|
func readQueryParam(r *http.Request, names ...string) string {
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
value := r.URL.Query().Get(strings.ToLower(name))
|
value := r.URL.Query().Get(strings.ToLower(name))
|
||||||
|
@ -103,3 +97,11 @@ func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T,
|
||||||
}
|
}
|
||||||
return obj, nil
|
return obj, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
|
||||||
|
c := r.Context()
|
||||||
|
for k, v := range ctx {
|
||||||
|
c = context.WithValue(c, k, v)
|
||||||
|
}
|
||||||
|
return r.WithContext(c)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue