diff --git a/cmd/serve.go b/cmd/serve.go index 3cc01143..be772e6b 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -306,32 +306,31 @@ func sigHandlerConfigReload(config string) { } func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { - //try parsing as prefix - prefix, err := netip.ParsePrefix(host) - if err == nil { - prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter - return prefixes, nil // success - } + //try parsing as prefix + prefix, err := netip.ParsePrefix(host) + if err == nil { + prefixes = append(prefixes, prefix.Masked()) // Masked returns the prefix in its canonical form, the same for every ip in the range. This exists for ease of debugging. For example, 10.1.2.3/16 is 10.1.0.0/16. + return prefixes, nil // success + } - // not a prefix, parse as host or IP - // LookupHost forwards through if it's an IP - ips, err := net.LookupHost(host) - if err == nil { - for _, i := range ips { - ip, err := netip.ParseAddr(i) - if err == nil { - prefix, err := ip.Prefix(ip.BitLen()) - if err != nil { - return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) - } - prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip - } - } - } - return + // not a prefix, parse as host or IP + // LookupHost forwards through if it's an IP + ips, err := net.LookupHost(host) + if err == nil { + for _, i := range ips { + ip, err := netip.ParseAddr(i) + if err == nil { + prefix, err := ip.Prefix(ip.BitLen()) + if err != nil { + return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) + } + prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip + } + } + } + return } - func reloadLogLevel(inputSource altsrc.InputSourceContext) { newLevelStr, err := inputSource.String("log-level") if err != nil { diff --git a/server/message_cache.go b/server/message_cache.go index 4845a918..f4433399 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -456,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) { return nil, err } } + senderIP, err := netip.ParseAddr(sender) + if err != nil { + senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0 + } + var att *attachment if attachmentName != "" && attachmentURL != "" { att = &attachment{ @@ -479,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { Icon: icon, Actions: actions, Attachment: att, - Sender: netip.MustParseAddr(sender), // Must parse assuming database must be correct + Sender: senderIP, // Must parse assuming database must be correct Encoding: encoding, }) } diff --git a/server/server.go b/server/server.go index 0b9cb21a..6b801739 100644 --- a/server/server.go +++ b/server/server.go @@ -643,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return false, false, "", false, errHTTPBadRequestDelayTooLarge } m.Time = delay.Unix() + m.Sender = v.ip // Important for rate limiting } - m.Sender = v.ip // Important for rate limiting actionsStr := readParam(r, "x-actions", "actions", "action") if actionsStr != "" { m.Actions, err = parseActions(actionsStr) @@ -1220,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() { if s.firebaseClient == nil { return } - v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor + v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0 for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): @@ -1287,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) { + if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if err := v.RequestAllowed(); err != nil { return errHTTPTooManyRequestsLimitRequests @@ -1449,8 +1449,8 @@ func (s *Server) visitor(r *http.Request) *visitor { ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) if err != nil { - log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err) - // fall back to regular remote address if x forwarded for is damaged + log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error()) + // fall back to regular remote address if X-Forwarded-For is damaged } else { ip = myip } diff --git a/util/util.go b/util/util.go index de4b908f..86566edf 100644 --- a/util/util.go +++ b/util/util.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "math/rand" + "net/netip" "os" "regexp" "strconv" @@ -46,8 +47,8 @@ func Contains[T comparable](haystack []T, needle T) bool { return false } -// ContainsContains returns true if any element of haystack .Contains(needle). -func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool { +// ContainsIP returns true if any one of the of prefixes contains the ip. +func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool { for _, s := range haystack { if s.Contains(needle) { return true