Merge branch 'ip-range-exempt'

pull/432/head
Philipp Heckel 2022-10-08 17:58:21 -04:00
commit cbc912d1e3
14 changed files with 161 additions and 60 deletions

View File

@ -5,16 +5,18 @@ package cmd
import ( import (
"errors" "errors"
"fmt" "fmt"
"heckel.io/ntfy/log"
"io/fs" "io/fs"
"math" "math"
"net" "net"
"net/netip"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"heckel.io/ntfy/log"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc" "github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/server" "heckel.io/ntfy/server"
@ -208,16 +210,14 @@ func execServe(c *cli.Context) error {
} }
// Resolve hosts // Resolve hosts
visitorRequestLimitExemptIPs := make([]string, 0) visitorRequestLimitExemptIPs := make([]netip.Prefix, 0)
for _, host := range visitorRequestLimitExemptHosts { for _, host := range visitorRequestLimitExemptHosts {
ips, err := net.LookupIP(host) ips, err := parseIPHostPrefix(host)
if err != nil { if err != nil {
log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
continue continue
} }
for _, ip := range ips { visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String())
}
} }
// Run server // Run server
@ -303,6 +303,31 @@ func sigHandlerConfigReload(config string) {
} }
} }
func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
// Try parsing as prefix, e.g. 10.0.1.0/24
prefix, err := netip.ParsePrefix(host)
if err == nil {
prefixes = append(prefixes, prefix.Masked())
return prefixes, nil
}
// Not a prefix, parse as host or IP (LookupHost passes through an IP as is)
ips, err := net.LookupHost(host)
if err != nil {
return nil, err
}
for _, ipStr := range ips {
ip, err := netip.ParseAddr(ipStr)
if err == nil {
prefix, err := ip.Prefix(ip.BitLen())
if err != nil {
return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error())
}
prefixes = append(prefixes, prefix.Masked())
}
}
return
}
func reloadLogLevel(inputSource altsrc.InputSourceContext) { func reloadLogLevel(inputSource altsrc.InputSourceContext) {
newLevelStr, err := inputSource.String("log-level") newLevelStr, err := inputSource.String("log-level")
if err != nil { if err != nil {

View File

@ -2,17 +2,19 @@ package cmd
import ( import (
"fmt" "fmt"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"heckel.io/ntfy/test"
"heckel.io/ntfy/util"
"math/rand" "math/rand"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"heckel.io/ntfy/test"
"heckel.io/ntfy/util"
) )
func init() { func init() {
@ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "mytopic", m.Topic)
} }
func TestIP_Host_Parsing(t *testing.T) {
cases := map[string]string{
"1.1.1.1": "1.1.1.1/32",
"fd00::1234": "fd00::1234/128",
"192.168.0.3/24": "192.168.0.0/24",
"10.1.2.3/8": "10.0.0.0/8",
"201:be93::4a6/21": "201:b800::/21",
}
for q, expectedAnswer := range cases {
ips, err := parseIPHostPrefix(q)
require.Nil(t, err)
assert.Equal(t, 1, len(ips))
assert.Equal(t, expectedAnswer, ips[0].String())
}
}
func newEmptyFile(t *testing.T) string { func newEmptyFile(t *testing.T) string {
filename := filepath.Join(t.TempDir(), "empty") filename := filepath.Join(t.TempDir(), "empty")
require.Nil(t, os.WriteFile(filename, []byte{}, 0600)) require.Nil(t, os.WriteFile(filename, []byte{}, 0600))

View File

@ -2,6 +2,7 @@ package server
import ( import (
"io/fs" "io/fs"
"net/netip"
"time" "time"
) )
@ -92,7 +93,7 @@ type Config struct {
VisitorAttachmentDailyBandwidthLimit int VisitorAttachmentDailyBandwidthLimit int
VisitorRequestLimitBurst int VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration VisitorRequestLimitReplenish time.Duration
VisitorRequestExemptIPAddrs []string VisitorRequestExemptIPAddrs []netip.Prefix
VisitorEmailLimitBurst int VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration VisitorEmailLimitReplenish time.Duration
BehindProxy bool BehindProxy bool
@ -135,7 +136,7 @@ func NewConfig() *Config {
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorRequestExemptIPAddrs: make([]string, 0), VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
BehindProxy: false, BehindProxy: false,

View File

@ -5,11 +5,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"strings"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"strings"
"time"
) )
var ( var (
@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error {
attachmentSize, attachmentSize,
attachmentExpires, attachmentExpires,
attachmentURL, attachmentURL,
m.Sender, m.Sender.String(),
m.Encoding, m.Encoding,
published, published,
) )
@ -454,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
return nil, err return nil, err
} }
} }
senderIP, err := netip.ParseAddr(sender)
if err != nil {
senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0
}
var att *attachment var att *attachment
if attachmentName != "" && attachmentURL != "" { if attachmentName != "" && attachmentURL != "" {
att = &attachment{ att = &attachment{
@ -477,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Icon: icon, Icon: icon,
Actions: actions, Actions: actions,
Attachment: att, Attachment: att,
Sender: sender, Sender: senderIP, // Must parse assuming database must be correct
Encoding: encoding, Encoding: encoding,
}) })
} }

View File

@ -3,11 +3,17 @@ package server
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/stretchr/testify/assert" "net/netip"
"github.com/stretchr/testify/require"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
exampleIP1234 = netip.MustParseAddr("1.2.3.4")
) )
func TestSqliteCache_Messages(t *testing.T) { func TestSqliteCache_Messages(t *testing.T) {
@ -281,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() expires1 := time.Now().Add(-4 * time.Hour).Unix()
m := newDefaultMessage("mytopic", "flower for you") m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1" m.ID = "m1"
m.Sender = "1.2.3.4" m.Sender = exampleIP1234
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "flower.jpg", Name: "flower.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -294,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car") m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2" m.ID = "m2"
m.Sender = "1.2.3.4" m.Sender = exampleIP1234
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "car.jpg", Name: "car.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -307,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car") m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3" m.ID = "m3"
m.Sender = "1.2.3.4" m.Sender = exampleIP1234
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "another-car.jpg", Name: "another-car.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -327,7 +333,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, int64(5000), messages[0].Attachment.Size) require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires) require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender) require.Equal(t, "1.2.3.4", messages[0].Sender.String())
require.Equal(t, "sending you a car", messages[1].Message) require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name) require.Equal(t, "car.jpg", messages[1].Attachment.Name)
@ -335,7 +341,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, int64(10000), messages[1].Attachment.Size) require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires) require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender) require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := c.AttachmentBytesUsed("1.2.3.4") size, err := c.AttachmentBytesUsed("1.2.3.4")
require.Nil(t, err) require.Nil(t, err)

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"os" "os"
"path" "path"
@ -42,7 +43,7 @@ type Server struct {
smtpServerBackend *smtpBackend smtpServerBackend *smtpBackend
smtpSender mailer smtpSender mailer
topics map[string]*topic topics map[string]*topic
visitors map[string]*visitor visitors map[netip.Addr]*visitor
firebaseClient *firebaseClient firebaseClient *firebaseClient
messages int64 messages int64
auth auth.Auther auth auth.Auther
@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) {
smtpSender: mailer, smtpSender: mailer,
topics: topics, topics: topics,
auth: auther, auth: auther,
visitors: make(map[string]*visitor), visitors: make(map[netip.Addr]*visitor),
}, nil }, nil
} }
@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
if s.firebaseClient == nil { if s.firebaseClient == nil {
return return
} }
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
for { for {
select { select {
case <-time.After(s.config.FirebaseKeepaliveInterval): case <-time.After(s.config.FirebaseKeepaliveInterval):
@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
func (s *Server) limitRequests(next handleFunc) handleFunc { func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error { return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) { if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v) return next(w, r, v)
} else if err := v.RequestAllowed(); err != nil { } else if err := v.RequestAllowed(); err != nil {
return errHTTPTooManyRequestsLimitRequests return errHTTPTooManyRequestsLimitRequests
@ -1436,21 +1437,33 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor { func (s *Server) visitor(r *http.Request) *visitor {
remoteAddr := r.RemoteAddr remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr) addrPort, err := netip.ParseAddrPort(remoteAddr)
ip := addrPort.Addr()
if err != nil { if err != nil {
ip = remoteAddr // This should not happen in real life; only in tests. // This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
ip, err = netip.ParseAddr(remoteAddr)
if err != nil {
ip = netip.IPv4Unspecified()
log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err)
}
} }
if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
// only the right-most address can be trusted (as this is the one added by our proxy server). // only the right-most address can be trusted (as this is the one added by our proxy server).
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
ip = strings.TrimSpace(util.LastString(ips, remoteAddr)) realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
if err != nil {
log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
// Fall back to regular remote address if X-Forwarded-For is damaged
} else {
ip = realIP
}
} }
return s.visitorFromIP(ip) return s.visitorFromIP(ip)
} }
func (s *Server) visitorFromIP(ip string) *visitor { func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
v, exists := s.visitors[ip] v, exists := s.visitors[ip]

View File

@ -3,13 +3,15 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"firebase.google.com/go/v4/messaging"
"fmt" "fmt"
"github.com/stretchr/testify/require" "net/netip"
"heckel.io/ntfy/auth"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"firebase.google.com/go/v4/messaging"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/auth"
) )
type testAuther struct { type testAuther struct {
@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
func TestToFirebaseSender_Abuse(t *testing.T) { func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2} sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{}) client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.Messages())) require.Equal(t, 1, len(sender.Messages()))

View File

@ -1,11 +1,13 @@
package server package server
import ( import (
"github.com/stretchr/testify/require"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) { func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
func TestMatrix_WriteMatrixError(t *testing.T) { func TestMatrix_WriteMatrixError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil) r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
v := newVisitor(newTestConfig(t), nil, "1.2.3.4") v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch})) require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
require.Equal(t, 200, w.Result().StatusCode) require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String()) require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())

View File

@ -6,18 +6,20 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"io" "io"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/auth" "heckel.io/ntfy/auth"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
@ -292,13 +294,13 @@ func TestServer_PublishAt(t *testing.T) {
messages = toMessages(t, response.Body.String()) messages = toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages)) require.Equal(t, 1, len(messages))
require.Equal(t, "a message", messages[0].Message) require.Equal(t, "a message", messages[0].Message)
require.Equal(t, "", messages[0].Sender) // Never return the sender! require.Equal(t, netip.Addr{}, messages[0].Sender) // Never return the sender!
messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true) messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 1, len(messages)) require.Equal(t, 1, len(messages))
require.Equal(t, "a message", messages[0].Message) require.Equal(t, "a message", messages[0].Message)
require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though! require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though!
} }
func TestServer_PublishAtWithCacheError(t *testing.T) { func TestServer_PublishAtWithCacheError(t *testing.T) {
@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
c := newTestConfig(t) c := newTestConfig(t)
c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request() c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
s := newTestServer(t, c) s := newTestServer(t, c)
for i := 0; i < 65; i++ { // > 60 for i := 0; i < 65; i++ { // > 60
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
@ -1132,7 +1134,7 @@ func TestServer_PublishAttachment(t *testing.T) {
require.Equal(t, int64(5000), msg.Attachment.Size) require.Equal(t, int64(5000), msg.Attachment.Size)
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.Equal(t, "", msg.Sender) // Should never be returned require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
// GET // GET
@ -1168,7 +1170,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
require.Equal(t, int64(21), msg.Attachment.Size) require.Equal(t, int64(21), msg.Attachment.Size)
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix()) require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.Equal(t, "", msg.Sender) // Should never be returned require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
@ -1195,7 +1197,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
require.Equal(t, "", msg.Attachment.Type) require.Equal(t, "", msg.Attachment.Type)
require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Size)
require.Equal(t, int64(0), msg.Attachment.Expires) require.Equal(t, int64(0), msg.Attachment.Expires)
require.Equal(t, "", msg.Sender) require.Equal(t, netip.Addr{}, msg.Sender)
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments // Slightly unrelated cross-test: make sure we don't add an owner for external attachments
size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
@ -1216,7 +1218,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) {
require.Equal(t, "", msg.Attachment.Type) require.Equal(t, "", msg.Attachment.Type)
require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Size)
require.Equal(t, int64(0), msg.Attachment.Expires) require.Equal(t, int64(0), msg.Attachment.Expires)
require.Equal(t, "", msg.Sender) require.Equal(t, netip.Addr{}, msg.Sender)
} }
func TestServer_PublishAttachmentBadURL(t *testing.T) { func TestServer_PublishAttachmentBadURL(t *testing.T) {
@ -1391,7 +1393,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
r.RemoteAddr = "8.9.10.11" r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty! r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
v := s.visitor(r) v := s.visitor(r)
require.Equal(t, "8.9.10.11", v.ip) require.Equal(t, "8.9.10.11", v.ip.String())
} }
func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
@ -1402,7 +1404,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
r.RemoteAddr = "8.9.10.11" r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.1.1.1") r.Header.Set("X-Forwarded-For", "1.1.1.1")
v := s.visitor(r) v := s.visitor(r)
require.Equal(t, "1.1.1.1", v.ip) require.Equal(t, "1.1.1.1", v.ip.String())
} }
func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
@ -1413,7 +1415,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
r.RemoteAddr = "8.9.10.11" r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ") r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
v := s.visitor(r) v := s.visitor(r)
require.Equal(t, "234.5.2.1", v.ip) require.Equal(t, "234.5.2.1", v.ip.String())
} }
func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {

View File

@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
if err != nil { if err != nil {
return err return err
} }
message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m) message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,9 +1,11 @@
package server package server
import ( import (
"heckel.io/ntfy/util"
"net/http" "net/http"
"net/netip"
"time" "time"
"heckel.io/ntfy/util"
) )
// List of possible events // List of possible events
@ -33,7 +35,7 @@ type message struct {
Actions []*action `json:"actions,omitempty"` Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"` Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"` PollID string `json:"poll_id,omitempty"`
Sender string `json:"-"` // IP address of uploader, used for rate limiting Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
} }

View File

@ -2,10 +2,12 @@ package server
import ( import (
"errors" "errors"
"golang.org/x/time/rate" "net/netip"
"heckel.io/ntfy/util"
"sync" "sync"
"time" "time"
"golang.org/x/time/rate"
"heckel.io/ntfy/util"
) )
const ( const (
@ -23,7 +25,7 @@ var (
type visitor struct { type visitor struct {
config *Config config *Config
messageCache *messageCache messageCache *messageCache
ip string ip netip.Addr
requests *rate.Limiter requests *rate.Limiter
emails *rate.Limiter emails *rate.Limiter
subscriptions util.Limiter subscriptions util.Limiter
@ -40,7 +42,7 @@ type visitorStats struct {
VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
} }
func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor {
return &visitor{ return &visitor{
config: conf, config: conf,
messageCache: messageCache, messageCache: messageCache,
@ -115,7 +117,7 @@ func (v *visitor) Stale() bool {
} }
func (v *visitor) Stats() (*visitorStats, error) { func (v *visitor) Stats() (*visitorStats, error) {
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -5,16 +5,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gabriel-vasile/mimetype"
"golang.org/x/term"
"io" "io"
"math/rand" "math/rand"
"net/netip"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gabriel-vasile/mimetype"
"golang.org/x/term"
) )
const ( const (
@ -45,6 +47,16 @@ func Contains[T comparable](haystack []T, needle T) bool {
return false return false
} }
// ContainsIP returns true if any one of the of prefixes contains the ip.
func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
for _, s := range haystack {
if s.Contains(needle) {
return true
}
}
return false
}
// ContainsAll returns true if all needles are contained in haystack // ContainsAll returns true if all needles are contained in haystack
func ContainsAll[T comparable](haystack []T, needles []T) bool { func ContainsAll[T comparable](haystack []T, needles []T) bool {
matches := 0 matches := 0

View File

@ -1,10 +1,12 @@
package util package util
import ( import (
"github.com/stretchr/testify/require" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestRandomString(t *testing.T) { func TestRandomString(t *testing.T) {
@ -42,6 +44,13 @@ func TestContains(t *testing.T) {
require.False(t, Contains(s, 3)) require.False(t, Contains(s, 3))
} }
func TestContainsIP(t *testing.T) {
require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1")))
require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876")))
require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.2.0.1")))
require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fc00::1")))
}
func TestSplitNoEmpty(t *testing.T) { func TestSplitNoEmpty(t *testing.T) {
require.Equal(t, []string{}, SplitNoEmpty("", ",")) require.Equal(t, []string{}, SplitNoEmpty("", ","))
require.Equal(t, []string{}, SplitNoEmpty(",,,", ",")) require.Equal(t, []string{}, SplitNoEmpty(",,,", ","))