Email rate limiting + tests

This commit is contained in:
Philipp Heckel 2021-12-24 00:03:04 +01:00
parent 873c57b3d8
commit 7280ae1ebc
7 changed files with 183 additions and 34 deletions

View file

@ -1,4 +1,3 @@
// Package cmd provides the ntfy CLI application
package cmd
import (
@ -22,10 +21,16 @@ var flagsServe = []cli.Flag{
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-addr", EnvVars: []string{"NTFY_SMTP_ADDR"}, Usage: "SMTP address (host:port) to allow email sending"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-user", EnvVars: []string{"NTFY_SMTP_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-pass", EnvVars: []string{"NTFY_SMTP_PASS"}, Usage: "SMTP password (if e-mail sending is enabled)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-from", EnvVars: []string{"NTFY_SMTP_FROM"}, Usage: "SMTP sender address (if e-mail sending is enabled)"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: server.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
}
@ -61,10 +66,16 @@ func execServe(c *cli.Context) error {
cacheDuration := c.Duration("cache-duration")
keepaliveInterval := c.Duration("keepalive-interval")
managerInterval := c.Duration("manager-interval")
smtpAddr := c.String("smtp-addr")
smtpUser := c.String("smtp-user")
smtpPass := c.String("smtp-pass")
smtpFrom := c.String("smtp-from")
globalTopicLimit := c.Int("global-topic-limit")
visitorSubscriptionLimit := c.Int("visitor-subscription-limit")
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
behindProxy := c.Bool("behind-proxy")
// Check values
@ -82,6 +93,8 @@ func execServe(c *cli.Context) error {
return errors.New("if set, certificate file must exist")
} else if listenHTTPS != "" && (keyFile == "" || certFile == "") {
return errors.New("if listen-https is set, both key-file and cert-file must be set")
} else if smtpAddr != "" && (smtpUser == "" || smtpPass == "" || smtpFrom == "") {
return errors.New("if smtp-addr is set, smtp-user, smtp-pass and smtp-from must also be set")
}
// Run server
@ -95,11 +108,16 @@ func execServe(c *cli.Context) error {
conf.CacheDuration = cacheDuration
conf.KeepaliveInterval = keepaliveInterval
conf.ManagerInterval = managerInterval
//XXXXXXXXX
conf.SMTPAddr = smtpAddr
conf.SMTPUser = smtpUser
conf.SMTPPass = smtpPass
conf.SMTPFrom = smtpFrom
conf.GlobalTopicLimit = globalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.BehindProxy = behindProxy
s, err := server.New(conf)
if err != nil {

View file

@ -20,11 +20,14 @@ const (
// Defines all the limits
// - global topic limit: max number of topics overall
// - per visitor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds)
// - per visitor email limit: max number of emails (here: 16 email bucket, replenished at a rate of one per hour)
// - per visitor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
const (
DefaultGlobalTopicLimit = 5000
DefaultVisitorRequestLimitBurst = 60
DefaultVisitorRequestLimitReplenish = 10 * time.Second
DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorSubscriptionLimit = 30
)
@ -51,6 +54,8 @@ type Config struct {
GlobalTopicLimit int
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration
VisitorSubscriptionLimit int
BehindProxy bool
}
@ -75,6 +80,8 @@ func NewConfig() *Config {
GlobalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
BehindProxy: false,
}

35
server/mailer.go Normal file
View file

@ -0,0 +1,35 @@
package server
import (
"fmt"
"net"
"net/smtp"
"strings"
)
type mailer interface {
Send(to string, m *message) error
}
type smtpMailer struct {
config *Config
}
func (s *smtpMailer) Send(to string, m *message) error {
host, _, err := net.SplitHostPort(s.config.SMTPAddr)
if err != nil {
return err
}
subject := m.Title
if subject == "" {
subject = m.Message
}
subject += " - " + m.Topic
subject = strings.ReplaceAll(strings.ReplaceAll(subject, "\r", ""), "\n", " ")
msg := []byte(fmt.Sprintf("From: %s\r\n"+
"To: %s\r\n"+
"Subject: %s\r\n\r\n"+
"%s\r\n", s.config.SMTPFrom, to, subject, m.Message))
auth := smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, host)
return smtp.SendMail(s.config.SMTPAddr, auth, s.config.SMTPFrom, []string{to}, msg)
}

View file

@ -3,7 +3,7 @@ package server
import (
"bytes"
"context"
"embed" // required for go:embed
"embed"
"encoding/json"
firebase "firebase.google.com/go"
"firebase.google.com/go/messaging"
@ -15,7 +15,6 @@ import (
"log"
"net"
"net/http"
"net/smtp"
"regexp"
"strconv"
"strings"
@ -34,6 +33,7 @@ type Server struct {
topics map[string]*topic
visitors map[string]*visitor
firebase subscriber
mailer mailer
messages int64
cache cache
closeChan chan bool
@ -111,6 +111,7 @@ var (
const (
firebaseControlTopic = "~control" // See Android if changed
emptyMessageBody = "triggered"
)
// New instantiates a new Server. It creates the cache and adds a Firebase
@ -124,6 +125,10 @@ func New(conf *Config) (*Server, error) {
return nil, err
}
}
var mailer mailer
if conf.SMTPAddr != "" {
mailer = &smtpMailer{config: conf}
}
cache, err := createCache(conf)
if err != nil {
return nil, err
@ -136,6 +141,7 @@ func New(conf *Config) (*Server, error) {
config: conf,
cache: cache,
firebase: firebaseSubscriber,
mailer: mailer,
topics: topics,
visitors: make(map[string]*visitor),
}, nil
@ -189,23 +195,6 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) {
}, nil
}
func (s *Server) sendMail(to string, m *message) error {
host, _, err := net.SplitHostPort(s.config.SMTPAddr)
if err != nil {
return err
}
subject := m.Title
if subject == "" {
subject = m.Message
}
msg := []byte(fmt.Sprintf("From: %s\r\n"+
"To: %s\r\n"+
"Subject: %s\r\n\r\n"+
"%s\r\n", s.config.SMTPFrom, to, subject, m.Message))
auth := smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, host)
return smtp.SendMail(s.config.SMTPAddr, auth, s.config.SMTPFrom, []string{to}, msg)
}
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
// a manager go routine to print stats and prune messages.
func (s *Server) Run() error {
@ -314,7 +303,7 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error {
return nil
}
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visitor) error {
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
t, err := s.topicFromPath(r.URL.Path)
if err != nil {
return err
@ -329,8 +318,16 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito
if err != nil {
return err
}
if email != "" {
if err := v.EmailAllowed(); err != nil {
return err
}
}
if s.mailer == nil && email != "" {
return errHTTPBadRequest
}
if m.Message == "" {
m.Message = "triggered"
m.Message = emptyMessageBody
}
delayed := m.Time > time.Now().Unix()
if !delayed {
@ -345,9 +342,9 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito
}
}()
}
if s.config.SMTPAddr != "" && email != "" && !delayed {
if s.mailer != nil && email != "" && !delayed {
go func() {
if err := s.sendMail(email, m); err != nil {
if err := s.mailer.Send(email, m); err != nil {
log.Printf("Unable to send email: %v", err.Error())
}
}()
@ -369,7 +366,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito
func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) {
cache = readParam(r, "x-cache", "cache") != "no"
firebase = readParam(r, "x-firebase", "firebase") != "no"
email = readParam(r, "x-email", "email", "mail", "e")
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
m.Title = readParam(r, "x-title", "title", "t")
messageStr := readParam(r, "x-message", "message", "m")
if messageStr != "" {
@ -391,6 +388,9 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase
if !cache {
return false, false, "", errHTTPBadRequest
}
if email != "" {
return false, false, "", errHTTPBadRequest // we cannot store the email address (yet)
}
delay, err := util.ParseFutureTime(delayStr, time.Now())
if err != nil {
return false, false, "", errHTTPBadRequest
@ -740,7 +740,7 @@ func (s *Server) sendDelayedMessages() error {
log.Printf("unable to publish to Firebase: %v", err.Error())
}
}
// FIXME delayed email
// TODO delayed email sending
}
if err := s.cache.MarkPublished(m); err != nil {
return err

View file

@ -61,6 +61,13 @@
# visitor-request-limit-burst: 60
# visitor-request-limit-replenish: 10s
# Rate limiting: Allowed emails per visitor:
# - visitor-email-limit-burst is the initial bucket of emails each visitor has
# - visitor-email-limit-replenish is the rate at which the bucket is refilled
#
# visitor-email-limit-burst: 16
# visitor-email-limit-replenish: 1h
# If set, the X-Forwarded-For header is used to determine the visitor IP address
# instead of the remote address of the connection.
#

View file

@ -508,6 +508,76 @@ func TestServer_Curl_Publish_Poll(t *testing.T) {
}
*/
type testMailer struct {
count int
}
func (t *testMailer) Send(to string, m *message) error {
t.count++
return nil
}
func TestServer_PublishTooManyEmails_Defaults(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
s.mailer = &testMailer{}
for i := 0; i < 16; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 200, response.Code)
}
response := request(t, s, "PUT", "/mytopic", "one too many", map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 429, response.Code)
}
func TestServer_PublishTooManyEmails_Replenish(t *testing.T) {
c := newTestConfig(t)
c.VisitorEmailLimitReplenish = 500 * time.Millisecond
s := newTestServer(t, c)
s.mailer = &testMailer{}
for i := 0; i < 16; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 200, response.Code)
}
response := request(t, s, "PUT", "/mytopic", "one too many", map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 429, response.Code)
time.Sleep(510 * time.Millisecond)
response = request(t, s, "PUT", "/mytopic", "this should be okay again too many", map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 200, response.Code)
response = request(t, s, "PUT", "/mytopic", "and bad again", map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 429, response.Code)
}
func TestServer_PublishDelayedEmail_Fail(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
s.mailer = &testMailer{}
response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{
"E-Mail": "test@example.com",
"Delay": "20 min",
})
require.Equal(t, 400, response.Code)
}
func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 400, response.Code)
}
func newTestConfig(t *testing.T) *Config {
conf := NewConfig()
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")

View file

@ -8,13 +8,17 @@ import (
)
const (
visitorExpungeAfter = 30 * time.Minute
// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
// they are replenished faster (typically).
visitorExpungeAfter = 24 * time.Hour
)
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
config *Config
limiter *rate.Limiter
requests *rate.Limiter
emails *rate.Limiter
subscriptions *util.Limiter
seen time.Time
mu sync.Mutex
@ -23,14 +27,22 @@ type visitor struct {
func newVisitor(conf *Config) *visitor {
return &visitor{
config: conf,
limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
seen: time.Now(),
}
}
func (v *visitor) RequestAllowed() error {
if !v.limiter.Allow() {
if !v.requests.Allow() {
return errHTTPTooManyRequests
}
return nil
}
func (v *visitor) EmailAllowed() error {
if !v.emails.Allow() {
return errHTTPTooManyRequests
}
return nil