diff --git a/cmd/app.go b/cmd/app.go index 975fd2fa..805c5dc7 100644 --- a/cmd/app.go +++ b/cmd/app.go @@ -2,10 +2,12 @@ package cmd import ( + "fmt" "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "heckel.io/ntfy/log" "os" + "regexp" ) const ( @@ -20,8 +22,14 @@ var flagsDefault = []cli.Flag{ &cli.BoolFlag{Name: "trace", EnvVars: []string{"NTFY_TRACE"}, Usage: "enable tracing (very verbose, be careful)"}, &cli.BoolFlag{Name: "no-log-dates", Aliases: []string{"no_log_dates"}, EnvVars: []string{"NTFY_NO_LOG_DATES"}, Usage: "disable the date/time prefix"}, altsrc.NewStringFlag(&cli.StringFlag{Name: "log-level", Aliases: []string{"log_level"}, Value: log.InfoLevel.String(), EnvVars: []string{"NTFY_LOG_LEVEL"}, Usage: "set log level"}), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "log-level-overrides", Aliases: []string{"log_level_overrides"}, EnvVars: []string{"NTFY_LOG_LEVEL_OVERRIDES"}, Usage: "set log level overrides"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "log-format", Aliases: []string{"log_format"}, Value: log.TextFormat.String(), EnvVars: []string{"NTFY_LOG_FORMAT"}, Usage: "set log level"}), } +var ( + logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=]+)\s*=\s*(\S+)\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`) +) + // New creates a new CLI application func New() *cli.App { return &cli.App{ @@ -40,15 +48,30 @@ func New() *cli.App { } func initLogFunc(c *cli.Context) error { + log.SetLevel(log.ToLevel(c.String("log-level"))) + log.SetFormat(log.ToFormat(c.String("log-format"))) if c.Bool("trace") { log.SetLevel(log.TraceLevel) } else if c.Bool("debug") { log.SetLevel(log.DebugLevel) - } else { - log.SetLevel(log.ToLevel(c.String("log-level"))) } if c.Bool("no-log-dates") { log.DisableDates() } + if err := applyLogLevelOverrides(c.StringSlice("log-level-overrides")); err != nil { + return err + } + return nil +} + +func applyLogLevelOverrides(rawOverrides []string) error { + for _, override := range rawOverrides { + m := logLevelOverrideRegex.FindStringSubmatch(override) + if len(m) != 4 { + return fmt.Errorf(`invalid log level override "%s", must be "field=value -> loglevel", e.g. "user_id=u_123 -> DEBUG"`, override) + } + field, value, level := m[1], m[2], m[3] + log.SetLevelOverride(field, value, log.ToLevel(level)) + } return nil } diff --git a/cmd/serve.go b/cmd/serve.go index 974ef4bc..3aaafeb8 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -309,9 +309,9 @@ func execServe(c *cli.Context) error { // Run server s, err := server.New(conf) if err != nil { - log.Fatal(err) + log.Fatal(err.Error()) } else if err := s.Run(); err != nil { - log.Fatal(err) + log.Fatal(err.Error()) } log.Info("Exiting.") return nil @@ -338,7 +338,9 @@ func sigHandlerConfigReload(config string) { log.Warn("Hot reload failed: %s", err.Error()) continue } - reloadLogLevel(inputSource) + if err := reloadLogLevel(inputSource); err != nil { + log.Warn("Reloading log level failed: %s", err.Error()) + } } } @@ -367,13 +369,24 @@ func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { return } -func reloadLogLevel(inputSource altsrc.InputSourceContext) { +func reloadLogLevel(inputSource altsrc.InputSourceContext) error { newLevelStr, err := inputSource.String("log-level") if err != nil { - log.Warn("Cannot load log level: %s", err.Error()) - return + return fmt.Errorf("cannot load log level: %s", err.Error()) } - newLevel := log.ToLevel(newLevelStr) - log.SetLevel(newLevel) - log.Info("Log level is %s", newLevel.String()) + overrides, err := inputSource.StringSlice("log-level-overrides") + if err != nil { + return fmt.Errorf("cannot load log level overrides (1): %s", err.Error()) + } + log.ResetLevelOverride() + if err := applyLogLevelOverrides(overrides); err != nil { + return fmt.Errorf("cannot load log level overrides (2): %s", err.Error()) + } + log.SetLevel(log.ToLevel(newLevelStr)) + if len(overrides) > 0 { + log.Info("Log level is %v, %d override(s) in place", strings.ToUpper(newLevelStr), len(overrides)) + } else { + log.Info("Log level is %v", strings.ToUpper(newLevelStr)) + } + return nil } diff --git a/log/event.go b/log/event.go new file mode 100644 index 00000000..81232773 --- /dev/null +++ b/log/event.go @@ -0,0 +1,150 @@ +package log + +import ( + "encoding/json" + "fmt" + "log" + "os" + "sort" + "strings" + "time" +) + +const ( + tagField = "tag" + errorField = "error" +) + +type Event struct { + Time int64 `json:"time"` + Level Level `json:"level"` + Message string `json:"message"` + fields map[string]any +} + +func newEvent() *Event { + return &Event{ + Time: time.Now().UnixMilli(), + fields: make(map[string]any), + } +} + +func (e *Event) Fatal(message string, v ...any) { + e.Log(FatalLevel, message, v...) + os.Exit(1) +} + +func (e *Event) Error(message string, v ...any) { + e.Log(ErrorLevel, message, v...) +} + +func (e *Event) Warn(message string, v ...any) { + e.Log(WarnLevel, message, v...) +} + +func (e *Event) Info(message string, v ...any) { + e.Log(InfoLevel, message, v...) +} + +func (e *Event) Debug(message string, v ...any) { + e.Log(DebugLevel, message, v...) +} + +func (e *Event) Trace(message string, v ...any) { + e.Log(TraceLevel, message, v...) +} + +func (e *Event) Tag(tag string) *Event { + e.fields[tagField] = tag + return e +} + +func (e *Event) Err(err error) *Event { + e.fields[errorField] = err + return e +} + +func (e *Event) Field(key string, value any) *Event { + e.fields[key] = value + return e +} + +func (e *Event) Fields(fields map[string]any) *Event { + for k, v := range fields { + e.fields[k] = v + } + return e +} + +func (e *Event) Context(contexts ...Ctx) *Event { + for _, c := range contexts { + e.Fields(c.Context()) + } + return e +} + +func (e *Event) Log(l Level, message string, v ...any) { + e.Message = fmt.Sprintf(message, v...) + e.Level = l + if e.shouldPrint() { + if CurrentFormat() == JSONFormat { + log.Println(e.JSON()) + } else { + log.Println(e.String()) + } + } +} + +// Loggable returns true if the given log level is lower or equal to the current log level +func (e *Event) Loggable(l Level) bool { + return e.globalLevelWithOverride() <= l +} + +// IsTrace returns true if the current log level is TraceLevel +func (e *Event) IsTrace() bool { + return e.Loggable(TraceLevel) +} + +// IsDebug returns true if the current log level is DebugLevel or below +func (e *Event) IsDebug() bool { + return e.Loggable(DebugLevel) +} + +func (e *Event) JSON() string { + b, _ := json.Marshal(e) + s := string(b) + if len(e.fields) > 0 { + b, _ := json.Marshal(e.fields) + s = fmt.Sprintf("{%s,%s}", s[1:len(s)-1], string(b[1:len(b)-1])) + } + return s +} + +func (e *Event) String() string { + if len(e.fields) == 0 { + return fmt.Sprintf("%s %s", e.Level.String(), e.Message) + } + fields := make([]string, 0) + for k, v := range e.fields { + fields = append(fields, fmt.Sprintf("%s=%v", k, v)) + } + sort.Strings(fields) + return fmt.Sprintf("%s %s (%s)", e.Level.String(), e.Message, strings.Join(fields, ", ")) +} + +func (e *Event) shouldPrint() bool { + return e.globalLevelWithOverride() <= e.Level +} + +func (e *Event) globalLevelWithOverride() Level { + mu.Lock() + l, ov := level, overrides + mu.Unlock() + for field, override := range ov { + value, exists := e.fields[field] + if exists && value == override.value { + return override.level + } + } + return l +} diff --git a/log/log.go b/log/log.go index 4061921d..5eb88035 100644 --- a/log/log.go +++ b/log/log.go @@ -2,71 +2,60 @@ package log import ( "log" - "strings" "sync" ) -// Level is a well-known log level, as defined below -type Level int - -// Well known log levels -const ( - TraceLevel Level = iota - DebugLevel - InfoLevel - WarnLevel - ErrorLevel -) - -func (l Level) String() string { - switch l { - case TraceLevel: - return "TRACE" - case DebugLevel: - return "DEBUG" - case InfoLevel: - return "INFO" - case WarnLevel: - return "WARN" - case ErrorLevel: - return "ERROR" - } - return "unknown" -} - var ( - level = InfoLevel - mu = &sync.Mutex{} + level = InfoLevel + format = TextFormat + overrides = make(map[string]*levelOverride) + mu = &sync.Mutex{} ) -// Trace prints the given message, if the current log level is TRACE -func Trace(message string, v ...any) { - logIf(TraceLevel, message, v...) -} - -// Debug prints the given message, if the current log level is DEBUG or lower -func Debug(message string, v ...any) { - logIf(DebugLevel, message, v...) -} - -// Info prints the given message, if the current log level is INFO or lower -func Info(message string, v ...any) { - logIf(InfoLevel, message, v...) -} - -// Warn prints the given message, if the current log level is WARN or lower -func Warn(message string, v ...any) { - logIf(WarnLevel, message, v...) +// Fatal prints the given message, and exits the program +func Fatal(message string, v ...any) { + newEvent().Fatal(message, v...) } // Error prints the given message, if the current log level is ERROR or lower func Error(message string, v ...any) { - logIf(ErrorLevel, message, v...) + newEvent().Error(message, v...) } -// Fatal prints the given message, and exits the program -func Fatal(v ...any) { - log.Fatalln(v...) +// Warn prints the given message, if the current log level is WARN or lower +func Warn(message string, v ...any) { + newEvent().Warn(message, v...) +} + +// Info prints the given message, if the current log level is INFO or lower +func Info(message string, v ...any) { + newEvent().Info(message, v...) +} + +// Debug prints the given message, if the current log level is DEBUG or lower +func Debug(message string, v ...any) { + newEvent().Debug(message, v...) +} + +// Trace prints the given message, if the current log level is TRACE +func Trace(message string, v ...any) { + newEvent().Trace(message, v...) +} + +func Context(contexts ...Ctx) *Event { + return newEvent().Context(contexts...) +} + +func Field(key string, value any) *Event { + return newEvent().Field(key, value) +} + +func Fields(fields map[string]any) *Event { + return newEvent().Fields(fields) +} + +func Tag(tag string) *Event { + return newEvent().Tag(tag) } // CurrentLevel returns the current log level @@ -83,30 +72,42 @@ func SetLevel(newLevel Level) { level = newLevel } +// SetLevelOverride adds a log override for the given field +func SetLevelOverride(field string, value any, level Level) { + mu.Lock() + defer mu.Unlock() + overrides[field] = &levelOverride{value: value, level: level} +} + +// ResetLevelOverride removes all log level overrides +func ResetLevelOverride() { + mu.Lock() + defer mu.Unlock() + overrides = make(map[string]*levelOverride) +} + +// CurrentFormat returns the current log formt +func CurrentFormat() Format { + mu.Lock() + defer mu.Unlock() + return format +} + +// SetFormat sets a new log format +func SetFormat(newFormat Format) { + mu.Lock() + defer mu.Unlock() + format = newFormat + if newFormat == JSONFormat { + DisableDates() + } +} + // DisableDates disables the date/time prefix func DisableDates() { log.SetFlags(0) } -// ToLevel converts a string to a Level. It returns InfoLevel if the string -// does not match any known log levels. -func ToLevel(s string) Level { - switch strings.ToUpper(s) { - case "TRACE": - return TraceLevel - case "DEBUG": - return DebugLevel - case "INFO": - return InfoLevel - case "WARN", "WARNING": - return WarnLevel - case "ERROR": - return ErrorLevel - default: - return InfoLevel - } -} - // Loggable returns true if the given log level is lower or equal to the current log level func Loggable(l Level) bool { return CurrentLevel() <= l @@ -121,9 +122,3 @@ func IsTrace() bool { func IsDebug() bool { return Loggable(DebugLevel) } - -func logIf(l Level, message string, v ...any) { - if CurrentLevel() <= l { - log.Printf(l.String()+" "+message, v...) - } -} diff --git a/log/log_test.go b/log/log_test.go new file mode 100644 index 00000000..bae46c7f --- /dev/null +++ b/log/log_test.go @@ -0,0 +1,57 @@ +package log_test + +import ( + "heckel.io/ntfy/log" + "net/http" + "testing" +) + +const tagPay = "PAY" + +type visitor struct { + UserID string + IP string +} + +func (v *visitor) Context() map[string]any { + return map[string]any{ + "user_id": v.UserID, + "ip": v.IP, + } +} + +func TestEvent_Info(t *testing.T) { + /* + log-level: INFO, user_id:u_abc=DEBUG + log-level-overrides: + - user_id=u_abc: DEBUG + log-filter = + + */ + v := &visitor{ + UserID: "u_abc", + IP: "1.2.3.4", + } + stripeCtx := log.NewCtx(map[string]any{ + "tag": "pay", + }) + log.SetLevel(log.InfoLevel) + //log.SetFormat(log.JSONFormat) + //log.SetLevelOverride("user_id", "u_abc", log.DebugLevel) + log.SetLevelOverride("tag", "pay", log.DebugLevel) + mlog := log.Field("tag", "manager") + mlog.Field("one", 1).Info("this is one") + mlog.Err(http.ErrHandlerTimeout).Field("two", 2).Info("this is two") + log.Info("somebody did something") + log. + Context(stripeCtx, v). + Fields(map[string]any{ + "tier": "ti_abc", + "user_id": "u_abc", + }). + Debug("Somebody paid something for $%d", 10) + log. + Field("tag", "account"). + Field("user_id", "u_abc"). + Debug("User logged in") +} diff --git a/log/types.go b/log/types.go new file mode 100644 index 00000000..e43b67cf --- /dev/null +++ b/log/types.go @@ -0,0 +1,111 @@ +package log + +import ( + "encoding/json" + "strings" +) + +// Level is a well-known log level, as defined below +type Level int + +// Well known log levels +const ( + TraceLevel Level = iota + DebugLevel + InfoLevel + WarnLevel + ErrorLevel + FatalLevel +) + +func (l Level) String() string { + switch l { + case TraceLevel: + return "TRACE" + case DebugLevel: + return "DEBUG" + case InfoLevel: + return "INFO" + case WarnLevel: + return "WARN" + case ErrorLevel: + return "ERROR" + case FatalLevel: + return "FATAL" + } + return "unknown" +} + +func (l Level) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +// ToLevel converts a string to a Level. It returns InfoLevel if the string +// does not match any known log levels. +func ToLevel(s string) Level { + switch strings.ToUpper(s) { + case "TRACE": + return TraceLevel + case "DEBUG": + return DebugLevel + case "INFO": + return InfoLevel + case "WARN", "WARNING": + return WarnLevel + case "ERROR": + return ErrorLevel + default: + return InfoLevel + } +} + +// Format is a well-known log format +type Format int + +// Log formats +const ( + TextFormat Format = iota + JSONFormat +) + +func (f Format) String() string { + switch f { + case TextFormat: + return "text" + case JSONFormat: + return "json" + } + return "unknown" +} + +// ToFormat converts a string to a Format. It returns TextFormat if the string +// does not match any known log formats. +func ToFormat(s string) Format { + switch strings.ToLower(s) { + case "text": + return TextFormat + case "json": + return JSONFormat + default: + return TextFormat + } +} + +type Ctx interface { + Context() map[string]any +} + +type fieldsCtx map[string]any + +func (f fieldsCtx) Context() map[string]any { + return f +} + +func NewCtx(fields map[string]any) Ctx { + return fieldsCtx(fields) +} + +type levelOverride struct { + value any + level Level +} diff --git a/server/server.go b/server/server.go index f538289d..57499533 100644 --- a/server/server.go +++ b/server/server.go @@ -138,6 +138,19 @@ const ( wsPongWait = 15 * time.Second ) +// Log tags +const ( + tagPublish = "publish" + tagFirebase = "firebase" + tagEmail = "email" // Send email + tagSMTP = "smtp" // Receive email + tagPay = "pay" + tagAccount = "account" + tagManager = "manager" + tagResetter = "resetter" + tagWebsocket = "websocket" +) + // New instantiates a new Server. It creates the cache and adds a Firebase // subscriber (if configured). func New(conf *Config) (*Server, error) { @@ -305,9 +318,9 @@ func (s *Server) closeDatabases() { func (s *Server) handle(w http.ResponseWriter, r *http.Request) { v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned if err == nil { - log.Debug("%s Dispatching request", logHTTPPrefix(v, r)) + logvr(v, r).Debug("Dispatching request") if log.IsTrace() { - log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r)) + logvr(v, r).Trace("Entire request (headers and body):\n%s", renderHTTPRequest(r)) } err = s.handleInternal(w, r, v) } @@ -315,9 +328,9 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { if websocket.IsWebSocketUpgrade(r) { isNormalError := strings.Contains(err.Error(), "i/o timeout") if isNormalError { - log.Debug("%s WebSocket error (this error is okay, it happens a lot): %s", logHTTPPrefix(v, r), err.Error()) + logvr(v, r).Tag(tagWebsocket).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error()) } else { - log.Info("%s WebSocket error: %s", logHTTPPrefix(v, r), err.Error()) + logvr(v, r).Tag(tagWebsocket).Info("WebSocket error: %s", err.Error()) } return // Do not attempt to write to upgraded connection } @@ -331,9 +344,21 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { } isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest if isNormalError { - log.Debug("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error()) + logvr(v, r). + Fields(map[string]any{ + "error": err, + "error_code": httpErr.Code, + "http_status": httpErr.HTTPCode, + }). + Debug("Connection closed with HTTP %d (ntfy error %d): %s", httpErr.HTTPCode, httpErr.Code, err.Error()) } else { - log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error()) + logvr(v, r). + Fields(map[string]any{ + "error": err, + "error_code": httpErr.Code, + "http_status": httpErr.HTTPCode, + }). + Info("Connection closed with HTTP %d (ntfy error %d): %s", httpErr.HTTPCode, httpErr.Code, err.Error()) } w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests @@ -586,10 +611,20 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes m.Message = emptyMessageBody } delayed := m.Time > time.Now().Unix() - log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s", - logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email) + logvrm(v, r, m). + Tag(tagPublish). + Fields(map[string]any{ + "message_delayed": delayed, + "message_firebase": firebase, + "message_unifiedpush": unifiedpush, + "message_email": email, + }). + Debug("Received message") if log.IsTrace() { - log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m)) + logvrm(v, r, m). + Tag(tagPublish). + Field("message_body", util.MaybeMarshalJSON(m)). + Trace("Message body") } if !delayed { if err := t.Publish(v, m); err != nil { @@ -605,10 +640,10 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes go s.forwardPollRequest(v, m) } } else { - log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m)) + logvrm(v, r, m).Tag(tagPublish).Debug("Message delayed, will process later") } if cache { - log.Debug("%s Adding message to cache", logMessagePrefix(v, m)) + logvrm(v, r, m).Tag(tagPublish).Debug("Adding message to cache") if err := s.messageCache.AddMessage(m); err != nil { return nil, err } @@ -640,20 +675,20 @@ func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v * } func (s *Server) sendToFirebase(v *visitor, m *message) { - log.Debug("%s Publishing to Firebase", logMessagePrefix(v, m)) + logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase") if err := s.firebaseClient.Send(v, m); err != nil { if err == errFirebaseTemporarilyBanned { - log.Debug("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Tag(tagFirebase).Err(err).Debug("Unable to publish to Firebase: %v", err.Error()) } else { - log.Warn("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Tag(tagFirebase).Err(err).Warn("Unable to publish to Firebase: %v", err.Error()) } } } func (s *Server) sendEmail(v *visitor, m *message, email string) { - log.Debug("%s Sending email to %s", logMessagePrefix(v, m), email) + logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email) if err := s.smtpSender.Send(v, m, email); err != nil { - log.Warn("%s Unable to send email to %s: %v", logMessagePrefix(v, m), email, err.Error()) + logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error()) } } @@ -661,10 +696,10 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic) topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL))) forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash) - log.Debug("%s Publishing poll request to %s", logMessagePrefix(v, m), forwardURL) + logvm(v, m).Debug("Publishing poll request to %s", forwardURL) req, err := http.NewRequest("POST", forwardURL, strings.NewReader("")) if err != nil { - log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Err(err).Warn("Unable to publish poll request") return } req.Header.Set("X-Poll-ID", m.ID) @@ -673,10 +708,10 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { } response, err := httpClient.Do(req) if err != nil { - log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Err(err).Warn("Unable to publish poll request") return } else if response.StatusCode != http.StatusOK { - log.Warn("%s Unable to publish poll request, unexpected HTTP status: %d", logMessagePrefix(v, m), response.StatusCode) + logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d") return } } @@ -924,8 +959,8 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v } func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error { - log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r)) - defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r)) + logvr(v, r).Debug("HTTP stream connection opened") + defer logvr(v, r).Debug("HTTP stream connection closed") if !v.SubscriptionAllowed() { return errHTTPTooManyRequestsLimitSubscriptions } @@ -993,7 +1028,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * case <-r.Context().Done(): return nil case <-time.After(s.config.KeepaliveInterval): - log.Trace("%s Sending keepalive message", logHTTPPrefix(v, r)) + logvr(v, r).Trace("Sending keepalive message") v.Keepalive() if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message return err @@ -1010,8 +1045,8 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return errHTTPTooManyRequestsLimitSubscriptions } defer v.RemoveSubscription() - log.Debug("%s WebSocket connection opened", logHTTPPrefix(v, r)) - defer log.Debug("%s WebSocket connection closed", logHTTPPrefix(v, r)) + logvr(v, r).Tag(tagWebsocket).Debug("WebSocket connection opened") + defer logvr(v, r).Tag(tagWebsocket).Debug("WebSocket connection closed") topics, topicsStr, err := s.topicsFromPath(r.URL.Path) if err != nil { return err @@ -1047,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return err } conn.SetPongHandler(func(appData string) error { - log.Trace("%s Received WebSocket pong", logHTTPPrefix(v, r)) + logvr(v, r).Tag(tagWebsocket).Trace("Received WebSocket pong") return conn.SetReadDeadline(time.Now().Add(pongWait)) }) for { @@ -1069,7 +1104,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil { return err } - log.Trace("%s Sending WebSocket ping", logHTTPPrefix(v, r)) + logvr(v, r).Tag(tagWebsocket).Trace("Sending WebSocket ping") return conn.WriteMessage(websocket.PingMessage, nil) } for { @@ -1077,7 +1112,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi case <-gctx.Done(): return nil case <-cancelCtx.Done(): - log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r)) + logvr(v, r).Tag(tagWebsocket).Trace("Cancel received, closing subscriber connection") conn.Close() return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"} case <-time.After(s.config.KeepaliveInterval): @@ -1120,7 +1155,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } err = g.Wait() if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Trace("%s WebSocket connection closed: %s", logHTTPPrefix(v, r), err.Error()) + logvr(v, r).Tag(tagWebsocket).Err(err).Trace("WebSocket connection closed") return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot } return err @@ -1251,8 +1286,8 @@ func (s *Server) topicFromID(id string) (*topic, error) { } func (s *Server) execManager() { - log.Debug("Manager: Starting") - defer log.Debug("Manager: Finished") + log.Tag(tagManager).Debug("Starting manager") + defer log.Tag(tagManager).Debug("Finished manager") // WARNING: Make sure to only selectively lock with the mutex, and be aware that this // there is no mutex for the entire function. @@ -1393,13 +1428,13 @@ func (s *Server) runStatsResetter() { for { runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now()) timer := time.NewTimer(time.Until(runAt)) - log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt) + log.Tag(tagResetter).Debug("Waiting until %v to reset visitor stats", runAt) select { case <-timer.C: - log.Debug("Stats resetter: Running") + log.Tag(tagResetter).Debug("Running stats resetter") s.resetStats() case <-s.closeChan: - log.Debug("Stats resetter: Stopping timer") + log.Tag(tagResetter).Debug("Stopping stats resetter") timer.Stop() return } @@ -1415,7 +1450,7 @@ func (s *Server) resetStats() { } if s.userManager != nil { if err := s.userManager.ResetStats(); err != nil { - log.Warn("Failed to write to database: %s", err.Error()) + log.Tag(tagResetter).Warn("Failed to write to database: %s", err.Error()) } } } @@ -1442,7 +1477,7 @@ func (s *Server) runDelayedSender() { select { case <-time.After(s.config.DelayedSenderInterval): if err := s.sendDelayedMessages(); err != nil { - log.Warn("Error sending delayed messages: %s", err.Error()) + log.Tag(tagPublish).Err(err).Warn("Error sending delayed messages") } case <-s.closeChan: return @@ -1460,20 +1495,20 @@ func (s *Server) sendDelayedMessages() error { if s.userManager != nil && m.User != "" { u, err = s.userManager.User(m.User) if err != nil { - log.Warn("Error sending delayed message %s: %s", m.ID, err.Error()) + log.Context(m).Err(err).Warn("Error sending delayed message") continue } } v := s.visitor(m.Sender, u) if err := s.sendDelayedMessage(v, m); err != nil { - log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Err(err).Warn("Error sending delayed message") } } return nil } func (s *Server) sendDelayedMessage(v *visitor, m *message) error { - log.Debug("%s Sending delayed message", logMessagePrefix(v, m)) + logvm(v, m).Debug("Sending delayed message") s.mu.Lock() t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published s.mu.Unlock() @@ -1481,7 +1516,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { go func() { // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler if err := t.Publish(v, m); err != nil { - log.Warn("%s Unable to publish message: %v", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Err(err).Warn("Unable to publish message") } }() } @@ -1595,7 +1630,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc u := v.User() for _, t := range topics { if err := s.userManager.Authorize(u, t.ID, perm); err != nil { - log.Info("unauthorized: %s", err.Error()) + logvr(v, r).Err(err).Debug("Unauthorized") return errHTTPForbidden } } @@ -1609,7 +1644,7 @@ func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) { ip := extractIPAddress(r, s.config.BehindProxy) var u *user.User // may stay nil if no auth header! if u, err = s.authenticate(r); err != nil { - log.Debug("authentication failed: %s", err.Error()) + logr(r).Debug("Authentication failed: %s", err.Error()) err = errHTTPUnauthorized // Always return visitor, even when error occurs! } v = s.visitor(ip, u) diff --git a/server/server_account.go b/server/server_account.go index 6ffd1622..b4ad2faf 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -155,7 +155,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * return errHTTPBadRequestIncorrectPasswordConfirmation } if u.Billing.StripeSubscriptionID != "" { - log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID) + logvr(v, r).Tag(tagPay).Info("Canceling billing subscription for user %s", u.Name) if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil { return err } @@ -163,7 +163,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil { return err } - log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), u.Name) + logvr(v, r).Tag(tagAccount).Info("Marking user %s as deleted", u.Name) if err := s.userManager.MarkUserRemoved(u); err != nil { return err } @@ -184,6 +184,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ if err := s.userManager.ChangePassword(u.Name, req.NewPassword); err != nil { return err } + logvr(v, r).Tag(tagAccount).Debug("Changed password for user %s", u.Name) return s.writeJSON(w, newSuccessResponse()) } @@ -201,10 +202,12 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request if req.Expires != nil { expires = time.Unix(*req.Expires, 0) } - token, err := s.userManager.CreateToken(v.User().ID, label, expires, v.IP()) + u := v.User() + token, err := s.userManager.CreateToken(u.ID, label, expires, v.IP()) if err != nil { return err } + logvr(v, r).Tag(tagAccount).Debug("Created token for user %s", u.Name) response := &apiAccountTokenResponse{ Token: token.Value, Label: token.Label, diff --git a/server/server_firebase.go b/server/server_firebase.go index 20f880fe..a315bed2 100644 --- a/server/server_firebase.go +++ b/server/server_firebase.go @@ -47,11 +47,16 @@ func (c *firebaseClient) Send(v *visitor, m *message) error { return err } if log.IsTrace() { - log.Trace("%s Firebase message: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(fbm)) + logvm(v, m). + Tag(tagFirebase). + Field("firebase_message", util.MaybeMarshalJSON(fbm)). + Trace("Firebase message") } err = c.sender.Send(fbm) if err == errFirebaseQuotaExceeded { - log.Warn("%s Firebase quota exceeded (likely for topic), temporarily denying Firebase access to visitor", logMessagePrefix(v, m)) + logvm(v, m). + Tag(tagFirebase). + Warn("Firebase quota exceeded (likely for topic), temporarily denying Firebase access to visitor") v.FirebaseTemporarilyDeny() } return err diff --git a/server/smtp_sender.go b/server/smtp_sender.go index 7d6b7519..8971f12d 100644 --- a/server/smtp_sender.go +++ b/server/smtp_sender.go @@ -4,7 +4,6 @@ import ( _ "embed" // required by go:embed "encoding/json" "fmt" - "heckel.io/ntfy/log" "heckel.io/ntfy/util" "mime" "net" @@ -37,8 +36,18 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error { return err } auth := smtp.PlainAuth("", s.config.SMTPSenderUser, s.config.SMTPSenderPass, host) - log.Debug("%s Sending mail: via=%s, user=%s, pass=***, to=%s", logMessagePrefix(v, m), s.config.SMTPSenderAddr, s.config.SMTPSenderUser, to) - log.Trace("%s Mail body: %s", logMessagePrefix(v, m), message) + logvm(v, m). + Tag(tagEmail). + Fields(map[string]any{ + "email_via": s.config.SMTPSenderAddr, + "email_user": s.config.SMTPSenderUser, + "email_to": to, + }). + Debug("Sending email") + logvm(v, m). + Tag(tagEmail). + Field("email_body", message). + Trace("Email body") return smtp.SendMail(s.config.SMTPSenderAddr, auth, s.config.SMTPSenderFrom, []string{to}, []byte(message)) }) } @@ -54,7 +63,7 @@ func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error { s.mu.Lock() defer s.mu.Unlock() if err != nil { - log.Debug("%s Sending mail failed: %s", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Err(err).Debug("Sending mail failed") s.failure++ } else { s.success++ diff --git a/server/smtp_server.go b/server/smtp_server.go index 3f4b9b68..52f8f851 100644 --- a/server/smtp_server.go +++ b/server/smtp_server.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "github.com/emersion/go-smtp" - "heckel.io/ntfy/log" "io" "mime" "mime/multipart" @@ -41,13 +40,13 @@ func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Reques } } -func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - log.Debug("%s Incoming mail, login with user %s", logSMTPPrefix(state), username) +func (b *smtpBackend) Login(state *smtp.ConnectionState, username, _ string) (smtp.Session, error) { + logem(state).Debug("Incoming mail, login with user %s", username) return &smtpSession{backend: b, state: state}, nil } func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - log.Debug("%s Incoming mail, anonymous login", logSMTPPrefix(state)) + logem(state).Debug("Incoming mail, anonymous login") return &smtpSession{backend: b, state: state}, nil } @@ -66,17 +65,17 @@ type smtpSession struct { } func (s *smtpSession) AuthPlain(username, password string) error { - log.Debug("%s AUTH PLAIN (with username %s)", logSMTPPrefix(s.state), username) + logem(s.state).Debug("AUTH PLAIN (with username %s)", username) return nil } func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error { - log.Debug("%s MAIL FROM: %s (with options: %#v)", logSMTPPrefix(s.state), from, opts) + logem(s.state).Debug("%s MAIL FROM: %s (with options: %#v)", from, opts) return nil } func (s *smtpSession) Rcpt(to string) error { - log.Debug("%s RCPT TO: %s", logSMTPPrefix(s.state), to) + logem(s.state).Debug("RCPT TO: %s", to) return s.withFailCount(func() error { conf := s.backend.config addressList, err := mail.ParseAddressList(to) @@ -113,10 +112,11 @@ func (s *smtpSession) Data(r io.Reader) error { if err != nil { return err } - if log.IsTrace() { - log.Trace("%s DATA: %s", logSMTPPrefix(s.state), string(b)) - } else if log.IsDebug() { - log.Debug("%s DATA: %d byte(s)", logSMTPPrefix(s.state), len(b)) + ev := logem(s.state).Tag(tagSMTP) + if ev.IsTrace() { + ev.Field("smtp_data", string(b)).Trace("DATA") + } else if ev.IsDebug() { + ev.Debug("DATA: %d byte(s)", len(b)) } msg, err := mail.ReadMessage(bytes.NewReader(b)) if err != nil { @@ -198,7 +198,7 @@ func (s *smtpSession) withFailCount(fn func() error) error { if err != nil { // Almost all of these errors are parse errors, and user input errors. // We do not want to spam the log with WARN messages. - log.Debug("%s Incoming mail error: %s", logSMTPPrefix(s.state), err.Error()) + logem(s.state).Err(err).Debug("Incoming mail error") s.backend.failure++ } return err diff --git a/server/topic.go b/server/topic.go index aacf6bea..21ab65b1 100644 --- a/server/topic.go +++ b/server/topic.go @@ -58,18 +58,18 @@ func (t *topic) Publish(v *visitor, m *message) error { // subscribers map here. Actually sending out the messages then doesn't have to lock. subscribers := t.subscribersCopy() if len(subscribers) > 0 { - log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers)) + logvm(v, m).Tag(tagPublish).Debug("Forwarding to %d subscriber(s)", len(subscribers)) for _, s := range subscribers { // We call the subscriber functions in their own Go routines because they are blocking, and // we don't want individual slow subscribers to be able to block others. go func(s subscriber) { if err := s(v, m); err != nil { - log.Warn("%s Error forwarding to subscriber: %s", logMessagePrefix(v, m), err.Error()) + logvm(v, m).Tag(tagPublish).Err(err).Warn("Error forwarding to subscriber") } }(s.subscriber) } } else { - log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m)) + logvm(v, m).Tag(tagPublish).Trace("No stream or WebSocket subscribers, not forwarding") } }() return nil diff --git a/server/types.go b/server/types.go index 07aee521..2c818ca6 100644 --- a/server/types.go +++ b/server/types.go @@ -42,6 +42,23 @@ type message struct { User string `json:"-"` // Username of the uploader, used to associated attachments } +func (m *message) Context() map[string]any { + fields := map[string]any{ + "message_id": m.ID, + "message_time": m.Time, + "message_event": m.Event, + "message_topic": m.Topic, + "message_body_size": len(m.Message), + } + if m.Sender != netip.IPv4Unspecified() { + fields["message_sender"] = m.Sender.String() + } + if m.User != "" { + fields["message_user"] = m.User + } + return fields +} + type attachment struct { Name string `json:"name"` Type string `json:"type,omitempty"` diff --git a/server/util.go b/server/util.go index 3e24dacf..2fabf135 100644 --- a/server/util.go +++ b/server/util.go @@ -48,8 +48,44 @@ func readQueryParam(r *http.Request, names ...string) string { return "" } -func logMessagePrefix(v *visitor, m *message) string { - return fmt.Sprintf("%s/%s/%s", v.String(), m.Topic, m.ID) +func logr(r *http.Request) *log.Event { + return log.Fields(logFieldsHTTP(r)) +} + +func logv(v *visitor) *log.Event { + return log.Context(v) +} + +func logvr(v *visitor, r *http.Request) *log.Event { + return logv(v).Fields(logFieldsHTTP(r)) +} + +func logvrm(v *visitor, r *http.Request, m *message) *log.Event { + return logvr(v, r).Context(m) +} + +func logvm(v *visitor, m *message) *log.Event { + return logv(v).Context(m) +} + +func logem(state *smtp.ConnectionState) *log.Event { + return log. + Tag(tagSMTP). + Fields(map[string]any{ + "smtp_hostname": state.Hostname, + "smtp_remote_addr": state.RemoteAddr.String(), + }) +} + +func logFieldsHTTP(r *http.Request) map[string]any { + requestURI := r.RequestURI + if requestURI == "" { + requestURI = r.URL.Path + } + return map[string]any{ + "http_method": r.Method, + "http_path": requestURI, + } } func logHTTPPrefix(v *visitor, r *http.Request) string { @@ -67,10 +103,6 @@ func logStripePrefix(customerID, subscriptionID string) string { return fmt.Sprintf("STRIPE %s", customerID) } -func logSMTPPrefix(state *smtp.ConnectionState) string { - return fmt.Sprintf("SMTP %s/%s", state.Hostname, state.RemoteAddr.String()) -} - func renderHTTPRequest(r *http.Request) string { peekLimit := 4096 lines := fmt.Sprintf("%s %s %s\n", r.Method, r.URL.RequestURI(), r.Proto) diff --git a/server/visitor.go b/server/visitor.go index 88ec91bb..444e576a 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -150,6 +150,25 @@ func (v *visitor) stringNoLock() string { return v.ip.String() } +func (v *visitor) Context() map[string]any { + v.mu.Lock() + defer v.mu.Unlock() + fields := map[string]any{ + "visitor_ip": v.ip.String(), + } + if v.user != nil { + fields["user_id"] = v.user.ID + fields["user_name"] = v.user.Name + if v.user.Billing.StripeCustomerID != "" { + fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID + } + if v.user.Billing.StripeSubscriptionID != "" { + fields["stripe_subscription_id"] = v.user.Billing.StripeSubscriptionID + } + } + return fields +} + func (v *visitor) RequestAllowed() error { v.mu.Lock() // limiters could be replaced! defer v.mu.Unlock() @@ -254,12 +273,6 @@ func (v *visitor) User() *user.User { return v.user // May be nil } -// Admin returns true if the visitor is a user, and an admin -func (v *visitor) Admin() bool { - u := v.User() - return u != nil && u.Role == user.RoleAdmin -} - // IP returns the visitor IP address func (v *visitor) IP() netip.Addr { v.mu.Lock() @@ -297,7 +310,7 @@ func (v *visitor) MaybeUserID() string { } func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) { - log.Debug("%s Resetting limiters for visitor", v.stringNoLock()) + log.Context(v).Debug("%s Resetting limiters for visitor", v.stringNoLock()) limits := v.limitsNoLock() v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst) v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages)