diff --git a/server/config.go b/server/config.go index 1e2b517c..888eb096 100644 --- a/server/config.go +++ b/server/config.go @@ -116,7 +116,7 @@ func NewConfig() *Config { FirebaseKeyFile: "", CacheFile: "", CacheDuration: DefaultCacheDuration, - CacheBatchSize: 0, + CacheBatchSize: 10, CacheBatchTimeout: 0, AuthFile: "", AuthDefaultRead: true, diff --git a/server/message_cache.go b/server/message_cache.go index 376c7611..48f43ef0 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -195,7 +195,7 @@ const ( type messageCache struct { db *sql.DB - queue *util.BatchingQueue[*message] + queue chan *message nop bool } @@ -208,16 +208,18 @@ func newSqliteCache(filename, startupQueries string, batchSize int, batchTimeout if err := setupCacheDB(db, startupQueries); err != nil { return nil, err } - var queue *util.BatchingQueue[*message] - if batchSize > 0 || batchTimeout > 0 { - queue = util.NewBatchingQueue[*message](batchSize, batchTimeout) - } cache := &messageCache{ - db: db, - queue: queue, - nop: nop, + db: db, + nop: nop, } - go cache.processMessageBatches() + if batchSize > 0 { + // here, batchSize determines the maximum number of unprocessed messages the server will + // buffer before (briefly) blocking. + cache.queue = make(chan *message, batchSize) + } + // here, batchSize indicates the maximum number of messages which will be inserted into + // the database per transaction. + go cache.processMessageBatches(batchSize, batchTimeout) return cache, nil } @@ -242,11 +244,17 @@ func createMemoryFilename() string { return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10)) } -// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously. -// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor. +// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously. +// The message is queued only if "batchSize" > 0 is passed to the constructor. func (c *messageCache) AddMessage(m *message) error { + if m.Event != messageEvent { + return errUnexpectedMessageType + } if c.queue != nil { - c.queue.Enqueue(m) + c.queue <- m + return nil + } + if c.nop { return nil } return c.addMessages([]*message{m}) @@ -255,9 +263,6 @@ func (c *messageCache) AddMessage(m *message) error { // addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until // SQLite's busy_timeout is exceeded before erroring out. func (c *messageCache) addMessages(ms []*message) error { - if c.nop { - return nil - } if len(ms) == 0 { return nil } @@ -273,9 +278,6 @@ func (c *messageCache) addMessages(ms []*message) error { } defer stmt.Close() for _, m := range ms { - if m.Event != messageEvent { - return errUnexpectedMessageType - } published := m.Time <= time.Now().Unix() tags := strings.Join(m.Tags, ",") var attachmentName, attachmentType, attachmentURL string @@ -460,13 +462,46 @@ func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) { return size, nil } -func (c *messageCache) processMessageBatches() { +func (c *messageCache) processMessageBatches(batchSize int, batchTimeout time.Duration) { if c.queue == nil { return } - for messages := range c.queue.Dequeue() { + // initialise the array once to avoid needing to recreate it each iteration + var messagebuffer []*message = make([]*message, batchSize) + var bufferRemainingCapacityOnPreviousIteration int + for { + messages := messagebuffer[:0] + + // To increase the efficiency of database insertions, optionally + // delay processing the incoming message stream for a short period, + // unless the previous batch insertion held the maximum allowable + // number of messages. + if batchTimeout > 0 && bufferRemainingCapacityOnPreviousIteration > 0 { + time.Sleep(batchTimeout) + } + + // Perform a blocking read; until at least one message + // is pending, there is nothing to do. + messages = append(messages, <-c.queue) + + retrieve_messages_from_channel: + for { + select { + case message := <-c.queue: + messages = append(messages, message) + if batchSize == len(messages) { + // no more room in the messagebuffer. + break retrieve_messages_from_channel + } + default: + // no more incoming messages. + break retrieve_messages_from_channel + } + } + bufferRemainingCapacityOnPreviousIteration = batchSize - len(messages) + if err := c.addMessages(messages); err != nil { - log.Error("Cache: %s", err.Error()) + log.Error("processMessageBatches: %s", err.Error()) } } } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 2dcd7b3e..571b6eb3 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -16,6 +16,10 @@ var ( exampleIP1234 = netip.MustParseAddr("1.2.3.4") ) +func TestBufferedSqliteCache_Messages(t *testing.T) { + testCacheMessages(t, newBufferedSqliteTestCache(t, 10, 0)) +} + func TestSqliteCache_Messages(t *testing.T) { testCacheMessages(t, newSqliteTestCache(t)) } @@ -24,6 +28,68 @@ func TestMemCache_Messages(t *testing.T) { testCacheMessages(t, newMemTestCache(t)) } +func TestBufferedCacheFlushBehaviour(t *testing.T) { + cooldown := time.Millisecond * 100 + queueSize := 3 + c := newBufferedSqliteTestCache(t, queueSize, cooldown) + + // Add a single message. It should be buffered but not yet processed. + require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my example message"))) + counts, err := c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 0, counts["mytopic"]) + + // wait less than cooldown. Because it's the first one, + // it should be processed without delay, so should be visible by now. + time.Sleep(cooldown / 3) + counts, err = c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 1, counts["mytopic"]) + + // Add a second message. It should be buffered but not yet processed, + // even after waiting + require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my example message"))) + time.Sleep(cooldown / 3) + counts, err = c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 1, counts["mytopic"]) + + // If we wait a litle longer, enough time for the cooldown to expire, the second + // message should be processed + time.Sleep(2 * cooldown / 3) + counts, err = c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 2, counts["mytopic"]) + + // At this point the queue should be empty, and ~1/3 into its cooldown period. + // Attempt to send exactly the number of messages our queue has capacity for + t1 := time.Now() + for i := 0; i < queueSize; i++ { + require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my example message"))) + } + // These insertions should not have taken much time at all; they should have completed + // well before the cooldown period ends + require.Less(t, time.Since(t1), cooldown/3) + + // Assert that none of these messages have been processed + counts, err = c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 2, counts["mytopic"]) + + // Add an extra message. Because the buffered queue is at capacity, this should block + // this goroutine until the cooldown period has expired, and at least one of the pending + // messages has been read from the channel. + require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my example message"))) + require.Greater(t, time.Since(t1), cooldown/3) + + // Because the channel was full, there should not be a cooldown, and our new message should + // be processed without delay + time.Sleep(cooldown / 3) + counts, err = c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 3+queueSize, counts["mytopic"]) +} + func testCacheMessages(t *testing.T, c *messageCache) { m1 := newDefaultMessage("mytopic", "my message") m1.Time = 1 @@ -39,6 +105,11 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added! require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! + // If a queue is used, allow time for async processing to occur + if c.queue != nil { + time.Sleep(time.Millisecond * 100) + } + // mytopic: count counts, err := c.MessageCounts() require.Nil(t, err) @@ -161,7 +232,6 @@ func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) { m.Priority = 5 m.Title = "some title" require.Nil(t, c.AddMessage(m)) - messages, _ := c.Messages("mytopic", sinceAllMessages, false) require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) require.Equal(t, 5, messages[0].Priority) @@ -523,6 +593,14 @@ func TestMemCache_NopCache(t *testing.T) { assert.Empty(t, topics) } +func newBufferedSqliteTestCache(t *testing.T, queueSize int, cooldown time.Duration) *messageCache { + c, err := newSqliteCache(newSqliteTestCacheFile(t), "", queueSize, cooldown, false) + if err != nil { + t.Fatal(err) + } + return c +} + func newSqliteTestCache(t *testing.T) *messageCache { c, err := newSqliteCache(newSqliteTestCacheFile(t), "", 0, 0, false) if err != nil { diff --git a/server/server_test.go b/server/server_test.go index e328cb1b..cd4f266a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1465,6 +1465,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { func newTestConfig(t *testing.T) *Config { conf := NewConfig() + conf.CacheBatchSize = 0 conf.BaseURL = "http://127.0.0.1:12345" conf.CacheFile = filepath.Join(t.TempDir(), "cache.db") conf.AttachmentCacheDir = t.TempDir() diff --git a/test/server.go b/test/server.go index 0b9200a6..991a30df 100644 --- a/test/server.go +++ b/test/server.go @@ -2,12 +2,13 @@ package test import ( "fmt" - "heckel.io/ntfy/server" "math/rand" "net/http" "path/filepath" "testing" "time" + + "heckel.io/ntfy/server" ) func init() { @@ -24,6 +25,7 @@ func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, i port := 10000 + rand.Intn(20000) conf.ListenHTTP = fmt.Sprintf(":%d", port) conf.AttachmentCacheDir = t.TempDir() + conf.CacheBatchSize = 0 conf.CacheFile = filepath.Join(t.TempDir(), "cache.db") s, err := server.New(conf) if err != nil { diff --git a/util/batching_queue.go b/util/batching_queue.go deleted file mode 100644 index 85ba9be9..00000000 --- a/util/batching_queue.go +++ /dev/null @@ -1,86 +0,0 @@ -package util - -import ( - "sync" - "time" -) - -// BatchingQueue is a queue that creates batches of the enqueued elements based on a -// max batch size and a batch timeout. -// -// Example: -// -// q := NewBatchingQueue[int](2, 500 * time.Millisecond) -// go func() { -// for batch := range q.Dequeue() { -// fmt.Println(batch) -// } -// }() -// q.Enqueue(1) -// q.Enqueue(2) -// q.Enqueue(3) -// time.Sleep(time.Second) -// -// This example will emit batch [1, 2] immediately (because the batch size is 2), and -// a batch [3] after 500ms. -type BatchingQueue[T any] struct { - batchSize int - timeout time.Duration - in []T - out chan []T - mu sync.Mutex -} - -// NewBatchingQueue creates a new BatchingQueue -func NewBatchingQueue[T any](batchSize int, timeout time.Duration) *BatchingQueue[T] { - q := &BatchingQueue[T]{ - batchSize: batchSize, - timeout: timeout, - in: make([]T, 0), - out: make(chan []T), - } - go q.timeoutTicker() - return q -} - -// Enqueue enqueues an element to the queue. If the configured batch size is reached, -// the batch will be emitted immediately. -func (q *BatchingQueue[T]) Enqueue(element T) { - q.mu.Lock() - q.in = append(q.in, element) - var elements []T - if len(q.in) == q.batchSize { - elements = q.dequeueAll() - } - q.mu.Unlock() - if len(elements) > 0 { - q.out <- elements - } -} - -// Dequeue returns a channel emitting batches of elements -func (q *BatchingQueue[T]) Dequeue() <-chan []T { - return q.out -} - -func (q *BatchingQueue[T]) dequeueAll() []T { - elements := make([]T, len(q.in)) - copy(elements, q.in) - q.in = q.in[:0] - return elements -} - -func (q *BatchingQueue[T]) timeoutTicker() { - if q.timeout == 0 { - return - } - ticker := time.NewTicker(q.timeout) - for range ticker.C { - q.mu.Lock() - elements := q.dequeueAll() - q.mu.Unlock() - if len(elements) > 0 { - q.out <- elements - } - } -} diff --git a/util/batching_queue_test.go b/util/batching_queue_test.go deleted file mode 100644 index b3c41a4c..00000000 --- a/util/batching_queue_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package util_test - -import ( - "github.com/stretchr/testify/require" - "heckel.io/ntfy/util" - "math/rand" - "sync" - "testing" - "time" -) - -func TestBatchingQueue_InfTimeout(t *testing.T) { - q := util.NewBatchingQueue[int](25, 1*time.Hour) - batches, total := make([][]int, 0), 0 - var mu sync.Mutex - go func() { - for batch := range q.Dequeue() { - mu.Lock() - batches = append(batches, batch) - total += len(batch) - mu.Unlock() - } - }() - for i := 0; i < 101; i++ { - go q.Enqueue(i) - } - time.Sleep(time.Second) - mu.Lock() - require.Equal(t, 100, total) // One is missing, stuck in the last batch! - require.Equal(t, 4, len(batches)) - mu.Unlock() -} - -func TestBatchingQueue_WithTimeout(t *testing.T) { - q := util.NewBatchingQueue[int](25, 100*time.Millisecond) - batches, total := make([][]int, 0), 0 - var mu sync.Mutex - go func() { - for batch := range q.Dequeue() { - mu.Lock() - batches = append(batches, batch) - total += len(batch) - mu.Unlock() - } - }() - for i := 0; i < 101; i++ { - go func(i int) { - time.Sleep(time.Duration(rand.Intn(700)) * time.Millisecond) - q.Enqueue(i) - }(i) - } - time.Sleep(time.Second) - mu.Lock() - require.Equal(t, 101, total) - require.True(t, len(batches) > 4) // 101/25 - require.True(t, len(batches) < 21) - mu.Unlock() -}