diff --git a/server/errors.go b/server/errors.go index 819f972e..a00105c3 100644 --- a/server/errors.go +++ b/server/errors.go @@ -3,9 +3,8 @@ package server import ( "encoding/json" "fmt" - "net/http" - "heckel.io/ntfy/log" + "net/http" ) // errHTTP is a generic HTTP error for any non-200 HTTP error diff --git a/server/server_test.go b/server/server_test.go index 5e2a30a7..eab70018 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1942,6 +1942,60 @@ func TestServer_SubscriberRateLimiting(t *testing.T) { require.Equal(t, 429, rr.Code) } +func TestServer_SubscriberRateLimiting_UP_Only(t *testing.T) { + c := newTestConfigWithAuthFile(t) + c.VisitorRequestLimitBurst = 3 + s := newTestServer(t, c) + + // "Register" 5 different UnifiedPush visitors + for i := 0; i < 5; i++ { + subscriberFn := func(r *http.Request) { + r.RemoteAddr = fmt.Sprintf("1.2.3.%d", i+1) + } + rr := request(t, s, "GET", fmt.Sprintf("/upsomething%d/json?poll=1", i), "", nil, subscriberFn) + require.Equal(t, 200, rr.Code) + } + + // Publish 2 messages per topic + for i := 0; i < 5; i++ { + for j := 0; j < 2; j++ { + rr := request(t, s, "PUT", fmt.Sprintf("/upsomething%d?up=1", i), "some message", nil) + require.Equal(t, 200, rr.Code) + } + } +} + +func TestServer_Matrix_SubscriberRateLimiting_UP_Only(t *testing.T) { + c := newTestConfig(t) + c.VisitorRequestLimitBurst = 3 + s := newTestServer(t, c) + + // "Register" 5 different UnifiedPush visitors + for i := 0; i < 5; i++ { + subscriberFn := func(r *http.Request) { + r.RemoteAddr = fmt.Sprintf("1.2.3.%d", i+1) + } + rr := request(t, s, "GET", fmt.Sprintf("/upsomething%d/json?poll=1", i), "", nil, subscriberFn) + require.Equal(t, 200, rr.Code) + } + + // Publish 2 messages per topic + for i := 0; i < 5; i++ { + notification := fmt.Sprintf(`{"notification":{"devices":[{"pushkey":"http://127.0.0.1:12345/upsomething%d?up=1"}]}}`, i) + for j := 0; j < 2; j++ { + response := request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil) + require.Equal(t, 200, response.Code) + require.Equal(t, `{"rejected":[]}`+"\n", response.Body.String()) + } + response := request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil) + require.Equal(t, 429, response.Code, notification) + // FIXME this is because we switched the order of the "limitRequests" handler + // FIXME there should be tests for the 429s on the "/" and "/_matrix.." endpoint + + require.Equal(t, fmt.Sprintf(`{"rejected":["http://127.0.0.1:12345/upsomething%d?up=1"]}`+"\n", i), response.Body.String()) + } +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.BaseURL = "http://127.0.0.1:12345"