Polishing
parent
8eae44ea61
commit
2329695a47
|
@ -570,14 +570,8 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
|||
}
|
||||
|
||||
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
|
||||
vrate, ok := r.Context().Value(contextRateVisitor).(*visitor)
|
||||
if !ok {
|
||||
return nil, errHTTPInternalError
|
||||
}
|
||||
t, ok := r.Context().Value(contextTopic).(*topic)
|
||||
if !ok {
|
||||
return nil, errHTTPInternalError
|
||||
}
|
||||
t := fromContext[topic](r, contextTopic)
|
||||
vrate := fromContext[visitor](r, contextRateVisitor)
|
||||
if !vrate.MessageAllowed() {
|
||||
return nil, errHTTPTooManyRequestsLimitMessages
|
||||
}
|
||||
|
@ -586,10 +580,13 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
return nil, err
|
||||
}
|
||||
m := newDefaultMessage(t.ID, "")
|
||||
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vrate, m)
|
||||
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if email != "" && !vrate.EmailAllowed() {
|
||||
return nil, errHTTPTooManyRequestsLimitEmails
|
||||
}
|
||||
if m.PollID != "" {
|
||||
m = newPollRequestMessage(t.ID, m.PollID)
|
||||
}
|
||||
|
@ -605,13 +602,15 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
m.Message = emptyMessageBody
|
||||
}
|
||||
delayed := m.Time > time.Now().Unix()
|
||||
ev := logvrm(vrate, r, m).
|
||||
ev := logvrm(v, r, m).
|
||||
Tag(tagPublish).
|
||||
Fields(log.Context{
|
||||
"message_delayed": delayed,
|
||||
"message_firebase": firebase,
|
||||
"message_unifiedpush": unifiedpush,
|
||||
"message_email": email,
|
||||
"rate_visitor_ip": vrate.IP().String(),
|
||||
"rate_user_id": vrate.MaybeUserID(),
|
||||
})
|
||||
if ev.IsTrace() {
|
||||
ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
|
||||
|
@ -623,7 +622,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
return nil, err
|
||||
}
|
||||
if s.firebaseClient != nil && firebase {
|
||||
go s.sendToFirebase(vrate, m)
|
||||
go s.sendToFirebase(v, m)
|
||||
}
|
||||
if s.smtpSender != nil && email != "" {
|
||||
go s.sendEmail(v, m, email)
|
||||
|
@ -708,7 +707,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
||||
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
||||
cache = readBoolParam(r, true, "x-cache", "cache")
|
||||
firebase = readBoolParam(r, true, "x-firebase", "firebase")
|
||||
m.Title = readParam(r, "x-title", "title", "t")
|
||||
|
@ -747,11 +746,6 @@ func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message)
|
|||
m.Icon = icon
|
||||
}
|
||||
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
||||
if email != "" {
|
||||
if !vrate.EmailAllowed() {
|
||||
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
|
||||
}
|
||||
}
|
||||
if s.smtpSender == nil && email != "" {
|
||||
return false, false, "", false, errHTTPBadRequestEmailDisabled
|
||||
}
|
||||
|
@ -993,7 +987,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
defer cancel()
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
|
||||
}
|
||||
defer func() {
|
||||
for i, subscriberID := range subscriberIDs {
|
||||
|
@ -1126,7 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
|
||||
}
|
||||
defer func() {
|
||||
for i, subscriberID := range subscriberIDs {
|
||||
|
|
|
@ -3,7 +3,6 @@ package server
|
|||
import (
|
||||
"heckel.io/ntfy/log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) execManager() {
|
||||
|
@ -38,16 +37,23 @@ func (s *Server) execManager() {
|
|||
subs := t.SubscribersCount()
|
||||
ev := log.Tag(tagManager)
|
||||
if ev.IsTrace() {
|
||||
expiryMessage := ""
|
||||
if subs == 0 {
|
||||
expiryTime := time.Until(t.expires)
|
||||
expiryMessage = ", expires in " + expiryTime.String()
|
||||
vrate := t.RateVisitor()
|
||||
if vrate != nil {
|
||||
ev.Fields(log.Context{
|
||||
"rate_visitor_ip": vrate.IP(),
|
||||
"rate_visitor_user_id": vrate.MaybeUserID(),
|
||||
})
|
||||
}
|
||||
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
|
||||
ev.
|
||||
Fields(log.Context{
|
||||
"message_topic": t.ID,
|
||||
"message_topic_subscribers": subs,
|
||||
}).
|
||||
Trace("- topic %s: %d subscribers", t.ID, subs)
|
||||
}
|
||||
msgs, exists := messageCounts[t.ID]
|
||||
if t.Stale() && (!exists || msgs == 0) {
|
||||
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
|
||||
log.Tag(tagManager).Field("message_topic", t.ID).Trace("Deleting empty topic %s", t.ID)
|
||||
emptyTopics++
|
||||
delete(s.topics, t.ID)
|
||||
continue
|
||||
|
|
|
@ -2030,7 +2030,40 @@ func TestServer_Matrix_SubscriberRateLimiting_UP_Only(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// FIXME add test for rate visitor expiration
|
||||
func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
|
||||
c := newTestConfig(t)
|
||||
c.VisitorRequestLimitBurst = 3
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// "Register" rate visitor
|
||||
subscriberFn := func(r *http.Request) {
|
||||
r.RemoteAddr = "1.2.3.4"
|
||||
}
|
||||
rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{
|
||||
"rate-topics": "*",
|
||||
}, subscriberFn)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String())
|
||||
require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor)
|
||||
|
||||
// Publish message, observe rate visitor tokens being decreased
|
||||
response := request(t, s, "POST", "/mytopic", "some message", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, int64(0), s.visitors["ip:9.9.9.9"].messagesLimiter.Value())
|
||||
require.Equal(t, int64(1), s.topics["mytopic"].rateVisitor.messagesLimiter.Value())
|
||||
require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor)
|
||||
|
||||
// Expire visitor
|
||||
s.visitors["ip:1.2.3.4"].seen = time.Now().Add(-1 * 25 * time.Hour)
|
||||
s.pruneVisitors()
|
||||
|
||||
// Publish message again, observe that rateVisitor is not used anymore and is reset
|
||||
response = request(t, s, "POST", "/mytopic", "some message", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, int64(1), s.visitors["ip:9.9.9.9"].messagesLimiter.Value())
|
||||
require.Nil(t, s.topics["mytopic"].rateVisitor)
|
||||
require.Nil(t, s.visitors["ip:1.2.3.4"])
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *Config {
|
||||
conf := NewConfig()
|
||||
|
|
|
@ -4,11 +4,6 @@ import (
|
|||
"heckel.io/ntfy/log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
topicExpiryDuration = 6 * time.Hour
|
||||
)
|
||||
|
||||
// topic represents a channel to which subscribers can subscribe, and publishers
|
||||
|
@ -17,13 +12,12 @@ type topic struct {
|
|||
ID string
|
||||
subscribers map[int]*topicSubscriber
|
||||
rateVisitor *visitor
|
||||
expires time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type topicSubscriber struct {
|
||||
userID string // User ID associated with this subscription, may be empty
|
||||
subscriber subscriber
|
||||
visitor *visitor // User ID associated with this subscription, may be empty
|
||||
cancel func()
|
||||
}
|
||||
|
||||
|
@ -39,12 +33,12 @@ func newTopic(id string) *topic {
|
|||
}
|
||||
|
||||
// Subscribe subscribes to this topic
|
||||
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
|
||||
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscriberID := rand.Int()
|
||||
t.subscribers[subscriberID] = &topicSubscriber{
|
||||
visitor: visitor, // May be empty
|
||||
userID: userID, // May be empty
|
||||
subscriber: s,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
@ -54,7 +48,10 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
|
|||
func (t *topic) Stale() bool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return len(t.subscribers) == 0 && t.expires.Before(time.Now())
|
||||
if t.rateVisitor != nil && !t.rateVisitor.Stale() {
|
||||
return false
|
||||
}
|
||||
return len(t.subscribers) == 0
|
||||
}
|
||||
|
||||
func (t *topic) SetRateVisitor(v *visitor) {
|
||||
|
@ -66,6 +63,9 @@ func (t *topic) SetRateVisitor(v *visitor) {
|
|||
func (t *topic) RateVisitor() *visitor {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.rateVisitor != nil && t.rateVisitor.Stale() {
|
||||
t.rateVisitor = nil
|
||||
}
|
||||
return t.rateVisitor
|
||||
}
|
||||
|
||||
|
@ -74,9 +74,6 @@ func (t *topic) Unsubscribe(id int) {
|
|||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
delete(t.subscribers, id)
|
||||
if len(t.subscribers) == 0 {
|
||||
t.expires = time.Now().Add(topicExpiryDuration)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish asynchronously publishes to all subscribers
|
||||
|
@ -115,9 +112,14 @@ func (t *topic) CancelSubscribers(exceptUserID string) {
|
|||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, s := range t.subscribers {
|
||||
if s.visitor.MaybeUserID() != exceptUserID {
|
||||
// TODO: Shouldn't this log the IP for anonymous visitors? It was s.userID before my change.
|
||||
log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.visitor.MaybeUserID())
|
||||
if s.userID != exceptUserID {
|
||||
log.
|
||||
Tag(tagSubscribe).
|
||||
Fields(log.Context{
|
||||
"message_topic": t.ID,
|
||||
"user_id": s.userID,
|
||||
}).
|
||||
Debug("Canceling subscriber %s", s.userID)
|
||||
s.cancel()
|
||||
}
|
||||
}
|
||||
|
@ -130,7 +132,7 @@ func (t *topic) subscribersCopy() map[int]*topicSubscriber {
|
|||
subscribers := make(map[int]*topicSubscriber)
|
||||
for k, sub := range t.subscribers {
|
||||
subscribers[k] = &topicSubscriber{
|
||||
visitor: sub.visitor,
|
||||
userID: sub.userID,
|
||||
subscriber: sub.subscriber,
|
||||
cancel: sub.cancel,
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -105,3 +106,11 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
|
|||
}
|
||||
return r.WithContext(c)
|
||||
}
|
||||
|
||||
func fromContext[T any](r *http.Request, key contextKey) *T {
|
||||
t, ok := r.Context().Value(key).(*T)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("cannot find key %v in request context", key))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue