From bfc3983d06178c769ee0018dc62a953a24f40a82 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Fri, 24 Feb 2023 14:45:30 -0500 Subject: [PATCH] Only set rate visitor if allowed --- server/server.go | 72 ++++++++++++++++++++++++++++++++++--------- server/server_test.go | 68 +++++++++++++++++++++++++++++++++++++++- server/visitor.go | 1 + user/manager.go | 27 +++++++++++++++- 4 files changed, 151 insertions(+), 17 deletions(-) diff --git a/server/server.go b/server/server.go index cdd85c6b..81e8e952 100644 --- a/server/server.go +++ b/server/server.go @@ -112,7 +112,6 @@ const ( encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages jsonBodyBytesLimit = 16384 unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber - rateTopicsWildcard = "*" // Allows defining all topics in the request subscriber-rate-limited topics ) // WebSocket constants @@ -977,7 +976,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * } return nil } - registerRateVisitors(topics, rateTopics, v) + if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil { + return err + } w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { @@ -1113,7 +1114,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } return conn.WriteJSON(msg) } - registerRateVisitors(topics, rateTopics, v) + if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil { + return err + } w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests if poll { return s.sendOldMessages(topics, since, scheduled, v, sub) @@ -1156,23 +1159,62 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu return } -// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic -// will be rate limited against the rate visitor instead of the publishing visitor. +// maybeSetRateVisitors sets the rate visitor on a topic (v.SetRateVisitor), indicating that all messages published +// to that topic will be rate limited against the rate visitor instead of the publishing visitor. +// +// Setting the rate visitor is ony allowed if +// - auth-file is not set (everything is open by default) +// - the topic is reserved, and v.user is the owner +// - the topic is not reserved, and v.user has write access // // Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition // until the Android app will send the "Rate-Topics" header. -func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) { - if len(rateTopics) == 1 && rateTopics[0] == rateTopicsWildcard { - for _, t := range topics { - t.SetRateVisitor(v) - } - } else { - for _, t := range topics { - if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) { - t.SetRateVisitor(v) - } +func (s *Server) maybeSetRateVisitors(r *http.Request, v *visitor, topics []*topic, rateTopics []string) error { + // Make a list of topics that we'll actually set the RateVisitor on + eligibleRateTopics := make([]*topic, 0) + for _, t := range topics { + if strings.HasPrefix(t.ID, unifiedPushTopicPrefix) || util.Contains(rateTopics, t.ID) { + eligibleRateTopics = append(eligibleRateTopics, t) } } + if len(eligibleRateTopics) == 0 { + return nil + } + + // If access controls are turned off, v has access to everything, and we can set the rate visitor + if s.userManager == nil { + return s.setRateVisitors(r, v, eligibleRateTopics) + } + + // If access controls are enabled, only set rate visitor if + // - topic is reserved, and v.user is the owner + // - topic is not reserved, and v.user has write access + writableRateTopics := make([]*topic, 0) + for _, t := range topics { + ownerUserID, err := s.userManager.ReservationOwner(t.ID) + if err != nil { + return err + } + if ownerUserID == "" { + if err := s.userManager.Authorize(v.User(), t.ID, user.PermissionWrite); err == nil { + writableRateTopics = append(writableRateTopics, t) + } + } else if ownerUserID == v.MaybeUserID() { + writableRateTopics = append(writableRateTopics, t) + } + } + return s.setRateVisitors(r, v, writableRateTopics) +} + +func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topic) error { + for _, t := range rateTopics { + logvr(v, r). + Tag(tagSubscribe). + Field("message_topic", t.ID). + Debug("Setting visitor as rate visitor for topic %s", t.ID) + t.SetRateVisitor(v) + } + return nil } // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the diff --git a/server/server_test.go b/server/server_test.go index 802e8742..24525849 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2040,7 +2040,7 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) { r.RemoteAddr = "1.2.3.4" } rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{ - "rate-topics": "*", + "rate-topics": "mytopic", }, subscriberFn) require.Equal(t, 200, rr.Code) require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String()) @@ -2065,6 +2065,72 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) { require.Nil(t, s.visitors["ip:1.2.3.4"]) } +func TestServer_SubscriberRateLimiting_ProtectedTopics(t *testing.T) { + c := newTestConfigWithAuthFile(t) + c.AuthDefault = user.PermissionDenyAll + s := newTestServer(t, c) + + // Create some ACLs + require.Nil(t, s.userManager.AddTier(&user.Tier{ + Code: "test", + MessageLimit: 5, + })) + require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("ben", "test")) + require.Nil(t, s.userManager.AllowAccess("ben", "announcements", user.PermissionReadWrite)) + require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead)) + require.Nil(t, s.userManager.AllowAccess(user.Everyone, "public_topic", user.PermissionReadWrite)) + + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("phil", "test")) + require.Nil(t, s.userManager.AddReservation("phil", "reserved-for-phil", user.PermissionReadWrite)) + + // Set rate visitor as user "phil" on topic + // - "reserved-for-phil": Allowed, because I am the owner + // - "public_topic": Allowed, because it has read-write permissions for everyone + // - "announcements": NOT allowed, because it has read-only permissions for everyone + rr := request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + "Rate-Topics": "reserved-for-phil,public_topic,announcements", + }) + require.Equal(t, 200, rr.Code) + require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name) + require.Equal(t, "phil", s.topics["public_topic"].rateVisitor.user.Name) + require.Nil(t, s.topics["announcements"].rateVisitor) + + // Set rate visitor as user "ben" on topic + // - "reserved-for-phil": NOT allowed, because I am not the owner + // - "public_topic": Allowed, because it has read-write permissions for everyone + // - "announcements": Allowed, because I have read-write permissions + rr = request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{ + "Authorization": util.BasicAuth("ben", "ben"), + "Rate-Topics": "reserved-for-phil,public_topic,announcements", + }) + require.Equal(t, 200, rr.Code) + require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name) + require.Equal(t, "ben", s.topics["public_topic"].rateVisitor.user.Name) + require.Equal(t, "ben", s.topics["announcements"].rateVisitor.user.Name) +} + +func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *testing.T) { + c := newTestConfigWithAuthFile(t) + c.AuthDefault = user.PermissionReadWrite + s := newTestServer(t, c) + + // Create some ACLs + require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead)) + + // Set rate visitor as ip:1.2.3.4 on topic + // - "up1234": Allowed, because no ACLs and nobody owns the topic + // - "announcements": NOT allowed, because it has read-only permissions for everyone + rr := request(t, s, "GET", "/up1234,announcements/json?poll=1", "", nil, func(r *http.Request) { + r.RemoteAddr = "1.2.3.4" + }) + require.Equal(t, 200, rr.Code) + require.Equal(t, "1.2.3.4", s.topics["up1234"].rateVisitor.ip.String()) + require.Nil(t, s.topics["announcements"].rateVisitor) +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.BaseURL = "http://127.0.0.1:12345" diff --git a/server/visitor.go b/server/visitor.go index b96563df..80bac46f 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -141,6 +141,7 @@ func (v *visitor) Context() log.Context { func (v *visitor) contextNoLock() log.Context { info := v.infoLightNoLock() fields := log.Context{ + "visitor_id": visitorID(v.ip, v.user), "visitor_ip": v.ip.String(), "visitor_messages": info.Stats.Messages, "visitor_messages_limit": info.Limits.MessageLimit, diff --git a/user/manager.go b/user/manager.go index 58a8f4c7..7fe115d9 100644 --- a/user/manager.go +++ b/user/manager.go @@ -201,7 +201,14 @@ const ( selectUserReservationsCountQuery = ` SELECT COUNT(*) FROM user_access - WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?) + WHERE user_id = owner_user_id + AND owner_user_id = (SELECT id FROM user WHERE user = ?) + ` + selectUserReservationsOwnerQuery = ` + SELECT owner_user_id + FROM user_access + WHERE topic = ? + AND user_id = owner_user_id ` selectUserHasReservationQuery = ` SELECT COUNT(*) @@ -1025,6 +1032,24 @@ func (a *Manager) ReservationsCount(username string) (int64, error) { return count, nil } +// ReservationOwner returns user ID of the user that owns this topic, or an +// empty string if it's not owned by anyone +func (a *Manager) ReservationOwner(topic string) (string, error) { + rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic) + if err != nil { + return "", err + } + defer rows.Close() + if !rows.Next() { + return "", nil + } + var ownerUserID string + if err := rows.Scan(&ownerUserID); err != nil { + return "", err + } + return ownerUserID, nil +} + // ChangePassword changes a user's password func (a *Manager) ChangePassword(username, password string) error { hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)