From 52136030be49d91619a19779d3837abd983337c0 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Mon, 15 Nov 2021 07:56:58 -0500 Subject: [PATCH] Subscribe to more than one topic --- server/server.go | 77 +++++++++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/server/server.go b/server/server.go index 1cb185cb..71ab024c 100644 --- a/server/server.go +++ b/server/server.go @@ -78,9 +78,9 @@ const ( var ( topicRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! - jsonRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/json$`) - sseRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/sse$`) - rawRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/raw$`) + jsonRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`) + sseRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`) + rawRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`) staticRegex = regexp.MustCompile(`^/static/.+`) @@ -223,7 +223,7 @@ func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { } func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { - t, err := s.topic(r.URL.Path[1:]) + t, err := s.topicFromID(r.URL.Path[1:]) if err != nil { return err } @@ -289,7 +289,9 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi return errHTTPTooManyRequests } defer v.RemoveSubscription() - t, err := s.topic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack + topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack + topicIDs := strings.Split(topicsStr, ",") + topics, err := s.topicsFromIDs(topicIDs...) if err != nil { return err } @@ -314,14 +316,21 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { - return s.sendOldMessages(t, since, sub) + return s.sendOldMessages(topics, since, sub) } - subscriberID := t.Subscribe(sub) - defer t.Unsubscribe(subscriberID) - if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message + subscriberIDs := make([]int, 0) + for _, t := range topics { + subscriberIDs = append(subscriberIDs, t.Subscribe(sub)) + } + defer func() { + for i, subscriberID := range subscriberIDs { + topics[i].Unsubscribe(subscriberID) // Order! + } + }() + if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message return err } - if err := s.sendOldMessages(t, since, sub); err != nil { + if err := s.sendOldMessages(topics, since, sub); err != nil { return err } for { @@ -330,25 +339,27 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi return nil case <-time.After(s.config.KeepaliveInterval): v.Keepalive() - if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message + if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message return err } } } } -func (s *Server) sendOldMessages(t *topic, since sinceTime, sub subscriber) error { +func (s *Server) sendOldMessages(topics []*topic, since sinceTime, sub subscriber) error { if since.IsNone() { return nil } - messages, err := s.cache.Messages(t.id, since) - if err != nil { - return err - } - for _, m := range messages { - if err := sub(m); err != nil { + for _, t := range topics { + messages, err := s.cache.Messages(t.id, since) + if err != nil { return err } + for _, m := range messages { + if err := sub(m); err != nil { + return err + } + } } return nil } @@ -382,19 +393,31 @@ func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error { return nil } -func (s *Server) topic(id string) (*topic, error) { +func (s *Server) topicFromID(id string) (*topic, error) { + topics, err := s.topicsFromIDs(id) + if err != nil { + return nil, err + } + return topics[0], nil +} + +func (s *Server) topicsFromIDs(ids... string) ([]*topic, error) { s.mu.Lock() defer s.mu.Unlock() - if _, ok := s.topics[id]; !ok { - if len(s.topics) >= s.config.GlobalTopicLimit { - return nil, errHTTPTooManyRequests - } - s.topics[id] = newTopic(id, time.Now()) - if s.firebase != nil { - s.topics[id].Subscribe(s.firebase) + topics := make([]*topic, 0) + for _, id := range ids { + if _, ok := s.topics[id]; !ok { + if len(s.topics) >= s.config.GlobalTopicLimit { + return nil, errHTTPTooManyRequests + } + s.topics[id] = newTopic(id, time.Now()) + if s.firebase != nil { + s.topics[id].Subscribe(s.firebase) + } } + topics = append(topics, s.topics[id]) } - return s.topics[id], nil + return topics, nil } func (s *Server) updateStatsAndExpire() {