From c2382d29a1c5e28b4682047f1eeadc78c644d07f Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Wed, 5 Oct 2022 15:42:07 -0500 Subject: [PATCH] refactor visitor IPs and allow exempting IP Ranges Use netip.Addr instead of storing addresses as strings. This requires conversions at the database level and in tests, but is more memory efficient otherwise, and facilitates the following. Parse rate limit exemptions as netip.Prefix. This allows storing IP ranges in the exemption list. Regular IP addresses (entered explicitly or resolved from hostnames) are IPV4/32, denoting a range of one address. --- cmd/serve.go | 37 ++++++++++++++++++++++++++++++---- server/config.go | 5 +++-- server/message_cache.go | 10 +++++---- server/message_cache_test.go | 12 ++++++----- server/server.go | 27 ++++++++++++++++--------- server/server_firebase_test.go | 10 +++++---- server/server_matrix_test.go | 6 ++++-- server/server_test.go | 6 ++++-- server/smtp_sender.go | 2 +- server/types.go | 6 ++++-- server/visitor.go | 12 ++++++----- util/util.go | 15 ++++++++++++-- 12 files changed, 106 insertions(+), 42 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 952c426e..3cc01143 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,16 +5,18 @@ package cmd import ( "errors" "fmt" - "heckel.io/ntfy/log" "io/fs" "math" "net" + "net/netip" "os" "os/signal" "strings" "syscall" "time" + "heckel.io/ntfy/log" + "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "heckel.io/ntfy/server" @@ -208,15 +210,15 @@ func execServe(c *cli.Context) error { } // Resolve hosts - visitorRequestLimitExemptIPs := make([]string, 0) + visitorRequestLimitExemptIPs := make([]netip.Prefix, 0) for _, host := range visitorRequestLimitExemptHosts { - ips, err := net.LookupIP(host) + ips, err := parseIPHostPrefix(host) if err != nil { log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) continue } for _, ip := range ips { - visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String()) + visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip) } } @@ -303,6 +305,33 @@ func sigHandlerConfigReload(config string) { } } +func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { + //try parsing as prefix + prefix, err := netip.ParsePrefix(host) + if err == nil { + prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter + return prefixes, nil // success + } + + // not a prefix, parse as host or IP + // LookupHost forwards through if it's an IP + ips, err := net.LookupHost(host) + if err == nil { + for _, i := range ips { + ip, err := netip.ParseAddr(i) + if err == nil { + prefix, err := ip.Prefix(ip.BitLen()) + if err != nil { + return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) + } + prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip + } + } + } + return +} + + func reloadLogLevel(inputSource altsrc.InputSourceContext) { newLevelStr, err := inputSource.String("log-level") if err != nil { diff --git a/server/config.go b/server/config.go index e117da88..d8fd429e 100644 --- a/server/config.go +++ b/server/config.go @@ -2,6 +2,7 @@ package server import ( "io/fs" + "net/netip" "time" ) @@ -92,7 +93,7 @@ type Config struct { VisitorAttachmentDailyBandwidthLimit int VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration - VisitorRequestExemptIPAddrs []string + VisitorRequestExemptIPAddrs []netip.Prefix VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration BehindProxy bool @@ -135,7 +136,7 @@ func NewConfig() *Config { VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, - VisitorRequestExemptIPAddrs: make([]string, 0), + VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, BehindProxy: false, diff --git a/server/message_cache.go b/server/message_cache.go index a2f49e75..4845a918 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -5,11 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "net/netip" + "strings" + "time" + _ "github.com/mattn/go-sqlite3" // SQLite driver "heckel.io/ntfy/log" "heckel.io/ntfy/util" - "strings" - "time" ) var ( @@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error { attachmentSize, attachmentExpires, attachmentURL, - m.Sender, + m.Sender.String(), m.Encoding, published, ) @@ -477,7 +479,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { Icon: icon, Actions: actions, Attachment: att, - Sender: sender, + Sender: netip.MustParseAddr(sender), // Must parse assuming database must be correct Encoding: encoding, }) } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 23c080d4..ea9580a5 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -3,11 +3,13 @@ package server import ( "database/sql" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "net/netip" "path/filepath" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSqliteCache_Messages(t *testing.T) { @@ -281,7 +283,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires1 := time.Now().Add(-4 * time.Hour).Unix() m := newDefaultMessage("mytopic", "flower for you") m.ID = "m1" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "flower.jpg", Type: "image/jpeg", @@ -294,7 +296,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires2 := time.Now().Add(2 * time.Hour).Unix() // Future m = newDefaultMessage("mytopic", "sending you a car") m.ID = "m2" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "car.jpg", Type: "image/jpeg", @@ -307,7 +309,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires3 := time.Now().Add(1 * time.Hour).Unix() // Future m = newDefaultMessage("another-topic", "sending you another car") m.ID = "m3" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "another-car.jpg", Type: "image/jpeg", diff --git a/server/server.go b/server/server.go index 276e56fa..0b9cb21a 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "io" "net" "net/http" + "net/netip" "net/url" "os" "path" @@ -42,7 +43,7 @@ type Server struct { smtpServerBackend *smtpBackend smtpSender mailer topics map[string]*topic - visitors map[string]*visitor + visitors map[netip.Addr]*visitor firebaseClient *firebaseClient messages int64 auth auth.Auther @@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) { smtpSender: mailer, topics: topics, auth: auther, - visitors: make(map[string]*visitor), + visitors: make(map[netip.Addr]*visitor), }, nil } @@ -642,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return false, false, "", false, errHTTPBadRequestDelayTooLarge } m.Time = delay.Unix() - m.Sender = v.ip // Important for rate limiting } + m.Sender = v.ip // Important for rate limiting actionsStr := readParam(r, "x-actions", "actions", "action") if actionsStr != "" { m.Actions, err = parseActions(actionsStr) @@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() { if s.firebaseClient == nil { return } - v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor + v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): @@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) { + if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if err := v.RequestAllowed(); err != nil { return errHTTPTooManyRequestsLimitRequests @@ -1436,21 +1437,29 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). func (s *Server) visitor(r *http.Request) *visitor { remoteAddr := r.RemoteAddr - ip, _, err := net.SplitHostPort(remoteAddr) + ipport, err := netip.ParseAddrPort(remoteAddr) + ip := ipport.Addr() if err != nil { - ip = remoteAddr // This should not happen in real life; only in tests. + ip = netip.MustParseAddr(remoteAddr) // This should not happen in real life; only in tests. So, using MustParse, which panics on error. } if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, // only the right-most address can be trusted (as this is the one added by our proxy server). // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") - ip = strings.TrimSpace(util.LastString(ips, remoteAddr)) + myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) + if err != nil { + log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err) + // fall back to regular remote address if x forwarded for is damaged + } else { + ip = myip + } + } return s.visitorFromIP(ip) } -func (s *Server) visitorFromIP(ip string) *visitor { +func (s *Server) visitorFromIP(ip netip.Addr) *visitor { s.mu.Lock() defer s.mu.Unlock() v, exists := s.visitors[ip] diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index 3e034c06..36fd8b51 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -3,13 +3,15 @@ package server import ( "encoding/json" "errors" - "firebase.google.com/go/v4/messaging" "fmt" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/auth" + "net/netip" "strings" "sync" "testing" + + "firebase.google.com/go/v4/messaging" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/auth" ) type testAuther struct { @@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { func TestToFirebaseSender_Abuse(t *testing.T) { sender := &testFirebaseSender{allowed: 2} client := newFirebaseClient(sender, &testAuther{}) - visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") + visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4")) require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) require.Equal(t, 1, len(sender.Messages())) diff --git a/server/server_matrix_test.go b/server/server_matrix_test.go index b2f9b1d5..4b5a66c4 100644 --- a/server/server_matrix_test.go +++ b/server/server_matrix_test.go @@ -1,11 +1,13 @@ package server import ( - "github.com/stretchr/testify/require" "net/http" "net/http/httptest" + "net/netip" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) { @@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) { func TestMatrix_WriteMatrixError(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil) - v := newVisitor(newTestConfig(t), nil, "1.2.3.4") + v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4")) require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch})) require.Equal(t, 200, w.Result().StatusCode) require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String()) diff --git a/server/server_test.go b/server/server_test.go index ea3495d6..5a3dcc8d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,18 +6,20 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/stretchr/testify/assert" "io" "log" "math/rand" "net/http" "net/http/httptest" + "net/netip" "path/filepath" "strings" "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "heckel.io/ntfy/auth" "heckel.io/ntfy/util" @@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) { func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { c := newTestConfig(t) - c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request() + c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request() s := newTestServer(t, c) for i := 0; i < 65; i++ { // > 60 response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) diff --git a/server/smtp_sender.go b/server/smtp_sender.go index ecefd9c2..7d6b7519 100644 --- a/server/smtp_sender.go +++ b/server/smtp_sender.go @@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error { if err != nil { return err } - message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m) + message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m) if err != nil { return err } diff --git a/server/types.go b/server/types.go index b217b9db..ce57c9b5 100644 --- a/server/types.go +++ b/server/types.go @@ -1,9 +1,11 @@ package server import ( - "heckel.io/ntfy/util" "net/http" + "net/netip" "time" + + "heckel.io/ntfy/util" ) // List of possible events @@ -33,7 +35,7 @@ type message struct { Actions []*action `json:"actions,omitempty"` Attachment *attachment `json:"attachment,omitempty"` PollID string `json:"poll_id,omitempty"` - Sender string `json:"-"` // IP address of uploader, used for rate limiting + Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes } diff --git a/server/visitor.go b/server/visitor.go index 5a8e186b..cd120c43 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -2,10 +2,12 @@ package server import ( "errors" - "golang.org/x/time/rate" - "heckel.io/ntfy/util" + "net/netip" "sync" "time" + + "golang.org/x/time/rate" + "heckel.io/ntfy/util" ) const ( @@ -23,7 +25,7 @@ var ( type visitor struct { config *Config messageCache *messageCache - ip string + ip netip.Addr requests *rate.Limiter emails *rate.Limiter subscriptions util.Limiter @@ -40,7 +42,7 @@ type visitorStats struct { VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` } -func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { +func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor { return &visitor{ config: conf, messageCache: messageCache, @@ -115,7 +117,7 @@ func (v *visitor) Stale() bool { } func (v *visitor) Stats() (*visitorStats, error) { - attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) + attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String()) if err != nil { return nil, err } diff --git a/util/util.go b/util/util.go index 05079180..de4b908f 100644 --- a/util/util.go +++ b/util/util.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gabriel-vasile/mimetype" - "golang.org/x/term" "io" "math/rand" "os" @@ -15,6 +13,9 @@ import ( "strings" "sync" "time" + + "github.com/gabriel-vasile/mimetype" + "golang.org/x/term" ) const ( @@ -45,6 +46,16 @@ func Contains[T comparable](haystack []T, needle T) bool { return false } +// ContainsContains returns true if any element of haystack .Contains(needle). +func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool { + for _, s := range haystack { + if s.Contains(needle) { + return true + } + } + return false +} + // ContainsAll returns true if all needles are contained in haystack func ContainsAll[T comparable](haystack []T, needles []T) bool { matches := 0