diff --git a/server/server.go b/server/server.go index 5a4b4cbf..7ef59394 100644 --- a/server/server.go +++ b/server/server.go @@ -43,12 +43,19 @@ type Server struct { // errHTTP is a generic HTTP error for any non-200 HTTP error type errHTTP struct { - Code int - Status string + Code int `json:"code,omitempty"` + HTTPCode int `json:"http"` + Message string `json:"error"` + Link string `json:"link,omitempty"` } func (e errHTTP) Error() string { - return fmt.Sprintf("http: %s", e.Status) + return e.Message +} + +func (e errHTTP) JSON() string { + b, _ := json.Marshal(&e) + return string(b) } type indexPage struct { @@ -105,9 +112,22 @@ var ( docsStaticFs embed.FS docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs} - errHTTPBadRequest = &errHTTP{http.StatusBadRequest, http.StatusText(http.StatusBadRequest)} - errHTTPNotFound = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)} - errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)} + errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} + errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitGlobalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"} + errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""} + errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""} + errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"} + errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} + errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} + errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"} + errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"} + errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""} + errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""} + errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} ) const ( @@ -241,11 +261,16 @@ func (s *Server) Stop() { func (s *Server) handle(w http.ResponseWriter, r *http.Request) { if err := s.handleInternal(w, r); err != nil { - if e, ok := err.(*errHTTP); ok { - s.fail(w, r, e.Code, e) - } else { - s.fail(w, r, http.StatusInternalServerError, err) + var e *errHTTP + var ok bool + if e, ok = err.(*errHTTP); !ok { + e = errHTTPInternalError } + log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, e.HTTPCode, err.Error()) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + w.WriteHeader(e.HTTPCode) + io.WriteString(w, e.JSON()+"\n") } } @@ -315,7 +340,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return err } m := newDefaultMessage(t.ID, strings.TrimSpace(string(b))) - cache, firebase, email, unifiedpush, err := s.parseParams(r, m) + cache, firebase, email, err := s.parseParams(r, m) if err != nil { return err } @@ -329,13 +354,13 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if email != "" { if err := v.EmailAllowed(); err != nil { - return err + return errHTTPTooManyRequestsLimitEmails } } m.UnifiedPush = unifiedpush if s.mailer == nil && email != "" { - return errHTTPBadRequest + return errHTTPBadRequestEmailDisabled } if m.Message == "" { m.Message = emptyMessageBody @@ -376,11 +401,10 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return nil } -func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { +func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) { cache = readParam(r, "x-cache", "cache") != "no" firebase = readParam(r, "x-firebase", "firebase") != "no" email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") - unifiedpush = readParam(r, "up", "unifiedpush") == "1" m.Title = readParam(r, "x-title", "title", "t") messageStr := readParam(r, "x-message", "message", "m") if messageStr != "" { @@ -388,7 +412,7 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase } m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p")) if err != nil { - return false, false, "", false, errHTTPBadRequest + return false, false, "", errHTTPBadRequestPriorityInvalid } tagsStr := readParam(r, "x-tags", "tags", "tag", "ta") if tagsStr != "" { @@ -400,22 +424,22 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") if delayStr != "" { if !cache { - return false, false, "", false, errHTTPBadRequest + return false, false, "", errHTTPBadRequestDelayNoCache } if email != "" { - return false, false, "", false, errHTTPBadRequest // we cannot store the email address (yet) + return false, false, "", errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet) } delay, err := util.ParseFutureTime(delayStr, time.Now()) if err != nil { - return false, false, "", false, errHTTPBadRequest + return false, false, "", errHTTPBadRequestDelayCannotParse } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() { - return false, false, "", false, errHTTPBadRequest + return false, false, "", errHTTPBadRequestDelayTooSmall } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() { - return false, false, "", false, errHTTPBadRequest + return false, false, "", errHTTPBadRequestDelayTooLarge } m.Time = delay.Unix() } - return cache, firebase, email, unifiedpush, nil + return cache, firebase, email, nil } func readParam(r *http.Request, names ...string) string { @@ -470,8 +494,8 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v } func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error { - if err := v.AddSubscription(); err != nil { - return errHTTPTooManyRequests + if err := v.SubscriptionAllowed(); err != nil { + return errHTTPTooManyRequestsLimitSubscriptions } defer v.RemoveSubscription() topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack @@ -617,7 +641,7 @@ func parseSince(r *http.Request, poll bool) (sinceTime, error) { } else if d, err := time.ParseDuration(since); err == nil { return sinceTime(time.Now().Add(-1 * d)), nil } - return sinceNoMessages, errHTTPBadRequest + return sinceNoMessages, errHTTPBadRequestSinceInvalid } func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error { @@ -629,7 +653,7 @@ func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error { func (s *Server) topicFromPath(path string) (*topic, error) { parts := strings.Split(path, "/") if len(parts) < 2 { - return nil, errHTTPBadRequest + return nil, errHTTPBadRequestTopicInvalid } topics, err := s.topicsFromIDs(parts[1]) if err != nil { @@ -644,11 +668,11 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { topics := make([]*topic, 0) for _, id := range ids { if util.InStringList(disallowedTopics, id) { - return nil, errHTTPBadRequest + return nil, errHTTPBadRequestTopicDisallowed } if _, ok := s.topics[id]; !ok { if len(s.topics) >= s.config.GlobalTopicLimit { - return nil, errHTTPTooManyRequests + return nil, errHTTPTooManyRequestsLimitGlobalTopics } s.topics[id] = newTopic(id) } @@ -766,7 +790,7 @@ func (s *Server) sendDelayedMessages() error { func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error { v := s.visitor(r) if err := v.RequestAllowed(); err != nil { - return err + return errHTTPTooManyRequestsLimitRequests } return handler(w, r, v) } @@ -798,9 +822,3 @@ func (s *Server) inc(counter *int64) { defer s.mu.Unlock() *counter++ } - -func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) { - log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error()) - w.WriteHeader(code) - _, _ = io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code))) -} diff --git a/server/server_test.go b/server/server_test.go index c290cca9..cf3fc58c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -252,6 +252,7 @@ func TestServer_PublishAtWithCacheError(t *testing.T) { "In": "30 min", }) require.Equal(t, 400, response.Code) + require.Equal(t, errHTTPBadRequestDelayNoCache, toHTTPError(t, response.Body.String())) } func TestServer_PublishAtTooShortDelay(t *testing.T) { @@ -644,6 +645,12 @@ func toMessage(t *testing.T, s string) *message { return &m } +func toHTTPError(t *testing.T, s string) *errHTTP { + var e errHTTP + require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&e)) + return &e +} + func firebaseServiceAccountFile(t *testing.T) string { if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" { return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") diff --git a/server/visitor.go b/server/visitor.go index 0dfa6cef..a1bab367 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -1,6 +1,7 @@ package server import ( + "errors" "golang.org/x/time/rate" "heckel.io/ntfy/util" "sync" @@ -14,6 +15,10 @@ const ( visitorExpungeAfter = 24 * time.Hour ) +var ( + errVisitorLimitReached = errors.New("limit reached") +) + // visitor represents an API user, and its associated rate.Limiter used for rate limiting type visitor struct { config *Config @@ -42,23 +47,23 @@ func (v *visitor) IP() string { func (v *visitor) RequestAllowed() error { if !v.requests.Allow() { - return errHTTPTooManyRequests + return errVisitorLimitReached } return nil } func (v *visitor) EmailAllowed() error { if !v.emails.Allow() { - return errHTTPTooManyRequests + return errVisitorLimitReached } return nil } -func (v *visitor) AddSubscription() error { +func (v *visitor) SubscriptionAllowed() error { v.mu.Lock() defer v.mu.Unlock() if err := v.subscriptions.Add(1); err != nil { - return errHTTPTooManyRequests + return errVisitorLimitReached } return nil }