limitRequestsWithTopic
This commit is contained in:
		
							parent
							
								
									0e4044b747
								
							
						
					
					
						commit
						ce7d447f16
					
				
					 2 changed files with 37 additions and 15 deletions
				
			
		|  | @ -437,13 +437,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit | ||||||
| 	} else if r.Method == http.MethodOptions { | 	} else if r.Method == http.MethodOptions { | ||||||
| 		return s.limitRequests(s.handleOptions)(w, r, v) // Should work even if the web app is not enabled, see #598 | 		return s.limitRequests(s.handleOptions)(w, r, v) // Should work even if the web app is not enabled, see #598 | ||||||
| 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" { | 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" { | ||||||
| 		return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v) | 		// So I don't *really* have to switch this order, since this is unrelated to UP; But, making this and matrix inconsistent is just calling for confusion, no? | ||||||
|  | 		return s.transformBodyJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish)))(w, r, v) | ||||||
| 	} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath { | 	} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath { | ||||||
| 		return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v) | 		return s.transformMatrixJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v) | ||||||
| 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { | 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { | ||||||
| 		return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v) | 		return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v) | ||||||
| 	} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { | 	} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { | ||||||
| 		return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v) | 		return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v) | ||||||
| 	} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { | 	} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { | ||||||
| 		return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v) | 		return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v) | ||||||
| 	} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { | 	} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { | ||||||
|  | @ -602,21 +603,18 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { | func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { | ||||||
| 	t, err := s.topicFromPath(r.URL.Path) | 	vRate, ok := r.Context().Value("vRate").(*visitor) | ||||||
| 	if err != nil { | 	if !ok { | ||||||
| 		return nil, err | 		return nil, errHTTPInternalError | ||||||
| 	} | 	} | ||||||
| 	vRate := v | 	t, ok := r.Context().Value("topic").(*topic) | ||||||
| 	if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { | 	if !ok { | ||||||
| 		vRate = topicCountsAgainst | 		return nil, errHTTPInternalError | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !vRate.MessageAllowed() { | 	if !vRate.MessageAllowed() { | ||||||
| 		vRate = v |  | ||||||
| 		if !v.MessageAllowed() { |  | ||||||
| 		return nil, errHTTPTooManyRequestsLimitMessages | 		return nil, errHTTPTooManyRequestsLimitMessages | ||||||
| 	} | 	} | ||||||
| 	} |  | ||||||
| 	body, err := util.Peek(r.Body, s.config.MessageLimit) | 	body, err := util.Peek(r.Body, s.config.MessageLimit) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  |  | ||||||
|  | @ -1,8 +1,10 @@ | ||||||
| package server | package server | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"heckel.io/ntfy/util" | 	"context" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"heckel.io/ntfy/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (s *Server) limitRequests(next handleFunc) handleFunc { | func (s *Server) limitRequests(next handleFunc) handleFunc { | ||||||
|  | @ -16,6 +18,28 @@ func (s *Server) limitRequests(next handleFunc) handleFunc { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context | ||||||
|  | func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc { | ||||||
|  | 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
|  | 		t, err := s.topicFromPath(r.URL.Path) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		vRate := v | ||||||
|  | 		if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { | ||||||
|  | 			vRate = topicCountsAgainst | ||||||
|  | 		} | ||||||
|  | 		r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t)) | ||||||
|  | 
 | ||||||
|  | 		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { | ||||||
|  | 			return next(w, r, v) | ||||||
|  | 		} else if !vRate.RequestAllowed() { | ||||||
|  | 			return errHTTPTooManyRequestsLimitRequests | ||||||
|  | 		} | ||||||
|  | 		return next(w, r, v) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (s *Server) ensureWebEnabled(next handleFunc) handleFunc { | func (s *Server) ensureWebEnabled(next handleFunc) handleFunc { | ||||||
| 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error { | 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 		if !s.config.EnableWeb { | 		if !s.config.EnableWeb { | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue