Plan-based message and attachment expiry

pull/584/head
binwiederhier 2023-01-07 09:34:02 -05:00
parent ac4042ca04
commit a54a11db88
13 changed files with 222 additions and 122 deletions

View File

@ -57,6 +57,7 @@ var (
errHTTPBadRequestNoTokenProvided = &errHTTP{40023, http.StatusBadRequest, "invalid request: no token provided", ""} errHTTPBadRequestNoTokenProvided = &errHTTP{40023, http.StatusBadRequest, "invalid request: no token provided", ""}
errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""} errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""}
errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""} errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""}
errHTTPBadRequestMakesNoSenseForAdmin = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}

View File

@ -3,13 +3,13 @@ package server
import ( import (
"errors" "errors"
"fmt" "fmt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"sync" "sync"
"time"
) )
var ( var (
@ -75,8 +75,11 @@ func (c *fileCache) Remove(ids ...string) error {
if !fileIDRegex.MatchString(id) { if !fileIDRegex.MatchString(id) {
return errInvalidFileID return errInvalidFileID
} }
log.Debug("File Cache: Deleting attachment %s", id)
file := filepath.Join(c.dir, id) file := filepath.Join(c.dir, id)
_ = os.Remove(file) // Best effort delete if err := os.Remove(file); err != nil {
log.Debug("File Cache: Error deleting attachment %s: %s", id, err.Error())
}
} }
size, err := dirSize(c.dir) size, err := dirSize(c.dir)
if err != nil { if err != nil {
@ -88,25 +91,6 @@ func (c *fileCache) Remove(ids ...string) error {
return nil return nil
} }
// Expired returns a list of file IDs for expired files
func (c *fileCache) Expired(olderThan time.Time) ([]string, error) {
entries, err := os.ReadDir(c.dir)
if err != nil {
return nil, err
}
var ids []string
for _, e := range entries {
info, err := e.Info()
if err != nil {
continue
}
if info.ModTime().Before(olderThan) && fileIDRegex.MatchString(e.Name()) {
ids = append(ids, e.Name())
}
}
return ids, nil
}
func (c *fileCache) Size() int64 { func (c *fileCache) Size() int64 {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()

View File

@ -8,7 +8,6 @@ import (
"os" "os"
"strings" "strings"
"testing" "testing"
"time"
) )
var ( var (
@ -63,29 +62,6 @@ func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
require.NoFileExists(t, dir+"/abcdefghijkl") require.NoFileExists(t, dir+"/abcdefghijkl")
} }
func TestFileCache_RemoveExpired(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)))
require.Nil(t, err)
_, err = c.Write("notdeleted12", bytes.NewReader(make([]byte, 1001)))
require.Nil(t, err)
modTime := time.Now().Add(-1 * 4 * time.Hour)
require.Nil(t, os.Chtimes(dir+"/abcdefghijkl", modTime, modTime))
olderThan := time.Now().Add(-1 * 3 * time.Hour)
ids, err := c.Expired(olderThan)
require.Nil(t, err)
require.Equal(t, []string{"abcdefghijkl"}, ids)
require.Nil(t, c.Remove(ids...))
require.NoFileExists(t, dir+"/abcdefghijkl")
require.FileExists(t, dir+"/notdeleted12")
ids, err = c.Expired(olderThan)
require.Nil(t, err)
require.Empty(t, ids)
}
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) { func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
dir = t.TempDir() dir = t.TempDir()
cache, err := newFileCache(dir, 10*1024) cache, err := newFileCache(dir, 10*1024)

View File

@ -26,6 +26,7 @@ const (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL, mid TEXT NOT NULL,
time INT NOT NULL, time INT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL, topic TEXT NOT NULL,
message TEXT NOT NULL, message TEXT NOT NULL,
title TEXT NOT NULL, title TEXT NOT NULL,
@ -39,6 +40,7 @@ const (
attachment_size INT NOT NULL, attachment_size INT NOT NULL,
attachment_expires INT NOT NULL, attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL, attachment_url TEXT NOT NULL,
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL, sender TEXT NOT NULL,
user TEXT NOT NULL, user TEXT NOT NULL,
encoding TEXT NOT NULL, encoding TEXT NOT NULL,
@ -47,48 +49,59 @@ const (
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time); CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
COMMIT; COMMIT;
` `
insertMessageQuery = ` insertMessageQuery = `
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding, published) INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1` deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesSinceTimeQuery = ` selectMessagesSinceTimeQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND time >= ? AND published = 1 WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceTimeIncludeScheduledQuery = ` selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND time >= ? WHERE topic = ? AND time >= ?
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceIDQuery = ` selectMessagesSinceIDQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND id > ? AND published = 1 WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceIDIncludeScheduledQuery = ` selectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND (id > ? OR published = 0) WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id ORDER BY time, id
` `
selectMessagesDueQuery = ` selectMessagesDueQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE time <= ? AND published = 0 WHERE time <= ? AND published = 0
ORDER BY time, id ORDER BY time, id
` `
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` selectMessagesExpiredQuery = `
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` FROM messages
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` WHERE expires <= ? AND published = 1
ORDER BY time, id
`
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
) )
@ -197,6 +210,10 @@ const (
// 9 -> 10 // 9 -> 10
migrate9To10AlterMessagesTableQuery = ` migrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT(''); ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
` `
) )
@ -286,7 +303,7 @@ func (c *messageCache) addMessages(ms []*message) error {
published := m.Time <= time.Now().Unix() published := m.Time <= time.Now().Unix()
tags := strings.Join(m.Tags, ",") tags := strings.Join(m.Tags, ",")
var attachmentName, attachmentType, attachmentURL string var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires int64 var attachmentSize, attachmentExpires, attachmentDeleted int64
if m.Attachment != nil { if m.Attachment != nil {
attachmentName = m.Attachment.Name attachmentName = m.Attachment.Name
attachmentType = m.Attachment.Type attachmentType = m.Attachment.Type
@ -309,6 +326,7 @@ func (c *messageCache) addMessages(ms []*message) error {
_, err := stmt.Exec( _, err := stmt.Exec(
m.ID, m.ID,
m.Time, m.Time,
m.Expires,
m.Topic, m.Topic,
m.Message, m.Message,
m.Title, m.Title,
@ -322,6 +340,7 @@ func (c *messageCache) addMessages(ms []*message) error {
attachmentSize, attachmentSize,
attachmentExpires, attachmentExpires,
attachmentURL, attachmentURL,
attachmentDeleted, // Always zero
sender, sender,
m.User, m.User,
m.Encoding, m.Encoding,
@ -332,10 +351,10 @@ func (c *messageCache) addMessages(ms []*message) error {
} }
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
log.Error("Cache: Writing %d message(s) failed (took %v)", len(ms), time.Since(start)) log.Error("Message Cache: Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
return err return err
} }
log.Debug("Cache: Wrote %d message(s) in %v", len(ms), time.Since(start)) log.Debug("Message Cache: Wrote %d message(s) in %v", len(ms), time.Since(start))
return nil return nil
} }
@ -396,6 +415,14 @@ func (c *messageCache) MessagesDue() ([]*message, error) {
return readMessages(rows) return readMessages(rows)
} }
func (c *messageCache) MessagesExpired() ([]*message, error) {
rows, err := c.db.Query(selectMessagesExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *messageCache) MarkPublished(m *message) error { func (c *messageCache) MarkPublished(m *message) error {
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID) _, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
return err return err
@ -441,13 +468,52 @@ func (c *messageCache) Topics() (map[string]*topic, error) {
return topics, nil return topics, nil
} }
func (c *messageCache) Prune(olderThan time.Time) error { func (c *messageCache) DeleteMessages(ids ...string) error {
start := time.Now() tx, err := c.db.Begin()
if _, err := c.db.Exec(pruneMessagesQuery, olderThan.Unix()); err != nil { if err != nil {
log.Warn("Cache: Pruning failed (after %v): %s", time.Since(start), err.Error()) return err
} }
log.Debug("Cache: Pruning successful (took %v)", time.Since(start)) defer tx.Rollback()
return nil for _, id := range ids {
if _, err := tx.Exec(deleteMessageQuery, id); err != nil {
return err
}
}
return tx.Commit()
}
func (c *messageCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
func (c *messageCache) MarkAttachmentsDeleted(ids []string) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(updateAttachmentDeleted, id); err != nil {
return err
}
}
return tx.Commit()
} }
func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) { func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
@ -486,7 +552,7 @@ func (c *messageCache) processMessageBatches() {
} }
for messages := range c.queue.Dequeue() { for messages := range c.queue.Dequeue() {
if err := c.addMessages(messages); err != nil { if err := c.addMessages(messages); err != nil {
log.Error("Cache: %s", err.Error()) log.Error("Message Cache: %s", err.Error())
} }
} }
} }
@ -495,12 +561,13 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close() defer rows.Close()
messages := make([]*message, 0) messages := make([]*message, 0)
for rows.Next() { for rows.Next() {
var timestamp, attachmentSize, attachmentExpires int64 var timestamp, expires, attachmentSize, attachmentExpires int64
var priority int var priority int
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
err := rows.Scan( err := rows.Scan(
&id, &id,
&timestamp, &timestamp,
&expires,
&topic, &topic,
&msg, &msg,
&title, &title,
@ -548,6 +615,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
messages = append(messages, &message{ messages = append(messages, &message{
ID: id, ID: id,
Time: timestamp, Time: timestamp,
Expires: expires,
Event: messageEvent, Event: messageEvent,
Topic: topic, Topic: topic,
Message: msg, Message: msg,
@ -742,10 +810,19 @@ func migrateFrom8(db *sql.DB) error {
func migrateFrom9(db *sql.DB) error { func migrateFrom9(db *sql.DB) error {
log.Info("Migrating cache database schema: from 9 to 10") log.Info("Migrating cache database schema: from 9 to 10")
if _, err := db.Exec(migrate9To10AlterMessagesTableQuery); err != nil { tx, err := db.Begin()
if err != nil {
return err return err
} }
if _, err := db.Exec(updateSchemaVersion, 10); err != nil { defer tx.Rollback()
if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
return err
}
// FIXME add logic to set "expires" column
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err return err
} }
return nil // Update this when a new version is added return nil // Update this when a new version is added

View File

@ -247,26 +247,40 @@ func TestMemCache_Prune(t *testing.T) {
} }
func testCachePrune(t *testing.T, c *messageCache) { func testCachePrune(t *testing.T, c *messageCache) {
now := time.Now().Unix()
m1 := newDefaultMessage("mytopic", "my message") m1 := newDefaultMessage("mytopic", "my message")
m1.Time = 1 m1.Time = now - 10
m1.Expires = now - 5
m2 := newDefaultMessage("mytopic", "my other message") m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = 2 m2.Time = now - 5
m2.Expires = now + 5 // In the future
m3 := newDefaultMessage("another_topic", "and another one") m3 := newDefaultMessage("another_topic", "and another one")
m3.Time = 1 m3.Time = now - 12
m3.Expires = now - 2
require.Nil(t, c.AddMessage(m1)) require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2)) require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3)) require.Nil(t, c.AddMessage(m3))
require.Nil(t, c.Prune(time.Unix(2, 0)))
counts, err := c.MessageCounts() counts, err := c.MessageCounts()
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 1, counts["mytopic"]) require.Equal(t, 2, counts["mytopic"])
require.Equal(t, 1, counts["another_topic"])
expiredMessages, err := c.MessagesExpired()
require.Nil(t, err)
ids := make([]string, 0)
for _, m := range expiredMessages {
ids = append(ids, m.ID)
}
require.Nil(t, c.DeleteMessages(ids...))
counts, err = c.MessageCounts() counts, err = c.MessageCounts()
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 1, counts["mytopic"])
require.Equal(t, 0, counts["another_topic"]) require.Equal(t, 0, counts["another_topic"])
messages, err := c.Messages("mytopic", sinceAllMessages, false) messages, err := c.Messages("mytopic", sinceAllMessages, false)

View File

@ -37,9 +37,6 @@ import (
/* /*
TODO TODO
limits & rate limiting: limits & rate limiting:
message cache duration
Keep 10000 messages or keep X days?
Attachment expiration based on plan
login/account endpoints login/account endpoints
plan: plan:
weirdness with admin and "default" account weirdness with admin and "default" account
@ -57,6 +54,8 @@ import (
- figure out what settings are "web" or "phone" - figure out what settings are "web" or "phone"
Tests: Tests:
- visitor with/without user - visitor with/without user
- plan-based message expiry
- plan-based attachment expiry
Refactor: Refactor:
- rename TopicsLimit -> ReservationsLimit - rename TopicsLimit -> ReservationsLimit
- rename /access -> /reservation - rename /access -> /reservation
@ -544,6 +543,11 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if v.user != nil { if v.user != nil {
m.User = v.user.Name m.User = v.user.Name
} }
if v.user != nil && v.user.Plan != nil {
m.Expires = time.Now().Unix() + v.user.Plan.MessagesExpiryDuration
} else {
m.Expires = time.Now().Add(s.config.CacheDuration).Unix()
}
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err return nil, err
} }
@ -815,7 +819,15 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error { func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" { if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
return errHTTPBadRequestAttachmentsDisallowed return errHTTPBadRequestAttachmentsDisallowed
} else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() { }
var attachmentExpiryDuration time.Duration
if v.user != nil && v.user.Plan != nil {
attachmentExpiryDuration = time.Duration(v.user.Plan.AttachmentExpiryDuration) * time.Second
} else {
attachmentExpiryDuration = s.config.AttachmentExpiryDuration
}
attachmentExpiry := time.Now().Add(attachmentExpiryDuration).Unix()
if m.Time > attachmentExpiry {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
} }
stats, err := v.Info() stats, err := v.Info()
@ -834,7 +846,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
} }
var ext string var ext string
m.Sender = v.ip // Important for attachment rate limiting m.Sender = v.ip // Important for attachment rate limiting
m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix() m.Attachment.Expires = attachmentExpiry
m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name) m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext) m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
if m.Attachment.Name == "" { if m.Attachment.Name == "" {
@ -1224,26 +1236,40 @@ func (s *Server) execManager() {
} }
// Delete expired attachments // Delete expired attachments
if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 { if s.fileCache != nil {
olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration) ids, err := s.messageCache.AttachmentsExpired()
ids, err := s.fileCache.Expired(olderThan)
if err != nil { if err != nil {
log.Warn("Error retrieving expired attachments: %s", err.Error()) log.Warn("Error retrieving expired attachments: %s", err.Error())
} else if len(ids) > 0 { } else if len(ids) > 0 {
log.Debug("Manager: Deleting expired attachments: %v", ids)
if err := s.fileCache.Remove(ids...); err != nil { if err := s.fileCache.Remove(ids...); err != nil {
log.Warn("Error deleting attachments: %s", err.Error()) log.Warn("Error deleting attachments: %s", err.Error())
} }
if err := s.messageCache.MarkAttachmentsDeleted(ids); err != nil {
log.Warn("Error marking attachments deleted: %s", err.Error())
}
} else { } else {
log.Debug("Manager: No expired attachments to delete") log.Debug("Manager: No expired attachments to delete")
} }
} }
// Prune message cache // DeleteMessages message cache
olderThan := time.Now().Add(-1 * s.config.CacheDuration) log.Debug("Manager: Pruning messages")
log.Debug("Manager: Pruning messages older than %s", olderThan.Format("2006-01-02 15:04:05")) expiredMessages, err := s.messageCache.MessagesExpired()
if err := s.messageCache.Prune(olderThan); err != nil { if err != nil {
log.Warn("Manager: Error pruning cache: %s", err.Error()) log.Warn("Manager: Error retrieving expired messages: %s", err.Error())
} else if len(expiredMessages) > 0 {
ids := make([]string, 0)
for _, m := range expiredMessages {
ids = append(ids, m.ID)
}
if err := s.fileCache.Remove(ids...); err != nil {
log.Warn("Manager: Error deleting attachments for expired messages: %s", err.Error())
}
if err := s.messageCache.DeleteMessages(ids...); err != nil {
log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
}
} else {
log.Debug("Manager: No expired messages to delete")
} }
// Message count per topic // Message count per topic

View File

@ -2,7 +2,6 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"net/http" "net/http"
@ -57,12 +56,14 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
AttachmentTotalSizeRemaining: stats.AttachmentTotalSizeRemaining, AttachmentTotalSizeRemaining: stats.AttachmentTotalSizeRemaining,
}, },
Limits: &apiAccountLimits{ Limits: &apiAccountLimits{
Basis: stats.Basis, Basis: stats.Basis,
Messages: stats.MessagesLimit, Messages: stats.MessagesLimit,
Emails: stats.EmailsLimit, MessagesExpiryDuration: stats.MessagesExpiryDuration,
Topics: stats.TopicsLimit, Emails: stats.EmailsLimit,
AttachmentTotalSize: stats.AttachmentTotalSizeLimit, Topics: stats.TopicsLimit,
AttachmentFileSize: stats.AttachmentFileSizeLimit, AttachmentTotalSize: stats.AttachmentTotalSizeLimit,
AttachmentFileSize: stats.AttachmentFileSizeLimit,
AttachmentExpiryDuration: stats.AttachmentExpiryDuration,
}, },
} }
if v.user != nil { if v.user != nil {
@ -325,6 +326,9 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
} }
func (s *Server) handleAccountAccessAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountAccessAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user != nil && v.user.Role == user.RoleAdmin {
return errHTTPBadRequestMakesNoSenseForAdmin
}
req, err := readJSONWithLimit[apiAccountAccessRequest](r.Body, jsonBodyBytesLimit) req, err := readJSONWithLimit[apiAccountAccessRequest](r.Body, jsonBodyBytesLimit)
if err != nil { if err != nil {
return err return err
@ -337,7 +341,7 @@ func (s *Server) handleAccountAccessAdd(w http.ResponseWriter, r *http.Request,
return errHTTPBadRequestPermissionInvalid return errHTTPBadRequestPermissionInvalid
} }
if v.user.Plan == nil { if v.user.Plan == nil {
return errors.New("no plan") // FIXME there should always be a plan! return errHTTPUnauthorized // FIXME there should always be a plan!
} }
if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil { if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil {
return errHTTPConflictTopicReserved return errHTTPConflictTopicReserved

View File

@ -351,7 +351,7 @@ func TestAccount_Reservation_Add_User_No_Plan_Failure(t *testing.T) {
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
rr = request(t, s, "POST", "/v1/account/access", `{"everyone":"deny-all"}`, map[string]string{ rr = request(t, s, "POST", "/v1/account/access", `{"topic":"mytopic", "everyone":"deny-all"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"), "Authorization": util.BasicAuth("phil", "mypass"),
}) })
require.Equal(t, 401, rr.Code) require.Equal(t, 401, rr.Code)
@ -363,10 +363,11 @@ func TestAccount_Reservation_Add_Admin_Success(t *testing.T) {
s := newTestServer(t, conf) s := newTestServer(t, conf)
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin)) require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin))
rr := request(t, s, "POST", "/v1/account/access", `{"everyone":"deny-all"}`, map[string]string{ rr := request(t, s, "POST", "/v1/account/access", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "adminpass"), "Authorization": util.BasicAuth("phil", "adminpass"),
}) })
require.Equal(t, 200, rr.Code) require.Equal(t, 400, rr.Code)
require.Equal(t, 40026, toHTTPError(t, rr.Body.String()).Code)
} }
func TestAccount_Reservation_Add_Remove_User_With_Plan_Success(t *testing.T) { func TestAccount_Reservation_Add_Remove_User_With_Plan_Success(t *testing.T) {
@ -383,8 +384,8 @@ func TestAccount_Reservation_Add_Remove_User_With_Plan_Success(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
_, err = db.Exec(` _, err = db.Exec(`
INSERT INTO plan (id, code, messages_limit, emails_limit, attachment_file_size_limit, attachment_total_size_limit, topics_limit) INSERT INTO plan (id, code, messages_limit, messages_expiry_duration, emails_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, topics_limit)
VALUES (1, 'testplan', 10, 10, 10, 10, 2); VALUES (1, 'testplan', 10, 86400, 10, 10, 10, 10800, 2);
UPDATE user SET plan_id = 1 WHERE user = 'phil'; UPDATE user SET plan_id = 1 WHERE user = 'phil';
`) `)
@ -455,8 +456,8 @@ func TestAccount_Reservation_Add_Access_By_Anonymous_Fails(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
_, err = db.Exec(` _, err = db.Exec(`
INSERT INTO plan (id, code, messages_limit, emails_limit, attachment_file_size_limit, attachment_total_size_limit, topics_limit) INSERT INTO plan (id, code, messages_limit, messages_expiry_duration, emails_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, topics_limit)
VALUES (1, 'testplan', 10, 10, 10, 10, 2); VALUES (1, 'testplan', 10, 86400, 10, 10, 10, 10800, 2);
UPDATE user SET plan_id = 1 WHERE user = 'phil'; UPDATE user SET plan_id = 1 WHERE user = 'phil';
`) `)

View File

@ -1271,7 +1271,7 @@ func TestServer_PublishAttachmentAndPrune(t *testing.T) {
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
require.Equal(t, content, response.Body.String()) require.Equal(t, content, response.Body.String())
// Prune and makes sure it's gone // DeleteMessages and makes sure it's gone
time.Sleep(time.Second) // Sigh ... time.Sleep(time.Second) // Sigh ...
s.execManager() s.execManager()
require.NoFileExists(t, file) require.NoFileExists(t, file)

View File

@ -23,9 +23,10 @@ const (
// message represents a message published to a topic // message represents a message published to a topic
type message struct { type message struct {
ID string `json:"id"` // Random message ID ID string `json:"id"` // Random message ID
Time int64 `json:"time"` // Unix time in seconds Time int64 `json:"time"` // Unix time in seconds
Event string `json:"event"` // One of the above Expires int64 `json:"expires"` // Unix time in seconds
Event string `json:"event"` // One of the above
Topic string `json:"topic"` Topic string `json:"topic"`
Title string `json:"title,omitempty"` Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"` Message string `json:"message,omitempty"`
@ -240,12 +241,14 @@ type apiAccountPlan struct {
} }
type apiAccountLimits struct { type apiAccountLimits struct {
Basis string `json:"basis"` // "ip", "role" or "plan" Basis string `json:"basis"` // "ip", "role" or "plan"
Messages int64 `json:"messages"` Messages int64 `json:"messages"`
Emails int64 `json:"emails"` MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
Topics int64 `json:"topics"` Emails int64 `json:"emails"`
AttachmentTotalSize int64 `json:"attachment_total_size"` Topics int64 `json:"topics"`
AttachmentFileSize int64 `json:"attachment_file_size"` AttachmentTotalSize int64 `json:"attachment_total_size"`
AttachmentFileSize int64 `json:"attachment_file_size"`
AttachmentExpiryDuration int64 `json:"attachment_expiry_duration"`
} }
type apiAccountStats struct { type apiAccountStats struct {

View File

@ -46,6 +46,7 @@ type visitorInfo struct {
Messages int64 Messages int64
MessagesLimit int64 MessagesLimit int64
MessagesRemaining int64 MessagesRemaining int64
MessagesExpiryDuration int64
Emails int64 Emails int64
EmailsLimit int64 EmailsLimit int64
EmailsRemaining int64 EmailsRemaining int64
@ -56,6 +57,7 @@ type visitorInfo struct {
AttachmentTotalSizeLimit int64 AttachmentTotalSizeLimit int64
AttachmentTotalSizeRemaining int64 AttachmentTotalSizeRemaining int64
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64
AttachmentExpiryDuration int64
} }
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
@ -179,20 +181,26 @@ func (v *visitor) Info() (*visitorInfo, error) {
if v.user != nil && v.user.Role == user.RoleAdmin { if v.user != nil && v.user.Role == user.RoleAdmin {
info.Basis = "role" info.Basis = "role"
// All limits are zero! // All limits are zero!
info.MessagesExpiryDuration = 24 * 3600 // FIXME this is awful. Should be from the Unlimited plan
info.AttachmentExpiryDuration = 24 * 3600 // FIXME this is awful. Should be from the Unlimited plan
} else if v.user != nil && v.user.Plan != nil { } else if v.user != nil && v.user.Plan != nil {
info.Basis = "plan" info.Basis = "plan"
info.MessagesLimit = v.user.Plan.MessagesLimit info.MessagesLimit = v.user.Plan.MessagesLimit
info.MessagesExpiryDuration = v.user.Plan.MessagesExpiryDuration
info.EmailsLimit = v.user.Plan.EmailsLimit info.EmailsLimit = v.user.Plan.EmailsLimit
info.TopicsLimit = v.user.Plan.TopicsLimit info.TopicsLimit = v.user.Plan.TopicsLimit
info.AttachmentTotalSizeLimit = v.user.Plan.AttachmentTotalSizeLimit info.AttachmentTotalSizeLimit = v.user.Plan.AttachmentTotalSizeLimit
info.AttachmentFileSizeLimit = v.user.Plan.AttachmentFileSizeLimit info.AttachmentFileSizeLimit = v.user.Plan.AttachmentFileSizeLimit
info.AttachmentExpiryDuration = v.user.Plan.AttachmentExpiryDuration
} else { } else {
info.Basis = "ip" info.Basis = "ip"
info.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish) info.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish)
info.MessagesExpiryDuration = int64(v.config.CacheDuration.Seconds())
info.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish) info.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish)
info.TopicsLimit = 0 // FIXME info.TopicsLimit = 0 // FIXME
info.AttachmentTotalSizeLimit = v.config.VisitorAttachmentTotalSizeLimit info.AttachmentTotalSizeLimit = v.config.VisitorAttachmentTotalSizeLimit
info.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit info.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit
info.AttachmentExpiryDuration = int64(v.config.AttachmentExpiryDuration.Seconds())
} }
var attachmentsBytesUsed int64 // FIXME Maybe move this to endpoint? var attachmentsBytesUsed int64 // FIXME Maybe move this to endpoint?
var err error var err error

View File

@ -36,10 +36,12 @@ const (
id INT NOT NULL, id INT NOT NULL,
code TEXT NOT NULL, code TEXT NOT NULL,
messages_limit INT NOT NULL, messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL, emails_limit INT NOT NULL,
topics_limit INT NOT NULL, topics_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL, attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
PRIMARY KEY (id) PRIMARY KEY (id)
); );
CREATE TABLE IF NOT EXISTS user ( CREATE TABLE IF NOT EXISTS user (
@ -83,13 +85,13 @@ const (
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.topics_limit, p.attachment_file_size_limit, p.attachment_total_size_limit SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.topics_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
FROM user u FROM user u
LEFT JOIN plan p on p.id = u.plan_id LEFT JOIN plan p on p.id = u.plan_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.topics_limit, p.attachment_file_size_limit, p.attachment_total_size_limit SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.topics_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN plan p on p.id = u.plan_id LEFT JOIN plan p on p.id = u.plan_id
@ -375,7 +377,7 @@ func (a *Manager) userStatsQueueWriter(interval time.Duration) {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
for range ticker.C { for range ticker.C {
if err := a.writeUserStatsQueue(); err != nil { if err := a.writeUserStatsQueue(); err != nil {
log.Warn("UserManager: Writing user stats queue failed: %s", err.Error()) log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
} }
} }
} }
@ -384,7 +386,7 @@ func (a *Manager) writeUserStatsQueue() error {
a.mu.Lock() a.mu.Lock()
if len(a.statsQueue) == 0 { if len(a.statsQueue) == 0 {
a.mu.Unlock() a.mu.Unlock()
log.Trace("UserManager: No user stats updates to commit") log.Trace("User Manager: No user stats updates to commit")
return nil return nil
} }
statsQueue := a.statsQueue statsQueue := a.statsQueue
@ -395,9 +397,9 @@ func (a *Manager) writeUserStatsQueue() error {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
log.Debug("UserManager: Writing user stats queue for %d user(s)", len(statsQueue)) log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
for username, u := range statsQueue { for username, u := range statsQueue {
log.Trace("UserManager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails) log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails)
if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil { if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil {
return err return err
} }
@ -523,11 +525,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
var username, hash, role string var username, hash, role string
var settings, planCode sql.NullString var settings, planCode sql.NullString
var messages, emails int64 var messages, emails int64
var messagesLimit, emailsLimit, topicsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, topicsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrNotFound return nil, ErrNotFound
} }
if err := rows.Scan(&username, &hash, &role, &messages, &emails, &settings, &planCode, &messagesLimit, &emailsLimit, &topicsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit); err != nil { if err := rows.Scan(&username, &hash, &role, &messages, &emails, &settings, &planCode, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &topicsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -552,10 +554,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
Code: planCode.String, Code: planCode.String,
Upgradeable: false, Upgradeable: false,
MessagesLimit: messagesLimit.Int64, MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: messagesExpiryDuration.Int64,
EmailsLimit: emailsLimit.Int64, EmailsLimit: emailsLimit.Int64,
TopicsLimit: topicsLimit.Int64, TopicsLimit: topicsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: attachmentExpiryDuration.Int64,
} }
} }
return user, nil return user, nil

View File

@ -58,10 +58,12 @@ type Plan struct {
Code string `json:"name"` Code string `json:"name"`
Upgradeable bool `json:"upgradeable"` Upgradeable bool `json:"upgradeable"`
MessagesLimit int64 `json:"messages_limit"` MessagesLimit int64 `json:"messages_limit"`
MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
EmailsLimit int64 `json:"emails_limit"` EmailsLimit int64 `json:"emails_limit"`
TopicsLimit int64 `json:"topics_limit"` TopicsLimit int64 `json:"topics_limit"`
AttachmentFileSizeLimit int64 `json:"attachment_file_size_limit"` AttachmentFileSizeLimit int64 `json:"attachment_file_size_limit"`
AttachmentTotalSizeLimit int64 `json:"attachment_total_size_limit"` AttachmentTotalSizeLimit int64 `json:"attachment_total_size_limit"`
AttachmentExpiryDuration int64 `json:"attachment_expiry_seconds"`
} }
// Subscription represents a user's topic subscription // Subscription represents a user's topic subscription