Working test
parent
29340e7e24
commit
21b27b5dbe
|
@ -571,7 +571,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
|
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
|
||||||
vRate, ok := r.Context().Value("vRate").(*visitor)
|
vrate, ok := r.Context().Value("vRate").(*visitor)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errHTTPInternalError
|
return nil, errHTTPInternalError
|
||||||
}
|
}
|
||||||
|
@ -579,8 +579,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errHTTPInternalError
|
return nil, errHTTPInternalError
|
||||||
}
|
}
|
||||||
|
if !vrate.MessageAllowed() {
|
||||||
if !vRate.MessageAllowed() {
|
|
||||||
return nil, errHTTPTooManyRequestsLimitMessages
|
return nil, errHTTPTooManyRequestsLimitMessages
|
||||||
}
|
}
|
||||||
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
||||||
|
@ -588,7 +587,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := newDefaultMessage(t.ID, "")
|
m := newDefaultMessage(t.ID, "")
|
||||||
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m)
|
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vrate, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -607,7 +606,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
m.Message = emptyMessageBody
|
m.Message = emptyMessageBody
|
||||||
}
|
}
|
||||||
delayed := m.Time > time.Now().Unix()
|
delayed := m.Time > time.Now().Unix()
|
||||||
ev := logvrm(vRate, r, m).
|
ev := logvrm(vrate, r, m).
|
||||||
Tag(tagPublish).
|
Tag(tagPublish).
|
||||||
Fields(log.Context{
|
Fields(log.Context{
|
||||||
"message_delayed": delayed,
|
"message_delayed": delayed,
|
||||||
|
@ -625,7 +624,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.firebaseClient != nil && firebase {
|
if s.firebaseClient != nil && firebase {
|
||||||
go s.sendToFirebase(vRate, m)
|
go s.sendToFirebase(vrate, m)
|
||||||
}
|
}
|
||||||
if s.smtpSender != nil && email != "" {
|
if s.smtpSender != nil && email != "" {
|
||||||
go s.sendEmail(v, m, email)
|
go s.sendEmail(v, m, email)
|
||||||
|
@ -657,7 +656,6 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.writeJSON(w, m)
|
return s.writeJSON(w, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -766,7 +764,7 @@ func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, "", false, errHTTPBadRequestPriorityInvalid
|
return false, false, "", false, errHTTPBadRequestPriorityInvalid
|
||||||
}
|
}
|
||||||
m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta")
|
m.Tags = readCommaSeparatedParam(r, "x-tags", "tags", "tag", "ta")
|
||||||
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
||||||
if delayStr != "" {
|
if delayStr != "" {
|
||||||
if !cache {
|
if !cache {
|
||||||
|
@ -986,6 +984,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
for _, t := range topics {
|
||||||
|
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
||||||
|
if subscriberRateLimited {
|
||||||
|
t.SetRateVisitor(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||||
if poll {
|
if poll {
|
||||||
|
@ -995,8 +999,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
defer cancel()
|
defer cancel()
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
||||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
for i, subscriberID := range subscriberIDs {
|
for i, subscriberID := range subscriberIDs {
|
||||||
|
@ -1122,14 +1125,19 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
return conn.WriteJSON(msg)
|
return conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
|
for _, t := range topics {
|
||||||
|
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
||||||
|
if subscriberRateLimited {
|
||||||
|
t.SetRateVisitor(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
if poll {
|
if poll {
|
||||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||||
}
|
}
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
||||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
for i, subscriberID := range subscriberIDs {
|
for i, subscriberID := range subscriberIDs {
|
||||||
|
@ -1161,7 +1169,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
|
subscriberTopics = readCommaSeparatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,8 +26,8 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
vrate := v
|
vrate := v
|
||||||
if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
|
if rateVisitor := t.RateVisitor(); rateVisitor != nil {
|
||||||
vrate = topicCountsAgainst
|
vrate = rateVisitor
|
||||||
}
|
}
|
||||||
r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t))
|
r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t))
|
||||||
|
|
||||||
|
|
|
@ -1894,15 +1894,17 @@ func TestServer_SubscriberRateLimiting(t *testing.T) {
|
||||||
c.VisitorRequestLimitBurst = 3
|
c.VisitorRequestLimitBurst = 3
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
// "Register" visitor 1.2.3.4 to topic "subscriber1topic" as a rate limit visitor
|
||||||
subscriber1Fn := func(r *http.Request) {
|
subscriber1Fn := func(r *http.Request) {
|
||||||
r.RemoteAddr = "1.2.3.4"
|
r.RemoteAddr = "1.2.3.4"
|
||||||
}
|
}
|
||||||
rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
|
rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
|
||||||
"Subscriber-Rate-Limit-Topics": "mytopic1",
|
"Subscriber-Rate-Limit-Topics": "subscriber1topic",
|
||||||
}, subscriber1Fn)
|
}, subscriber1Fn)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
require.Equal(t, "", rr.Body.String())
|
require.Equal(t, "", rr.Body.String())
|
||||||
|
|
||||||
|
// "Register" visitor 8.7.7.1 to topic "upSUB2topic" as a rate limit visitor (implicitly via topic name)
|
||||||
subscriber2Fn := func(r *http.Request) {
|
subscriber2Fn := func(r *http.Request) {
|
||||||
r.RemoteAddr = "8.7.7.1"
|
r.RemoteAddr = "8.7.7.1"
|
||||||
}
|
}
|
||||||
|
@ -1910,20 +1912,28 @@ func TestServer_SubscriberRateLimiting(t *testing.T) {
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
require.Equal(t, "", rr.Body.String())
|
require.Equal(t, "", rr.Body.String())
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
// Publish 2 messages to "subscriber1topic" as visitor 9.9.9.9. It'd be 3 normally, but the
|
||||||
|
// GET request before is also counted towards the request limiter.
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
}
|
}
|
||||||
rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
||||||
require.Equal(t, 429, rr.Code)
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
// Publish another 2 messages to "upSUB2topic" as visitor 9.9.9.9
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
||||||
require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor!
|
require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor!
|
||||||
}
|
}
|
||||||
rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
||||||
require.Equal(t, 429, rr.Code)
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
|
// Hurray! At this point, visitor 9.9.9.9 has published 4 messages, even though
|
||||||
|
// VisitorRequestLimitBurst is 3. That means it's working.
|
||||||
|
|
||||||
|
// Now let's confirm that so far we haven't used up any of visitor 9.9.9.9's request limiter
|
||||||
|
// by publishing another 3 requests from it.
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
rr := request(t, s, "PUT", "/some-other-topic", "some message", nil)
|
rr := request(t, s, "PUT", "/some-other-topic", "some message", nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
@ -1959,18 +1969,18 @@ func newTestServer(t *testing.T, config *Config) *Server {
|
||||||
|
|
||||||
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder {
|
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req, err := http.NewRequest(method, url, strings.NewReader(body))
|
r, err := http.NewRequest(method, url, strings.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
req.RemoteAddr = "9.9.9.9" // Used for tests
|
r.RemoteAddr = "9.9.9.9" // Used for tests
|
||||||
for k, v := range headers {
|
for k, v := range headers {
|
||||||
req.Header.Set(k, v)
|
r.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
for _, f := range fn {
|
for _, f := range fn {
|
||||||
f(req)
|
f(r)
|
||||||
}
|
}
|
||||||
s.handle(rr, req)
|
s.handle(rr, r)
|
||||||
return rr
|
return rr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,9 @@ type topic struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type topicSubscriber struct {
|
type topicSubscriber struct {
|
||||||
subscriber subscriber
|
subscriber subscriber
|
||||||
visitor *visitor // User ID associated with this subscription, may be empty
|
visitor *visitor // User ID associated with this subscription, may be empty
|
||||||
cancel func()
|
cancel func()
|
||||||
subscriberRateLimit bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// subscriber is a function that is called for every new message on a topic
|
// subscriber is a function that is called for every new message on a topic
|
||||||
|
@ -37,39 +36,40 @@ func newTopic(id string) *topic {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe subscribes to this topic
|
// Subscribe subscribes to this topic
|
||||||
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscriberRateLimit bool) int {
|
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
subscriberID := rand.Int()
|
subscriberID := rand.Int()
|
||||||
t.subscribers[subscriberID] = &topicSubscriber{
|
t.subscribers[subscriberID] = &topicSubscriber{
|
||||||
visitor: visitor, // May be empty
|
visitor: visitor, // May be empty
|
||||||
subscriber: s,
|
subscriber: s,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
subscriberRateLimit: subscriberRateLimit,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if no subscriber is already handling the rate limit
|
|
||||||
if t.rateVisitor == nil && subscriberRateLimit {
|
|
||||||
t.rateVisitor = visitor
|
|
||||||
t.rateVisitorExpires = time.Time{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return subscriberID
|
return subscriberID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *topic) Stale() bool {
|
func (t *topic) Stale() bool {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
// if Time is initialized (not the zero value) and the expiry time has passed
|
if t.rateVisitorExpires.Before(time.Now()) {
|
||||||
if !t.rateVisitorExpires.IsZero() && t.rateVisitorExpires.Before(time.Now()) {
|
|
||||||
t.rateVisitor = nil
|
t.rateVisitor = nil
|
||||||
}
|
}
|
||||||
return len(t.subscribers) == 0 && t.rateVisitor == nil
|
return len(t.subscribers) == 0 && t.rateVisitor == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *topic) Billee() *visitor {
|
func (t *topic) SetRateVisitor(v *visitor) {
|
||||||
t.mu.RLock()
|
t.mu.Lock()
|
||||||
defer t.mu.RUnlock()
|
defer t.mu.Unlock()
|
||||||
|
t.rateVisitor = v
|
||||||
|
t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *topic) RateVisitor() *visitor {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if t.rateVisitorExpires.Before(time.Now()) {
|
||||||
|
t.rateVisitor = nil
|
||||||
|
}
|
||||||
return t.rateVisitor
|
return t.rateVisitor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,24 +77,7 @@ func (t *topic) Billee() *visitor {
|
||||||
func (t *topic) Unsubscribe(id int) {
|
func (t *topic) Unsubscribe(id int) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
deletingSub := t.subscribers[id]
|
|
||||||
delete(t.subscribers, id)
|
delete(t.subscribers, id)
|
||||||
|
|
||||||
// look for an active subscriber (in random order) that wants to handle the rate limit
|
|
||||||
for _, v := range t.subscribers {
|
|
||||||
if v.subscriberRateLimit {
|
|
||||||
t.rateVisitor = v.visitor
|
|
||||||
t.rateVisitorExpires = time.Time{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if no active subscriber is found, count it towards the leaving subscriber
|
|
||||||
if deletingSub.subscriberRateLimit {
|
|
||||||
t.rateVisitor = deletingSub.visitor
|
|
||||||
t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish asynchronously publishes to all subscribers
|
// Publish asynchronously publishes to all subscribers
|
||||||
|
|
|
@ -16,7 +16,7 @@ func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
||||||
return value == "1" || value == "yes" || value == "true"
|
return value == "1" || value == "yes" || value == "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
func readCommaSeperatedParam(r *http.Request, names ...string) (params []string) {
|
func readCommaSeparatedParam(r *http.Request, names ...string) (params []string) {
|
||||||
paramStr := readParam(r, names...)
|
paramStr := readParam(r, names...)
|
||||||
if paramStr != "" {
|
if paramStr != "" {
|
||||||
params = make([]string, 0)
|
params = make([]string, 0)
|
||||||
|
|
Loading…
Reference in New Issue