Only set rate visitor if allowed
This commit is contained in:
		
							parent
							
								
									2329695a47
								
							
						
					
					
						commit
						bfc3983d06
					
				
					 4 changed files with 151 additions and 17 deletions
				
			
		|  | @ -112,7 +112,6 @@ const ( | ||||||
| 	encodingBase64           = "base64"                  // Used mainly for binary UnifiedPush messages | 	encodingBase64           = "base64"                  // Used mainly for binary UnifiedPush messages | ||||||
| 	jsonBodyBytesLimit       = 16384 | 	jsonBodyBytesLimit       = 16384 | ||||||
| 	unifiedPushTopicPrefix   = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber | 	unifiedPushTopicPrefix   = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber | ||||||
| 	rateTopicsWildcard       = "*"  // Allows defining all topics in the request subscriber-rate-limited topics |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // WebSocket constants | // WebSocket constants | ||||||
|  | @ -977,7 +976,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	registerRateVisitors(topics, rateTopics, v) | 	if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||||
| 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset! | 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset! | ||||||
| 	if poll { | 	if poll { | ||||||
|  | @ -1113,7 +1114,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | ||||||
| 		} | 		} | ||||||
| 		return conn.WriteJSON(msg) | 		return conn.WriteJSON(msg) | ||||||
| 	} | 	} | ||||||
| 	registerRateVisitors(topics, rateTopics, v) | 	if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||||
| 	if poll { | 	if poll { | ||||||
| 		return s.sendOldMessages(topics, since, scheduled, v, sub) | 		return s.sendOldMessages(topics, since, scheduled, v, sub) | ||||||
|  | @ -1156,23 +1159,62 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic | // maybeSetRateVisitors sets the rate visitor on a topic (v.SetRateVisitor), indicating that all messages published | ||||||
| // will be rate limited against the rate visitor instead of the publishing visitor. | // to that topic will be rate limited against the rate visitor instead of the publishing visitor. | ||||||
|  | // | ||||||
|  | // Setting the rate visitor is ony allowed if | ||||||
|  | // - auth-file is not set (everything is open by default) | ||||||
|  | // - the topic is reserved, and v.user is the owner | ||||||
|  | // - the topic is not reserved, and v.user has write access | ||||||
| // | // | ||||||
| // Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition | // Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition | ||||||
| // until the Android app will send the "Rate-Topics" header. | // until the Android app will send the "Rate-Topics" header. | ||||||
| func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) { | func (s *Server) maybeSetRateVisitors(r *http.Request, v *visitor, topics []*topic, rateTopics []string) error { | ||||||
| 	if len(rateTopics) == 1 && rateTopics[0] == rateTopicsWildcard { | 	// Make a list of topics that we'll actually set the RateVisitor on | ||||||
|  | 	eligibleRateTopics := make([]*topic, 0) | ||||||
| 	for _, t := range topics { | 	for _, t := range topics { | ||||||
|  | 		if strings.HasPrefix(t.ID, unifiedPushTopicPrefix) || util.Contains(rateTopics, t.ID) { | ||||||
|  | 			eligibleRateTopics = append(eligibleRateTopics, t) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if len(eligibleRateTopics) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If access controls are turned off, v has access to everything, and we can set the rate visitor | ||||||
|  | 	if s.userManager == nil { | ||||||
|  | 		return s.setRateVisitors(r, v, eligibleRateTopics) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If access controls are enabled, only set rate visitor if | ||||||
|  | 	// - topic is reserved, and v.user is the owner | ||||||
|  | 	// - topic is not reserved, and v.user has write access | ||||||
|  | 	writableRateTopics := make([]*topic, 0) | ||||||
|  | 	for _, t := range topics { | ||||||
|  | 		ownerUserID, err := s.userManager.ReservationOwner(t.ID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if ownerUserID == "" { | ||||||
|  | 			if err := s.userManager.Authorize(v.User(), t.ID, user.PermissionWrite); err == nil { | ||||||
|  | 				writableRateTopics = append(writableRateTopics, t) | ||||||
|  | 			} | ||||||
|  | 		} else if ownerUserID == v.MaybeUserID() { | ||||||
|  | 			writableRateTopics = append(writableRateTopics, t) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return s.setRateVisitors(r, v, writableRateTopics) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topic) error { | ||||||
|  | 	for _, t := range rateTopics { | ||||||
|  | 		logvr(v, r). | ||||||
|  | 			Tag(tagSubscribe). | ||||||
|  | 			Field("message_topic", t.ID). | ||||||
|  | 			Debug("Setting visitor as rate visitor for topic %s", t.ID) | ||||||
| 		t.SetRateVisitor(v) | 		t.SetRateVisitor(v) | ||||||
| 	} | 	} | ||||||
| 	} else { | 	return nil | ||||||
| 		for _, t := range topics { |  | ||||||
| 			if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) { |  | ||||||
| 				t.SetRateVisitor(v) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the | // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the | ||||||
|  |  | ||||||
|  | @ -2040,7 +2040,7 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) { | ||||||
| 		r.RemoteAddr = "1.2.3.4" | 		r.RemoteAddr = "1.2.3.4" | ||||||
| 	} | 	} | ||||||
| 	rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{ | 	rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{ | ||||||
| 		"rate-topics": "*", | 		"rate-topics": "mytopic", | ||||||
| 	}, subscriberFn) | 	}, subscriberFn) | ||||||
| 	require.Equal(t, 200, rr.Code) | 	require.Equal(t, 200, rr.Code) | ||||||
| 	require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String()) | 	require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String()) | ||||||
|  | @ -2065,6 +2065,72 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) { | ||||||
| 	require.Nil(t, s.visitors["ip:1.2.3.4"]) | 	require.Nil(t, s.visitors["ip:1.2.3.4"]) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestServer_SubscriberRateLimiting_ProtectedTopics(t *testing.T) { | ||||||
|  | 	c := newTestConfigWithAuthFile(t) | ||||||
|  | 	c.AuthDefault = user.PermissionDenyAll | ||||||
|  | 	s := newTestServer(t, c) | ||||||
|  | 
 | ||||||
|  | 	// Create some ACLs | ||||||
|  | 	require.Nil(t, s.userManager.AddTier(&user.Tier{ | ||||||
|  | 		Code:         "test", | ||||||
|  | 		MessageLimit: 5, | ||||||
|  | 	})) | ||||||
|  | 	require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) | ||||||
|  | 	require.Nil(t, s.userManager.ChangeTier("ben", "test")) | ||||||
|  | 	require.Nil(t, s.userManager.AllowAccess("ben", "announcements", user.PermissionReadWrite)) | ||||||
|  | 	require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead)) | ||||||
|  | 	require.Nil(t, s.userManager.AllowAccess(user.Everyone, "public_topic", user.PermissionReadWrite)) | ||||||
|  | 
 | ||||||
|  | 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) | ||||||
|  | 	require.Nil(t, s.userManager.ChangeTier("phil", "test")) | ||||||
|  | 	require.Nil(t, s.userManager.AddReservation("phil", "reserved-for-phil", user.PermissionReadWrite)) | ||||||
|  | 
 | ||||||
|  | 	// Set rate visitor as user "phil" on topic | ||||||
|  | 	// - "reserved-for-phil": Allowed, because I am the owner | ||||||
|  | 	// - "public_topic": Allowed, because it has read-write permissions for everyone | ||||||
|  | 	// - "announcements": NOT allowed, because it has read-only permissions for everyone | ||||||
|  | 	rr := request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{ | ||||||
|  | 		"Authorization": util.BasicAuth("phil", "phil"), | ||||||
|  | 		"Rate-Topics":   "reserved-for-phil,public_topic,announcements", | ||||||
|  | 	}) | ||||||
|  | 	require.Equal(t, 200, rr.Code) | ||||||
|  | 	require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name) | ||||||
|  | 	require.Equal(t, "phil", s.topics["public_topic"].rateVisitor.user.Name) | ||||||
|  | 	require.Nil(t, s.topics["announcements"].rateVisitor) | ||||||
|  | 
 | ||||||
|  | 	// Set rate visitor as user "ben" on topic | ||||||
|  | 	// - "reserved-for-phil": NOT allowed, because I am not the owner | ||||||
|  | 	// - "public_topic": Allowed, because it has read-write permissions for everyone | ||||||
|  | 	// - "announcements": Allowed, because I have read-write permissions | ||||||
|  | 	rr = request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{ | ||||||
|  | 		"Authorization": util.BasicAuth("ben", "ben"), | ||||||
|  | 		"Rate-Topics":   "reserved-for-phil,public_topic,announcements", | ||||||
|  | 	}) | ||||||
|  | 	require.Equal(t, 200, rr.Code) | ||||||
|  | 	require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name) | ||||||
|  | 	require.Equal(t, "ben", s.topics["public_topic"].rateVisitor.user.Name) | ||||||
|  | 	require.Equal(t, "ben", s.topics["announcements"].rateVisitor.user.Name) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *testing.T) { | ||||||
|  | 	c := newTestConfigWithAuthFile(t) | ||||||
|  | 	c.AuthDefault = user.PermissionReadWrite | ||||||
|  | 	s := newTestServer(t, c) | ||||||
|  | 
 | ||||||
|  | 	// Create some ACLs | ||||||
|  | 	require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead)) | ||||||
|  | 
 | ||||||
|  | 	// Set rate visitor as ip:1.2.3.4 on topic | ||||||
|  | 	// - "up1234": Allowed, because no ACLs and nobody owns the topic | ||||||
|  | 	// - "announcements": NOT allowed, because it has read-only permissions for everyone | ||||||
|  | 	rr := request(t, s, "GET", "/up1234,announcements/json?poll=1", "", nil, func(r *http.Request) { | ||||||
|  | 		r.RemoteAddr = "1.2.3.4" | ||||||
|  | 	}) | ||||||
|  | 	require.Equal(t, 200, rr.Code) | ||||||
|  | 	require.Equal(t, "1.2.3.4", s.topics["up1234"].rateVisitor.ip.String()) | ||||||
|  | 	require.Nil(t, s.topics["announcements"].rateVisitor) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func newTestConfig(t *testing.T) *Config { | func newTestConfig(t *testing.T) *Config { | ||||||
| 	conf := NewConfig() | 	conf := NewConfig() | ||||||
| 	conf.BaseURL = "http://127.0.0.1:12345" | 	conf.BaseURL = "http://127.0.0.1:12345" | ||||||
|  |  | ||||||
|  | @ -141,6 +141,7 @@ func (v *visitor) Context() log.Context { | ||||||
| func (v *visitor) contextNoLock() log.Context { | func (v *visitor) contextNoLock() log.Context { | ||||||
| 	info := v.infoLightNoLock() | 	info := v.infoLightNoLock() | ||||||
| 	fields := log.Context{ | 	fields := log.Context{ | ||||||
|  | 		"visitor_id":                     visitorID(v.ip, v.user), | ||||||
| 		"visitor_ip":                     v.ip.String(), | 		"visitor_ip":                     v.ip.String(), | ||||||
| 		"visitor_messages":               info.Stats.Messages, | 		"visitor_messages":               info.Stats.Messages, | ||||||
| 		"visitor_messages_limit":         info.Limits.MessageLimit, | 		"visitor_messages_limit":         info.Limits.MessageLimit, | ||||||
|  |  | ||||||
|  | @ -201,7 +201,14 @@ const ( | ||||||
| 	selectUserReservationsCountQuery = ` | 	selectUserReservationsCountQuery = ` | ||||||
| 		SELECT COUNT(*) | 		SELECT COUNT(*) | ||||||
| 		FROM user_access | 		FROM user_access | ||||||
| 		WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?) | 		WHERE user_id = owner_user_id  | ||||||
|  | 		  AND owner_user_id = (SELECT id FROM user WHERE user = ?) | ||||||
|  | 	` | ||||||
|  | 	selectUserReservationsOwnerQuery = ` | ||||||
|  | 		SELECT owner_user_id | ||||||
|  | 		FROM user_access | ||||||
|  | 		WHERE topic = ? | ||||||
|  | 		  AND user_id = owner_user_id | ||||||
| 	` | 	` | ||||||
| 	selectUserHasReservationQuery = ` | 	selectUserHasReservationQuery = ` | ||||||
| 		SELECT COUNT(*) | 		SELECT COUNT(*) | ||||||
|  | @ -1025,6 +1032,24 @@ func (a *Manager) ReservationsCount(username string) (int64, error) { | ||||||
| 	return count, nil | 	return count, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ReservationOwner returns user ID of the user that owns this topic, or an | ||||||
|  | // empty string if it's not owned by anyone | ||||||
|  | func (a *Manager) ReservationOwner(topic string) (string, error) { | ||||||
|  | 	rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	defer rows.Close() | ||||||
|  | 	if !rows.Next() { | ||||||
|  | 		return "", nil | ||||||
|  | 	} | ||||||
|  | 	var ownerUserID string | ||||||
|  | 	if err := rows.Scan(&ownerUserID); err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return ownerUserID, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // ChangePassword changes a user's password | // ChangePassword changes a user's password | ||||||
| func (a *Manager) ChangePassword(username, password string) error { | func (a *Manager) ChangePassword(username, password string) error { | ||||||
| 	hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost) | 	hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost) | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue