Make Firebase logic testable, test it
parent
f9284a098a
commit
c80e4e1aa9
|
@ -32,22 +32,22 @@ import (
|
||||||
|
|
||||||
// Server is the main server, providing the UI and API for ntfy
|
// Server is the main server, providing the UI and API for ntfy
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *Config
|
config *Config
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
httpsServer *http.Server
|
httpsServer *http.Server
|
||||||
unixListener net.Listener
|
unixListener net.Listener
|
||||||
smtpServer *smtp.Server
|
smtpServer *smtp.Server
|
||||||
smtpBackend *smtpBackend
|
smtpBackend *smtpBackend
|
||||||
topics map[string]*topic
|
topics map[string]*topic
|
||||||
visitors map[string]*visitor
|
visitors map[string]*visitor
|
||||||
firebase subscriber
|
firebaseClient *firebaseClient
|
||||||
mailer mailer
|
mailer mailer
|
||||||
messages int64
|
messages int64
|
||||||
auth auth.Auther
|
auth auth.Auther
|
||||||
messageCache *messageCache
|
messageCache *messageCache
|
||||||
fileCache *fileCache
|
fileCache *fileCache
|
||||||
closeChan chan bool
|
closeChan chan bool
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
|
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
|
||||||
|
@ -134,23 +134,23 @@ func New(conf *Config) (*Server, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var firebaseSubscriber subscriber
|
var firebaseClient *firebaseClient
|
||||||
if conf.FirebaseKeyFile != "" {
|
if conf.FirebaseKeyFile != "" {
|
||||||
var err error
|
sender, err := newFirebaseSender(conf.FirebaseKeyFile)
|
||||||
firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
firebaseClient = newFirebaseClient(sender, auther)
|
||||||
}
|
}
|
||||||
return &Server{
|
return &Server{
|
||||||
config: conf,
|
config: conf,
|
||||||
messageCache: messageCache,
|
messageCache: messageCache,
|
||||||
fileCache: fileCache,
|
fileCache: fileCache,
|
||||||
firebase: firebaseSubscriber,
|
firebaseClient: firebaseClient,
|
||||||
mailer: mailer,
|
mailer: mailer,
|
||||||
topics: topics,
|
topics: topics,
|
||||||
auth: auther,
|
auth: auther,
|
||||||
visitors: make(map[string]*visitor),
|
visitors: make(map[string]*visitor),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,7 +437,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if s.firebase != nil && firebase && !delayed {
|
if s.firebaseClient != nil && firebase && !delayed {
|
||||||
go s.sendToFirebase(v, m)
|
go s.sendToFirebase(v, m)
|
||||||
}
|
}
|
||||||
if s.mailer != nil && email != "" && !delayed {
|
if s.mailer != nil && email != "" && !delayed {
|
||||||
|
@ -463,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||||
if err := s.firebase(v, m); err != nil {
|
if err := s.firebaseClient.Send(v, m); err != nil {
|
||||||
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
|
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1096,20 +1096,16 @@ func (s *Server) runDelayedSender() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) runFirebaseKeepaliver() {
|
func (s *Server) runFirebaseKeepaliver() {
|
||||||
if s.firebase == nil {
|
if s.firebaseClient == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
|
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||||
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
|
||||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
|
|
||||||
}
|
|
||||||
case <-time.After(s.config.FirebasePollInterval):
|
case <-time.After(s.config.FirebasePollInterval):
|
||||||
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
|
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
|
||||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
|
|
||||||
}
|
|
||||||
case <-s.closeChan:
|
case <-s.closeChan:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1142,7 +1138,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if s.firebase != nil { // Firebase subscribers may not show up in topics map
|
if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
|
||||||
go s.sendToFirebase(v, m)
|
go s.sendToFirebase(v, m)
|
||||||
}
|
}
|
||||||
if s.config.UpstreamBaseURL != "" {
|
if s.config.UpstreamBaseURL != "" {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -18,33 +19,75 @@ const (
|
||||||
fcmApnsBodyMessageLimit = 100
|
fcmApnsBodyMessageLimit = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subscriber, error) {
|
var (
|
||||||
|
errFirebaseQuotaExceeded = errors.New("Firebase quota exceeded")
|
||||||
|
)
|
||||||
|
|
||||||
|
// firebaseClient is a generic client that formats and sends messages to Firebase.
|
||||||
|
// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
|
||||||
|
type firebaseClient struct {
|
||||||
|
sender firebaseSender
|
||||||
|
auther auth.Auther
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFirebaseClient(sender firebaseSender, auther auth.Auther) *firebaseClient {
|
||||||
|
return &firebaseClient{
|
||||||
|
sender: sender,
|
||||||
|
auther: auther,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
||||||
|
if err := v.FirebaseAllowed(); err != nil {
|
||||||
|
return errFirebaseQuotaExceeded
|
||||||
|
}
|
||||||
|
fbm, err := toFirebaseMessage(m, c.auther)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = c.sender.Send(fbm)
|
||||||
|
if err == errFirebaseQuotaExceeded {
|
||||||
|
log.Printf("[%s] FB quota exceeded for topic %s, temporarily denying FB access to visitor", v.ip, m.Topic)
|
||||||
|
v.FirebaseTemporarilyDeny()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// firebaseSender is an interface that represents a client that can send to Firebase Cloud Messaging.
|
||||||
|
// In tests, this can be implemented with a mock.
|
||||||
|
type firebaseSender interface {
|
||||||
|
// Send sends a message to Firebase, or returns an error. It returns errFirebaseQuotaExceeded
|
||||||
|
// if a rate limit has reached.
|
||||||
|
Send(m *messaging.Message) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// firebaseSenderImpl is a firebaseSender that actually talks to Firebase
|
||||||
|
type firebaseSenderImpl struct {
|
||||||
|
client *messaging.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFirebaseSender(credentialsFile string) (*firebaseSenderImpl, error) {
|
||||||
fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile))
|
fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
msg, err := fb.Messaging(context.Background())
|
client, err := fb.Messaging(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return func(v *visitor, m *message) error {
|
return &firebaseSenderImpl{
|
||||||
if err := v.FirebaseAllowed(); err != nil {
|
client: client,
|
||||||
return errHTTPTooManyRequestsFirebaseQuotaReached
|
|
||||||
}
|
|
||||||
fbm, err := toFirebaseMessage(m, auther)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = msg.Send(context.Background(), fbm)
|
|
||||||
if err != nil && messaging.IsQuotaExceeded(err) {
|
|
||||||
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
|
|
||||||
v.FirebaseTemporarilyDeny()
|
|
||||||
return errHTTPTooManyRequestsFirebaseQuotaReached
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
|
||||||
|
_, err := c.client.Send(context.Background(), m)
|
||||||
|
if err != nil && messaging.IsQuotaExceeded(err) {
|
||||||
|
return errFirebaseQuotaExceeded
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// toFirebaseMessage converts a message to a Firebase message.
|
// toFirebaseMessage converts a message to a Firebase message.
|
||||||
//
|
//
|
||||||
// Normal messages ("message"):
|
// Normal messages ("message"):
|
||||||
|
|
|
@ -26,6 +26,25 @@ func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error {
|
||||||
return errors.New("unauthorized")
|
return errors.New("unauthorized")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testFirebaseSender struct {
|
||||||
|
allowed int
|
||||||
|
messages []*messaging.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestFirebaseSender(allowed int) *testFirebaseSender {
|
||||||
|
return &testFirebaseSender{
|
||||||
|
allowed: allowed,
|
||||||
|
messages: make([]*messaging.Message, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (s *testFirebaseSender) Send(m *messaging.Message) error {
|
||||||
|
if len(s.messages)+1 > s.allowed {
|
||||||
|
return errFirebaseQuotaExceeded
|
||||||
|
}
|
||||||
|
s.messages = append(s.messages, m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
||||||
m := newKeepaliveMessage("mytopic")
|
m := newKeepaliveMessage("mytopic")
|
||||||
fbm, err := toFirebaseMessage(m, nil)
|
fbm, err := toFirebaseMessage(m, nil)
|
||||||
|
@ -285,3 +304,22 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
|
||||||
require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
|
require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
|
||||||
require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
|
require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToFirebaseSender_Abuse(t *testing.T) {
|
||||||
|
sender := &testFirebaseSender{allowed: 2}
|
||||||
|
client := newFirebaseClient(sender, &testAuther{})
|
||||||
|
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
|
||||||
|
|
||||||
|
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||||
|
require.Equal(t, 1, len(sender.messages))
|
||||||
|
|
||||||
|
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||||
|
require.Equal(t, 2, len(sender.messages))
|
||||||
|
|
||||||
|
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||||
|
require.Equal(t, 2, len(sender.messages))
|
||||||
|
|
||||||
|
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
|
||||||
|
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||||
|
require.Equal(t, 0, len(sender.messages))
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -55,6 +54,21 @@ func TestServer_PublishAndPoll(t *testing.T) {
|
||||||
require.Equal(t, "my second message", lines[1]) // \n -> " "
|
require.Equal(t, "my second message", lines[1]) // \n -> " "
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_PublishWithFirebase(t *testing.T) {
|
||||||
|
sender := newTestFirebaseSender(10)
|
||||||
|
s := newTestServer(t, newTestConfig(t))
|
||||||
|
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
|
||||||
|
|
||||||
|
response := request(t, s, "PUT", "/mytopic", "my first message", nil)
|
||||||
|
msg1 := toMessage(t, response.Body.String())
|
||||||
|
require.NotEmpty(t, msg1.ID)
|
||||||
|
require.Equal(t, "my first message", msg1.Message)
|
||||||
|
require.Equal(t, 1, len(sender.messages))
|
||||||
|
require.Equal(t, "my first message", sender.messages[0].Data["message"])
|
||||||
|
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.Aps.Alert.Body)
|
||||||
|
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.CustomData["message"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_SubscribeOpenAndKeepalive(t *testing.T) {
|
func TestServer_SubscribeOpenAndKeepalive(t *testing.T) {
|
||||||
c := newTestConfig(t)
|
c := newTestConfig(t)
|
||||||
c.KeepaliveInterval = time.Second
|
c.KeepaliveInterval = time.Second
|
||||||
|
@ -461,27 +475,6 @@ func TestServer_PublishMessageInHeaderWithNewlines(t *testing.T) {
|
||||||
require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n !
|
require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n !
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PublishFirebase(t *testing.T) {
|
|
||||||
// This is unfortunately not much of a test, since it merely fires the messages towards Firebase,
|
|
||||||
// but cannot re-read them. There is no way from Go to read the messages back, or even get an error back.
|
|
||||||
// I tried everything. I already had written the test, and it increases the code coverage, so I'll leave it ... :shrug: ...
|
|
||||||
|
|
||||||
c := newTestConfig(t)
|
|
||||||
c.FirebaseKeyFile = firebaseServiceAccountFile(t) // May skip the test!
|
|
||||||
s := newTestServer(t, c)
|
|
||||||
|
|
||||||
// Normal message
|
|
||||||
response := request(t, s, "PUT", "/mytopic", "This is a message for firebase", nil)
|
|
||||||
msg := toMessage(t, response.Body.String())
|
|
||||||
require.NotEmpty(t, msg.ID)
|
|
||||||
|
|
||||||
// Keepalive message
|
|
||||||
v := newVisitor(s.config, s.messageCache, "1.2.3.4")
|
|
||||||
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
|
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond) // Time for sends
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_PublishInvalidTopic(t *testing.T) {
|
func TestServer_PublishInvalidTopic(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
s := newTestServer(t, newTestConfig(t))
|
||||||
s.mailer = &testMailer{}
|
s.mailer = &testMailer{}
|
||||||
|
@ -1341,18 +1334,6 @@ func toHTTPError(t *testing.T, s string) *errHTTP {
|
||||||
return &e
|
return &e
|
||||||
}
|
}
|
||||||
|
|
||||||
func firebaseServiceAccountFile(t *testing.T) string {
|
|
||||||
if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
|
|
||||||
return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
|
|
||||||
} else if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT") != "" {
|
|
||||||
filename := filepath.Join(t.TempDir(), "firebase.json")
|
|
||||||
require.NotNil(t, os.WriteFile(filename, []byte(os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT")), 0o600))
|
|
||||||
return filename
|
|
||||||
}
|
|
||||||
t.SkipNow()
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func basicAuth(s string) string {
|
func basicAuth(s string) string {
|
||||||
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
|
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue