Firebase quota limit

This commit is contained in:
Philipp Heckel 2022-05-31 20:38:56 -04:00
parent 8a81c8e95b
commit 8283b6be97
9 changed files with 180 additions and 119 deletions

View file

@ -15,6 +15,7 @@ const (
DefaultMaxDelay = 3 * 24 * time.Hour DefaultMaxDelay = 3 * 24 * time.Hour
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs) DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
DefaultFirebaseQuotaLimitPenaltyDuration = 10 * time.Minute
) )
// Defines all global and per-visitor limits // Defines all global and per-visitor limits
@ -69,6 +70,7 @@ type Config struct {
AtSenderInterval time.Duration AtSenderInterval time.Duration
FirebaseKeepaliveInterval time.Duration FirebaseKeepaliveInterval time.Duration
FirebasePollInterval time.Duration FirebasePollInterval time.Duration
FirebaseQuotaLimitPenaltyDuration time.Duration
UpstreamBaseURL string UpstreamBaseURL string
SMTPSenderAddr string SMTPSenderAddr string
SMTPSenderUser string SMTPSenderUser string
@ -121,6 +123,7 @@ func NewConfig() *Config {
AtSenderInterval: DefaultAtSenderInterval, AtSenderInterval: DefaultAtSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval, FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
FirebasePollInterval: DefaultFirebasePollInterval, FirebasePollInterval: DefaultFirebasePollInterval,
FirebaseQuotaLimitPenaltyDuration: DefaultFirebaseQuotaLimitPenaltyDuration,
TotalTopicLimit: DefaultTotalTopicLimit, TotalTopicLimit: DefaultTotalTopicLimit,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit, VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit, VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,

View file

@ -59,6 +59,7 @@ var (
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsFirebaseQuotaReached = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: Firebase quota for topic reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""} errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
) )

View file

@ -7,13 +7,11 @@ import (
"embed" "embed"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"os" "os"
"path" "path"
@ -221,7 +219,7 @@ func (s *Server) Run() error {
} }
s.mu.Unlock() s.mu.Unlock()
go s.runManager() go s.runManager()
go s.runAtSender() go s.runDelayedSender()
go s.runFirebaseKeepaliver() go s.runFirebaseKeepaliver()
return <-errChan return <-errChan
@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
} }
delayed := m.Time > time.Now().Unix() delayed := m.Time > time.Now().Unix()
if !delayed { if !delayed {
if err := t.Publish(m); err != nil { if err := t.Publish(v, m); err != nil {
return err return err
} }
} }
@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
} }
func (s *Server) sendToFirebase(v *visitor, m *message) { func (s *Server) sendToFirebase(v *visitor, m *message) {
if err := s.firebase(m); err != nil { if err := s.firebase(v, m); err != nil {
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error()) log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
} }
} }
@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return err return err
} }
var wlock sync.Mutex var wlock sync.Mutex
sub := func(msg *message) error { sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) { if !filters.Pass(msg) {
return nil return nil
} }
@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll { if poll {
return s.sendOldMessages(topics, since, scheduled, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)
for _, t := range topics { for _, t := range topics {
@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
topics[i].Unsubscribe(subscriberID) // Order! topics[i].Unsubscribe(subscriberID) // Order!
} }
}() }()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err return err
} }
for { for {
@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
v.Keepalive() v.Keepalive()
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err return err
} }
} }
@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
} }
}) })
sub := func(msg *message) error { sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) { if !filters.Pass(msg) {
return nil return nil
} }
@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
if poll { if poll {
return s.sendOldMessages(topics, since, scheduled, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)
for _, t := range topics { for _, t := range topics {
@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
topics[i].Unsubscribe(subscriberID) // Order! topics[i].Unsubscribe(subscriberID) // Order!
} }
}() }()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err return err
} }
err = g.Wait() err = g.Wait()
@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
return return
} }
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error { func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
if since.IsNone() { if since.IsNone() {
return nil return nil
} }
@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
return err return err
} }
for _, m := range messages { for _, m := range messages {
if err := sub(m); err != nil { if err := sub(v, m); err != nil {
return err return err
} }
} }
@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() {
} }
func (s *Server) runSMTPServer() error { func (s *Server) runSMTPServer() error {
sub := func(m *message) error { s.smtpBackend = newMailBackend(s.config, s.handle)
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.handle(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
s.smtpBackend = newMailBackend(s.config, sub)
s.smtpServer = smtp.NewServer(s.smtpBackend) s.smtpServer = smtp.NewServer(s.smtpBackend)
s.smtpServer.Addr = s.config.SMTPServerListen s.smtpServer.Addr = s.config.SMTPServerListen
s.smtpServer.Domain = s.config.SMTPServerDomain s.smtpServer.Domain = s.config.SMTPServerDomain
@ -1096,7 +1078,7 @@ func (s *Server) runManager() {
} }
} }
func (s *Server) runAtSender() { func (s *Server) runDelayedSender() {
for { for {
select { select {
case <-time.After(s.config.AtSenderInterval): case <-time.After(s.config.AtSenderInterval):
@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() {
if s.firebase == nil { if s.firebase == nil {
return return
} }
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
for { for {
select { select {
case <-time.After(s.config.FirebaseKeepaliveInterval): case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil { if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error()) log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
} }
case <-time.After(s.config.FirebasePollInterval): case <-time.After(s.config.FirebasePollInterval):
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil { if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error()) log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
} }
case <-s.closeChan: case <-s.closeChan:
@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() {
} }
func (s *Server) sendDelayedMessages() error { func (s *Server) sendDelayedMessages() error {
s.mu.Lock()
defer s.mu.Unlock()
messages, err := s.messageCache.MessagesDue() messages, err := s.messageCache.MessagesDue()
if err != nil { if err != nil {
return err return err
} }
for _, m := range messages { for _, m := range messages {
v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!!
if err := s.sendDelayedMessage(v, m); err != nil {
log.Printf("error sending delayed message: %s", err.Error())
}
}
return nil
}
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok { if ok {
if err := t.Publish(m); err != nil { if err := t.Publish(v, m); err != nil {
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error()) return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
} }
} }
if s.firebase != nil { // Firebase subscribers may not show up in topics map if s.firebase != nil { // Firebase subscribers may not show up in topics map
if err := s.firebase(m); err != nil { if err := s.firebase(v, m); err != nil {
log.Printf("unable to publish to Firebase: %v", err.Error()) return fmt.Errorf("unable to publish to Firebase: %v", err.Error())
} }
} }
if err := s.messageCache.MarkPublished(m); err != nil { if err := s.messageCache.MarkPublished(m); err != nil {
return err return err
} }
}
return nil return nil
} }
@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// visitor creates or retrieves a rate.Limiter for the given visitor. // visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor { func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
remoteAddr := r.RemoteAddr remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr) ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" { if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For") ip = r.Header.Get("X-Forwarded-For")
} }
return s.visitorFromIP(ip)
}
func (s *Server) visitorFromIP(ip string) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[ip] v, exists := s.visitors[ip]
if !exists { if !exists {
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip) s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"strings" "strings"
firebase "firebase.google.com/go/v4" firebase "firebase.google.com/go/v4"
@ -26,12 +27,20 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc
if err != nil { if err != nil {
return nil, err return nil, err
} }
return func(m *message) error { return func(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
return errHTTPTooManyRequestsFirebaseQuotaReached
}
fbm, err := toFirebaseMessage(m, auther) fbm, err := toFirebaseMessage(m, auther)
if err != nil { if err != nil {
return err return err
} }
_, err = msg.Send(context.Background(), fbm) _, err = msg.Send(context.Background(), fbm)
if err != nil && messaging.IsQuotaExceeded(err) {
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
return errHTTPTooManyRequestsFirebaseQuotaReached
}
return err return err
}, nil }, nil
} }

View file

@ -469,7 +469,8 @@ func TestServer_PublishFirebase(t *testing.T) {
require.NotEmpty(t, msg.ID) require.NotEmpty(t, msg.ID)
// Keepalive message // Keepalive message
require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic))) v := newVisitor(s.config, s.messageCache, "1.2.3.4")
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
time.Sleep(500 * time.Millisecond) // Time for sends time.Sleep(500 * time.Millisecond) // Time for sends
} }

View file

@ -3,10 +3,13 @@ package server
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"io" "io"
"mime" "mime"
"mime/multipart" "mime/multipart"
"net/http"
"net/http/httptest"
"net/mail" "net/mail"
"strings" "strings"
"sync" "sync"
@ -23,25 +26,25 @@ var (
// smtpBackend implements SMTP server methods. // smtpBackend implements SMTP server methods.
type smtpBackend struct { type smtpBackend struct {
config *Config config *Config
sub subscriber handler func(http.ResponseWriter, *http.Request)
success int64 success int64
failure int64 failure int64
mu sync.Mutex mu sync.Mutex
} }
func newMailBackend(conf *Config, sub subscriber) *smtpBackend { func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
return &smtpBackend{ return &smtpBackend{
config: conf, config: conf,
sub: sub, handler: handler,
} }
} }
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
return &smtpSession{backend: b}, nil return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
} }
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
return &smtpSession{backend: b}, nil return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
} }
func (b *smtpBackend) Counts() (success int64, failure int64) { func (b *smtpBackend) Counts() (success int64, failure int64) {
@ -53,6 +56,7 @@ func (b *smtpBackend) Counts() (success int64, failure int64) {
// smtpSession is returned after EHLO. // smtpSession is returned after EHLO.
type smtpSession struct { type smtpSession struct {
backend *smtpBackend backend *smtpBackend
remoteAddr string
topic string topic string
mu sync.Mutex mu sync.Mutex
} }
@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error {
m.Message = m.Title // Flip them, this makes more sense m.Message = m.Title // Flip them, this makes more sense
m.Title = "" m.Title = ""
} }
if err := s.backend.sub(m); err != nil { if err := s.publishMessage(m); err != nil {
return err return err
} }
s.backend.mu.Lock() s.backend.mu.Lock()
@ -138,6 +142,24 @@ func (s *smtpSession) Data(r io.Reader) error {
}) })
} }
func (s *smtpSession) publishMessage(m *message) error {
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
req.RemoteAddr = s.remoteAddr // rate limiting!!
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.backend.handler(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
func (s *smtpSession) Reset() { func (s *smtpSession) Reset() {
s.mu.Lock() s.mu.Lock()
s.topic = "" s.topic = ""

View file

@ -3,6 +3,9 @@ package server
import ( import (
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io"
"net"
"net/http"
"strings" "strings"
"testing" "testing"
) )
@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div> <div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div>
--000000000000f3320b05d42915c9--` --000000000000f3320b05d42915c9--`
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", m.Title) require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr"><br></div> <div dir="ltr"><br></div>
--000000000000bcf4a405d429f8d4--` --000000000000bcf4a405d429f8d4--`
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "emailtest", m.Topic) require.Equal(t, "/emailtest", r.URL.Path)
require.Equal(t, "", m.Title) // We flipped message and body require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
require.Equal(t, "This email has a subject but no body", m.Message) require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8"
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", m.Title) require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "Very short mail", m.Title) require.Equal(t, "Very short mail", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8"
what's up what's up
` `
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "Three santas 🎅🎅🎅", m.Title) require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -204,7 +202,7 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
that should do it that should do it
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
expected := `you know this is a string. expected := `you know this is a string.
it's a long string. it's a long string.
it's supposed to be longer than the max message length it's supposed to be longer than the max message length
@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
...................................................................... ......................................................................
...................................................................... ......................................................................
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBB` BBBBBBBBBBBBBBBBBBBBBBBBB`
require.Equal(t, 4096, len(expected)) // Sanity check require.Equal(t, 4096, len(expected)) // Sanity check
require.Equal(t, expected, m.Message) require.Equal(t, expected, readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
return nil // Nothing.
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.Login(nil, "user", "pass") session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email))) require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
} }
func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) { func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
conf := newTestConfig(t) conf := newTestConfig(t)
conf.SMTPServerListen = ":25" conf.SMTPServerListen = ":25"
conf.SMTPServerDomain = "ntfy.sh" conf.SMTPServerDomain = "ntfy.sh"
conf.SMTPServerAddrPrefix = "ntfy-" conf.SMTPServerAddrPrefix = "ntfy-"
backend := newMailBackend(conf, sub) backend := newMailBackend(conf, handler)
return conf, backend return conf, backend
} }
func readAll(t *testing.T, rc io.ReadCloser) string {
b, err := io.ReadAll(rc)
if err != nil {
t.Fatal(err)
}
return string(b)
}
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
ip, err := net.ResolveIPAddr("ip", remoteAddr)
if err != nil {
t.Fatal(err)
}
return &smtp.ConnectionState{
Hostname: "myhostname",
LocalAddr: ip,
RemoteAddr: ip,
}
}

View file

@ -15,7 +15,7 @@ type topic struct {
} }
// subscriber is a function that is called for every new message on a topic // subscriber is a function that is called for every new message on a topic
type subscriber func(msg *message) error type subscriber func(v *visitor, msg *message) error
// newTopic creates a new topic // newTopic creates a new topic
func newTopic(id string) *topic { func newTopic(id string) *topic {
@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) {
} }
// Publish asynchronously publishes to all subscribers // Publish asynchronously publishes to all subscribers
func (t *topic) Publish(m *message) error { func (t *topic) Publish(v *visitor, m *message) error {
go func() { go func() {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
for _, s := range t.subscribers { for _, s := range t.subscribers {
if err := s(m); err != nil { if err := s(v, m); err != nil {
log.Printf("error publishing message to subscriber") log.Printf("error publishing message to subscriber")
} }
} }

View file

@ -28,6 +28,7 @@ type visitor struct {
emails *rate.Limiter emails *rate.Limiter
subscriptions util.Limiter subscriptions util.Limiter
bandwidth util.Limiter bandwidth util.Limiter
firebase time.Time // Next allowed Firebase message
seen time.Time seen time.Time
mu sync.Mutex mu sync.Mutex
} }
@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour), bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
firebase: time.Unix(0, 0),
seen: time.Now(), seen: time.Now(),
} }
} }
func (v *visitor) IP() string {
return v.ip
}
func (v *visitor) RequestAllowed() error { func (v *visitor) RequestAllowed() error {
if !v.requests.Allow() { if !v.requests.Allow() {
return errVisitorLimitReached return errVisitorLimitReached
@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error {
return nil return nil
} }
func (v *visitor) FirebaseAllowed() error {
v.mu.Lock()
defer v.mu.Unlock()
if time.Now().Before(v.firebase) {
return errVisitorLimitReached
}
return nil
}
func (v *visitor) FirebaseTemporarilyDeny() {
v.mu.Lock()
defer v.mu.Unlock()
v.firebase = time.Now().Add(v.config.FirebaseQuotaLimitPenaltyDuration)
}
func (v *visitor) EmailAllowed() error { func (v *visitor) EmailAllowed() error {
if !v.emails.Allow() { if !v.emails.Allow() {
return errVisitorLimitReached return errVisitorLimitReached