diff --git a/log/event.go b/log/event.go index a8d35c26..284c879e 100644 --- a/log/event.go +++ b/log/event.go @@ -13,88 +13,95 @@ import ( const ( tagField = "tag" errorField = "error" + exitCodeField = "exit_code" timestampFormat = "2006-01-02T15:04:05.999Z07:00" ) // Event represents a single log event type Event struct { - Timestamp string `json:"time"` - Level Level `json:"level"` - Message string `json:"message"` - fields Context + Timestamp string `json:"time"` + Level Level `json:"level"` + Message string `json:"message"` + time time.Time + contexters []Contexter + fields Context } // newEvent creates a new log event +// +// We delay allocations and processing for efficiency, because most log events +// are never actually rendered, so we don't format the time, or allocate a fields map. func newEvent() *Event { - now := time.Now() return &Event{ - Timestamp: now.Format(timestampFormat), - fields: make(Context), + time: time.Now(), } } // Fatal logs the event as FATAL, and exits the program with exit code 1 func (e *Event) Fatal(message string, v ...any) { - e.Field("exit_code", 1).Log(FatalLevel, message, v...) + e.Field(exitCodeField, 1).maybeLog(FatalLevel, message, v...) fmt.Fprintf(os.Stderr, message+"\n", v...) // Always output error to stderr os.Exit(1) } // Error logs the event with log level error func (e *Event) Error(message string, v ...any) { - e.Log(ErrorLevel, message, v...) + e.maybeLog(ErrorLevel, message, v...) } // Warn logs the event with log level warn func (e *Event) Warn(message string, v ...any) { - e.Log(WarnLevel, message, v...) + e.maybeLog(WarnLevel, message, v...) } // Info logs the event with log level info func (e *Event) Info(message string, v ...any) { - e.Log(InfoLevel, message, v...) + e.maybeLog(InfoLevel, message, v...) } // Debug logs the event with log level debug func (e *Event) Debug(message string, v ...any) { - e.Log(DebugLevel, message, v...) + e.maybeLog(DebugLevel, message, v...) } // Trace logs the event with log level trace func (e *Event) Trace(message string, v ...any) { - e.Log(TraceLevel, message, v...) + e.maybeLog(TraceLevel, message, v...) } // Tag adds a "tag" field to the log event func (e *Event) Tag(tag string) *Event { - e.fields[tagField] = tag - return e + return e.Field(tagField, tag) } // Time sets the time field func (e *Event) Time(t time.Time) *Event { - e.Timestamp = t.Format(timestampFormat) + e.time = t return e } // Err adds an "error" field to the log event func (e *Event) Err(err error) *Event { if c, ok := err.(Contexter); ok { - e.Fields(c.Context()) - } else { - e.fields[errorField] = err.Error() + return e.Fields(c.Context()) } - return e + return e.Field(errorField, err.Error()) } // Field adds a custom field and value to the log event func (e *Event) Field(key string, value any) *Event { + if e.fields == nil { + e.fields = make(Context) + } e.fields[key] = value return e } // Fields adds a map of fields to the log event func (e *Event) Fields(fields Context) *Event { + if e.fields == nil { + e.fields = make(Context) + } for k, v := range fields { e.fields[k] = v } @@ -103,22 +110,36 @@ func (e *Event) Fields(fields Context) *Event { // With adds the fields of the given Contexter structs to the log event by calling their With method func (e *Event) With(contexts ...Contexter) *Event { - for _, c := range contexts { - e.Fields(c.Context()) + if e.contexters == nil { + e.contexters = contexts + } else { + e.contexters = append(e.contexters, contexts...) } return e } -// Log logs a message with the given log level -func (e *Event) Log(l Level, message string, v ...any) { +// maybeLog logs the event to the defined output. The event is only logged, if +// either the global log level is >= l, or if the log level in one of the overrides matches +// the level. +// +// If no overrides are defined (default), the Contexter array is not applied unless the event +// is actually logged. If overrides are defined, then Contexters have to be applied in any case +// to determine if they match. This is super complicated, but required for efficiency. +func (e *Event) maybeLog(l Level, message string, v ...any) { + appliedContexters := e.maybeApplyContexters() + if !e.shouldLog(l) { + return + } e.Message = fmt.Sprintf(message, v...) e.Level = l - if e.shouldPrint() { - if CurrentFormat() == JSONFormat { - log.Println(e.JSON()) - } else { - log.Println(e.String()) - } + e.Timestamp = e.time.Format(timestampFormat) + if !appliedContexters { + e.applyContexters() + } + if CurrentFormat() == JSONFormat { + log.Println(e.JSON()) + } else { + log.Println(e.String()) } } @@ -161,14 +182,17 @@ func (e *Event) String() string { return fmt.Sprintf("%s %s (%s)", e.Level.String(), e.Message, strings.Join(fields, ", ")) } -func (e *Event) shouldPrint() bool { - return e.globalLevelWithOverride() <= e.Level +func (e *Event) shouldLog(l Level) bool { + return e.globalLevelWithOverride() <= l } func (e *Event) globalLevelWithOverride() Level { mu.Lock() l, ov := level, overrides mu.Unlock() + if e.fields == nil { + return l + } for field, override := range ov { value, exists := e.fields[field] if exists && value == override.value { @@ -177,3 +201,19 @@ func (e *Event) globalLevelWithOverride() Level { } return l } + +func (e *Event) maybeApplyContexters() bool { + mu.Lock() + hasOverrides := len(overrides) > 0 + mu.Unlock() + if hasOverrides { + e.applyContexters() + } + return hasOverrides // = applied +} + +func (e *Event) applyContexters() { + for _, c := range e.contexters { + e.Fields(c.Context()) + } +} diff --git a/log/log_test.go b/log/log_test.go index 358c6027..edc9ee1b 100644 --- a/log/log_test.go +++ b/log/log_test.go @@ -1,9 +1,8 @@ -package log_test +package log import ( "bytes" "github.com/stretchr/testify/require" - "heckel.io/ntfy/log" "os" "testing" "time" @@ -12,7 +11,7 @@ import ( func TestMain(m *testing.M) { exitCode := m.Run() resetState() - log.SetLevel(log.ErrorLevel) // For other modules! + SetLevel(ErrorLevel) // For other modules! os.Exit(exitCode) } @@ -27,22 +26,21 @@ func TestLog_TagContextFieldFields(t *testing.T) { Message: "some error", } var out bytes.Buffer - log.SetOutput(&out) - log.SetFormat(log.JSONFormat) - log.SetLevelOverride("tag", "stripe", log.DebugLevel) + SetOutput(&out) + SetFormat(JSONFormat) + SetLevelOverride("tag", "stripe", DebugLevel) - log. - Tag("mytag"). + Tag("mytag"). Field("field2", 123). Field("field1", "value1"). Time(time.Unix(123, 999000000).UTC()). Info("hi there %s", "phil") - log. - Tag("not-stripe"). + + Tag("not-stripe"). Debug("this message will not appear") - log. - With(v). - Fields(log.Context{ + + With(v). + Fields(Context{ "stripe_customer_id": "acct_123", "stripe_subscription_id": "sub_123", }). @@ -57,6 +55,76 @@ func TestLog_TagContextFieldFields(t *testing.T) { require.Equal(t, expected, out.String()) } +func TestLog_NoAllocIfNotPrinted(t *testing.T) { + t.Cleanup(resetState) + v := &fakeVisitor{ + UserID: "u_abc", + IP: "1.2.3.4", + } + + var out bytes.Buffer + SetOutput(&out) + SetFormat(JSONFormat) + + // Do not log, do not call contexters (because global level is INFO) + v.contextCalled = false + ev := With(v) + ev.Debug("some message") + require.False(t, v.contextCalled) + require.Equal(t, "", ev.Timestamp) + require.Equal(t, Level(0), ev.Level) + require.Equal(t, "", ev.Message) + require.Nil(t, ev.fields) + + // Logged because info level, contexters called + v.contextCalled = false + ev = With(v).Time(time.Unix(1111, 0).UTC()) + ev.Info("some message") + require.True(t, v.contextCalled) + require.NotNil(t, ev.fields) + require.Equal(t, "1.2.3.4", ev.fields["visitor_ip"]) + + // Not logged, but contexters called, because overrides exist + SetLevel(DebugLevel) + SetLevelOverride("tag", "overridetag", TraceLevel) + v.contextCalled = false + ev = Tag("sometag").Field("field", "value").With(v).Time(time.Unix(123, 0).UTC()) + ev.Trace("some debug message") + require.True(t, v.contextCalled) // If there are overrides, we must call the context to determine the filter fields + require.Equal(t, "", ev.Timestamp) + require.Equal(t, Level(0), ev.Level) + require.Equal(t, "", ev.Message) + require.Equal(t, 4, len(ev.fields)) + require.Equal(t, "value", ev.fields["field"]) + require.Equal(t, "sometag", ev.fields["tag"]) + + // Logged because of override tag, and contexters called + v.contextCalled = false + ev = Tag("overridetag").Field("field", "value").With(v).Time(time.Unix(123, 0).UTC()) + ev.Trace("some trace message") + require.True(t, v.contextCalled) + require.Equal(t, "1970-01-01T00:02:03Z", ev.Timestamp) + require.Equal(t, TraceLevel, ev.Level) + require.Equal(t, "some trace message", ev.Message) + + // Logged because of field override, and contexters called + ResetLevelOverrides() + SetLevelOverride("visitor_ip", "1.2.3.4", TraceLevel) + v.contextCalled = false + ev = With(v).Time(time.Unix(124, 0).UTC()) + ev.Trace("some trace message with override") + require.True(t, v.contextCalled) + require.Equal(t, "1970-01-01T00:02:04Z", ev.Timestamp) + require.Equal(t, TraceLevel, ev.Level) + require.Equal(t, "some trace message with override", ev.Message) + + expected := `{"time":"1970-01-01T00:18:31Z","level":"INFO","message":"some message","user_id":"u_abc","visitor_ip":"1.2.3.4"} +{"time":"1970-01-01T00:02:03Z","level":"TRACE","message":"some trace message","field":"value","tag":"overridetag","user_id":"u_abc","visitor_ip":"1.2.3.4"} +{"time":"1970-01-01T00:02:04Z","level":"TRACE","message":"some trace message with override","user_id":"u_abc","visitor_ip":"1.2.3.4"} +` + require.Equal(t, expected, out.String()) +} + type fakeError struct { Code int Message string @@ -66,28 +134,30 @@ func (e fakeError) Error() string { return e.Message } -func (e fakeError) Context() log.Context { - return log.Context{ +func (e fakeError) Context() Context { + return Context{ "error": e.Message, "error_code": e.Code, } } type fakeVisitor struct { - UserID string - IP string + UserID string + IP string + contextCalled bool } -func (v *fakeVisitor) Context() log.Context { - return log.Context{ +func (v *fakeVisitor) Context() Context { + v.contextCalled = true + return Context{ "user_id": v.UserID, "visitor_ip": v.IP, } } func resetState() { - log.SetLevel(log.DefaultLevel) - log.SetFormat(log.DefaultFormat) - log.SetOutput(log.DefaultOutput) - log.ResetLevelOverrides() + SetLevel(DefaultLevel) + SetFormat(DefaultFormat) + SetOutput(DefaultOutput) + ResetLevelOverrides() } diff --git a/log/types.go b/log/types.go index 75d78179..51f6fef8 100644 --- a/log/types.go +++ b/log/types.go @@ -94,7 +94,6 @@ func ToFormat(s string) Format { // Contexter allows structs to export a key-value pairs in the form of a Context type Contexter interface { - // Context returns the object context as key-value pairs Context() Context } diff --git a/server/server.go b/server/server.go index 81b4b78b..af7f624a 100644 --- a/server/server.go +++ b/server/server.go @@ -556,7 +556,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) return err } bandwidthVisitor = s.visitor(v.IP(), u) - } else if m.Sender != netip.IPv4Unspecified() { + } else if m.Sender.IsValid() { bandwidthVisitor = s.visitor(m.Sender, nil) } if !bandwidthVisitor.BandwidthAllowed(stat.Size()) { @@ -599,6 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) } + m.Sender = v.IP() m.User = v.MaybeUserID() m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix() if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { @@ -792,7 +793,6 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return false, false, "", false, errHTTPBadRequestDelayTooLarge } m.Time = delay.Unix() - m.Sender = v.ip // Important for rate limiting } actionsStr := readParam(r, "x-actions", "actions", "action") if actionsStr != "" { @@ -896,7 +896,6 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, m.Attachment = &attachment{} } var ext string - m.Sender = v.ip // Important for attachment rate limiting m.Attachment.Expires = attachmentExpiry m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name) m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext) diff --git a/server/types.go b/server/types.go index 0b034206..c6331359 100644 --- a/server/types.go +++ b/server/types.go @@ -51,7 +51,7 @@ func (m *message) Context() log.Context { "message_topic": m.Topic, "message_body_size": len(m.Message), } - if m.Sender != netip.IPv4Unspecified() { + if m.Sender.IsValid() { fields["message_sender"] = m.Sender.String() } if m.User != "" {