From e1a4a7490541cf360c1a4fc305370e5e4e7b94e7 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Wed, 8 Feb 2023 15:20:44 -0500 Subject: [PATCH] Auth rate limiter --- cmd/access_test.go | 2 + cmd/serve.go | 1 + cmd/tier_test.go | 2 + cmd/token_test.go | 2 + cmd/user_test.go | 6 +++ log/event.go | 6 ++- server/config.go | 8 ++++ server/errors.go | 1 + server/server.go | 83 +++++++++++++++++++--------------- server/server_account.go | 1 + server/server_firebase.go | 2 +- server/server_middleware.go | 12 +++++ server/server_payments_test.go | 8 ++-- server/server_test.go | 18 ++++++++ server/visitor.go | 59 +++++++++++++++++------- user/manager.go | 1 + 16 files changed, 152 insertions(+), 60 deletions(-) diff --git a/cmd/access_test.go b/cmd/access_test.go index 6582fab0..359beb92 100644 --- a/cmd/access_test.go +++ b/cmd/access_test.go @@ -79,7 +79,9 @@ user * (role: anonymous, tier: none) func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error { userArgs := []string{ "ntfy", + "--log-level=ERROR", "access", + "--config=" + conf.File, // Dummy config file to avoid lookups of real file "--auth-file=" + conf.AuthFile, "--auth-default-access=" + conf.AuthDefault.String(), } diff --git a/cmd/serve.go b/cmd/serve.go index f54b07e7..03240330 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -253,6 +253,7 @@ func execServe(c *cli.Context) error { // Run server conf := server.NewConfig() + conf.File = config conf.BaseURL = baseURL conf.ListenHTTP = listenHTTP conf.ListenHTTPS = listenHTTPS diff --git a/cmd/tier_test.go b/cmd/tier_test.go index f94f3ee9..062d7047 100644 --- a/cmd/tier_test.go +++ b/cmd/tier_test.go @@ -38,7 +38,9 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) { func runTierCommand(app *cli.App, conf *server.Config, args ...string) error { userArgs := []string{ "ntfy", + "--log-level=ERROR", "tier", + "--config=" + conf.File, // Dummy config file to avoid lookups of real file "--auth-file=" + conf.AuthFile, "--auth-default-access=" + conf.AuthDefault.String(), } diff --git a/cmd/token_test.go b/cmd/token_test.go index f4bdb448..40d7be7b 100644 --- a/cmd/token_test.go +++ b/cmd/token_test.go @@ -41,7 +41,9 @@ func TestCLI_Token_AddListRemove(t *testing.T) { func runTokenCommand(app *cli.App, conf *server.Config, args ...string) error { userArgs := []string{ "ntfy", + "--log-level=ERROR", "token", + "--config=" + conf.File, // Dummy config file to avoid lookups of real file "--auth-file=" + conf.AuthFile, } return app.Run(append(userArgs, args...)) diff --git a/cmd/user_test.go b/cmd/user_test.go index c8bcd6a6..1149285f 100644 --- a/cmd/user_test.go +++ b/cmd/user_test.go @@ -6,6 +6,7 @@ import ( "heckel.io/ntfy/server" "heckel.io/ntfy/test" "heckel.io/ntfy/user" + "os" "path/filepath" "testing" ) @@ -113,7 +114,10 @@ func TestCLI_User_Delete(t *testing.T) { } func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) { + configFile := filepath.Join(t.TempDir(), "server-dummy.yml") + require.Nil(t, os.WriteFile(configFile, []byte(""), 0600)) // Dummy config file to avoid lookup of real server.yml conf = server.NewConfig() + conf.File = configFile conf.AuthFile = filepath.Join(t.TempDir(), "user.db") conf.AuthDefault = user.PermissionDenyAll s, port = test.StartServerWithConfig(t, conf) @@ -123,7 +127,9 @@ func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, func runUserCommand(app *cli.App, conf *server.Config, args ...string) error { userArgs := []string{ "ntfy", + "--log-level=ERROR", "user", + "--config=" + conf.File, // Dummy config file to avoid lookups of real file "--auth-file=" + conf.AuthFile, "--auth-default-access=" + conf.AuthDefault.String(), } diff --git a/log/event.go b/log/event.go index 284c879e..a9bfa314 100644 --- a/log/event.go +++ b/log/event.go @@ -82,8 +82,10 @@ func (e *Event) Time(t time.Time) *Event { // Err adds an "error" field to the log event func (e *Event) Err(err error) *Event { - if c, ok := err.(Contexter); ok { - return e.Fields(c.Context()) + if err == nil { + return e + } else if c, ok := err.(Contexter); ok { + return e.With(c) } return e.Field(errorField, err.Error()) } diff --git a/server/config.go b/server/config.go index a1d7c829..d6817ad3 100644 --- a/server/config.go +++ b/server/config.go @@ -49,6 +49,8 @@ const ( DefaultVisitorEmailLimitReplenish = time.Hour DefaultVisitorAccountCreationLimitBurst = 3 DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour + DefaultVisitorAuthFailureLimitBurst = 10 + DefaultVisitorAuthFailureLimitReplenish = time.Minute DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB ) @@ -60,6 +62,7 @@ var ( // Config is the main config struct for the application. Use New to instantiate a default config struct. type Config struct { + File string // Config file, only used for testing BaseURL string ListenHTTP string ListenHTTPS string @@ -113,6 +116,8 @@ type Config struct { VisitorEmailLimitReplenish time.Duration VisitorAccountCreationLimitBurst int VisitorAccountCreationLimitReplenish time.Duration + VisitorAuthFailureLimitBurst int + VisitorAuthFailureLimitReplenish time.Duration VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats BehindProxy bool StripeSecretKey string @@ -129,6 +134,7 @@ type Config struct { // NewConfig instantiates a default new server config func NewConfig() *Config { return &Config{ + File: "", // Only used for testing BaseURL: "", ListenHTTP: DefaultListenHTTP, ListenHTTPS: "", @@ -182,6 +188,8 @@ func NewConfig() *Config { VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst, VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish, + VisitorAuthFailureLimitBurst: DefaultVisitorAuthFailureLimitBurst, + VisitorAuthFailureLimitReplenish: DefaultVisitorAuthFailureLimitReplenish, VisitorStatsResetTime: DefaultVisitorStatsResetTime, BehindProxy: false, StripeSecretKey: "", diff --git a/server/errors.go b/server/errors.go index 654ea15f..5fd6d625 100644 --- a/server/errors.go +++ b/server/errors.go @@ -87,6 +87,7 @@ var ( errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""} errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} diff --git a/server/server.go b/server/server.go index af7f624a..a7941c98 100644 --- a/server/server.go +++ b/server/server.go @@ -34,9 +34,9 @@ import ( /* -- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) -- HIGH Account limit creation triggers when account is taken! - HIGH Docs + - tiers + - api - HIGH Self-review - MEDIUM: Test for expiring messages after reservation removal - MEDIUM: Test new token endpoints & never-expiring token @@ -1540,18 +1540,6 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { return nil } -func (s *Server) limitRequests(next handleFunc) handleFunc { - return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { - return next(w, r, v) - } else if err := v.RequestAllowed(); err != nil { - logvr(v, r).Err(err).Trace("Request not allowed by rate limiter") - return errHTTPTooManyRequestsLimitRequests - } - return next(w, r, v) - } -} - // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers // before passing it on to the next handler. This is meant to be used in combination with handlePublish. func (s *Server) transformBodyJSON(next handleFunc) handleFunc { @@ -1648,43 +1636,65 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc } } -// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor. -// Note that this function will always return a visitor, even if an error occurs. -func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) { +// maybeAuthenticate reads the "Authorization" header and will try to authenticate the user +// if it is set. +// +// - If the header is not set, an IP-based visitor is returned +// - If the header is set, authenticate will be called to check the username/password (Basic auth), +// or the token (Bearer auth), and read the user from the database +// +// This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so +// that subsequent logging calls still have a visitor context. +func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) { + // Read "Authorization" header value, and exit out early if it's not set ip := extractIPAddress(r, s.config.BehindProxy) - var u *user.User // may stay nil if no auth header! - if u, err = s.authenticate(r); err != nil { - logr(r).Err(err).Debug("Authentication failed: %s", err.Error()) - err = errHTTPUnauthorized // Always return visitor, even when error occurs! + vip := s.visitor(ip, nil) + header, err := readAuthHeader(r) + if err != nil { + return vip, err + } else if header == "" { + return vip, nil + } else if s.userManager == nil { + return vip, errHTTPUnauthorized } - v = s.visitor(ip, u) - v.SetUser(u) // Update visitor user with latest from database! - return v, err // Always return visitor, even when error occurs! + // If we're trying to auth, check the rate limiter first + if !vip.AuthAllowed() { + return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs! + } + u, err := s.authenticate(r, header) + if err != nil { + vip.AuthFailed() + logr(r).Err(err).Debug("Authentication failed") + return vip, errHTTPUnauthorized // Always return visitor, even when error occurs! + } + // Authentication with user was successful + return s.visitor(ip, u), nil } // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...). // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth -// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)). -func (s *Server) authenticate(r *http.Request) (user *user.User, err error) { +// query param is effectively doubly base64 encoded. Its format is base64(Basic base64(user:pass)). +func (s *Server) authenticate(r *http.Request, header string) (user *user.User, err error) { + if strings.HasPrefix(header, "Bearer") { + return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(header, "Bearer"))) + } + return s.authenticateBasicAuth(r, header) +} + +// readAuthHeader reads the raw value of the Authorization header, either from the actual HTTP header, +// or from the ?auth... query parameter +func readAuthHeader(r *http.Request) (string, error) { value := strings.TrimSpace(r.Header.Get("Authorization")) queryParam := readQueryParam(r, "authorization", "auth") if queryParam != "" { a, err := base64.RawURLEncoding.DecodeString(queryParam) if err != nil { - return nil, err + return "", err } value = strings.TrimSpace(string(a)) } - if value == "" { - return nil, nil - } else if s.userManager == nil { - return nil, errHTTPUnauthorized - } - if strings.HasPrefix(value, "Bearer") { - return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))) - } - return s.authenticateBasicAuth(r, value) + return value, nil } func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) { @@ -1721,6 +1731,7 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor { return s.visitors[id] } v.Keepalive() + v.SetUser(user) // Always update with the latest user, may be nil! return v } diff --git a/server/server_account.go b/server/server_account.go index e12aedf6..b4dadc09 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -41,6 +41,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { return err } + v.AccountCreated() return s.writeJSON(w, newSuccessResponse()) } diff --git a/server/server_firebase.go b/server/server_firebase.go index 0aa61283..749901e5 100644 --- a/server/server_firebase.go +++ b/server/server_firebase.go @@ -39,7 +39,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien } func (c *firebaseClient) Send(v *visitor, m *message) error { - if err := v.FirebaseAllowed(); err != nil { + if !v.FirebaseAllowed() { return errFirebaseTemporarilyBanned } fbm, err := toFirebaseMessage(m, c.auther) diff --git a/server/server_middleware.go b/server/server_middleware.go index a3d945bf..684253ad 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -1,9 +1,21 @@ package server import ( + "heckel.io/ntfy/util" "net/http" ) +func (s *Server) limitRequests(next handleFunc) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { + return next(w, r, v) + } else if !v.RequestAllowed() { + return errHTTPTooManyRequestsLimitRequests + } + return next(w, r, v) + } +} + func (s *Server) ensureWebEnabled(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { if !s.config.EnableWeb { diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 7e2f0054..4640a728 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -374,13 +374,13 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes var wg sync.WaitGroup for i := 0; i < 209; i++ { wg.Add(1) - go func() { + go func(i int) { + defer wg.Done() rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) - require.Equal(t, 200, rr.Code) - wg.Done() - }() + require.Equal(t, 200, rr.Code, "Failed on %d", i) + }(i) } wg.Wait() rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{ diff --git a/server/server_test.go b/server/server_test.go index 711d45d8..aa7a2049 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -733,6 +733,24 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) { require.Equal(t, 403, response.Code) // Anonymous read not allowed } +func TestServer_Auth_Fail_Rate_Limiting(t *testing.T) { + c := newTestConfigWithAuthFile(t) + s := newTestServer(t, c) + + for i := 0; i < 10; i++ { + response := request(t, s, "PUT", "/announcements", "test", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 401, response.Code) + } + + response := request(t, s, "PUT", "/announcements", "test", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 429, response.Code) + require.Equal(t, 42909, toHTTPError(t, response.Body.String()).Code) +} + func TestServer_Auth_ViaQuery(t *testing.T) { c := newTestConfigWithAuthFile(t) c.AuthDefault = user.PermissionDenyAll diff --git a/server/visitor.go b/server/visitor.go index 34c598a6..172ef2f9 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -64,6 +64,7 @@ type visitor struct { subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections) bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil + authLimiter *rate.Limiter // Limiter for incorrect login attempts firebase time.Time // Next allowed Firebase message seen time.Time // Last seen time of this visitor (needed for removal of stale visitors) mu sync.Mutex @@ -130,6 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana emailsLimiter: nil, // Set in resetLimiters bandwidthLimiter: nil, // Set in resetLimiters accountLimiter: nil, // Set in resetLimiters, may be nil + authLimiter: nil, // Set in resetLimiters, may be nil } v.resetLimitersNoLock(messages, emails, false) return v @@ -154,6 +156,10 @@ func (v *visitor) contextNoLock() log.Context { "visitor_request_limiter_limit": v.requestLimiter.Limit(), "visitor_request_limiter_tokens": v.requestLimiter.Tokens(), } + if v.authLimiter != nil { + fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit() + fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens() + } if v.user != nil { fields["user_id"] = v.user.ID fields["user_name"] = v.user.Name @@ -182,28 +188,16 @@ func visitorExtendedInfoContext(info *visitorInfo) log.Context { } } -func (v *visitor) RequestAllowed() error { +func (v *visitor) RequestAllowed() bool { v.mu.Lock() // limiters could be replaced! defer v.mu.Unlock() - if !v.requestLimiter.Allow() { - return errVisitorLimitReached - } - return nil + return v.requestLimiter.Allow() } -func (v *visitor) RequestLimiter() *rate.Limiter { - v.mu.Lock() // limiters could be replaced! - defer v.mu.Unlock() - return v.requestLimiter -} - -func (v *visitor) FirebaseAllowed() error { +func (v *visitor) FirebaseAllowed() bool { v.mu.Lock() defer v.mu.Unlock() - if time.Now().Before(v.firebase) { - return errVisitorLimitReached - } - return nil + return !time.Now().Before(v.firebase) } func (v *visitor) FirebaseTemporarilyDeny() { @@ -230,15 +224,44 @@ func (v *visitor) SubscriptionAllowed() bool { return v.subscriptionLimiter.Allow() } +// AuthAllowed returns true if an auth request can be attempted (> 1 token available) +func (v *visitor) AuthAllowed() bool { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + if v.authLimiter == nil { + return true + } + return v.authLimiter.Tokens() > 1 +} + +// AuthFailed records an auth failure +func (v *visitor) AuthFailed() { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + if v.authLimiter != nil { + v.authLimiter.Allow() + } +} + +// AccountCreationAllowed returns true if a new account can be created func (v *visitor) AccountCreationAllowed() bool { v.mu.Lock() // limiters could be replaced! defer v.mu.Unlock() - if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) { + if v.accountLimiter == nil || (v.accountLimiter != nil && v.accountLimiter.Tokens() < 1) { return false } return true } +// AccountCreated decreases the account limiter. This is to be called after an account was created. +func (v *visitor) AccountCreated() { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + if v.accountLimiter != nil { + v.accountLimiter.Allow() + } +} + func (v *visitor) BandwidthAllowed(bytes int64) bool { v.mu.Lock() // limiters could be replaced! defer v.mu.Unlock() @@ -336,8 +359,10 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay) if v.user == nil { v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) + v.authLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAuthFailureLimitReplenish), v.config.VisitorAuthFailureLimitBurst) } else { v.accountLimiter = nil // Users cannot create accounts when logged in + v.authLimiter = nil // Users are already logged in, no need to limit requests } if enqueueUpdate && v.user != nil { go v.userManager.EnqueueStats(v.user.ID, &user.Stats{ diff --git a/user/manager.go b/user/manager.go index c9883774..c4cf589a 100644 --- a/user/manager.go +++ b/user/manager.go @@ -372,6 +372,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { } user, err := a.userByToken(token) if err != nil { + log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed") return nil, ErrUnauthenticated } user.Token = token