From 2b6363474e18f195108a10beca1305c198e71eaf Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Sat, 15 Jan 2022 23:17:46 -0500 Subject: [PATCH] Combine things, move stuff --- server/server.go | 169 ++++++++------------------------ server/server_test.go | 58 ----------- server/{message.go => types.go} | 70 +++++++++++++ server/util.go | 55 +++++++++++ server/util_test.go | 66 +++++++++++++ 5 files changed, 231 insertions(+), 187 deletions(-) rename server/{message.go => types.go} (54%) create mode 100644 server/util.go create mode 100644 server/util_test.go diff --git a/server/server.go b/server/server.go index 57dd0f38..06097a2f 100644 --- a/server/server.go +++ b/server/server.go @@ -32,9 +32,6 @@ import ( "unicode/utf8" ) -// TODO add "max messages in a topic" limit -// TODO implement "since=" - // Server is the main server, providing the UI and API for ntfy type Server struct { config *Config @@ -59,25 +56,6 @@ type indexPage struct { CacheDuration time.Duration } -type sinceTime time.Time - -func (t sinceTime) IsAll() bool { - return t == sinceAllMessages -} - -func (t sinceTime) IsNone() bool { - return t == sinceNoMessages -} - -func (t sinceTime) Time() time.Time { - return time.Time(t) -} - -var ( - sinceAllMessages = sinceTime(time.Unix(0, 0)) - sinceNoMessages = sinceTime(time.Unix(1, 0)) -) - var ( topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /! topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! @@ -117,7 +95,6 @@ const ( firebaseControlTopic = "~control" // See Android if changed emptyMessageBody = "triggered" // Used if message body is empty defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment - fcmMessageLimit = 4000 // see maybeTruncateFCMMessage for details ) // WebSocket constants @@ -232,25 +209,6 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) { }, nil } -// maybeTruncateFCMMessage performs best-effort truncation of FCM messages. -// The docs say the limit is 4000 characters, but during testing it wasn't quite clear -// what fields matter; so we're just capping the serialized JSON to 4000 bytes. -func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message { - s, err := json.Marshal(m) - if err != nil { - return m - } - if len(s) > fcmMessageLimit { - over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ... - message, ok := m.Data["message"] - if ok && len(message) > over { - m.Data["truncated"] = "1" - m.Data["message"] = message[:len(message)-over] - } - } - return m -} - // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts // a manager go routine to print stats and prune messages. func (s *Server) Run() error { @@ -391,7 +349,7 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error { } func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error { - unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see PUT/POST too! + unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too! if unifiedpush { w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests @@ -497,13 +455,15 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err := json.NewEncoder(w).Encode(m); err != nil { return err } - s.inc(&s.messages) + s.mu.Lock() + s.messages++ + s.mu.Unlock() return nil } func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, err error) { - cache = readParam(r, "x-cache", "cache") != "no" - firebase = readParam(r, "x-firebase", "firebase") != "no" + cache = readBoolParam(r, true, "x-cache", "cache") + firebase = readBoolParam(r, true, "x-firebase", "firebase") m.Title = readParam(r, "x-title", "title", "t") m.Click = readParam(r, "x-click", "click") filename := readParam(r, "x-filename", "filename", "file", "f") @@ -574,29 +534,13 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca } m.Time = delay.Unix() } - unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see GET too! + unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too! if unifiedpush { firebase = false } return cache, firebase, email, nil } -func readParam(r *http.Request, names ...string) string { - for _, name := range names { - value := r.Header.Get(name) - if value != "" { - return strings.TrimSpace(value) - } - } - for _, name := range names { - value := r.URL.Query().Get(strings.ToLower(name)) - if value != "" { - return strings.TrimSpace(value) - } - } - return "" -} - // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message. // // 1. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic @@ -680,7 +624,7 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v * } return buf.String(), nil } - return s.handleSubscribe(w, r, v, "json", "application/x-ndjson", encoder) + return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder) } func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -694,7 +638,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v } return fmt.Sprintf("data: %s\n", buf.String()), nil } - return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder) + return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder) } func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -704,33 +648,25 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v } return "\n", nil // "keepalive" and "open" events just send an empty line } - return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder) + return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder) } -func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error { +func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error { if err := v.SubscriptionAllowed(); err != nil { return errHTTPTooManyRequestsLimitSubscriptions } defer v.RemoveSubscription() - topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack - topicIDs := util.SplitNoEmpty(topicsStr, ",") - topics, err := s.topicsFromIDs(topicIDs...) + topics, topicsStr, err := s.topicsFromPath(r.URL.Path) if err != nil { return err } - poll := readParam(r, "x-poll", "poll", "po") == "1" - scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1" - since, err := parseSince(r, poll) - if err != nil { - return err - } - messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r) + poll, since, scheduled, filters, err := parseSubscribeParams(r) if err != nil { return err } var wlock sync.Mutex sub := func(msg *message) error { - if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) { + if !filters.Pass(msg) { return nil } m, err := encoder(msg) @@ -785,19 +721,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return errHTTPTooManyRequestsLimitSubscriptions } defer v.RemoveSubscription() - topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/ws") // Hack - topicIDs := util.SplitNoEmpty(topicsStr, ",") - topics, err := s.topicsFromIDs(topicIDs...) + topics, topicsStr, err := s.topicsFromPath(r.URL.Path) if err != nil { return err } - poll := readParam(r, "x-poll", "poll", "po") == "1" - scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1" - since, err := parseSince(r, poll) - if err != nil { - return err - } - messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r) + poll, since, scheduled, filters, err := parseSubscribeParams(r) if err != nil { return err } @@ -850,7 +778,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } }) sub := func(msg *message) error { - if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) { + if !filters.Pass(msg) { return nil } if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil { @@ -884,44 +812,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return err } -func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string, err error) { - messageFilter = readParam(r, "x-message", "message", "m") - titleFilter = readParam(r, "x-title", "title", "t") - tagsFilter = util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",") - priorityFilter = make([]int, 0) - for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") { - priority, err := util.ParsePriority(p) - if err != nil { - return "", "", nil, nil, err - } - priorityFilter = append(priorityFilter, priority) +func parseSubscribeParams(r *http.Request) (poll bool, since sinceTime, scheduled bool, filters *queryFilter, err error) { + poll = readBoolParam(r, false, "x-poll", "poll", "po") + scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched") + since, err = parseSince(r, poll) + if err != nil { + return + } + filters, err = parseQueryFilters(r) + if err != nil { + return } return } -func passesQueryFilter(msg *message, messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string) bool { - if msg.Event != messageEvent { - return true // filters only apply to messages - } - if messageFilter != "" && msg.Message != messageFilter { - return false - } - if titleFilter != "" && msg.Title != titleFilter { - return false - } - messagePriority := msg.Priority - if messagePriority == 0 { - messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0) - } - if len(priorityFilter) > 0 && !util.InIntList(priorityFilter, messagePriority) { - return false - } - if len(tagsFilter) > 0 && !util.InStringListAll(msg.Tags, tagsFilter) { - return false - } - return true -} - func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error { if since.IsNone() { return nil @@ -980,6 +884,19 @@ func (s *Server) topicFromPath(path string) (*topic, error) { return topics[0], nil } +func (s *Server) topicsFromPath(path string) ([]*topic, string, error) { + parts := strings.Split(path, "/") + if len(parts) < 2 { + return nil, "", errHTTPBadRequestTopicInvalid + } + topicIDs := util.SplitNoEmpty(parts[1], ",") + topics, err := s.topicsFromIDs(topicIDs...) + if err != nil { + return nil, "", errHTTPBadRequestTopicInvalid + } + return topics, parts[1], nil +} + func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { s.mu.Lock() defer s.mu.Unlock() @@ -1180,9 +1097,3 @@ func (s *Server) visitor(r *http.Request) *visitor { v.Keepalive() return v } - -func (s *Server) inc(counter *int64) { - s.mu.Lock() - defer s.mu.Unlock() - *counter++ -} diff --git a/server/server_test.go b/server/server_test.go index 492edf91..f888136c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,7 +4,6 @@ import ( "bufio" "context" "encoding/json" - "firebase.google.com/go/messaging" "fmt" "github.com/stretchr/testify/require" "heckel.io/ntfy/util" @@ -624,63 +623,6 @@ func TestServer_UnifiedPushDiscovery(t *testing.T) { require.Equal(t, `{"unifiedpush":{"version":1}}`+"\n", response.Body.String()) } -func TestServer_MaybeTruncateFCMMessage(t *testing.T) { - origMessage := strings.Repeat("this is a long string", 300) - origFCMMessage := &messaging.Message{ - Topic: "mytopic", - Data: map[string]string{ - "id": "abcdefg", - "time": "1641324761", - "event": "message", - "topic": "mytopic", - "priority": "0", - "tags": "", - "title": "", - "message": origMessage, - }, - Android: &messaging.AndroidConfig{ - Priority: "high", - }, - } - origMessageLength := len(origFCMMessage.Data["message"]) - serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) - require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition - - truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) - truncatedMessageLength := len(truncatedFCMMessage.Data["message"]) - serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage) - require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage)) - require.Equal(t, "1", truncatedFCMMessage.Data["truncated"]) - require.NotEqual(t, origMessageLength, truncatedMessageLength) -} - -func TestServer_MaybeTruncateFCMMessage_NotTooLong(t *testing.T) { - origMessage := "not really a long string" - origFCMMessage := &messaging.Message{ - Topic: "mytopic", - Data: map[string]string{ - "id": "abcdefg", - "time": "1641324761", - "event": "message", - "topic": "mytopic", - "priority": "0", - "tags": "", - "title": "", - "message": origMessage, - }, - } - origMessageLength := len(origFCMMessage.Data["message"]) - serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) - require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition - - notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) - notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"]) - serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage) - require.Equal(t, origMessageLength, notTruncatedMessageLength) - require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage)) - require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"]) -} - func TestServer_PublishAttachment(t *testing.T) { content := util.RandomString(5000) // > 4096 s := newTestServer(t, newTestConfig(t)) diff --git a/server/message.go b/server/types.go similarity index 54% rename from server/message.go rename to server/types.go index 27695f14..357a3780 100644 --- a/server/message.go +++ b/server/types.go @@ -2,6 +2,7 @@ package server import ( "heckel.io/ntfy/util" + "net/http" "time" ) @@ -70,3 +71,72 @@ func newKeepaliveMessage(topic string) *message { func newDefaultMessage(topic, msg string) *message { return newMessage(messageEvent, topic, msg) } + +type sinceTime time.Time + +func (t sinceTime) IsAll() bool { + return t == sinceAllMessages +} + +func (t sinceTime) IsNone() bool { + return t == sinceNoMessages +} + +func (t sinceTime) Time() time.Time { + return time.Time(t) +} + +var ( + sinceAllMessages = sinceTime(time.Unix(0, 0)) + sinceNoMessages = sinceTime(time.Unix(1, 0)) +) + +type queryFilter struct { + Message string + Title string + Tags []string + Priority []int +} + +func parseQueryFilters(r *http.Request) (*queryFilter, error) { + messageFilter := readParam(r, "x-message", "message", "m") + titleFilter := readParam(r, "x-title", "title", "t") + tagsFilter := util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",") + priorityFilter := make([]int, 0) + for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") { + priority, err := util.ParsePriority(p) + if err != nil { + return nil, err + } + priorityFilter = append(priorityFilter, priority) + } + return &queryFilter{ + Message: messageFilter, + Title: titleFilter, + Tags: tagsFilter, + Priority: priorityFilter, + }, nil +} + +func (q *queryFilter) Pass(msg *message) bool { + if msg.Event != messageEvent { + return true // filters only apply to messages + } + if q.Message != "" && msg.Message != q.Message { + return false + } + if q.Title != "" && msg.Title != q.Title { + return false + } + messagePriority := msg.Priority + if messagePriority == 0 { + messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0) + } + if len(q.Priority) > 0 && !util.InIntList(q.Priority, messagePriority) { + return false + } + if len(q.Tags) > 0 && !util.InStringListAll(msg.Tags, q.Tags) { + return false + } + return true +} diff --git a/server/util.go b/server/util.go new file mode 100644 index 00000000..4966cb0f --- /dev/null +++ b/server/util.go @@ -0,0 +1,55 @@ +package server + +import ( + "encoding/json" + "firebase.google.com/go/messaging" + "net/http" + "strings" +) + +const ( + fcmMessageLimit = 4000 +) + +// maybeTruncateFCMMessage performs best-effort truncation of FCM messages. +// The docs say the limit is 4000 characters, but during testing it wasn't quite clear +// what fields matter; so we're just capping the serialized JSON to 4000 bytes. +func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message { + s, err := json.Marshal(m) + if err != nil { + return m + } + if len(s) > fcmMessageLimit { + over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ... + message, ok := m.Data["message"] + if ok && len(message) > over { + m.Data["truncated"] = "1" + m.Data["message"] = message[:len(message)-over] + } + } + return m +} + +func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { + value := strings.ToLower(readParam(r, names...)) + if value == "" { + return defaultValue + } + return value == "1" || value == "yes" || value == "true" +} + +func readParam(r *http.Request, names ...string) string { + for _, name := range names { + value := r.Header.Get(name) + if value != "" { + return strings.TrimSpace(value) + } + } + for _, name := range names { + value := r.URL.Query().Get(strings.ToLower(name)) + if value != "" { + return strings.TrimSpace(value) + } + } + return "" +} diff --git a/server/util_test.go b/server/util_test.go new file mode 100644 index 00000000..75a76a24 --- /dev/null +++ b/server/util_test.go @@ -0,0 +1,66 @@ +package server + +import ( + "encoding/json" + "firebase.google.com/go/messaging" + "github.com/stretchr/testify/require" + "strings" + "testing" +) + +func TestMaybeTruncateFCMMessage(t *testing.T) { + origMessage := strings.Repeat("this is a long string", 300) + origFCMMessage := &messaging.Message{ + Topic: "mytopic", + Data: map[string]string{ + "id": "abcdefg", + "time": "1641324761", + "event": "message", + "topic": "mytopic", + "priority": "0", + "tags": "", + "title": "", + "message": origMessage, + }, + Android: &messaging.AndroidConfig{ + Priority: "high", + }, + } + origMessageLength := len(origFCMMessage.Data["message"]) + serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) + require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition + + truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) + truncatedMessageLength := len(truncatedFCMMessage.Data["message"]) + serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage) + require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage)) + require.Equal(t, "1", truncatedFCMMessage.Data["truncated"]) + require.NotEqual(t, origMessageLength, truncatedMessageLength) +} + +func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { + origMessage := "not really a long string" + origFCMMessage := &messaging.Message{ + Topic: "mytopic", + Data: map[string]string{ + "id": "abcdefg", + "time": "1641324761", + "event": "message", + "topic": "mytopic", + "priority": "0", + "tags": "", + "title": "", + "message": origMessage, + }, + } + origMessageLength := len(origFCMMessage.Data["message"]) + serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) + require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition + + notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) + notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"]) + serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage) + require.Equal(t, origMessageLength, notTruncatedMessageLength) + require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage)) + require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"]) +}