commit
07b5d9a9df
|
@ -82,7 +82,7 @@ const (
|
||||||
`
|
`
|
||||||
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
||||||
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||||
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
|
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
|
||||||
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||||
selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
|
selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
|
||||||
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
|
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
|
||||||
|
@ -332,22 +332,24 @@ func (c *messageCache) MarkPublished(m *message) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *messageCache) MessageCount(topic string) (int, error) {
|
func (c *messageCache) MessageCounts() (map[string]int, error) {
|
||||||
rows, err := c.db.Query(selectMessageCountForTopicQuery, topic)
|
rows, err := c.db.Query(selectMessageCountPerTopicQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
var topic string
|
||||||
var count int
|
var count int
|
||||||
if !rows.Next() {
|
counts := make(map[string]int)
|
||||||
return 0, errors.New("no rows found")
|
for rows.Next() {
|
||||||
}
|
if err := rows.Scan(&topic, &count); err != nil {
|
||||||
if err := rows.Scan(&count); err != nil {
|
return nil, err
|
||||||
return 0, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
} else if err := rows.Err(); err != nil {
|
||||||
return 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return count, nil
|
counts[topic] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *messageCache) Topics() (map[string]*topic, error) {
|
func (c *messageCache) Topics() (map[string]*topic, error) {
|
||||||
|
|
|
@ -34,9 +34,9 @@ func testCacheMessages(t *testing.T, c *messageCache) {
|
||||||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
|
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
|
||||||
|
|
||||||
// mytopic: count
|
// mytopic: count
|
||||||
count, err := c.MessageCount("mytopic")
|
counts, err := c.MessageCounts()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 2, count)
|
require.Equal(t, 2, counts["mytopic"])
|
||||||
|
|
||||||
// mytopic: since all
|
// mytopic: since all
|
||||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
||||||
|
@ -66,18 +66,18 @@ func testCacheMessages(t *testing.T, c *messageCache) {
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
|
||||||
// example: count
|
// example: count
|
||||||
count, err = c.MessageCount("example")
|
counts, err = c.MessageCounts()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 1, count)
|
require.Equal(t, 1, counts["example"])
|
||||||
|
|
||||||
// example: since all
|
// example: since all
|
||||||
messages, _ = c.Messages("example", sinceAllMessages, false)
|
messages, _ = c.Messages("example", sinceAllMessages, false)
|
||||||
require.Equal(t, "my example message", messages[0].Message)
|
require.Equal(t, "my example message", messages[0].Message)
|
||||||
|
|
||||||
// non-existing: count
|
// non-existing: count
|
||||||
count, err = c.MessageCount("doesnotexist")
|
counts, err = c.MessageCounts()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, count)
|
require.Equal(t, 0, counts["doesnotexist"])
|
||||||
|
|
||||||
// non-existing: since all
|
// non-existing: since all
|
||||||
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
|
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
|
||||||
|
@ -255,13 +255,13 @@ func testCachePrune(t *testing.T, c *messageCache) {
|
||||||
require.Nil(t, c.AddMessage(m3))
|
require.Nil(t, c.AddMessage(m3))
|
||||||
require.Nil(t, c.Prune(time.Unix(2, 0)))
|
require.Nil(t, c.Prune(time.Unix(2, 0)))
|
||||||
|
|
||||||
count, err := c.MessageCount("mytopic")
|
counts, err := c.MessageCounts()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 1, count)
|
require.Equal(t, 1, counts["mytopic"])
|
||||||
|
|
||||||
count, err = c.MessageCount("another_topic")
|
counts, err = c.MessageCounts()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, count)
|
require.Equal(t, 0, counts["another_topic"])
|
||||||
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
|
@ -798,6 +798,13 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var wlock sync.Mutex
|
var wlock sync.Mutex
|
||||||
|
defer func() {
|
||||||
|
// Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time.
|
||||||
|
// It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER
|
||||||
|
// this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the
|
||||||
|
// data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889.
|
||||||
|
wlock.TryLock()
|
||||||
|
}()
|
||||||
sub := func(v *visitor, msg *message) error {
|
sub := func(v *visitor, msg *message) error {
|
||||||
if !filters.Pass(msg) {
|
if !filters.Pass(msg) {
|
||||||
return nil
|
return nil
|
||||||
|
@ -1080,18 +1087,23 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) updateStatsAndPrune() {
|
func (s *Server) updateStatsAndPrune() {
|
||||||
s.mu.Lock()
|
log.Debug("Manager: Starting")
|
||||||
defer s.mu.Unlock()
|
defer log.Debug("Manager: Finished")
|
||||||
|
|
||||||
|
// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
|
||||||
|
// there is no mutex for the entire function.
|
||||||
|
|
||||||
// Expire visitors from rate visitors map
|
// Expire visitors from rate visitors map
|
||||||
|
s.mu.Lock()
|
||||||
staleVisitors := 0
|
staleVisitors := 0
|
||||||
for ip, v := range s.visitors {
|
for ip, v := range s.visitors {
|
||||||
if v.Stale() {
|
if v.Stale() {
|
||||||
log.Debug("Deleting stale visitor %s", v.ip)
|
log.Trace("Deleting stale visitor %s", v.ip)
|
||||||
delete(s.visitors, ip)
|
delete(s.visitors, ip)
|
||||||
staleVisitors++
|
staleVisitors++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
|
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
|
||||||
|
|
||||||
// Delete expired attachments
|
// Delete expired attachments
|
||||||
|
@ -1116,22 +1128,31 @@ func (s *Server) updateStatsAndPrune() {
|
||||||
log.Warn("Manager: Error pruning cache: %s", err.Error())
|
log.Warn("Manager: Error pruning cache: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune old topics, remove subscriptions without subscribers
|
// Message count per topic
|
||||||
var subscribers, messages int
|
var messages int
|
||||||
for _, t := range s.topics {
|
messageCounts, err := s.messageCache.MessageCounts()
|
||||||
subs := t.Subscribers()
|
|
||||||
msgs, err := s.messageCache.MessageCount(t.ID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("Manager: Cannot get stats for topic %s: %s", t.ID, err.Error())
|
log.Warn("Manager: Cannot get message counts: %s", err.Error())
|
||||||
continue
|
messageCounts = make(map[string]int) // Empty, so we can continue
|
||||||
}
|
}
|
||||||
if msgs == 0 && subs == 0 {
|
for _, count := range messageCounts {
|
||||||
|
messages += count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove subscriptions without subscribers
|
||||||
|
s.mu.Lock()
|
||||||
|
var subscribers int
|
||||||
|
for _, t := range s.topics {
|
||||||
|
subs := t.SubscribersCount()
|
||||||
|
msgs, exists := messageCounts[t.ID]
|
||||||
|
if subs == 0 && (!exists || msgs == 0) {
|
||||||
|
log.Trace("Deleting empty topic %s", t.ID)
|
||||||
delete(s.topics, t.ID)
|
delete(s.topics, t.ID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
subscribers += subs
|
subscribers += subs
|
||||||
messages += msgs
|
|
||||||
}
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Mail stats
|
// Mail stats
|
||||||
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
|
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
|
||||||
|
@ -1219,10 +1240,10 @@ func (s *Server) sendDelayedMessages() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
log.Debug("%s Sending delayed message", logMessagePrefix(v, m))
|
log.Debug("%s Sending delayed message", logMessagePrefix(v, m))
|
||||||
|
s.mu.Lock()
|
||||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||||
|
s.mu.Unlock()
|
||||||
if ok {
|
if ok {
|
||||||
go func() {
|
go func() {
|
||||||
// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
|
// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
|
||||||
|
|
|
@ -44,14 +44,19 @@ func (t *topic) Unsubscribe(id int) {
|
||||||
// Publish asynchronously publishes to all subscribers
|
// Publish asynchronously publishes to all subscribers
|
||||||
func (t *topic) Publish(v *visitor, m *message) error {
|
func (t *topic) Publish(v *visitor, m *message) error {
|
||||||
go func() {
|
go func() {
|
||||||
t.mu.Lock()
|
// We want to lock the topic as short as possible, so we make a shallow copy of the
|
||||||
defer t.mu.Unlock()
|
// subscribers map here. Actually sending out the messages then doesn't have to lock.
|
||||||
if len(t.subscribers) > 0 {
|
subscribers := t.subscribersCopy()
|
||||||
log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(t.subscribers))
|
if len(subscribers) > 0 {
|
||||||
for _, s := range t.subscribers {
|
log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers))
|
||||||
|
for _, s := range subscribers {
|
||||||
|
// We call the subscriber functions in their own Go routines because they are blocking, and
|
||||||
|
// we don't want individual slow subscribers to be able to block others.
|
||||||
|
go func(s subscriber) {
|
||||||
if err := s(v, m); err != nil {
|
if err := s(v, m); err != nil {
|
||||||
log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
|
log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
|
||||||
}
|
}
|
||||||
|
}(s)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))
|
log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))
|
||||||
|
@ -60,9 +65,20 @@ func (t *topic) Publish(v *visitor, m *message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribers returns the number of subscribers to this topic
|
// SubscribersCount returns the number of subscribers to this topic
|
||||||
func (t *topic) Subscribers() int {
|
func (t *topic) SubscribersCount() int {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
return len(t.subscribers)
|
return len(t.subscribers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// subscribersCopy returns a shallow copy of the subscribers map
|
||||||
|
func (t *topic) subscribersCopy() map[int]subscriber {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
subscribers := make(map[int]subscriber)
|
||||||
|
for k, v := range t.subscribers {
|
||||||
|
subscribers[k] = v
|
||||||
|
}
|
||||||
|
return subscribers
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue