refactor visitor IPs and allow exempting IP Ranges

Use netip.Addr instead of storing addresses as strings. This requires
conversions at the database level and in tests, but is more memory
efficient otherwise, and facilitates the following.

Parse rate limit exemptions as netip.Prefix. This allows storing IP
ranges in the exemption list. Regular IP addresses (entered explicitly
or resolved from hostnames) are IPV4/32, denoting a range of one
address.
This commit is contained in:
Karmanyaah Malhotra 2022-10-05 15:42:07 -05:00
parent e0ad926ce9
commit c2382d29a1
12 changed files with 106 additions and 42 deletions

View file

@ -5,16 +5,18 @@ package cmd
import ( import (
"errors" "errors"
"fmt" "fmt"
"heckel.io/ntfy/log"
"io/fs" "io/fs"
"math" "math"
"net" "net"
"net/netip"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"heckel.io/ntfy/log"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc" "github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/server" "heckel.io/ntfy/server"
@ -208,15 +210,15 @@ func execServe(c *cli.Context) error {
} }
// Resolve hosts // Resolve hosts
visitorRequestLimitExemptIPs := make([]string, 0) visitorRequestLimitExemptIPs := make([]netip.Prefix, 0)
for _, host := range visitorRequestLimitExemptHosts { for _, host := range visitorRequestLimitExemptHosts {
ips, err := net.LookupIP(host) ips, err := parseIPHostPrefix(host)
if err != nil { if err != nil {
log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
continue continue
} }
for _, ip := range ips { for _, ip := range ips {
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String()) visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip)
} }
} }
@ -303,6 +305,33 @@ func sigHandlerConfigReload(config string) {
} }
} }
func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
//try parsing as prefix
prefix, err := netip.ParsePrefix(host)
if err == nil {
prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter
return prefixes, nil // success
}
// not a prefix, parse as host or IP
// LookupHost forwards through if it's an IP
ips, err := net.LookupHost(host)
if err == nil {
for _, i := range ips {
ip, err := netip.ParseAddr(i)
if err == nil {
prefix, err := ip.Prefix(ip.BitLen())
if err != nil {
return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error()))
}
prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip
}
}
}
return
}
func reloadLogLevel(inputSource altsrc.InputSourceContext) { func reloadLogLevel(inputSource altsrc.InputSourceContext) {
newLevelStr, err := inputSource.String("log-level") newLevelStr, err := inputSource.String("log-level")
if err != nil { if err != nil {

View file

@ -2,6 +2,7 @@ package server
import ( import (
"io/fs" "io/fs"
"net/netip"
"time" "time"
) )
@ -92,7 +93,7 @@ type Config struct {
VisitorAttachmentDailyBandwidthLimit int VisitorAttachmentDailyBandwidthLimit int
VisitorRequestLimitBurst int VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration VisitorRequestLimitReplenish time.Duration
VisitorRequestExemptIPAddrs []string VisitorRequestExemptIPAddrs []netip.Prefix
VisitorEmailLimitBurst int VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration VisitorEmailLimitReplenish time.Duration
BehindProxy bool BehindProxy bool
@ -135,7 +136,7 @@ func NewConfig() *Config {
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorRequestExemptIPAddrs: make([]string, 0), VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
BehindProxy: false, BehindProxy: false,

View file

@ -5,11 +5,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"strings"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"strings"
"time"
) )
var ( var (
@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error {
attachmentSize, attachmentSize,
attachmentExpires, attachmentExpires,
attachmentURL, attachmentURL,
m.Sender, m.Sender.String(),
m.Encoding, m.Encoding,
published, published,
) )
@ -477,7 +479,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Icon: icon, Icon: icon,
Actions: actions, Actions: actions,
Attachment: att, Attachment: att,
Sender: sender, Sender: netip.MustParseAddr(sender), // Must parse assuming database must be correct
Encoding: encoding, Encoding: encoding,
}) })
} }

View file

@ -3,11 +3,13 @@ package server
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/stretchr/testify/assert" "net/netip"
"github.com/stretchr/testify/require"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestSqliteCache_Messages(t *testing.T) { func TestSqliteCache_Messages(t *testing.T) {
@ -281,7 +283,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() expires1 := time.Now().Add(-4 * time.Hour).Unix()
m := newDefaultMessage("mytopic", "flower for you") m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1" m.ID = "m1"
m.Sender = "1.2.3.4" m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "flower.jpg", Name: "flower.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -294,7 +296,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car") m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2" m.ID = "m2"
m.Sender = "1.2.3.4" m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "car.jpg", Name: "car.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -307,7 +309,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car") m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3" m.ID = "m3"
m.Sender = "1.2.3.4" m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "another-car.jpg", Name: "another-car.jpg",
Type: "image/jpeg", Type: "image/jpeg",

View file

@ -11,6 +11,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"os" "os"
"path" "path"
@ -42,7 +43,7 @@ type Server struct {
smtpServerBackend *smtpBackend smtpServerBackend *smtpBackend
smtpSender mailer smtpSender mailer
topics map[string]*topic topics map[string]*topic
visitors map[string]*visitor visitors map[netip.Addr]*visitor
firebaseClient *firebaseClient firebaseClient *firebaseClient
messages int64 messages int64
auth auth.Auther auth auth.Auther
@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) {
smtpSender: mailer, smtpSender: mailer,
topics: topics, topics: topics,
auth: auther, auth: auther,
visitors: make(map[string]*visitor), visitors: make(map[netip.Addr]*visitor),
}, nil }, nil
} }
@ -642,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
return false, false, "", false, errHTTPBadRequestDelayTooLarge return false, false, "", false, errHTTPBadRequestDelayTooLarge
} }
m.Time = delay.Unix() m.Time = delay.Unix()
m.Sender = v.ip // Important for rate limiting
} }
m.Sender = v.ip // Important for rate limiting
actionsStr := readParam(r, "x-actions", "actions", "action") actionsStr := readParam(r, "x-actions", "actions", "action")
if actionsStr != "" { if actionsStr != "" {
m.Actions, err = parseActions(actionsStr) m.Actions, err = parseActions(actionsStr)
@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
if s.firebaseClient == nil { if s.firebaseClient == nil {
return return
} }
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor
for { for {
select { select {
case <-time.After(s.config.FirebaseKeepaliveInterval): case <-time.After(s.config.FirebaseKeepaliveInterval):
@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
func (s *Server) limitRequests(next handleFunc) handleFunc { func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error { return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) { if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v) return next(w, r, v)
} else if err := v.RequestAllowed(); err != nil { } else if err := v.RequestAllowed(); err != nil {
return errHTTPTooManyRequestsLimitRequests return errHTTPTooManyRequestsLimitRequests
@ -1436,21 +1437,29 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor { func (s *Server) visitor(r *http.Request) *visitor {
remoteAddr := r.RemoteAddr remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr) ipport, err := netip.ParseAddrPort(remoteAddr)
ip := ipport.Addr()
if err != nil { if err != nil {
ip = remoteAddr // This should not happen in real life; only in tests. ip = netip.MustParseAddr(remoteAddr) // This should not happen in real life; only in tests. So, using MustParse, which panics on error.
} }
if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
// only the right-most address can be trusted (as this is the one added by our proxy server). // only the right-most address can be trusted (as this is the one added by our proxy server).
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
ip = strings.TrimSpace(util.LastString(ips, remoteAddr)) myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
if err != nil {
log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err)
// fall back to regular remote address if x forwarded for is damaged
} else {
ip = myip
}
} }
return s.visitorFromIP(ip) return s.visitorFromIP(ip)
} }
func (s *Server) visitorFromIP(ip string) *visitor { func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
v, exists := s.visitors[ip] v, exists := s.visitors[ip]

View file

@ -3,13 +3,15 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"firebase.google.com/go/v4/messaging"
"fmt" "fmt"
"github.com/stretchr/testify/require" "net/netip"
"heckel.io/ntfy/auth"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"firebase.google.com/go/v4/messaging"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/auth"
) )
type testAuther struct { type testAuther struct {
@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
func TestToFirebaseSender_Abuse(t *testing.T) { func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2} sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{}) client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.Messages())) require.Equal(t, 1, len(sender.Messages()))

View file

@ -1,11 +1,13 @@
package server package server
import ( import (
"github.com/stretchr/testify/require"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) { func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
func TestMatrix_WriteMatrixError(t *testing.T) { func TestMatrix_WriteMatrixError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil) r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
v := newVisitor(newTestConfig(t), nil, "1.2.3.4") v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch})) require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
require.Equal(t, 200, w.Result().StatusCode) require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String()) require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())

View file

@ -6,18 +6,20 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"io" "io"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/auth" "heckel.io/ntfy/auth"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
c := newTestConfig(t) c := newTestConfig(t)
c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request() c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
s := newTestServer(t, c) s := newTestServer(t, c)
for i := 0; i < 65; i++ { // > 60 for i := 0; i < 65; i++ { // > 60
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)

View file

@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
if err != nil { if err != nil {
return err return err
} }
message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m) message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,9 +1,11 @@
package server package server
import ( import (
"heckel.io/ntfy/util"
"net/http" "net/http"
"net/netip"
"time" "time"
"heckel.io/ntfy/util"
) )
// List of possible events // List of possible events
@ -33,7 +35,7 @@ type message struct {
Actions []*action `json:"actions,omitempty"` Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"` Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"` PollID string `json:"poll_id,omitempty"`
Sender string `json:"-"` // IP address of uploader, used for rate limiting Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
} }

View file

@ -2,10 +2,12 @@ package server
import ( import (
"errors" "errors"
"golang.org/x/time/rate" "net/netip"
"heckel.io/ntfy/util"
"sync" "sync"
"time" "time"
"golang.org/x/time/rate"
"heckel.io/ntfy/util"
) )
const ( const (
@ -23,7 +25,7 @@ var (
type visitor struct { type visitor struct {
config *Config config *Config
messageCache *messageCache messageCache *messageCache
ip string ip netip.Addr
requests *rate.Limiter requests *rate.Limiter
emails *rate.Limiter emails *rate.Limiter
subscriptions util.Limiter subscriptions util.Limiter
@ -40,7 +42,7 @@ type visitorStats struct {
VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
} }
func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor {
return &visitor{ return &visitor{
config: conf, config: conf,
messageCache: messageCache, messageCache: messageCache,
@ -115,7 +117,7 @@ func (v *visitor) Stale() bool {
} }
func (v *visitor) Stats() (*visitorStats, error) { func (v *visitor) Stats() (*visitorStats, error) {
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -5,8 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gabriel-vasile/mimetype"
"golang.org/x/term"
"io" "io"
"math/rand" "math/rand"
"os" "os"
@ -15,6 +13,9 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gabriel-vasile/mimetype"
"golang.org/x/term"
) )
const ( const (
@ -45,6 +46,16 @@ func Contains[T comparable](haystack []T, needle T) bool {
return false return false
} }
// ContainsContains returns true if any element of haystack .Contains(needle).
func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool {
for _, s := range haystack {
if s.Contains(needle) {
return true
}
}
return false
}
// ContainsAll returns true if all needles are contained in haystack // ContainsAll returns true if all needles are contained in haystack
func ContainsAll[T comparable](haystack []T, needles []T) bool { func ContainsAll[T comparable](haystack []T, needles []T) bool {
matches := 0 matches := 0