From e780b176220325ef390ef3283d252a725f8b304a Mon Sep 17 00:00:00 2001
From: Philipp Heckel <pheckel@datto.com>
Date: Mon, 16 May 2022 13:45:54 -0400
Subject: [PATCH] Update notifications; canary notifications

---
 server/errors.go        |  2 ++
 server/message_cache.go | 46 +++++++++++++++++++++++--
 server/server.go        | 75 +++++++++++++++++++++++++++++++++++------
 3 files changed, 111 insertions(+), 12 deletions(-)

diff --git a/server/errors.go b/server/errors.go
index 32c1b3b9..97af5472 100644
--- a/server/errors.go
+++ b/server/errors.go
@@ -50,7 +50,9 @@ var (
 	errHTTPBadRequestWebSocketsUpgradeHeaderMissing  = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", "https://ntfy.sh/docs/subscribe/api/#websockets"}
 	errHTTPBadRequestJSONInvalid                     = &errHTTP{40017, http.StatusBadRequest, "invalid request: request body must be message JSON", "https://ntfy.sh/docs/publish/#publish-as-json"}
 	errHTTPBadRequestActionsInvalid                  = &errHTTP{40018, http.StatusBadRequest, "invalid request: actions invalid", "https://ntfy.sh/docs/publish/#action-buttons"}
+	errHTTPBadRequestDelayExpected                   = &errHTTP{40019, http.StatusBadRequest, "invalid request: expected delay in request, but none found", ""}
 	errHTTPNotFound                                  = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
+	errHTTPNotFoundMessageDoesNotExist               = &errHTTP{40402, http.StatusNotFound, "message not found", ""}
 	errHTTPUnauthorized                              = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
 	errHTTPForbidden                                 = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}
 	errHTTPEntityTooLargeAttachmentTooLarge          = &errHTTP{41301, http.StatusRequestEntityTooLarge, "attachment too large, or bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
diff --git a/server/message_cache.go b/server/message_cache.go
index b55c34ba..bdc36c4c 100644
--- a/server/message_cache.go
+++ b/server/message_cache.go
@@ -48,6 +48,11 @@ const (
 		INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published) 
 		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
 	`
+	updateMessageQuery = `
+		UPDATE messages 
+		SET time = ?
+		WHERE topic = ? AND mid = ? AND published = 0 
+	`
 	pruneMessagesQuery           = `DELETE FROM messages WHERE time < ? AND published = 1`
 	selectRowIDFromMessageID     = `SELECT id FROM messages WHERE topic = ? AND mid = ?`
 	selectMessagesSinceTimeQuery = `
@@ -80,6 +85,11 @@ const (
 		WHERE time <= ? AND published = 0
 		ORDER BY time, id
 	`
+	selectMessagesScheduledByTagOrID = `
+		SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding
+		FROM messages 
+		WHERE topic = ? AND (tags LIKE ? OR mid = ?) AND published = 0
+	`
 	updateMessagePublishedQuery     = `UPDATE messages SET published = 1 WHERE mid = ?`
 	selectMessagesCountQuery        = `SELECT COUNT(*) FROM messages`
 	selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
@@ -219,8 +229,7 @@ func createMemoryFilename() string {
 func (c *messageCache) AddMessage(m *message) error {
 	if m.Event != messageEvent {
 		return errUnexpectedMessageType
-	}
-	if c.nop {
+	} else if c.nop {
 		return nil
 	}
 	published := m.Time <= time.Now().Unix()
@@ -266,6 +275,21 @@ func (c *messageCache) AddMessage(m *message) error {
 	return err
 }
 
+func (c *messageCache) UpdateMessage(m *message) error {
+	if m.Event != messageEvent {
+		return errUnexpectedMessageType
+	} else if c.nop {
+		return nil
+	}
+	_, err := c.db.Exec(
+		updateMessageQuery,
+		m.Time,
+		m.Topic,
+		m.ID,
+	)
+	return err
+}
+
 func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
 	if since.IsNone() {
 		return make([]*message, 0), nil
@@ -409,6 +433,24 @@ func (c *messageCache) AttachmentsExpired() ([]string, error) {
 	return ids, nil
 }
 
+func (c *messageCache) MessagesScheduledByTagOrID(topic, selector string) ([]*message, error) {
+	rows, err := c.db.Query(selectMessagesScheduledByTagOrID, topic, "%"+selector+"%", selector) // Ugly string matching search first, later match exactly
+	if err != nil {
+		return nil, err
+	}
+	maybeMatchingMessages, err := readMessages(rows)
+	if err != nil {
+		return nil, err
+	}
+	messages := make([]*message, 0)
+	for _, m := range maybeMatchingMessages {
+		if util.InStringList(m.Tags, selector) || m.ID == selector {
+			messages = append(messages, m)
+		}
+	}
+	return messages, nil
+}
+
 func readMessages(rows *sql.Rows) ([]*message, error) {
 	defer rows.Close()
 	messages := make([]*message, 0)
diff --git a/server/server.go b/server/server.go
index 1a643c23..738fba0b 100644
--- a/server/server.go
+++ b/server/server.go
@@ -56,8 +56,9 @@ type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
 
 var (
 	// If changed, don't forget to update Android App and auth_sqlite.go
-	topicRegex             = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)               // No /!
-	topicPathRegex         = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`)              // Regex must match JS & Android app!
+	topicRegex             = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No /!
+	topicPathRegex         = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
+	updatePathRegex        = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/[^/]+$`)
 	externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic
 	jsonPathRegex          = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
 	ssePathRegex           = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
@@ -287,6 +288,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 		return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
 	} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
 		return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
+	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && updatePathRegex.MatchString(r.URL.Path) {
+		return s.limitRequests(s.authWrite(s.handleUpdate))(w, r, v)
 	} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
 		return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
 	} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
@@ -518,7 +521,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 			m.Tags = append(m.Tags, strings.TrimSpace(s))
 		}
 	}
-	delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
+	delayStr := readDelayParam(r)
 	if delayStr != "" {
 		if !cache {
 			return false, false, "", false, errHTTPBadRequestDelayNoCache
@@ -526,15 +529,11 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 		if email != "" {
 			return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
 		}
-		delay, err := util.ParseFutureTime(delayStr, time.Now())
+		futureTime, err := s.parseDelay(delayStr)
 		if err != nil {
-			return false, false, "", false, errHTTPBadRequestDelayCannotParse
-		} else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
-			return false, false, "", false, errHTTPBadRequestDelayTooSmall
-		} else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
-			return false, false, "", false, errHTTPBadRequestDelayTooLarge
+			return false, false, "", false, err
 		}
-		m.Time = delay.Unix()
+		m.Time = futureTime
 	}
 	actionsStr := readParam(r, "x-actions", "actions", "action")
 	if actionsStr != "" {
@@ -551,6 +550,22 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 	return cache, firebase, email, unifiedpush, nil
 }
 
+func readDelayParam(r *http.Request) string {
+	return readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
+}
+
+func (s *Server) parseDelay(delayStr string) (int64, error) {
+	futureTime, err := util.ParseFutureTime(delayStr, time.Now())
+	if err != nil {
+		return 0, errHTTPBadRequestDelayCannotParse
+	} else if futureTime.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
+		return 0, errHTTPBadRequestDelayTooSmall
+	} else if futureTime.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
+		return 0, errHTTPBadRequestDelayTooLarge
+	}
+	return futureTime.Unix(), nil
+}
+
 // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
 //
 // 1. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
@@ -639,6 +654,46 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
 	return nil
 }
 
+func (s *Server) handleUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	// Parse updatable params
+	parts := strings.Split(r.URL.Path, "/")
+	if len(parts) < 3 {
+		return errHTTPBadRequestTopicInvalid
+	}
+	t := parts[1]
+	selector := parts[2]
+	delayStr := readDelayParam(r)
+	if delayStr == "" {
+		return errHTTPBadRequestDelayExpected
+	}
+	futureTime, err := s.parseDelay(delayStr)
+	if err != nil {
+		return err
+	}
+
+	// Update matching message(s) and print them
+	messages, err := s.messageCache.MessagesScheduledByTagOrID(t, selector)
+	if err != nil {
+		return err
+	} else if len(messages) == 0 {
+		return s.handlePublish(w, r, v) // If no messages found, publish a new one!
+	}
+	for _, m := range messages {
+		m.Time = futureTime
+		if err := s.messageCache.UpdateMessage(m); err != nil {
+			return err
+		}
+	}
+	w.Header().Set("Content-Type", "application/json")
+	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+	for _, m := range messages {
+		if err := json.NewEncoder(w).Encode(m); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		var buf bytes.Buffer