From 98c1ab9e86ead6c8eefb8e928c1e3c20590b5766 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Wed, 8 Dec 2021 22:57:31 -0500 Subject: [PATCH] Finish cache tests --- server/cache.go | 2 +- server/cache_mem.go | 26 ++++++++++++---------- server/cache_mem_test.go | 9 ++++++++ server/cache_sqlite.go | 11 +++++----- server/cache_sqlite_test.go | 8 +++++++ server/cache_test.go | 44 +++++++++++++++++++++++++++++++++++++ server/server.go | 21 +++++++++--------- server/topic.go | 11 +++------- 8 files changed, 95 insertions(+), 37 deletions(-) diff --git a/server/cache.go b/server/cache.go index 5e76f5d0..b3557910 100644 --- a/server/cache.go +++ b/server/cache.go @@ -17,5 +17,5 @@ type cache interface { Messages(topic string, since sinceTime) ([]*message, error) MessageCount(topic string) (int, error) Topics() (map[string]*topic, error) - Prune(keep time.Duration) error + Prune(olderThan time.Time) error } diff --git a/server/cache_mem.go b/server/cache_mem.go index d524ecc2..0edcd488 100644 --- a/server/cache_mem.go +++ b/server/cache_mem.go @@ -57,26 +57,30 @@ func (s *memCache) MessageCount(topic string) (int, error) { } func (s *memCache) Topics() (map[string]*topic, error) { - // Hack since we know when this is called there are no messages! - return make(map[string]*topic), nil + s.mu.Lock() + defer s.mu.Unlock() + topics := make(map[string]*topic) + for topic := range s.messages { + topics[topic] = newTopic(topic) + } + return topics, nil } -func (s *memCache) Prune(keep time.Duration) error { +func (s *memCache) Prune(olderThan time.Time) error { s.mu.Lock() defer s.mu.Unlock() for topic := range s.messages { - s.pruneTopic(topic, keep) + s.pruneTopic(topic, olderThan) } return nil } -func (s *memCache) pruneTopic(topic string, keep time.Duration) { - for i, m := range s.messages[topic] { - msgTime := time.Unix(m.Time, 0) - if time.Since(msgTime) < keep { - s.messages[topic] = s.messages[topic][i:] - return +func (s *memCache) pruneTopic(topic string, olderThan time.Time) { + messages := make([]*message, 0) + for _, m := range s.messages[topic] { + if m.Time >= olderThan.Unix() { + messages = append(messages, m) } } - s.messages[topic] = make([]*message, 0) // all messages expired + s.messages[topic] = messages } diff --git a/server/cache_mem_test.go b/server/cache_mem_test.go index 8c591b40..0878fe80 100644 --- a/server/cache_mem_test.go +++ b/server/cache_mem_test.go @@ -7,6 +7,15 @@ import ( func TestMemCache_Messages(t *testing.T) { testCacheMessages(t, newMemCache()) } + +func TestMemCache_Topics(t *testing.T) { + testCacheTopics(t, newMemCache()) +} + func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) { testCacheMessagesTagsPrioAndTitle(t, newMemCache()) } + +func TestMemCache_Prune(t *testing.T) { + testCachePrune(t, newMemCache()) +} diff --git a/server/cache_sqlite.go b/server/cache_sqlite.go index 6c53d6f2..3c3564de 100644 --- a/server/cache_sqlite.go +++ b/server/cache_sqlite.go @@ -36,7 +36,7 @@ const ( ` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` - selectTopicsQuery = `SELECT topic, MAX(time) FROM messages GROUP BY topic` + selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` ) // Schema management queries @@ -153,11 +153,10 @@ func (c *sqliteCache) Topics() (map[string]*topic, error) { topics := make(map[string]*topic) for rows.Next() { var id string - var last int64 - if err := rows.Scan(&id, &last); err != nil { + if err := rows.Scan(&id); err != nil { return nil, err } - topics[id] = newTopic(id, time.Unix(last, 0)) + topics[id] = newTopic(id) } if err := rows.Err(); err != nil { return nil, err @@ -165,8 +164,8 @@ func (c *sqliteCache) Topics() (map[string]*topic, error) { return topics, nil } -func (c *sqliteCache) Prune(keep time.Duration) error { - _, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1*keep).Unix()) +func (c *sqliteCache) Prune(olderThan time.Time) error { + _, err := c.db.Exec(pruneMessagesQuery, olderThan.Unix()) return err } diff --git a/server/cache_sqlite_test.go b/server/cache_sqlite_test.go index 214f7219..4e7d2f1e 100644 --- a/server/cache_sqlite_test.go +++ b/server/cache_sqlite_test.go @@ -9,10 +9,18 @@ func TestSqliteCache_AddMessage(t *testing.T) { testCacheMessages(t, newSqliteTestCache(t)) } +func TestSqliteCache_Topics(t *testing.T) { + testCacheTopics(t, newSqliteTestCache(t)) +} + func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) { testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t)) } +func TestSqliteCache_Prune(t *testing.T) { + testCachePrune(t, newSqliteTestCache(t)) +} + func newSqliteTestCache(t *testing.T) cache { filename := filepath.Join(t.TempDir(), "cache.db") c, err := newSqliteCache(filename) diff --git a/server/cache_test.go b/server/cache_test.go index fdf87d53..ab65b062 100644 --- a/server/cache_test.go +++ b/server/cache_test.go @@ -65,6 +65,50 @@ func testCacheMessages(t *testing.T, c cache) { assert.Empty(t, messages) } +func testCacheTopics(t *testing.T, c cache) { + assert.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message"))) + assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1"))) + assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2"))) + assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3"))) + + topics, err := c.Topics() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 2, len(topics)) + assert.Equal(t, "topic1", topics["topic1"].ID) + assert.Equal(t, "topic2", topics["topic2"].ID) +} + +func testCachePrune(t *testing.T, c cache) { + m1 := newDefaultMessage("mytopic", "my message") + m1.Time = 1 + + m2 := newDefaultMessage("mytopic", "my other message") + m2.Time = 2 + + m3 := newDefaultMessage("another_topic", "and another one") + m3.Time = 1 + + assert.Nil(t, c.AddMessage(m1)) + assert.Nil(t, c.AddMessage(m2)) + assert.Nil(t, c.AddMessage(m3)) + assert.Nil(t, c.Prune(time.Unix(2, 0))) + + count, err := c.MessageCount("mytopic") + assert.Nil(t, err) + assert.Equal(t, 1, count) + + count, err = c.MessageCount("another_topic") + assert.Nil(t, err) + assert.Equal(t, 0, count) + + messages, err := c.Messages("mytopic", sinceAllMessages) + assert.Nil(t, err) + assert.Equal(t, 1, len(messages)) + assert.Equal(t, "my other message", messages[0].Message) +} + func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) { m := newDefaultMessage("mytopic", "some message") m.Tags = []string{"tag1", "tag2"} diff --git a/server/server.go b/server/server.go index c8db3fc5..0edaab98 100644 --- a/server/server.go +++ b/server/server.go @@ -274,7 +274,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito if err != nil { return err } - m := newDefaultMessage(t.id, string(b)) + m := newDefaultMessage(t.ID, string(b)) if m.Message == "" { return errHTTPBadRequest } @@ -442,7 +442,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceTime, sub subscribe return nil } for _, t := range topics { - messages, err := s.cache.Messages(t.id, since) + messages, err := s.cache.Messages(t.ID, since) if err != nil { return err } @@ -468,11 +468,9 @@ func parseSince(r *http.Request) (sinceTime, error) { } if r.URL.Query().Get("since") == "all" { return sinceAllMessages, nil - } - if s, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil { + } else if s, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil { return sinceTime(time.Unix(s, 0)), nil - } - if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil { + } else if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil { return sinceTime(time.Now().Add(-1 * d)), nil } return sinceNoMessages, errHTTPBadRequest @@ -504,7 +502,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { if len(s.topics) >= s.config.GlobalTopicLimit { return nil, errHTTPTooManyRequests } - s.topics[id] = newTopic(id, time.Now()) + s.topics[id] = newTopic(id) if s.firebase != nil { s.topics[id].Subscribe(s.firebase) } @@ -526,7 +524,8 @@ func (s *Server) updateStatsAndExpire() { } // Prune cache - if err := s.cache.Prune(s.config.CacheDuration); err != nil { + olderThan := time.Now().Add(-1 * s.config.CacheDuration) + if err := s.cache.Prune(olderThan); err != nil { log.Printf("error pruning cache: %s", err.Error()) } @@ -534,13 +533,13 @@ func (s *Server) updateStatsAndExpire() { var subscribers, messages int for _, t := range s.topics { subs := t.Subscribers() - msgs, err := s.cache.MessageCount(t.id) + msgs, err := s.cache.MessageCount(t.ID) if err != nil { - log.Printf("cannot get stats for topic %s: %s", t.id, err.Error()) + log.Printf("cannot get stats for topic %s: %s", t.ID, err.Error()) continue } if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber! - delete(s.topics, t.id) + delete(s.topics, t.ID) continue } subscribers += subs diff --git a/server/topic.go b/server/topic.go index fb0ecac4..9badd7bd 100644 --- a/server/topic.go +++ b/server/topic.go @@ -4,14 +4,12 @@ import ( "log" "math/rand" "sync" - "time" ) // topic represents a channel to which subscribers can subscribe, and publishers // can publish a message type topic struct { - id string - last time.Time + ID string subscribers map[int]subscriber mu sync.Mutex } @@ -20,10 +18,9 @@ type topic struct { type subscriber func(msg *message) error // newTopic creates a new topic -func newTopic(id string, last time.Time) *topic { +func newTopic(id string) *topic { return &topic{ - id: id, - last: last, + ID: id, subscribers: make(map[int]subscriber), } } @@ -34,7 +31,6 @@ func (t *topic) Subscribe(s subscriber) int { defer t.mu.Unlock() subscriberID := rand.Int() t.subscribers[subscriberID] = s - t.last = time.Now() return subscriberID } @@ -50,7 +46,6 @@ func (t *topic) Publish(m *message) error { go func() { t.mu.Lock() defer t.mu.Unlock() - t.last = time.Now() for _, s := range t.subscribers { if err := s(m); err != nil { log.Printf("error publishing message to subscriber")