Logging WIP

This commit is contained in:
binwiederhier 2023-02-04 21:26:01 -05:00
parent a6641980c2
commit 5d6051c490
11 changed files with 108 additions and 124 deletions

View file

@ -26,6 +26,8 @@ func TestCLI_Access_Grant_And_Publish(t *testing.T) {
stdin.WriteString("philpass\nphilpass\nbenpass\nbenpass")
require.Nil(t, runUserCommand(app, conf, "add", "--role=admin", "phil"))
require.Nil(t, runUserCommand(app, conf, "add", "ben"))
app, stdin, _, _ = newTestApp()
require.Nil(t, runAccessCommand(app, conf, "ben", "announcements", "rw"))
require.Nil(t, runAccessCommand(app, conf, "ben", "sometopic", "read"))
require.Nil(t, runAccessCommand(app, conf, "everyone", "announcements", "read"))

View file

@ -76,7 +76,7 @@ func (e *Event) Fields(fields map[string]any) *Event {
return e
}
func (e *Event) Context(contexts ...Ctx) *Event {
func (e *Event) Context(contexts ...Contexter) *Event {
for _, c := range contexts {
e.Fields(c.Context())
}

View file

@ -42,7 +42,7 @@ func Trace(message string, v ...any) {
newEvent().Trace(message, v...)
}
func Context(contexts ...Ctx) *Event {
func Context(contexts ...Contexter) *Event {
return newEvent().Context(contexts...)
}

View file

@ -91,7 +91,7 @@ func ToFormat(s string) Format {
}
}
type Ctx interface {
type Contexter interface {
Context() map[string]any
}
@ -101,7 +101,7 @@ func (f fieldsCtx) Context() map[string]any {
return f
}
func NewCtx(fields map[string]any) Ctx {
func NewCtx(fields map[string]any) Contexter {
return fieldsCtx(fields)
}

View file

@ -149,6 +149,7 @@ const (
tagManager = "manager"
tagResetter = "resetter"
tagWebsocket = "websocket"
tagMatrix = "matrix"
)
// New instantiates a new Server. It creates the cache and adds a Firebase
@ -328,9 +329,9 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
if websocket.IsWebSocketUpgrade(r) {
isNormalError := strings.Contains(err.Error(), "i/o timeout")
if isNormalError {
logvr(v, r).Tag(tagWebsocket).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
logvr(v, r).Tag(tagWebsocket).Err(err).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
} else {
logvr(v, r).Tag(tagWebsocket).Info("WebSocket error: %s", err.Error())
logvr(v, r).Tag(tagWebsocket).Err(err).Info("WebSocket error: %s", err.Error())
}
return // Do not attempt to write to upgraded connection
}
@ -711,7 +712,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
logvm(v, m).Err(err).Warn("Unable to publish poll request")
return
} else if response.StatusCode != http.StatusOK {
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d")
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d", response.StatusCode)
return
}
}
@ -1537,6 +1538,7 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if err := v.RequestAllowed(); err != nil {
logvr(v, r).Err(err).Fields(requestLimiterFields(v.RequestLimiter())).Trace("Request not allowed by rate limiter")
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
@ -1601,6 +1603,7 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
if err != nil {
logvr(v, r).Tag(tagMatrix).Err(err).Trace("Invalid Matrix request")
return err
}
if err := next(w, newRequest, v); err != nil {
@ -1630,7 +1633,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
u := v.User()
for _, t := range topics {
if err := s.userManager.Authorize(u, t.ID, perm); err != nil {
logvr(v, r).Err(err).Debug("Unauthorized")
logvr(v, r).Err(err).Field("message_topic", t.ID).Debug("Access to topic %s not authorized", t.ID)
return errHTTPForbidden
}
}
@ -1644,7 +1647,7 @@ func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
ip := extractIPAddress(r, s.config.BehindProxy)
var u *user.User // may stay nil if no auth header!
if u, err = s.authenticate(r); err != nil {
logr(r).Debug("Authentication failed: %s", err.Error())
logr(r).Err(err).Debug("Authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
}
v = s.visitor(ip, u)

View file

@ -160,7 +160,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
return err
}
}
if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil {
if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, 0); err != nil {
return err
}
logvr(v, r).Tag(tagAccount).Info("Marking user %s as deleted", u.Name)
@ -462,18 +462,19 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
// maybeRemoveMessagesAndExcessReservations deletes topic reservations for the given user (if too many for tier),
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
} else if int64(len(reservations)) <= reservationsLimit {
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
return nil
}
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"io"
"net/http"
@ -147,7 +146,7 @@ func writeMatrixDiscoveryResponse(w http.ResponseWriter) error {
// writeMatrixError logs and writes the errMatrix to the given http.ResponseWriter as a matrixResponse
func writeMatrixError(w http.ResponseWriter, r *http.Request, v *visitor, err *errMatrix) error {
log.Debug("%s Matrix gateway error: %s", logHTTPPrefix(v, r), err.Error())
logvr(v, r).Tag(tagMatrix).Err(err).Debug("Matrix gateway error")
return writeMatrixResponse(w, err.pushKey)
}

View file

@ -2,7 +2,6 @@ package server
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/stripe/stripe-go/v74"
@ -121,7 +120,13 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
} else if tier.StripePriceID == "" {
return errNotAPaidTier
}
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
logvr(v, r).
Tag(tagPay).
Fields(map[string]any{
"tier": tier,
"stripe_price_id": tier.StripePriceID,
}).
Info("Creating Stripe checkout flow")
var stripeCustomerID *string
if u.Billing.StripeCustomerID != "" {
stripeCustomerID = &u.Billing.StripeCustomerID
@ -190,6 +195,18 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
return err
}
v.SetUser(u)
logvr(v, r).
Tag(tagPay).
Fields(map[string]any{
"tier_id": tier.ID,
"tier_name": tier.Name,
"stripe_price_id": tier.StripePriceID,
"stripe_customer_id": sess.Customer.ID,
"stripe_subscription_id": sub.ID,
"stripe_subscription_status": string(sub.Status),
"stripe_subscription_paid_until": sub.CurrentPeriodEnd,
}).
Info("Stripe checkout flow succeeded, updating user tier and subscription")
customerParams := &stripe.CustomerParams{
Params: stripe.Params{
Metadata: map[string]string{
@ -201,7 +218,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
return err
}
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -223,7 +240,15 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID)
logvr(v, r).
Tag(tagPay).
Fields(map[string]any{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
// Other stripe_* fields filled by visitor context
}).
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
if err != nil {
return err
@ -250,8 +275,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
logvr(v, r).Tag(tagPay).Info("Deleting Stripe subscription")
u := v.User()
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
if u.Billing.StripeSubscriptionID != "" {
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
@ -267,11 +292,11 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
logvr(v, r).Tag(tagPay).Info("Creating Stripe billing portal session")
u := v.User()
if u.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(u.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
@ -289,7 +314,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error {
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, v *visitor) error {
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestBillingRequestInvalid
@ -308,74 +333,105 @@ func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Requ
}
switch event.Type {
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
return s.handleAccountBillingWebhookSubscriptionUpdated(r, v, event)
case "customer.subscription.deleted":
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
return s.handleAccountBillingWebhookSubscriptionDeleted(r, v, event)
default:
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
logvr(v, r).
Tag(tagPay).
Field("stripe_webhook_type", event.Type).
Warn("Unhandled Stripe webhook event %s received", event.Type)
return nil
}
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request, v *visitor, event stripe.Event) error {
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
if err != nil {
return err
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
return errHTTPBadRequestBillingRequestInvalid
}
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
logvr(v, r).
Tag(tagPay).
Fields(map[string]any{
"stripe_webhook_type": event.Type,
"stripe_customer_id": ev.Customer,
"stripe_subscription_id": ev.ID,
"stripe_subscription_status": ev.Status,
"stripe_subscription_paid_until": ev.CurrentPeriodEnd,
"stripe_subscription_cancel_at": ev.CancelAt,
"stripe_price_id": priceID,
}).
Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
userFn := func() (*user.User, error) {
return s.userManager.UserByStripeCustomer(ev.Customer)
}
// We retry the user retrieval function, because during the Stripe checkout, there a race between the browser
// checkout success redirect (see handleAccountBillingSubscriptionCreateSuccess), and this webhook. The checkout
// success call is the one that updates the user with the Stripe customer ID.
u, err := util.Retry[user.User](userFn, retryUserDelays...)
if err != nil {
return err
}
v.SetUser(u)
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil
}
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request, v *visitor, event stripe.Event) error {
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
if err != nil {
return err
} else if ev.Customer == "" {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
v.SetUser(u)
logvr(v, r).
Tag(tagPay).
Field("stripe_webhook_type", event.Type).
Info("Subscription deleted, downgrading to unpaid tier")
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil
}
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationLimit
}
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, reservationsLimit); err != nil {
return err
}
if tier == nil {
if tier == nil && u.Tier != nil {
logvr(v, r).Tag(tagPay).Info("Resetting tier for user %s", u.Name)
if err := s.userManager.ResetTier(u.Name); err != nil {
return err
}
} else {
} else if tier != nil && u.TierID() != tier.ID {
logvr(v, r).
Tag(tagPay).
Fields(map[string]any{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
}).
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}

View file

@ -70,7 +70,7 @@ func (s *smtpSession) AuthPlain(username, password string) error {
}
func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error {
logem(s.state).Debug("%s MAIL FROM: %s (with options: %#v)", from, opts)
logem(s.state).Debug("MAIL FROM: %s (with options: %#v)", from, opts)
return nil
}

View file

@ -1,15 +1,12 @@
package server
import (
"fmt"
"github.com/emersion/go-smtp"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"io"
"net/http"
"net/netip"
"strings"
"unicode/utf8"
)
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@ -48,90 +45,6 @@ func readQueryParam(r *http.Request, names ...string) string {
return ""
}
func logr(r *http.Request) *log.Event {
return log.Fields(logFieldsHTTP(r))
}
func logv(v *visitor) *log.Event {
return log.Context(v)
}
func logvr(v *visitor, r *http.Request) *log.Event {
return logv(v).Fields(logFieldsHTTP(r))
}
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
return logvr(v, r).Context(m)
}
func logvm(v *visitor, m *message) *log.Event {
return logv(v).Context(m)
}
func logem(state *smtp.ConnectionState) *log.Event {
return log.
Tag(tagSMTP).
Fields(map[string]any{
"smtp_hostname": state.Hostname,
"smtp_remote_addr": state.RemoteAddr.String(),
})
}
func logFieldsHTTP(r *http.Request) map[string]any {
requestURI := r.RequestURI
if requestURI == "" {
requestURI = r.URL.Path
}
return map[string]any{
"http_method": r.Method,
"http_path": requestURI,
}
}
func logHTTPPrefix(v *visitor, r *http.Request) string {
requestURI := r.RequestURI
if requestURI == "" {
requestURI = r.URL.Path
}
return fmt.Sprintf("HTTP %s %s %s", v.String(), r.Method, requestURI)
}
func logStripePrefix(customerID, subscriptionID string) string {
if subscriptionID != "" {
return fmt.Sprintf("STRIPE %s/%s", customerID, subscriptionID)
}
return fmt.Sprintf("STRIPE %s", customerID)
}
func renderHTTPRequest(r *http.Request) string {
peekLimit := 4096
lines := fmt.Sprintf("%s %s %s\n", r.Method, r.URL.RequestURI(), r.Proto)
for key, values := range r.Header {
for _, value := range values {
lines += fmt.Sprintf("%s: %s\n", key, value)
}
}
lines += "\n"
body, err := util.Peek(r.Body, peekLimit)
if err != nil {
lines = fmt.Sprintf("(could not read body: %s)\n", err.Error())
} else if utf8.Valid(body.PeekedBytes) {
lines += string(body.PeekedBytes)
if body.LimitReached {
lines += fmt.Sprintf(" ... (peeked %d bytes)", peekLimit)
}
lines += "\n"
} else {
if body.LimitReached {
lines += fmt.Sprintf("(peeked bytes not UTF-8, peek limit of %d bytes reached, hex: %x ...)\n", peekLimit, body.PeekedBytes)
} else {
lines += fmt.Sprintf("(peeked bytes not UTF-8, %d bytes, hex: %x)\n", len(body.PeekedBytes), body.PeekedBytes)
}
}
r.Body = body // Important: Reset body, so it can be re-read
return strings.TrimSpace(lines)
}
func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
remoteAddr := r.RemoteAddr
addrPort, err := netip.ParseAddrPort(remoteAddr)

View file

@ -159,6 +159,10 @@ func (v *visitor) Context() map[string]any {
if v.user != nil {
fields["user_id"] = v.user.ID
fields["user_name"] = v.user.Name
if v.user.Tier != nil {
fields["tier_id"] = v.user.Tier.ID
fields["tier_name"] = v.user.Tier.Name
}
if v.user.Billing.StripeCustomerID != "" {
fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID
}
@ -178,6 +182,12 @@ func (v *visitor) RequestAllowed() error {
return nil
}
func (v *visitor) RequestLimiter() *rate.Limiter {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
return v.requestLimiter
}
func (v *visitor) FirebaseAllowed() error {
v.mu.Lock()
defer v.mu.Unlock()