Polishing
This commit is contained in:
		
							parent
							
								
									8eae44ea61
								
							
						
					
					
						commit
						2329695a47
					
				
					 5 changed files with 88 additions and 44 deletions
				
			
		|  | @ -570,14 +570,8 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { | |||
| } | ||||
| 
 | ||||
| func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { | ||||
| 	vrate, ok := r.Context().Value(contextRateVisitor).(*visitor) | ||||
| 	if !ok { | ||||
| 		return nil, errHTTPInternalError | ||||
| 	} | ||||
| 	t, ok := r.Context().Value(contextTopic).(*topic) | ||||
| 	if !ok { | ||||
| 		return nil, errHTTPInternalError | ||||
| 	} | ||||
| 	t := fromContext[topic](r, contextTopic) | ||||
| 	vrate := fromContext[visitor](r, contextRateVisitor) | ||||
| 	if !vrate.MessageAllowed() { | ||||
| 		return nil, errHTTPTooManyRequestsLimitMessages | ||||
| 	} | ||||
|  | @ -586,10 +580,13 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes | |||
| 		return nil, err | ||||
| 	} | ||||
| 	m := newDefaultMessage(t.ID, "") | ||||
| 	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vrate, m) | ||||
| 	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, m) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if email != "" && !vrate.EmailAllowed() { | ||||
| 		return nil, errHTTPTooManyRequestsLimitEmails | ||||
| 	} | ||||
| 	if m.PollID != "" { | ||||
| 		m = newPollRequestMessage(t.ID, m.PollID) | ||||
| 	} | ||||
|  | @ -605,13 +602,15 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes | |||
| 		m.Message = emptyMessageBody | ||||
| 	} | ||||
| 	delayed := m.Time > time.Now().Unix() | ||||
| 	ev := logvrm(vrate, r, m). | ||||
| 	ev := logvrm(v, r, m). | ||||
| 		Tag(tagPublish). | ||||
| 		Fields(log.Context{ | ||||
| 			"message_delayed":     delayed, | ||||
| 			"message_firebase":    firebase, | ||||
| 			"message_unifiedpush": unifiedpush, | ||||
| 			"message_email":       email, | ||||
| 			"rate_visitor_ip":     vrate.IP().String(), | ||||
| 			"rate_user_id":        vrate.MaybeUserID(), | ||||
| 		}) | ||||
| 	if ev.IsTrace() { | ||||
| 		ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message") | ||||
|  | @ -623,7 +622,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes | |||
| 			return nil, err | ||||
| 		} | ||||
| 		if s.firebaseClient != nil && firebase { | ||||
| 			go s.sendToFirebase(vrate, m) | ||||
| 			go s.sendToFirebase(v, m) | ||||
| 		} | ||||
| 		if s.smtpSender != nil && email != "" { | ||||
| 			go s.sendEmail(v, m, email) | ||||
|  | @ -708,7 +707,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { | ||||
| func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { | ||||
| 	cache = readBoolParam(r, true, "x-cache", "cache") | ||||
| 	firebase = readBoolParam(r, true, "x-firebase", "firebase") | ||||
| 	m.Title = readParam(r, "x-title", "title", "t") | ||||
|  | @ -747,11 +746,6 @@ func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) | |||
| 		m.Icon = icon | ||||
| 	} | ||||
| 	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") | ||||
| 	if email != "" { | ||||
| 		if !vrate.EmailAllowed() { | ||||
| 			return false, false, "", false, errHTTPTooManyRequestsLimitEmails | ||||
| 		} | ||||
| 	} | ||||
| 	if s.smtpSender == nil && email != "" { | ||||
| 		return false, false, "", false, errHTTPBadRequestEmailDisabled | ||||
| 	} | ||||
|  | @ -993,7 +987,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * | |||
| 	defer cancel() | ||||
| 	subscriberIDs := make([]int, 0) | ||||
| 	for _, t := range topics { | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		for i, subscriberID := range subscriberIDs { | ||||
|  | @ -1126,7 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 	} | ||||
| 	subscriberIDs := make([]int, 0) | ||||
| 	for _, t := range topics { | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		for i, subscriberID := range subscriberIDs { | ||||
|  |  | |||
|  | @ -3,7 +3,6 @@ package server | |||
| import ( | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func (s *Server) execManager() { | ||||
|  | @ -38,16 +37,23 @@ func (s *Server) execManager() { | |||
| 				subs := t.SubscribersCount() | ||||
| 				ev := log.Tag(tagManager) | ||||
| 				if ev.IsTrace() { | ||||
| 					expiryMessage := "" | ||||
| 					if subs == 0 { | ||||
| 						expiryTime := time.Until(t.expires) | ||||
| 						expiryMessage = ", expires in " + expiryTime.String() | ||||
| 					vrate := t.RateVisitor() | ||||
| 					if vrate != nil { | ||||
| 						ev.Fields(log.Context{ | ||||
| 							"rate_visitor_ip":      vrate.IP(), | ||||
| 							"rate_visitor_user_id": vrate.MaybeUserID(), | ||||
| 						}) | ||||
| 					} | ||||
| 					ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) | ||||
| 					ev. | ||||
| 						Fields(log.Context{ | ||||
| 							"message_topic":             t.ID, | ||||
| 							"message_topic_subscribers": subs, | ||||
| 						}). | ||||
| 						Trace("- topic %s: %d subscribers", t.ID, subs) | ||||
| 				} | ||||
| 				msgs, exists := messageCounts[t.ID] | ||||
| 				if t.Stale() && (!exists || msgs == 0) { | ||||
| 					log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID) | ||||
| 					log.Tag(tagManager).Field("message_topic", t.ID).Trace("Deleting empty topic %s", t.ID) | ||||
| 					emptyTopics++ | ||||
| 					delete(s.topics, t.ID) | ||||
| 					continue | ||||
|  |  | |||
|  | @ -2030,7 +2030,40 @@ func TestServer_Matrix_SubscriberRateLimiting_UP_Only(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // FIXME add test for rate visitor expiration | ||||
| func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) { | ||||
| 	c := newTestConfig(t) | ||||
| 	c.VisitorRequestLimitBurst = 3 | ||||
| 	s := newTestServer(t, c) | ||||
| 
 | ||||
| 	// "Register" rate visitor | ||||
| 	subscriberFn := func(r *http.Request) { | ||||
| 		r.RemoteAddr = "1.2.3.4" | ||||
| 	} | ||||
| 	rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{ | ||||
| 		"rate-topics": "*", | ||||
| 	}, subscriberFn) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 	require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String()) | ||||
| 	require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor) | ||||
| 
 | ||||
| 	// Publish message, observe rate visitor tokens being decreased | ||||
| 	response := request(t, s, "POST", "/mytopic", "some message", nil) | ||||
| 	require.Equal(t, 200, response.Code) | ||||
| 	require.Equal(t, int64(0), s.visitors["ip:9.9.9.9"].messagesLimiter.Value()) | ||||
| 	require.Equal(t, int64(1), s.topics["mytopic"].rateVisitor.messagesLimiter.Value()) | ||||
| 	require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor) | ||||
| 
 | ||||
| 	// Expire visitor | ||||
| 	s.visitors["ip:1.2.3.4"].seen = time.Now().Add(-1 * 25 * time.Hour) | ||||
| 	s.pruneVisitors() | ||||
| 
 | ||||
| 	// Publish message again, observe that rateVisitor is not used anymore and is reset | ||||
| 	response = request(t, s, "POST", "/mytopic", "some message", nil) | ||||
| 	require.Equal(t, 200, response.Code) | ||||
| 	require.Equal(t, int64(1), s.visitors["ip:9.9.9.9"].messagesLimiter.Value()) | ||||
| 	require.Nil(t, s.topics["mytopic"].rateVisitor) | ||||
| 	require.Nil(t, s.visitors["ip:1.2.3.4"]) | ||||
| } | ||||
| 
 | ||||
| func newTestConfig(t *testing.T) *Config { | ||||
| 	conf := NewConfig() | ||||
|  |  | |||
|  | @ -4,11 +4,6 @@ import ( | |||
| 	"heckel.io/ntfy/log" | ||||
| 	"math/rand" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	topicExpiryDuration = 6 * time.Hour | ||||
| ) | ||||
| 
 | ||||
| // topic represents a channel to which subscribers can subscribe, and publishers | ||||
|  | @ -17,13 +12,12 @@ type topic struct { | |||
| 	ID          string | ||||
| 	subscribers map[int]*topicSubscriber | ||||
| 	rateVisitor *visitor | ||||
| 	expires     time.Time | ||||
| 	mu          sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| type topicSubscriber struct { | ||||
| 	userID     string // User ID associated with this subscription, may be empty | ||||
| 	subscriber subscriber | ||||
| 	visitor    *visitor // User ID associated with this subscription, may be empty | ||||
| 	cancel     func() | ||||
| } | ||||
| 
 | ||||
|  | @ -39,12 +33,12 @@ func newTopic(id string) *topic { | |||
| } | ||||
| 
 | ||||
| // Subscribe subscribes to this topic | ||||
| func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { | ||||
| func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	subscriberID := rand.Int() | ||||
| 	t.subscribers[subscriberID] = &topicSubscriber{ | ||||
| 		visitor:    visitor, // May be empty | ||||
| 		userID:     userID, // May be empty | ||||
| 		subscriber: s, | ||||
| 		cancel:     cancel, | ||||
| 	} | ||||
|  | @ -54,7 +48,10 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { | |||
| func (t *topic) Stale() bool { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	return len(t.subscribers) == 0 && t.expires.Before(time.Now()) | ||||
| 	if t.rateVisitor != nil && !t.rateVisitor.Stale() { | ||||
| 		return false | ||||
| 	} | ||||
| 	return len(t.subscribers) == 0 | ||||
| } | ||||
| 
 | ||||
| func (t *topic) SetRateVisitor(v *visitor) { | ||||
|  | @ -66,6 +63,9 @@ func (t *topic) SetRateVisitor(v *visitor) { | |||
| func (t *topic) RateVisitor() *visitor { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	if t.rateVisitor != nil && t.rateVisitor.Stale() { | ||||
| 		t.rateVisitor = nil | ||||
| 	} | ||||
| 	return t.rateVisitor | ||||
| } | ||||
| 
 | ||||
|  | @ -74,9 +74,6 @@ func (t *topic) Unsubscribe(id int) { | |||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	delete(t.subscribers, id) | ||||
| 	if len(t.subscribers) == 0 { | ||||
| 		t.expires = time.Now().Add(topicExpiryDuration) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Publish asynchronously publishes to all subscribers | ||||
|  | @ -115,9 +112,14 @@ func (t *topic) CancelSubscribers(exceptUserID string) { | |||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	for _, s := range t.subscribers { | ||||
| 		if s.visitor.MaybeUserID() != exceptUserID { | ||||
| 			// TODO: Shouldn't this log the IP for anonymous visitors? It was s.userID before my change. | ||||
| 			log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.visitor.MaybeUserID()) | ||||
| 		if s.userID != exceptUserID { | ||||
| 			log. | ||||
| 				Tag(tagSubscribe). | ||||
| 				Fields(log.Context{ | ||||
| 					"message_topic": t.ID, | ||||
| 					"user_id":       s.userID, | ||||
| 				}). | ||||
| 				Debug("Canceling subscriber %s", s.userID) | ||||
| 			s.cancel() | ||||
| 		} | ||||
| 	} | ||||
|  | @ -130,7 +132,7 @@ func (t *topic) subscribersCopy() map[int]*topicSubscriber { | |||
| 	subscribers := make(map[int]*topicSubscriber) | ||||
| 	for k, sub := range t.subscribers { | ||||
| 		subscribers[k] = &topicSubscriber{ | ||||
| 			visitor:    sub.visitor, | ||||
| 			userID:     sub.userID, | ||||
| 			subscriber: sub.subscriber, | ||||
| 			cancel:     sub.cancel, | ||||
| 		} | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | @ -105,3 +106,11 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request { | |||
| 	} | ||||
| 	return r.WithContext(c) | ||||
| } | ||||
| 
 | ||||
| func fromContext[T any](r *http.Request, key contextKey) *T { | ||||
| 	t, ok := r.Context().Value(key).(*T) | ||||
| 	if !ok { | ||||
| 		panic(fmt.Sprintf("cannot find key %v in request context", key)) | ||||
| 	} | ||||
| 	return t | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue