Merge branch 'main' into up
This commit is contained in:
		
						commit
						41514cd557
					
				
					 3 changed files with 69 additions and 39 deletions
				
			
		|  | @ -43,12 +43,19 @@ type Server struct { | |||
| 
 | ||||
| // errHTTP is a generic HTTP error for any non-200 HTTP error | ||||
| type errHTTP struct { | ||||
| 	Code   int | ||||
| 	Status string | ||||
| 	Code     int    `json:"code,omitempty"` | ||||
| 	HTTPCode int    `json:"http"` | ||||
| 	Message  string `json:"error"` | ||||
| 	Link     string `json:"link,omitempty"` | ||||
| } | ||||
| 
 | ||||
| func (e errHTTP) Error() string { | ||||
| 	return fmt.Sprintf("http: %s", e.Status) | ||||
| 	return e.Message | ||||
| } | ||||
| 
 | ||||
| func (e errHTTP) JSON() string { | ||||
| 	b, _ := json.Marshal(&e) | ||||
| 	return string(b) | ||||
| } | ||||
| 
 | ||||
| type indexPage struct { | ||||
|  | @ -105,9 +112,22 @@ var ( | |||
| 	docsStaticFs     embed.FS | ||||
| 	docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs} | ||||
| 
 | ||||
| 	errHTTPBadRequest      = &errHTTP{http.StatusBadRequest, http.StatusText(http.StatusBadRequest)} | ||||
| 	errHTTPNotFound        = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)} | ||||
| 	errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)} | ||||
| 	errHTTPNotFound                          = &errHTTP{40401, http.StatusNotFound, "page not found", ""} | ||||
| 	errHTTPTooManyRequestsLimitRequests      = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} | ||||
| 	errHTTPTooManyRequestsLimitEmails        = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} | ||||
| 	errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} | ||||
| 	errHTTPTooManyRequestsLimitGlobalTopics  = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} | ||||
| 	errHTTPBadRequestEmailDisabled           = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"} | ||||
| 	errHTTPBadRequestDelayNoCache            = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""} | ||||
| 	errHTTPBadRequestDelayNoEmail            = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""} | ||||
| 	errHTTPBadRequestDelayCannotParse        = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"} | ||||
| 	errHTTPBadRequestDelayTooSmall           = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} | ||||
| 	errHTTPBadRequestDelayTooLarge           = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} | ||||
| 	errHTTPBadRequestPriorityInvalid         = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"} | ||||
| 	errHTTPBadRequestSinceInvalid            = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"} | ||||
| 	errHTTPBadRequestTopicInvalid            = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""} | ||||
| 	errHTTPBadRequestTopicDisallowed         = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""} | ||||
| 	errHTTPInternalError                     = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -241,11 +261,16 @@ func (s *Server) Stop() { | |||
| 
 | ||||
| func (s *Server) handle(w http.ResponseWriter, r *http.Request) { | ||||
| 	if err := s.handleInternal(w, r); err != nil { | ||||
| 		if e, ok := err.(*errHTTP); ok { | ||||
| 			s.fail(w, r, e.Code, e) | ||||
| 		} else { | ||||
| 			s.fail(w, r, http.StatusInternalServerError, err) | ||||
| 		var e *errHTTP | ||||
| 		var ok bool | ||||
| 		if e, ok = err.(*errHTTP); !ok { | ||||
| 			e = errHTTPInternalError | ||||
| 		} | ||||
| 		log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, e.HTTPCode, err.Error()) | ||||
| 		w.Header().Set("Content-Type", "application/json") | ||||
| 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 		w.WriteHeader(e.HTTPCode) | ||||
| 		io.WriteString(w, e.JSON()+"\n") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -315,7 +340,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito | |||
| 		return err | ||||
| 	} | ||||
| 	m := newDefaultMessage(t.ID, strings.TrimSpace(string(b))) | ||||
| 	cache, firebase, email, unifiedpush, err := s.parseParams(r, m) | ||||
| 	cache, firebase, email, err := s.parseParams(r, m) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -329,13 +354,13 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito | |||
| 
 | ||||
| 	if email != "" { | ||||
| 		if err := v.EmailAllowed(); err != nil { | ||||
| 			return err | ||||
| 			return errHTTPTooManyRequestsLimitEmails | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	m.UnifiedPush = unifiedpush | ||||
| 	if s.mailer == nil && email != "" { | ||||
| 		return errHTTPBadRequest | ||||
| 		return errHTTPBadRequestEmailDisabled | ||||
| 	} | ||||
| 	if m.Message == "" { | ||||
| 		m.Message = emptyMessageBody | ||||
|  | @ -376,11 +401,10 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { | ||||
| func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) { | ||||
| 	cache = readParam(r, "x-cache", "cache") != "no" | ||||
| 	firebase = readParam(r, "x-firebase", "firebase") != "no" | ||||
| 	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") | ||||
| 	unifiedpush = readParam(r, "up", "unifiedpush") == "1" | ||||
| 	m.Title = readParam(r, "x-title", "title", "t") | ||||
| 	messageStr := readParam(r, "x-message", "message", "m") | ||||
| 	if messageStr != "" { | ||||
|  | @ -388,7 +412,7 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase | |||
| 	} | ||||
| 	m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p")) | ||||
| 	if err != nil { | ||||
| 		return false, false, "", false, errHTTPBadRequest | ||||
| 		return false, false, "", errHTTPBadRequestPriorityInvalid | ||||
| 	} | ||||
| 	tagsStr := readParam(r, "x-tags", "tags", "tag", "ta") | ||||
| 	if tagsStr != "" { | ||||
|  | @ -400,22 +424,22 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase | |||
| 	delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") | ||||
| 	if delayStr != "" { | ||||
| 		if !cache { | ||||
| 			return false, false, "", false, errHTTPBadRequest | ||||
| 			return false, false, "", errHTTPBadRequestDelayNoCache | ||||
| 		} | ||||
| 		if email != "" { | ||||
| 			return false, false, "", false, errHTTPBadRequest // we cannot store the email address (yet) | ||||
| 			return false, false, "", errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet) | ||||
| 		} | ||||
| 		delay, err := util.ParseFutureTime(delayStr, time.Now()) | ||||
| 		if err != nil { | ||||
| 			return false, false, "", false, errHTTPBadRequest | ||||
| 			return false, false, "", errHTTPBadRequestDelayCannotParse | ||||
| 		} else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() { | ||||
| 			return false, false, "", false, errHTTPBadRequest | ||||
| 			return false, false, "", errHTTPBadRequestDelayTooSmall | ||||
| 		} else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() { | ||||
| 			return false, false, "", false, errHTTPBadRequest | ||||
| 			return false, false, "", errHTTPBadRequestDelayTooLarge | ||||
| 		} | ||||
| 		m.Time = delay.Unix() | ||||
| 	} | ||||
| 	return cache, firebase, email, unifiedpush, nil | ||||
| 	return cache, firebase, email, nil | ||||
| } | ||||
| 
 | ||||
| func readParam(r *http.Request, names ...string) string { | ||||
|  | @ -470,8 +494,8 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v | |||
| } | ||||
| 
 | ||||
| func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error { | ||||
| 	if err := v.AddSubscription(); err != nil { | ||||
| 		return errHTTPTooManyRequests | ||||
| 	if err := v.SubscriptionAllowed(); err != nil { | ||||
| 		return errHTTPTooManyRequestsLimitSubscriptions | ||||
| 	} | ||||
| 	defer v.RemoveSubscription() | ||||
| 	topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack | ||||
|  | @ -617,7 +641,7 @@ func parseSince(r *http.Request, poll bool) (sinceTime, error) { | |||
| 	} else if d, err := time.ParseDuration(since); err == nil { | ||||
| 		return sinceTime(time.Now().Add(-1 * d)), nil | ||||
| 	} | ||||
| 	return sinceNoMessages, errHTTPBadRequest | ||||
| 	return sinceNoMessages, errHTTPBadRequestSinceInvalid | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error { | ||||
|  | @ -629,7 +653,7 @@ func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error { | |||
| func (s *Server) topicFromPath(path string) (*topic, error) { | ||||
| 	parts := strings.Split(path, "/") | ||||
| 	if len(parts) < 2 { | ||||
| 		return nil, errHTTPBadRequest | ||||
| 		return nil, errHTTPBadRequestTopicInvalid | ||||
| 	} | ||||
| 	topics, err := s.topicsFromIDs(parts[1]) | ||||
| 	if err != nil { | ||||
|  | @ -644,11 +668,11 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { | |||
| 	topics := make([]*topic, 0) | ||||
| 	for _, id := range ids { | ||||
| 		if util.InStringList(disallowedTopics, id) { | ||||
| 			return nil, errHTTPBadRequest | ||||
| 			return nil, errHTTPBadRequestTopicDisallowed | ||||
| 		} | ||||
| 		if _, ok := s.topics[id]; !ok { | ||||
| 			if len(s.topics) >= s.config.GlobalTopicLimit { | ||||
| 				return nil, errHTTPTooManyRequests | ||||
| 				return nil, errHTTPTooManyRequestsLimitGlobalTopics | ||||
| 			} | ||||
| 			s.topics[id] = newTopic(id) | ||||
| 		} | ||||
|  | @ -766,7 +790,7 @@ func (s *Server) sendDelayedMessages() error { | |||
| func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error { | ||||
| 	v := s.visitor(r) | ||||
| 	if err := v.RequestAllowed(); err != nil { | ||||
| 		return err | ||||
| 		return errHTTPTooManyRequestsLimitRequests | ||||
| 	} | ||||
| 	return handler(w, r, v) | ||||
| } | ||||
|  | @ -798,9 +822,3 @@ func (s *Server) inc(counter *int64) { | |||
| 	defer s.mu.Unlock() | ||||
| 	*counter++ | ||||
| } | ||||
| 
 | ||||
| func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) { | ||||
| 	log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error()) | ||||
| 	w.WriteHeader(code) | ||||
| 	_, _ = io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code))) | ||||
| } | ||||
|  |  | |||
|  | @ -252,6 +252,7 @@ func TestServer_PublishAtWithCacheError(t *testing.T) { | |||
| 		"In":    "30 min", | ||||
| 	}) | ||||
| 	require.Equal(t, 400, response.Code) | ||||
| 	require.Equal(t, errHTTPBadRequestDelayNoCache, toHTTPError(t, response.Body.String())) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishAtTooShortDelay(t *testing.T) { | ||||
|  | @ -644,6 +645,12 @@ func toMessage(t *testing.T, s string) *message { | |||
| 	return &m | ||||
| } | ||||
| 
 | ||||
| func toHTTPError(t *testing.T, s string) *errHTTP { | ||||
| 	var e errHTTP | ||||
| 	require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&e)) | ||||
| 	return &e | ||||
| } | ||||
| 
 | ||||
| func firebaseServiceAccountFile(t *testing.T) string { | ||||
| 	if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" { | ||||
| 		return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"golang.org/x/time/rate" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"sync" | ||||
|  | @ -14,6 +15,10 @@ const ( | |||
| 	visitorExpungeAfter = 24 * time.Hour | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	errVisitorLimitReached = errors.New("limit reached") | ||||
| ) | ||||
| 
 | ||||
| // visitor represents an API user, and its associated rate.Limiter used for rate limiting | ||||
| type visitor struct { | ||||
| 	config        *Config | ||||
|  | @ -42,23 +47,23 @@ func (v *visitor) IP() string { | |||
| 
 | ||||
| func (v *visitor) RequestAllowed() error { | ||||
| 	if !v.requests.Allow() { | ||||
| 		return errHTTPTooManyRequests | ||||
| 		return errVisitorLimitReached | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (v *visitor) EmailAllowed() error { | ||||
| 	if !v.emails.Allow() { | ||||
| 		return errHTTPTooManyRequests | ||||
| 		return errVisitorLimitReached | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (v *visitor) AddSubscription() error { | ||||
| func (v *visitor) SubscriptionAllowed() error { | ||||
| 	v.mu.Lock() | ||||
| 	defer v.mu.Unlock() | ||||
| 	if err := v.subscriptions.Add(1); err != nil { | ||||
| 		return errHTTPTooManyRequests | ||||
| 		return errVisitorLimitReached | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue