From c1f7bed8d10d42a317f377295b05524b0b77d11a Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Tue, 21 Jun 2022 19:45:23 -0400 Subject: [PATCH] Fix tests, lock topic as short as possible --- server/message_cache_test.go | 20 ++++++++++---------- server/server.go | 7 ++++--- server/topic.go | 26 +++++++++++++++++++------- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 9132088e..07929e0c 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -34,9 +34,9 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! // mytopic: count - count, err := c.MessageCounts("mytopic") + counts, err := c.MessageCounts() require.Nil(t, err) - require.Equal(t, 2, count) + require.Equal(t, 2, counts["mytopic"]) // mytopic: since all messages, _ := c.Messages("mytopic", sinceAllMessages, false) @@ -66,18 +66,18 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, "my other message", messages[0].Message) // example: count - count, err = c.MessageCounts("example") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 1, count) + require.Equal(t, 1, counts["example"]) // example: since all messages, _ = c.Messages("example", sinceAllMessages, false) require.Equal(t, "my example message", messages[0].Message) // non-existing: count - count, err = c.MessageCounts("doesnotexist") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 0, count) + require.Equal(t, 0, counts["doesnotexist"]) // non-existing: since all messages, _ = c.Messages("doesnotexist", sinceAllMessages, false) @@ -255,13 +255,13 @@ func testCachePrune(t *testing.T, c *messageCache) { require.Nil(t, c.AddMessage(m3)) require.Nil(t, c.Prune(time.Unix(2, 0))) - count, err := c.MessageCounts("mytopic") + counts, err := c.MessageCounts() require.Nil(t, err) - require.Equal(t, 1, count) + require.Equal(t, 1, counts["mytopic"]) - count, err = c.MessageCounts("another_topic") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 0, count) + require.Equal(t, 0, counts["another_topic"]) messages, err := c.Messages("mytopic", sinceAllMessages, false) require.Nil(t, err) diff --git a/server/server.go b/server/server.go index 2c887cc1..0801ed9c 100644 --- a/server/server.go +++ b/server/server.go @@ -1090,7 +1090,7 @@ func (s *Server) updateStatsAndPrune() { staleVisitors := 0 for ip, v := range s.visitors { if v.Stale() { - log.Debug("Deleting stale visitor %s", v.ip) + log.Trace("Deleting stale visitor %s", v.ip) delete(s.visitors, ip) staleVisitors++ } @@ -1131,13 +1131,14 @@ func (s *Server) updateStatsAndPrune() { messages += count } - // Prune old topics, remove subscriptions without subscribers + // Remove subscriptions without subscribers s.mu.Lock() var subscribers int for _, t := range s.topics { - subs := t.Subscribers() + subs := t.SubscribersCount() msgs, exists := messageCounts[t.ID] if subs == 0 && (!exists || msgs == 0) { + log.Trace("Deleting empty topic %s", t.ID) delete(s.topics, t.ID) continue } diff --git a/server/topic.go b/server/topic.go index 889f1eb7..8ce7953a 100644 --- a/server/topic.go +++ b/server/topic.go @@ -44,11 +44,12 @@ func (t *topic) Unsubscribe(id int) { // Publish asynchronously publishes to all subscribers func (t *topic) Publish(v *visitor, m *message) error { go func() { - t.mu.Lock() - defer t.mu.Unlock() - if len(t.subscribers) > 0 { - log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(t.subscribers)) - for _, s := range t.subscribers { + // We want to lock the topic as short as possible, so we make a shallow copy of the + // subscribers map here. Actually sending out the messages then doesn't have to lock. + subscribers := t.subscribersCopy() + if len(subscribers) > 0 { + log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers)) + for _, s := range subscribers { if err := s(v, m); err != nil { log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) } @@ -60,9 +61,20 @@ func (t *topic) Publish(v *visitor, m *message) error { return nil } -// Subscribers returns the number of subscribers to this topic -func (t *topic) Subscribers() int { +// SubscribersCount returns the number of subscribers to this topic +func (t *topic) SubscribersCount() int { t.mu.Lock() defer t.mu.Unlock() return len(t.subscribers) } + +// subscribersCopy returns a shallow copy of the subscribers map +func (t *topic) subscribersCopy() map[int]subscriber { + t.mu.Lock() + defer t.mu.Unlock() + subscribers := make(map[int]subscriber) + for k, v := range t.subscribers { + subscribers[k] = v + } + return subscribers +}