Firebase quota limit

This commit is contained in:
Philipp Heckel 2022-05-31 20:38:56 -04:00
parent 8a81c8e95b
commit 8283b6be97
9 changed files with 180 additions and 119 deletions

View file

@ -6,15 +6,16 @@ import (
// Defines default config settings (excluding limits, see below)
const (
DefaultListenHTTP = ":80"
DefaultCacheDuration = 12 * time.Hour
DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!)
DefaultManagerInterval = time.Minute
DefaultAtSenderInterval = 10 * time.Second
DefaultMinDelay = 10 * time.Second
DefaultMaxDelay = 3 * 24 * time.Hour
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
DefaultListenHTTP = ":80"
DefaultCacheDuration = 12 * time.Hour
DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!)
DefaultManagerInterval = time.Minute
DefaultAtSenderInterval = 10 * time.Second
DefaultMinDelay = 10 * time.Second
DefaultMaxDelay = 3 * 24 * time.Hour
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
DefaultFirebaseQuotaLimitPenaltyDuration = 10 * time.Minute
)
// Defines all global and per-visitor limits
@ -69,6 +70,7 @@ type Config struct {
AtSenderInterval time.Duration
FirebaseKeepaliveInterval time.Duration
FirebasePollInterval time.Duration
FirebaseQuotaLimitPenaltyDuration time.Duration
UpstreamBaseURL string
SMTPSenderAddr string
SMTPSenderUser string
@ -121,6 +123,7 @@ func NewConfig() *Config {
AtSenderInterval: DefaultAtSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
FirebasePollInterval: DefaultFirebasePollInterval,
FirebaseQuotaLimitPenaltyDuration: DefaultFirebaseQuotaLimitPenaltyDuration,
TotalTopicLimit: DefaultTotalTopicLimit,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,

View file

@ -59,6 +59,7 @@ var (
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsFirebaseQuotaReached = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: Firebase quota for topic reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
)

View file

@ -7,13 +7,11 @@ import (
"embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
@ -221,7 +219,7 @@ func (s *Server) Run() error {
}
s.mu.Unlock()
go s.runManager()
go s.runAtSender()
go s.runDelayedSender()
go s.runFirebaseKeepaliver()
return <-errChan
@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
}
delayed := m.Time > time.Now().Unix()
if !delayed {
if err := t.Publish(m); err != nil {
if err := t.Publish(v, m); err != nil {
return err
}
}
@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
}
func (s *Server) sendToFirebase(v *visitor, m *message) {
if err := s.firebase(m); err != nil {
if err := s.firebase(v, m); err != nil {
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
}
}
@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return err
}
var wlock sync.Mutex
sub := func(msg *message) error {
sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) {
return nil
}
@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll {
return s.sendOldMessages(topics, since, scheduled, sub)
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err
}
for {
@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return nil
case <-time.After(s.config.KeepaliveInterval):
v.Keepalive()
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err
}
}
@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
}
})
sub := func(msg *message) error {
sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) {
return nil
}
@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
if poll {
return s.sendOldMessages(topics, since, scheduled, sub)
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err
}
err = g.Wait()
@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
return
}
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error {
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
if since.IsNone() {
return nil
}
@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
return err
}
for _, m := range messages {
if err := sub(m); err != nil {
if err := sub(v, m); err != nil {
return err
}
}
@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() {
}
func (s *Server) runSMTPServer() error {
sub := func(m *message) error {
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.handle(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
s.smtpBackend = newMailBackend(s.config, sub)
s.smtpBackend = newMailBackend(s.config, s.handle)
s.smtpServer = smtp.NewServer(s.smtpBackend)
s.smtpServer.Addr = s.config.SMTPServerListen
s.smtpServer.Domain = s.config.SMTPServerDomain
@ -1096,7 +1078,7 @@ func (s *Server) runManager() {
}
}
func (s *Server) runAtSender() {
func (s *Server) runDelayedSender() {
for {
select {
case <-time.After(s.config.AtSenderInterval):
@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() {
if s.firebase == nil {
return
}
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
for {
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
}
case <-time.After(s.config.FirebasePollInterval):
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil {
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
}
case <-s.closeChan:
@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() {
}
func (s *Server) sendDelayedMessages() error {
s.mu.Lock()
defer s.mu.Unlock()
messages, err := s.messageCache.MessagesDue()
if err != nil {
return err
}
for _, m := range messages {
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok {
if err := t.Publish(m); err != nil {
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
}
v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!!
if err := s.sendDelayedMessage(v, m); err != nil {
log.Printf("error sending delayed message: %s", err.Error())
}
if s.firebase != nil { // Firebase subscribers may not show up in topics map
if err := s.firebase(m); err != nil {
log.Printf("unable to publish to Firebase: %v", err.Error())
}
}
return nil
}
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok {
if err := t.Publish(v, m); err != nil {
return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
}
if err := s.messageCache.MarkPublished(m); err != nil {
return err
}
if s.firebase != nil { // Firebase subscribers may not show up in topics map
if err := s.firebase(v, m); err != nil {
return fmt.Errorf("unable to publish to Firebase: %v", err.Error())
}
}
if err := s.messageCache.MarkPublished(m); err != nil {
return err
}
return nil
}
@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For")
}
return s.visitorFromIP(ip)
}
func (s *Server) visitorFromIP(ip string) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[ip]
if !exists {
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)

View file

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
firebase "firebase.google.com/go/v4"
@ -26,12 +27,20 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc
if err != nil {
return nil, err
}
return func(m *message) error {
return func(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
return errHTTPTooManyRequestsFirebaseQuotaReached
}
fbm, err := toFirebaseMessage(m, auther)
if err != nil {
return err
}
_, err = msg.Send(context.Background(), fbm)
if err != nil && messaging.IsQuotaExceeded(err) {
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
return errHTTPTooManyRequestsFirebaseQuotaReached
}
return err
}, nil
}

View file

@ -469,7 +469,8 @@ func TestServer_PublishFirebase(t *testing.T) {
require.NotEmpty(t, msg.ID)
// Keepalive message
require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic)))
v := newVisitor(s.config, s.messageCache, "1.2.3.4")
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
time.Sleep(500 * time.Millisecond) // Time for sends
}

View file

@ -3,10 +3,13 @@ package server
import (
"bytes"
"errors"
"fmt"
"github.com/emersion/go-smtp"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/mail"
"strings"
"sync"
@ -23,25 +26,25 @@ var (
// smtpBackend implements SMTP server methods.
type smtpBackend struct {
config *Config
sub subscriber
handler func(http.ResponseWriter, *http.Request)
success int64
failure int64
mu sync.Mutex
}
func newMailBackend(conf *Config, sub subscriber) *smtpBackend {
func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
return &smtpBackend{
config: conf,
sub: sub,
config: conf,
handler: handler,
}
}
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
return &smtpSession{backend: b}, nil
return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
}
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
return &smtpSession{backend: b}, nil
return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
}
func (b *smtpBackend) Counts() (success int64, failure int64) {
@ -52,9 +55,10 @@ func (b *smtpBackend) Counts() (success int64, failure int64) {
// smtpSession is returned after EHLO.
type smtpSession struct {
backend *smtpBackend
topic string
mu sync.Mutex
backend *smtpBackend
remoteAddr string
topic string
mu sync.Mutex
}
func (s *smtpSession) AuthPlain(username, password string) error {
@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error {
m.Message = m.Title // Flip them, this makes more sense
m.Title = ""
}
if err := s.backend.sub(m); err != nil {
if err := s.publishMessage(m); err != nil {
return err
}
s.backend.mu.Lock()
@ -138,6 +142,24 @@ func (s *smtpSession) Data(r io.Reader) error {
})
}
func (s *smtpSession) publishMessage(m *message) error {
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
req.RemoteAddr = s.remoteAddr // rate limiting!!
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.backend.handler(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
func (s *smtpSession) Reset() {
s.mu.Lock()
s.topic = ""

View file

@ -3,6 +3,9 @@ package server
import (
"github.com/emersion/go-smtp"
"github.com/stretchr/testify/require"
"io"
"net"
"net/http"
"strings"
"testing"
)
@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div>
--000000000000f3320b05d42915c9--`
_, backend := newTestBackend(t, func(m *message) error {
require.Equal(t, "mytopic", m.Topic)
require.Equal(t, "and one more", m.Title)
require.Equal(t, "what's up", m.Message)
return nil
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", readAll(t, r.Body))
})
session, _ := backend.AnonymousLogin(nil)
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr"><br></div>
--000000000000bcf4a405d429f8d4--`
_, backend := newTestBackend(t, func(m *message) error {
require.Equal(t, "emailtest", m.Topic)
require.Equal(t, "", m.Title) // We flipped message and body
require.Equal(t, "This email has a subject but no body", m.Message)
return nil
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/emailtest", r.URL.Path)
require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
})
session, _ := backend.AnonymousLogin(nil)
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8"
what's up
`
conf, backend := newTestBackend(t, func(m *message) error {
require.Equal(t, "mytopic", m.Topic)
require.Equal(t, "and one more", m.Title)
require.Equal(t, "what's up", m.Message)
return nil
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", readAll(t, r.Body))
})
conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil)
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
what's up
`
conf, backend := newTestBackend(t, func(m *message) error {
require.Equal(t, "mytopic", m.Topic)
require.Equal(t, "Very short mail", m.Title)
require.Equal(t, "what's up", m.Message)
return nil
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "Very short mail", r.Header.Get("Title"))
require.Equal(t, "what's up", readAll(t, r.Body))
})
conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil)
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8"
what's up
`
_, backend := newTestBackend(t, func(m *message) error {
require.Equal(t, "Three santas 🎅🎅🎅", m.Title)
return nil
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
})
session, _ := backend.AnonymousLogin(nil)
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
@ -140,7 +138,7 @@ To: mytopic@ntfy.sh
Content-Type: text/plain; charset="UTF-8"
you know this is a string.
it's a long string.
it's a long string.
it's supposed to be longer than the max message length
which is 4096 bytes,
it used to be 512 bytes, but I increased that for the UnifiedPush support
@ -204,9 +202,9 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
that should do it
`
conf, backend := newTestBackend(t, func(m *message) error {
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
expected := `you know this is a string.
it's a long string.
it's a long string.
it's supposed to be longer than the max message length
which is 4096 bytes,
it used to be 512 bytes, but I increased that for the UnifiedPush support
@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
......................................................................
......................................................................
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBB`
BBBBBBBBBBBBBBBBBBBBBBBBB`
require.Equal(t, 4096, len(expected)) // Sanity check
require.Equal(t, expected, m.Message)
return nil
require.Equal(t, expected, readAll(t, r.Body))
})
conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil)
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE
what's up
`
conf, backend := newTestBackend(t, func(m *message) error {
return nil
conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
// Nothing.
})
conf.SMTPServerAddrPrefix = ""
session, _ := backend.Login(nil, "user", "pass")
session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
}
func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) {
func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
conf := newTestConfig(t)
conf.SMTPServerListen = ":25"
conf.SMTPServerDomain = "ntfy.sh"
conf.SMTPServerAddrPrefix = "ntfy-"
backend := newMailBackend(conf, sub)
backend := newMailBackend(conf, handler)
return conf, backend
}
func readAll(t *testing.T, rc io.ReadCloser) string {
b, err := io.ReadAll(rc)
if err != nil {
t.Fatal(err)
}
return string(b)
}
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
ip, err := net.ResolveIPAddr("ip", remoteAddr)
if err != nil {
t.Fatal(err)
}
return &smtp.ConnectionState{
Hostname: "myhostname",
LocalAddr: ip,
RemoteAddr: ip,
}
}

View file

@ -15,7 +15,7 @@ type topic struct {
}
// subscriber is a function that is called for every new message on a topic
type subscriber func(msg *message) error
type subscriber func(v *visitor, msg *message) error
// newTopic creates a new topic
func newTopic(id string) *topic {
@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) {
}
// Publish asynchronously publishes to all subscribers
func (t *topic) Publish(m *message) error {
func (t *topic) Publish(v *visitor, m *message) error {
go func() {
t.mu.Lock()
defer t.mu.Unlock()
for _, s := range t.subscribers {
if err := s(m); err != nil {
if err := s(v, m); err != nil {
log.Printf("error publishing message to subscriber")
}
}

View file

@ -28,6 +28,7 @@ type visitor struct {
emails *rate.Limiter
subscriptions util.Limiter
bandwidth util.Limiter
firebase time.Time // Next allowed Firebase message
seen time.Time
mu sync.Mutex
}
@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
firebase: time.Unix(0, 0),
seen: time.Now(),
}
}
func (v *visitor) IP() string {
return v.ip
}
func (v *visitor) RequestAllowed() error {
if !v.requests.Allow() {
return errVisitorLimitReached
@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error {
return nil
}
func (v *visitor) FirebaseAllowed() error {
v.mu.Lock()
defer v.mu.Unlock()
if time.Now().Before(v.firebase) {
return errVisitorLimitReached
}
return nil
}
func (v *visitor) FirebaseTemporarilyDeny() {
v.mu.Lock()
defer v.mu.Unlock()
v.firebase = time.Now().Add(v.config.FirebaseQuotaLimitPenaltyDuration)
}
func (v *visitor) EmailAllowed() error {
if !v.emails.Allow() {
return errVisitorLimitReached