From d05211648d4cd7115dc9be1b733102f30bddf675 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Mon, 20 Jun 2022 12:11:52 -0400 Subject: [PATCH] Fix `since=` implementation for multiple topics, closes #336 --- docs/releases.md | 1 + server/message_cache.go | 4 ++-- server/server.go | 18 +++++++++++----- server/server_test.go | 47 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/docs/releases.md b/docs/releases.md index 3f5a7b6c..935f5950 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -14,6 +14,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release * Return HTTP 500 for GET /_matrix/push/v1/notify when base-url is not configured (no ticket) * Disallow setting `upstream-base-url` to the same value as `base-url` ([#334](https://github.com/binwiederhier/ntfy/issues/334), thanks to [@oester](https://github.com/oester) for reporting) +* Fix `since=` implementation for multiple topics ([#336](https://github.com/binwiederhier/ntfy/issues/336), thanks to [@karmanyaahm](https://github.com/karmanyaahm) for reporting) ## ntfy Android app v1.14.0 (UNRELEASED) diff --git a/server/message_cache.go b/server/message_cache.go index 77aa4f78..afd4bf17 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -49,7 +49,7 @@ const ( VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1` - selectRowIDFromMessageID = `SELECT id FROM messages WHERE topic = ? AND mid = ?` + selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics selectMessagesSinceTimeQuery = ` SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding FROM messages @@ -294,7 +294,7 @@ func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, schedu } func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) { - idrows, err := c.db.Query(selectRowIDFromMessageID, topic, since.ID()) + idrows, err := c.db.Query(selectRowIDFromMessageID, since.ID()) if err != nil { return nil, err } diff --git a/server/server.go b/server/server.go index dacfa743..4d028d91 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,7 @@ import ( "path" "path/filepath" "regexp" + "sort" "strconv" "strings" "sync" @@ -972,19 +973,26 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu return } +// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the +// marker, returning only messages that are newer than the marker. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error { if since.IsNone() { return nil } + messages := make([]*message, 0) for _, t := range topics { - messages, err := s.messageCache.Messages(t.ID, since, scheduled) + topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled) if err != nil { return err } - for _, m := range messages { - if err := sub(v, m); err != nil { - return err - } + messages = append(messages, topicMessages...) + } + sort.Slice(messages, func(i, j int) bool { + return messages[i].Time < messages[j].Time + }) + for _, m := range messages { + if err := sub(v, m); err != nil { + return err } } return nil diff --git a/server/server_test.go b/server/server_test.go index 66ad6e1b..9fc9aa88 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -437,6 +437,53 @@ func TestServer_PublishAndPollSince(t *testing.T) { require.Equal(t, 40008, toHTTPError(t, response.Body.String()).Code) } +func newMessageWithTimestamp(topic, message string, timestamp int64) *message { + m := newDefaultMessage(topic, message) + m.Time = timestamp + return m +} + +func TestServer_PollSinceID_MultipleTopics(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 1", 1655740277))) + markerMessage := newMessageWithTimestamp("mytopic2", "test 2", 1655740283) + require.Nil(t, s.messageCache.AddMessage(markerMessage)) + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 3", 1655740289))) + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 4", 1655740293))) + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 5", 1655740297))) + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 6", 1655740303))) + + response := request(t, s, "GET", fmt.Sprintf("/mytopic1,mytopic2/json?poll=1&since=%s", markerMessage.ID), "", nil) + messages := toMessages(t, response.Body.String()) + require.Equal(t, 4, len(messages)) + require.Equal(t, "test 3", messages[0].Message) + require.Equal(t, "mytopic1", messages[0].Topic) + require.Equal(t, "test 4", messages[1].Message) + require.Equal(t, "mytopic2", messages[1].Topic) + require.Equal(t, "test 5", messages[2].Message) + require.Equal(t, "mytopic1", messages[2].Topic) + require.Equal(t, "test 6", messages[3].Message) + require.Equal(t, "mytopic2", messages[3].Topic) +} + +func TestServer_PollSinceID_MultipleTopics_IDDoesNotMatch(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 3", 1655740289))) + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 4", 1655740293))) + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 5", 1655740297))) + require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 6", 1655740303))) + + response := request(t, s, "GET", "/mytopic1,mytopic2/json?poll=1&since=NoMatchForID", "", nil) + messages := toMessages(t, response.Body.String()) + require.Equal(t, 4, len(messages)) + require.Equal(t, "test 3", messages[0].Message) + require.Equal(t, "test 4", messages[1].Message) + require.Equal(t, "test 5", messages[2].Message) + require.Equal(t, "test 6", messages[3].Message) +} + func TestServer_PublishViaGET(t *testing.T) { s := newTestServer(t, newTestConfig(t))