Kill existing subscribers when topic is reserved
This commit is contained in:
		
							parent
							
								
									e82a2e518c
								
							
						
					
					
						commit
						bce71cb196
					
				
					 5 changed files with 169 additions and 36 deletions
				
			
		|  | @ -38,11 +38,13 @@ import ( | |||
| TODO | ||||
| -- | ||||
| 
 | ||||
| - Reservation: Kill existing subscribers when topic is reserved (deadcade) | ||||
| - Rate limiting: Sensitive endpoints (account/login/change-password/...) | ||||
| - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) | ||||
| - Reservation (UI): Ask for confirmation when removing reservation (deadcade) | ||||
| - Reservation icons (UI) | ||||
| - reservation table delete button: dialog "keep or delete messages?" | ||||
| - UI: Flickering upgrade banner when logging in | ||||
| - JS constants | ||||
| 
 | ||||
| races: | ||||
| - v.user --> see publishSyncEventAsync() test | ||||
|  | @ -63,11 +65,6 @@ Limits & rate limiting: | |||
| Make sure account endpoints make sense for admins | ||||
| 
 | ||||
| 
 | ||||
| UI: | ||||
| - | ||||
| - reservation table delete button: dialog "keep or delete messages?" | ||||
| - flicker of upgrade banner | ||||
| - JS constants | ||||
| Sync: | ||||
| 	- sync problems with "deleteAfter=0" and "displayName=" | ||||
| 
 | ||||
|  | @ -359,7 +356,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { | |||
| 			log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error()) | ||||
| 		} | ||||
| 		w.Header().Set("Content-Type", "application/json") | ||||
| 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 		w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||
| 		w.WriteHeader(httpErr.HTTPCode) | ||||
| 		io.WriteString(w, httpErr.JSON()+"\n") | ||||
| 	} | ||||
|  | @ -461,7 +458,7 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) | |||
| 	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too! | ||||
| 	if unifiedpush { | ||||
| 		w.Header().Set("Content-Type", "application/json") | ||||
| 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 		w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||
| 		_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n") | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -538,7 +535,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) | |||
| 		} | ||||
| 	} | ||||
| 	w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size())) | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||
| 	if r.Method == http.MethodGet { | ||||
| 		f, err := os.Open(file) | ||||
| 		if err != nil { | ||||
|  | @ -969,14 +966,16 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * | |||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", "*")            // CORS, allow cross-origin requests | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||
| 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset! | ||||
| 	if poll { | ||||
| 		return s.sendOldMessages(topics, since, scheduled, v, sub) | ||||
| 	} | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
| 	subscriberIDs := make([]int, 0) | ||||
| 	for _, t := range topics { | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub)) | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		for i, subscriberID := range subscriberIDs { | ||||
|  | @ -991,6 +990,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * | |||
| 	} | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return nil | ||||
| 		case <-r.Context().Done(): | ||||
| 			return nil | ||||
| 		case <-time.After(s.config.KeepaliveInterval): | ||||
|  | @ -1033,8 +1034,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 		return err | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 
 | ||||
| 	// Subscription connections can be canceled externally, see topic.CancelSubscribers | ||||
| 	subscriberContext, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
| 
 | ||||
| 	// Use errgroup to run WebSocket reader and writer in Go routines | ||||
| 	var wlock sync.Mutex | ||||
| 	g, ctx := errgroup.WithContext(context.Background()) | ||||
| 	g, gctx := errgroup.WithContext(context.Background()) | ||||
| 	g.Go(func() error { | ||||
| 		<-subscriberContext.Done() | ||||
| 		log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r)) | ||||
| 		conn.Close() | ||||
| 		return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"} | ||||
| 	}) | ||||
| 	g.Go(func() error { | ||||
| 		pongWait := s.config.KeepaliveInterval + wsPongWait | ||||
| 		conn.SetReadLimit(wsReadLimit) | ||||
|  | @ -1050,6 +1063,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			select { | ||||
| 			case <-gctx.Done(): | ||||
| 				return nil | ||||
| 			default: | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
| 	g.Go(func() error { | ||||
|  | @ -1064,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 		} | ||||
| 		for { | ||||
| 			select { | ||||
| 			case <-ctx.Done(): | ||||
| 			case <-gctx.Done(): | ||||
| 				return nil | ||||
| 			case <-time.After(s.config.KeepaliveInterval): | ||||
| 				v.Keepalive() | ||||
|  | @ -1085,13 +1103,13 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 		} | ||||
| 		return conn.WriteJSON(msg) | ||||
| 	} | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||
| 	if poll { | ||||
| 		return s.sendOldMessages(topics, since, scheduled, v, sub) | ||||
| 	} | ||||
| 	subscriberIDs := make([]int, 0) | ||||
| 	for _, t := range topics { | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub)) | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		for i, subscriberID := range subscriberIDs { | ||||
|  | @ -1193,11 +1211,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) { | |||
| 	if len(parts) < 2 { | ||||
| 		return nil, errHTTPBadRequestTopicInvalid | ||||
| 	} | ||||
| 	topics, err := s.topicsFromIDs(parts[1]) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return topics[0], nil | ||||
| 	return s.topicFromID(parts[1]) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) topicsFromPath(path string) ([]*topic, string, error) { | ||||
|  | @ -1232,6 +1246,14 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { | |||
| 	return topics, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) topicFromID(id string) (*topic, error) { | ||||
| 	topics, err := s.topicsFromIDs(id) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return topics[0], nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) execManager() { | ||||
| 	log.Debug("Manager: Starting") | ||||
| 	defer log.Debug("Manager: Finished") | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"heckel.io/ntfy/user" | ||||
| 	"heckel.io/ntfy/util" | ||||
|  | @ -331,6 +330,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ | |||
| 	if v.user.Tier == nil { | ||||
| 		return errHTTPUnauthorized | ||||
| 	} | ||||
| 	// CHeck if we are allowed to reserve this topic | ||||
| 	if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil { | ||||
| 		return errHTTPConflictTopicReserved | ||||
| 	} | ||||
|  | @ -346,9 +346,16 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ | |||
| 			return errHTTPTooManyRequestsLimitReservations | ||||
| 		} | ||||
| 	} | ||||
| 	// Actually add the reservation | ||||
| 	if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// Kill existing subscribers | ||||
| 	t, err := s.topicFromID(req.Topic) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	t.CancelSubscribers(v.user.ID) | ||||
| 	return s.writeJSON(w, newSuccessResponse()) | ||||
| } | ||||
| 
 | ||||
|  | @ -402,13 +409,10 @@ func (s *Server) publishSyncEvent(v *visitor) error { | |||
| 		return nil | ||||
| 	} | ||||
| 	log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic) | ||||
| 	topics, err := s.topicsFromIDs(v.user.SyncTopic) | ||||
| 	syncTopic, err := s.topicFromID(v.user.SyncTopic) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if len(topics) == 0 { | ||||
| 		return errors.New("cannot retrieve sync topic") | ||||
| 	} | ||||
| 	syncTopic := topics[0] | ||||
| 	messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent}) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
|  |  | |||
|  | @ -496,3 +496,72 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) { | |||
| 	rr = request(t, s, "POST", "/mytopic", `Howdy`, nil) | ||||
| 	require.Equal(t, 403, rr.Code) | ||||
| } | ||||
| 
 | ||||
| func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { | ||||
| 	conf := newTestConfigWithAuthFile(t) | ||||
| 	conf.AuthDefault = user.PermissionReadWrite | ||||
| 	conf.EnableSignup = true | ||||
| 	s := newTestServer(t, conf) | ||||
| 
 | ||||
| 	// Create user with tier | ||||
| 	rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 
 | ||||
| 	require.Nil(t, s.userManager.CreateTier(&user.Tier{ | ||||
| 		Code:              "pro", | ||||
| 		MessagesLimit:     20, | ||||
| 		ReservationsLimit: 2, | ||||
| 	})) | ||||
| 	require.Nil(t, s.userManager.ChangeTier("phil", "pro")) | ||||
| 
 | ||||
| 	// Subscribe anonymously | ||||
| 	anonCh, userCh := make(chan bool), make(chan bool) | ||||
| 	go func() { | ||||
| 		rr := request(t, s, "GET", "/mytopic/json", ``, nil) | ||||
| 		require.Equal(t, 200, rr.Code) | ||||
| 		messages := toMessages(t, rr.Body.String()) | ||||
| 		require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message! | ||||
| 		require.Equal(t, "open", messages[0].Event) | ||||
| 		require.Equal(t, "message before reservation", messages[1].Message) | ||||
| 		anonCh <- true | ||||
| 	}() | ||||
| 
 | ||||
| 	// Subscribe with user | ||||
| 	go func() { | ||||
| 		rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ | ||||
| 			"Authorization": util.BasicAuth("phil", "mypass"), | ||||
| 		}) | ||||
| 		require.Equal(t, 200, rr.Code) | ||||
| 		messages := toMessages(t, rr.Body.String()) | ||||
| 		require.Equal(t, 3, len(messages)) | ||||
| 		require.Equal(t, "open", messages[0].Event) | ||||
| 		require.Equal(t, "message before reservation", messages[1].Message) | ||||
| 		require.Equal(t, "message after reservation", messages[2].Message) | ||||
| 		userCh <- true | ||||
| 	}() | ||||
| 
 | ||||
| 	// Publish message (before reservation) | ||||
| 	time.Sleep(700 * time.Millisecond) // Wait for subscribers | ||||
| 	rr = request(t, s, "POST", "/mytopic", "message before reservation", nil) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 	time.Sleep(700 * time.Millisecond) // Wait for subscribers to receive message | ||||
| 
 | ||||
| 	// Reserve a topic | ||||
| 	rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{ | ||||
| 		"Authorization": util.BasicAuth("phil", "mypass"), | ||||
| 	}) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 
 | ||||
| 	// Everyone but phil should be killed | ||||
| 	<-anonCh | ||||
| 
 | ||||
| 	// Publish a message | ||||
| 	rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{ | ||||
| 		"Authorization": util.BasicAuth("phil", "mypass"), | ||||
| 	}) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 
 | ||||
| 	// Kill user Go routine | ||||
| 	s.topics["mytopic"].CancelSubscribers("<invalid>") | ||||
| 	<-userCh | ||||
| } | ||||
|  |  | |||
|  | @ -10,10 +10,16 @@ import ( | |||
| // can publish a message | ||||
| type topic struct { | ||||
| 	ID          string | ||||
| 	subscribers map[int]subscriber | ||||
| 	subscribers map[int]*topicSubscriber | ||||
| 	mu          sync.Mutex | ||||
| } | ||||
| 
 | ||||
| type topicSubscriber struct { | ||||
| 	userID     string // User ID associated with this subscription, may be empty | ||||
| 	subscriber subscriber | ||||
| 	cancel     func() | ||||
| } | ||||
| 
 | ||||
| // subscriber is a function that is called for every new message on a topic | ||||
| type subscriber func(v *visitor, msg *message) error | ||||
| 
 | ||||
|  | @ -21,16 +27,20 @@ type subscriber func(v *visitor, msg *message) error | |||
| func newTopic(id string) *topic { | ||||
| 	return &topic{ | ||||
| 		ID:          id, | ||||
| 		subscribers: make(map[int]subscriber), | ||||
| 		subscribers: make(map[int]*topicSubscriber), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Subscribe subscribes to this topic | ||||
| func (t *topic) Subscribe(s subscriber) int { | ||||
| func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	subscriberID := rand.Int() | ||||
| 	t.subscribers[subscriberID] = s | ||||
| 	t.subscribers[subscriberID] = &topicSubscriber{ | ||||
| 		userID:     userID, // May be empty | ||||
| 		subscriber: s, | ||||
| 		cancel:     cancel, | ||||
| 	} | ||||
| 	return subscriberID | ||||
| } | ||||
| 
 | ||||
|  | @ -56,7 +66,7 @@ func (t *topic) Publish(v *visitor, m *message) error { | |||
| 					if err := s(v, m); err != nil { | ||||
| 						log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) | ||||
| 					} | ||||
| 				}(s) | ||||
| 				}(s.subscriber) | ||||
| 			} | ||||
| 		} else { | ||||
| 			log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m)) | ||||
|  | @ -72,13 +82,29 @@ func (t *topic) SubscribersCount() int { | |||
| 	return len(t.subscribers) | ||||
| } | ||||
| 
 | ||||
| // subscribersCopy returns a shallow copy of the subscribers map | ||||
| func (t *topic) subscribersCopy() map[int]subscriber { | ||||
| // CancelSubscribers calls the cancel function for all subscribers, forcing | ||||
| func (t *topic) CancelSubscribers(exceptUserID string) { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	subscribers := make(map[int]subscriber) | ||||
| 	for k, v := range t.subscribers { | ||||
| 		subscribers[k] = v | ||||
| 	for _, s := range t.subscribers { | ||||
| 		if s.userID != exceptUserID { | ||||
| 			log.Trace("Canceling subscriber %s", s.userID) | ||||
| 			s.cancel() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // subscribersCopy returns a shallow copy of the subscribers map | ||||
| func (t *topic) subscribersCopy() map[int]*topicSubscriber { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	subscribers := make(map[int]*topicSubscriber) | ||||
| 	for k, sub := range t.subscribers { | ||||
| 		subscribers[k] = &topicSubscriber{ | ||||
| 			userID:     sub.userID, | ||||
| 			subscriber: sub.subscriber, | ||||
| 			cancel:     sub.cancel, | ||||
| 		} | ||||
| 	} | ||||
| 	return subscribers | ||||
| } | ||||
|  |  | |||
|  | @ -228,12 +228,24 @@ func (v *visitor) ResetStats() { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SetUser sets the visitors user to the given value | ||||
| func (v *visitor) SetUser(u *user.User) { | ||||
| 	v.mu.Lock() | ||||
| 	defer v.mu.Unlock() | ||||
| 	v.user = u | ||||
| } | ||||
| 
 | ||||
| // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor, | ||||
| // an empty string is returned. | ||||
| func (v *visitor) MaybeUserID() string { | ||||
| 	v.mu.Lock() | ||||
| 	defer v.mu.Unlock() | ||||
| 	if v.user != nil { | ||||
| 		return v.user.ID | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (v *visitor) Limits() *visitorLimits { | ||||
| 	v.mu.Lock() | ||||
| 	defer v.mu.Unlock() | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue