Merge branch 'ip-range-exempt'
commit
cbc912d1e3
37
cmd/serve.go
37
cmd/serve.go
|
@ -5,16 +5,18 @@ package cmd
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"github.com/urfave/cli/v2/altsrc"
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
"heckel.io/ntfy/server"
|
"heckel.io/ntfy/server"
|
||||||
|
@ -208,16 +210,14 @@ func execServe(c *cli.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve hosts
|
// Resolve hosts
|
||||||
visitorRequestLimitExemptIPs := make([]string, 0)
|
visitorRequestLimitExemptIPs := make([]netip.Prefix, 0)
|
||||||
for _, host := range visitorRequestLimitExemptHosts {
|
for _, host := range visitorRequestLimitExemptHosts {
|
||||||
ips, err := net.LookupIP(host)
|
ips, err := parseIPHostPrefix(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
|
log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, ip := range ips {
|
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
|
||||||
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
|
@ -303,6 +303,31 @@ func sigHandlerConfigReload(config string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
|
||||||
|
// Try parsing as prefix, e.g. 10.0.1.0/24
|
||||||
|
prefix, err := netip.ParsePrefix(host)
|
||||||
|
if err == nil {
|
||||||
|
prefixes = append(prefixes, prefix.Masked())
|
||||||
|
return prefixes, nil
|
||||||
|
}
|
||||||
|
// Not a prefix, parse as host or IP (LookupHost passes through an IP as is)
|
||||||
|
ips, err := net.LookupHost(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, ipStr := range ips {
|
||||||
|
ip, err := netip.ParseAddr(ipStr)
|
||||||
|
if err == nil {
|
||||||
|
prefix, err := ip.Prefix(ip.BitLen())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error())
|
||||||
|
}
|
||||||
|
prefixes = append(prefixes, prefix.Masked())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func reloadLogLevel(inputSource altsrc.InputSourceContext) {
|
func reloadLogLevel(inputSource altsrc.InputSourceContext) {
|
||||||
newLevelStr, err := inputSource.String("log-level")
|
newLevelStr, err := inputSource.String("log-level")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,17 +2,19 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/client"
|
|
||||||
"heckel.io/ntfy/test"
|
|
||||||
"heckel.io/ntfy/util"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/client"
|
||||||
|
"heckel.io/ntfy/test"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) {
|
||||||
require.Equal(t, "mytopic", m.Topic)
|
require.Equal(t, "mytopic", m.Topic)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIP_Host_Parsing(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"1.1.1.1": "1.1.1.1/32",
|
||||||
|
"fd00::1234": "fd00::1234/128",
|
||||||
|
"192.168.0.3/24": "192.168.0.0/24",
|
||||||
|
"10.1.2.3/8": "10.0.0.0/8",
|
||||||
|
"201:be93::4a6/21": "201:b800::/21",
|
||||||
|
}
|
||||||
|
for q, expectedAnswer := range cases {
|
||||||
|
ips, err := parseIPHostPrefix(q)
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, 1, len(ips))
|
||||||
|
assert.Equal(t, expectedAnswer, ips[0].String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newEmptyFile(t *testing.T) string {
|
func newEmptyFile(t *testing.T) string {
|
||||||
filename := filepath.Join(t.TempDir(), "empty")
|
filename := filepath.Join(t.TempDir(), "empty")
|
||||||
require.Nil(t, os.WriteFile(filename, []byte{}, 0600))
|
require.Nil(t, os.WriteFile(filename, []byte{}, 0600))
|
||||||
|
|
|
@ -2,6 +2,7 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,7 +93,7 @@ type Config struct {
|
||||||
VisitorAttachmentDailyBandwidthLimit int
|
VisitorAttachmentDailyBandwidthLimit int
|
||||||
VisitorRequestLimitBurst int
|
VisitorRequestLimitBurst int
|
||||||
VisitorRequestLimitReplenish time.Duration
|
VisitorRequestLimitReplenish time.Duration
|
||||||
VisitorRequestExemptIPAddrs []string
|
VisitorRequestExemptIPAddrs []netip.Prefix
|
||||||
VisitorEmailLimitBurst int
|
VisitorEmailLimitBurst int
|
||||||
VisitorEmailLimitReplenish time.Duration
|
VisitorEmailLimitReplenish time.Duration
|
||||||
BehindProxy bool
|
BehindProxy bool
|
||||||
|
@ -135,7 +136,7 @@ func NewConfig() *Config {
|
||||||
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
|
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
|
||||||
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
||||||
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
||||||
VisitorRequestExemptIPAddrs: make([]string, 0),
|
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
||||||
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
||||||
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
||||||
BehindProxy: false,
|
BehindProxy: false,
|
||||||
|
|
|
@ -5,11 +5,13 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error {
|
||||||
attachmentSize,
|
attachmentSize,
|
||||||
attachmentExpires,
|
attachmentExpires,
|
||||||
attachmentURL,
|
attachmentURL,
|
||||||
m.Sender,
|
m.Sender.String(),
|
||||||
m.Encoding,
|
m.Encoding,
|
||||||
published,
|
published,
|
||||||
)
|
)
|
||||||
|
@ -454,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
senderIP, err := netip.ParseAddr(sender)
|
||||||
|
if err != nil {
|
||||||
|
senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0
|
||||||
|
}
|
||||||
|
|
||||||
var att *attachment
|
var att *attachment
|
||||||
if attachmentName != "" && attachmentURL != "" {
|
if attachmentName != "" && attachmentURL != "" {
|
||||||
att = &attachment{
|
att = &attachment{
|
||||||
|
@ -477,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||||
Icon: icon,
|
Icon: icon,
|
||||||
Actions: actions,
|
Actions: actions,
|
||||||
Attachment: att,
|
Attachment: att,
|
||||||
Sender: sender,
|
Sender: senderIP, // Must parse assuming database must be correct
|
||||||
Encoding: encoding,
|
Encoding: encoding,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,11 +3,17 @@ package server
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/assert"
|
"net/netip"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
exampleIP1234 = netip.MustParseAddr("1.2.3.4")
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSqliteCache_Messages(t *testing.T) {
|
func TestSqliteCache_Messages(t *testing.T) {
|
||||||
|
@ -281,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||||
expires1 := time.Now().Add(-4 * time.Hour).Unix()
|
expires1 := time.Now().Add(-4 * time.Hour).Unix()
|
||||||
m := newDefaultMessage("mytopic", "flower for you")
|
m := newDefaultMessage("mytopic", "flower for you")
|
||||||
m.ID = "m1"
|
m.ID = "m1"
|
||||||
m.Sender = "1.2.3.4"
|
m.Sender = exampleIP1234
|
||||||
m.Attachment = &attachment{
|
m.Attachment = &attachment{
|
||||||
Name: "flower.jpg",
|
Name: "flower.jpg",
|
||||||
Type: "image/jpeg",
|
Type: "image/jpeg",
|
||||||
|
@ -294,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||||
m = newDefaultMessage("mytopic", "sending you a car")
|
m = newDefaultMessage("mytopic", "sending you a car")
|
||||||
m.ID = "m2"
|
m.ID = "m2"
|
||||||
m.Sender = "1.2.3.4"
|
m.Sender = exampleIP1234
|
||||||
m.Attachment = &attachment{
|
m.Attachment = &attachment{
|
||||||
Name: "car.jpg",
|
Name: "car.jpg",
|
||||||
Type: "image/jpeg",
|
Type: "image/jpeg",
|
||||||
|
@ -307,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||||
m = newDefaultMessage("another-topic", "sending you another car")
|
m = newDefaultMessage("another-topic", "sending you another car")
|
||||||
m.ID = "m3"
|
m.ID = "m3"
|
||||||
m.Sender = "1.2.3.4"
|
m.Sender = exampleIP1234
|
||||||
m.Attachment = &attachment{
|
m.Attachment = &attachment{
|
||||||
Name: "another-car.jpg",
|
Name: "another-car.jpg",
|
||||||
Type: "image/jpeg",
|
Type: "image/jpeg",
|
||||||
|
@ -327,7 +333,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||||
require.Equal(t, "1.2.3.4", messages[0].Sender)
|
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||||
|
|
||||||
require.Equal(t, "sending you a car", messages[1].Message)
|
require.Equal(t, "sending you a car", messages[1].Message)
|
||||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||||
|
@ -335,7 +341,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||||
require.Equal(t, "1.2.3.4", messages[1].Sender)
|
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||||
|
|
||||||
size, err := c.AttachmentBytesUsed("1.2.3.4")
|
size, err := c.AttachmentBytesUsed("1.2.3.4")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
@ -42,7 +43,7 @@ type Server struct {
|
||||||
smtpServerBackend *smtpBackend
|
smtpServerBackend *smtpBackend
|
||||||
smtpSender mailer
|
smtpSender mailer
|
||||||
topics map[string]*topic
|
topics map[string]*topic
|
||||||
visitors map[string]*visitor
|
visitors map[netip.Addr]*visitor
|
||||||
firebaseClient *firebaseClient
|
firebaseClient *firebaseClient
|
||||||
messages int64
|
messages int64
|
||||||
auth auth.Auther
|
auth auth.Auther
|
||||||
|
@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) {
|
||||||
smtpSender: mailer,
|
smtpSender: mailer,
|
||||||
topics: topics,
|
topics: topics,
|
||||||
auth: auther,
|
auth: auther,
|
||||||
visitors: make(map[string]*visitor),
|
visitors: make(map[netip.Addr]*visitor),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
|
||||||
if s.firebaseClient == nil {
|
if s.firebaseClient == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
|
v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||||
|
@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||||
|
|
||||||
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
} else if err := v.RequestAllowed(); err != nil {
|
} else if err := v.RequestAllowed(); err != nil {
|
||||||
return errHTTPTooManyRequestsLimitRequests
|
return errHTTPTooManyRequestsLimitRequests
|
||||||
|
@ -1436,21 +1437,33 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
|
||||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||||
func (s *Server) visitor(r *http.Request) *visitor {
|
func (s *Server) visitor(r *http.Request) *visitor {
|
||||||
remoteAddr := r.RemoteAddr
|
remoteAddr := r.RemoteAddr
|
||||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
||||||
|
ip := addrPort.Addr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ip = remoteAddr // This should not happen in real life; only in tests.
|
// This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
|
||||||
|
ip, err = netip.ParseAddr(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
ip = netip.IPv4Unspecified()
|
||||||
|
log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
|
if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
|
||||||
// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
|
// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
|
||||||
// only the right-most address can be trusted (as this is the one added by our proxy server).
|
// only the right-most address can be trusted (as this is the one added by our proxy server).
|
||||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
|
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
|
||||||
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
|
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
|
||||||
ip = strings.TrimSpace(util.LastString(ips, remoteAddr))
|
realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
|
||||||
|
if err != nil {
|
||||||
|
log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
|
||||||
|
// Fall back to regular remote address if X-Forwarded-For is damaged
|
||||||
|
} else {
|
||||||
|
ip = realIP
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return s.visitorFromIP(ip)
|
return s.visitorFromIP(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) visitorFromIP(ip string) *visitor {
|
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
v, exists := s.visitors[ip]
|
v, exists := s.visitors[ip]
|
||||||
|
|
|
@ -3,13 +3,15 @@ package server
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"firebase.google.com/go/v4/messaging"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/require"
|
"net/netip"
|
||||||
"heckel.io/ntfy/auth"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"firebase.google.com/go/v4/messaging"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testAuther struct {
|
type testAuther struct {
|
||||||
|
@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
|
||||||
func TestToFirebaseSender_Abuse(t *testing.T) {
|
func TestToFirebaseSender_Abuse(t *testing.T) {
|
||||||
sender := &testFirebaseSender{allowed: 2}
|
sender := &testFirebaseSender{allowed: 2}
|
||||||
client := newFirebaseClient(sender, &testAuther{})
|
client := newFirebaseClient(sender, &testAuther{})
|
||||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
|
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))
|
||||||
|
|
||||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||||
require.Equal(t, 1, len(sender.Messages()))
|
require.Equal(t, 1, len(sender.Messages()))
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
|
func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
|
||||||
|
@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
|
||||||
func TestMatrix_WriteMatrixError(t *testing.T) {
|
func TestMatrix_WriteMatrixError(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
|
r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
|
||||||
v := newVisitor(newTestConfig(t), nil, "1.2.3.4")
|
v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
|
||||||
require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
|
require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
|
||||||
require.Equal(t, 200, w.Result().StatusCode)
|
require.Equal(t, 200, w.Result().StatusCode)
|
||||||
require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())
|
require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())
|
||||||
|
|
|
@ -6,18 +6,20 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"heckel.io/ntfy/auth"
|
"heckel.io/ntfy/auth"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
|
@ -292,13 +294,13 @@ func TestServer_PublishAt(t *testing.T) {
|
||||||
messages = toMessages(t, response.Body.String())
|
messages = toMessages(t, response.Body.String())
|
||||||
require.Equal(t, 1, len(messages))
|
require.Equal(t, 1, len(messages))
|
||||||
require.Equal(t, "a message", messages[0].Message)
|
require.Equal(t, "a message", messages[0].Message)
|
||||||
require.Equal(t, "", messages[0].Sender) // Never return the sender!
|
require.Equal(t, netip.Addr{}, messages[0].Sender) // Never return the sender!
|
||||||
|
|
||||||
messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true)
|
messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 1, len(messages))
|
require.Equal(t, 1, len(messages))
|
||||||
require.Equal(t, "a message", messages[0].Message)
|
require.Equal(t, "a message", messages[0].Message)
|
||||||
require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though!
|
require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though!
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PublishAtWithCacheError(t *testing.T) {
|
func TestServer_PublishAtWithCacheError(t *testing.T) {
|
||||||
|
@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
|
||||||
|
|
||||||
func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
|
func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
|
||||||
c := newTestConfig(t)
|
c := newTestConfig(t)
|
||||||
c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request()
|
c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
for i := 0; i < 65; i++ { // > 60
|
for i := 0; i < 65; i++ { // > 60
|
||||||
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
|
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
|
||||||
|
@ -1132,7 +1134,7 @@ func TestServer_PublishAttachment(t *testing.T) {
|
||||||
require.Equal(t, int64(5000), msg.Attachment.Size)
|
require.Equal(t, int64(5000), msg.Attachment.Size)
|
||||||
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours
|
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours
|
||||||
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
||||||
require.Equal(t, "", msg.Sender) // Should never be returned
|
require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
|
||||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
||||||
|
|
||||||
// GET
|
// GET
|
||||||
|
@ -1168,7 +1170,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
|
||||||
require.Equal(t, int64(21), msg.Attachment.Size)
|
require.Equal(t, int64(21), msg.Attachment.Size)
|
||||||
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
|
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
|
||||||
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
||||||
require.Equal(t, "", msg.Sender) // Should never be returned
|
require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
|
||||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
||||||
|
|
||||||
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
||||||
|
@ -1195,7 +1197,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
|
||||||
require.Equal(t, "", msg.Attachment.Type)
|
require.Equal(t, "", msg.Attachment.Type)
|
||||||
require.Equal(t, int64(0), msg.Attachment.Size)
|
require.Equal(t, int64(0), msg.Attachment.Size)
|
||||||
require.Equal(t, int64(0), msg.Attachment.Expires)
|
require.Equal(t, int64(0), msg.Attachment.Expires)
|
||||||
require.Equal(t, "", msg.Sender)
|
require.Equal(t, netip.Addr{}, msg.Sender)
|
||||||
|
|
||||||
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
|
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
|
||||||
size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
|
size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
|
||||||
|
@ -1216,7 +1218,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) {
|
||||||
require.Equal(t, "", msg.Attachment.Type)
|
require.Equal(t, "", msg.Attachment.Type)
|
||||||
require.Equal(t, int64(0), msg.Attachment.Size)
|
require.Equal(t, int64(0), msg.Attachment.Size)
|
||||||
require.Equal(t, int64(0), msg.Attachment.Expires)
|
require.Equal(t, int64(0), msg.Attachment.Expires)
|
||||||
require.Equal(t, "", msg.Sender)
|
require.Equal(t, netip.Addr{}, msg.Sender)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PublishAttachmentBadURL(t *testing.T) {
|
func TestServer_PublishAttachmentBadURL(t *testing.T) {
|
||||||
|
@ -1391,7 +1393,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
|
||||||
r.RemoteAddr = "8.9.10.11"
|
r.RemoteAddr = "8.9.10.11"
|
||||||
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
||||||
v := s.visitor(r)
|
v := s.visitor(r)
|
||||||
require.Equal(t, "8.9.10.11", v.ip)
|
require.Equal(t, "8.9.10.11", v.ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
||||||
|
@ -1402,7 +1404,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
||||||
r.RemoteAddr = "8.9.10.11"
|
r.RemoteAddr = "8.9.10.11"
|
||||||
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
||||||
v := s.visitor(r)
|
v := s.visitor(r)
|
||||||
require.Equal(t, "1.1.1.1", v.ip)
|
require.Equal(t, "1.1.1.1", v.ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
||||||
|
@ -1413,7 +1415,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
||||||
r.RemoteAddr = "8.9.10.11"
|
r.RemoteAddr = "8.9.10.11"
|
||||||
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
||||||
v := s.visitor(r)
|
v := s.visitor(r)
|
||||||
require.Equal(t, "234.5.2.1", v.ip)
|
require.Equal(t, "234.5.2.1", v.ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m)
|
message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"heckel.io/ntfy/util"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// List of possible events
|
// List of possible events
|
||||||
|
@ -33,7 +35,7 @@ type message struct {
|
||||||
Actions []*action `json:"actions,omitempty"`
|
Actions []*action `json:"actions,omitempty"`
|
||||||
Attachment *attachment `json:"attachment,omitempty"`
|
Attachment *attachment `json:"attachment,omitempty"`
|
||||||
PollID string `json:"poll_id,omitempty"`
|
PollID string `json:"poll_id,omitempty"`
|
||||||
Sender string `json:"-"` // IP address of uploader, used for rate limiting
|
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||||
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
|
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,12 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/time/rate"
|
"net/netip"
|
||||||
"heckel.io/ntfy/util"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -23,7 +25,7 @@ var (
|
||||||
type visitor struct {
|
type visitor struct {
|
||||||
config *Config
|
config *Config
|
||||||
messageCache *messageCache
|
messageCache *messageCache
|
||||||
ip string
|
ip netip.Addr
|
||||||
requests *rate.Limiter
|
requests *rate.Limiter
|
||||||
emails *rate.Limiter
|
emails *rate.Limiter
|
||||||
subscriptions util.Limiter
|
subscriptions util.Limiter
|
||||||
|
@ -40,7 +42,7 @@ type visitorStats struct {
|
||||||
VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
|
VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
|
func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor {
|
||||||
return &visitor{
|
return &visitor{
|
||||||
config: conf,
|
config: conf,
|
||||||
messageCache: messageCache,
|
messageCache: messageCache,
|
||||||
|
@ -115,7 +117,7 @@ func (v *visitor) Stale() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) Stats() (*visitorStats, error) {
|
func (v *visitor) Stats() (*visitorStats, error) {
|
||||||
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip)
|
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
16
util/util.go
16
util/util.go
|
@ -5,16 +5,18 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gabriel-vasile/mimetype"
|
|
||||||
"golang.org/x/term"
|
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gabriel-vasile/mimetype"
|
||||||
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -45,6 +47,16 @@ func Contains[T comparable](haystack []T, needle T) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContainsIP returns true if any one of the of prefixes contains the ip.
|
||||||
|
func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
|
||||||
|
for _, s := range haystack {
|
||||||
|
if s.Contains(needle) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// ContainsAll returns true if all needles are contained in haystack
|
// ContainsAll returns true if all needles are contained in haystack
|
||||||
func ContainsAll[T comparable](haystack []T, needles []T) bool {
|
func ContainsAll[T comparable](haystack []T, needles []T) bool {
|
||||||
matches := 0
|
matches := 0
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRandomString(t *testing.T) {
|
func TestRandomString(t *testing.T) {
|
||||||
|
@ -42,6 +44,13 @@ func TestContains(t *testing.T) {
|
||||||
require.False(t, Contains(s, 3))
|
require.False(t, Contains(s, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContainsIP(t *testing.T) {
|
||||||
|
require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1")))
|
||||||
|
require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876")))
|
||||||
|
require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.2.0.1")))
|
||||||
|
require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fc00::1")))
|
||||||
|
}
|
||||||
|
|
||||||
func TestSplitNoEmpty(t *testing.T) {
|
func TestSplitNoEmpty(t *testing.T) {
|
||||||
require.Equal(t, []string{}, SplitNoEmpty("", ","))
|
require.Equal(t, []string{}, SplitNoEmpty("", ","))
|
||||||
require.Equal(t, []string{}, SplitNoEmpty(",,,", ","))
|
require.Equal(t, []string{}, SplitNoEmpty(",,,", ","))
|
||||||
|
|
Loading…
Reference in New Issue