From bce71cb196b769ee96dc7df577abe4a2a8d98b9f Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 23 Jan 2023 14:05:41 -0500 Subject: [PATCH] Kill existing subscribers when topic is reserved --- server/server.go | 64 +++++++++++++++++++++----------- server/server_account.go | 14 ++++--- server/server_account_test.go | 69 +++++++++++++++++++++++++++++++++++ server/topic.go | 46 ++++++++++++++++++----- server/visitor.go | 12 ++++++ 5 files changed, 169 insertions(+), 36 deletions(-) diff --git a/server/server.go b/server/server.go index c6b94a2b..03bb9c25 100644 --- a/server/server.go +++ b/server/server.go @@ -38,11 +38,13 @@ import ( TODO -- -- Reservation: Kill existing subscribers when topic is reserved (deadcade) - Rate limiting: Sensitive endpoints (account/login/change-password/...) - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - Reservation (UI): Ask for confirmation when removing reservation (deadcade) - Reservation icons (UI) +- reservation table delete button: dialog "keep or delete messages?" +- UI: Flickering upgrade banner when logging in +- JS constants races: - v.user --> see publishSyncEventAsync() test @@ -63,11 +65,6 @@ Limits & rate limiting: Make sure account endpoints make sense for admins -UI: -- -- reservation table delete button: dialog "keep or delete messages?" -- flicker of upgrade banner -- JS constants Sync: - sync problems with "deleteAfter=0" and "displayName=" @@ -359,7 +356,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error()) } w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.WriteHeader(httpErr.HTTPCode) io.WriteString(w, httpErr.JSON()+"\n") } @@ -461,7 +458,7 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too! if unifiedpush { w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n") return err } @@ -538,7 +535,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) } } w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size())) - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests if r.Method == http.MethodGet { f, err := os.Open(file) if err != nil { @@ -969,14 +966,16 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * } return nil } - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! + w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests + w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { return s.sendOldMessages(topics, since, scheduled, v, sub) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -991,6 +990,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * } for { select { + case <-ctx.Done(): + return nil case <-r.Context().Done(): return nil case <-time.After(s.config.KeepaliveInterval): @@ -1033,8 +1034,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return err } defer conn.Close() + + // Subscription connections can be canceled externally, see topic.CancelSubscribers + subscriberContext, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Use errgroup to run WebSocket reader and writer in Go routines var wlock sync.Mutex - g, ctx := errgroup.WithContext(context.Background()) + g, gctx := errgroup.WithContext(context.Background()) + g.Go(func() error { + <-subscriberContext.Done() + log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r)) + conn.Close() + return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"} + }) g.Go(func() error { pongWait := s.config.KeepaliveInterval + wsPongWait conn.SetReadLimit(wsReadLimit) @@ -1050,6 +1063,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi if err != nil { return err } + select { + case <-gctx.Done(): + return nil + default: + } } }) g.Go(func() error { @@ -1064,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } for { select { - case <-ctx.Done(): + case <-gctx.Done(): return nil case <-time.After(s.config.KeepaliveInterval): v.Keepalive() @@ -1085,13 +1103,13 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } return conn.WriteJSON(msg) } - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests if poll { return s.sendOldMessages(topics, since, scheduled, v, sub) } subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1193,11 +1211,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) { if len(parts) < 2 { return nil, errHTTPBadRequestTopicInvalid } - topics, err := s.topicsFromIDs(parts[1]) - if err != nil { - return nil, err - } - return topics[0], nil + return s.topicFromID(parts[1]) } func (s *Server) topicsFromPath(path string) ([]*topic, string, error) { @@ -1232,6 +1246,14 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { return topics, nil } +func (s *Server) topicFromID(id string) (*topic, error) { + topics, err := s.topicsFromIDs(id) + if err != nil { + return nil, err + } + return topics[0], nil +} + func (s *Server) execManager() { log.Debug("Manager: Starting") defer log.Debug("Manager: Finished") diff --git a/server/server_account.go b/server/server_account.go index 6fce9f16..a08ed0bb 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -2,7 +2,6 @@ package server import ( "encoding/json" - "errors" "heckel.io/ntfy/log" "heckel.io/ntfy/user" "heckel.io/ntfy/util" @@ -331,6 +330,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ if v.user.Tier == nil { return errHTTPUnauthorized } + // CHeck if we are allowed to reserve this topic if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil { return errHTTPConflictTopicReserved } @@ -346,9 +346,16 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ return errHTTPTooManyRequestsLimitReservations } } + // Actually add the reservation if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil { return err } + // Kill existing subscribers + t, err := s.topicFromID(req.Topic) + if err != nil { + return err + } + t.CancelSubscribers(v.user.ID) return s.writeJSON(w, newSuccessResponse()) } @@ -402,13 +409,10 @@ func (s *Server) publishSyncEvent(v *visitor) error { return nil } log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic) - topics, err := s.topicsFromIDs(v.user.SyncTopic) + syncTopic, err := s.topicFromID(v.user.SyncTopic) if err != nil { return err - } else if len(topics) == 0 { - return errors.New("cannot retrieve sync topic") } - syncTopic := topics[0] messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent}) if err != nil { return err diff --git a/server/server_account_test.go b/server/server_account_test.go index 690dadd9..73610869 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -496,3 +496,72 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) { rr = request(t, s, "POST", "/mytopic", `Howdy`, nil) require.Equal(t, 403, rr.Code) } + +func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + conf.AuthDefault = user.PermissionReadWrite + conf.EnableSignup = true + s := newTestServer(t, conf) + + // Create user with tier + rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) + require.Equal(t, 200, rr.Code) + + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "pro", + MessagesLimit: 20, + ReservationsLimit: 2, + })) + require.Nil(t, s.userManager.ChangeTier("phil", "pro")) + + // Subscribe anonymously + anonCh, userCh := make(chan bool), make(chan bool) + go func() { + rr := request(t, s, "GET", "/mytopic/json", ``, nil) + require.Equal(t, 200, rr.Code) + messages := toMessages(t, rr.Body.String()) + require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message! + require.Equal(t, "open", messages[0].Event) + require.Equal(t, "message before reservation", messages[1].Message) + anonCh <- true + }() + + // Subscribe with user + go func() { + rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + messages := toMessages(t, rr.Body.String()) + require.Equal(t, 3, len(messages)) + require.Equal(t, "open", messages[0].Event) + require.Equal(t, "message before reservation", messages[1].Message) + require.Equal(t, "message after reservation", messages[2].Message) + userCh <- true + }() + + // Publish message (before reservation) + time.Sleep(700 * time.Millisecond) // Wait for subscribers + rr = request(t, s, "POST", "/mytopic", "message before reservation", nil) + require.Equal(t, 200, rr.Code) + time.Sleep(700 * time.Millisecond) // Wait for subscribers to receive message + + // Reserve a topic + rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + // Everyone but phil should be killed + <-anonCh + + // Publish a message + rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + // Kill user Go routine + s.topics["mytopic"].CancelSubscribers("") + <-userCh +} diff --git a/server/topic.go b/server/topic.go index 3bc74736..95a93c29 100644 --- a/server/topic.go +++ b/server/topic.go @@ -10,10 +10,16 @@ import ( // can publish a message type topic struct { ID string - subscribers map[int]subscriber + subscribers map[int]*topicSubscriber mu sync.Mutex } +type topicSubscriber struct { + userID string // User ID associated with this subscription, may be empty + subscriber subscriber + cancel func() +} + // subscriber is a function that is called for every new message on a topic type subscriber func(v *visitor, msg *message) error @@ -21,16 +27,20 @@ type subscriber func(v *visitor, msg *message) error func newTopic(id string) *topic { return &topic{ ID: id, - subscribers: make(map[int]subscriber), + subscribers: make(map[int]*topicSubscriber), } } // Subscribe subscribes to this topic -func (t *topic) Subscribe(s subscriber) int { +func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int { t.mu.Lock() defer t.mu.Unlock() subscriberID := rand.Int() - t.subscribers[subscriberID] = s + t.subscribers[subscriberID] = &topicSubscriber{ + userID: userID, // May be empty + subscriber: s, + cancel: cancel, + } return subscriberID } @@ -56,7 +66,7 @@ func (t *topic) Publish(v *visitor, m *message) error { if err := s(v, m); err != nil { log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) } - }(s) + }(s.subscriber) } } else { log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m)) @@ -72,13 +82,29 @@ func (t *topic) SubscribersCount() int { return len(t.subscribers) } -// subscribersCopy returns a shallow copy of the subscribers map -func (t *topic) subscribersCopy() map[int]subscriber { +// CancelSubscribers calls the cancel function for all subscribers, forcing +func (t *topic) CancelSubscribers(exceptUserID string) { t.mu.Lock() defer t.mu.Unlock() - subscribers := make(map[int]subscriber) - for k, v := range t.subscribers { - subscribers[k] = v + for _, s := range t.subscribers { + if s.userID != exceptUserID { + log.Trace("Canceling subscriber %s", s.userID) + s.cancel() + } + } +} + +// subscribersCopy returns a shallow copy of the subscribers map +func (t *topic) subscribersCopy() map[int]*topicSubscriber { + t.mu.Lock() + defer t.mu.Unlock() + subscribers := make(map[int]*topicSubscriber) + for k, sub := range t.subscribers { + subscribers[k] = &topicSubscriber{ + userID: sub.userID, + subscriber: sub.subscriber, + cancel: sub.cancel, + } } return subscribers } diff --git a/server/visitor.go b/server/visitor.go index 77ed9460..16192404 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -228,12 +228,24 @@ func (v *visitor) ResetStats() { } } +// SetUser sets the visitors user to the given value func (v *visitor) SetUser(u *user.User) { v.mu.Lock() defer v.mu.Unlock() v.user = u } +// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor, +// an empty string is returned. +func (v *visitor) MaybeUserID() string { + v.mu.Lock() + defer v.mu.Unlock() + if v.user != nil { + return v.user.ID + } + return "" +} + func (v *visitor) Limits() *visitorLimits { v.mu.Lock() defer v.mu.Unlock()