Make Firebase logic testable, test it
parent
f9284a098a
commit
c80e4e1aa9
|
@ -32,22 +32,22 @@ import (
|
|||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
type Server struct {
|
||||
config *Config
|
||||
httpServer *http.Server
|
||||
httpsServer *http.Server
|
||||
unixListener net.Listener
|
||||
smtpServer *smtp.Server
|
||||
smtpBackend *smtpBackend
|
||||
topics map[string]*topic
|
||||
visitors map[string]*visitor
|
||||
firebase subscriber
|
||||
mailer mailer
|
||||
messages int64
|
||||
auth auth.Auther
|
||||
messageCache *messageCache
|
||||
fileCache *fileCache
|
||||
closeChan chan bool
|
||||
mu sync.Mutex
|
||||
config *Config
|
||||
httpServer *http.Server
|
||||
httpsServer *http.Server
|
||||
unixListener net.Listener
|
||||
smtpServer *smtp.Server
|
||||
smtpBackend *smtpBackend
|
||||
topics map[string]*topic
|
||||
visitors map[string]*visitor
|
||||
firebaseClient *firebaseClient
|
||||
mailer mailer
|
||||
messages int64
|
||||
auth auth.Auther
|
||||
messageCache *messageCache
|
||||
fileCache *fileCache
|
||||
closeChan chan bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
|
||||
|
@ -134,23 +134,23 @@ func New(conf *Config) (*Server, error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
var firebaseSubscriber subscriber
|
||||
var firebaseClient *firebaseClient
|
||||
if conf.FirebaseKeyFile != "" {
|
||||
var err error
|
||||
firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther)
|
||||
sender, err := newFirebaseSender(conf.FirebaseKeyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
firebaseClient = newFirebaseClient(sender, auther)
|
||||
}
|
||||
return &Server{
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
fileCache: fileCache,
|
||||
firebase: firebaseSubscriber,
|
||||
mailer: mailer,
|
||||
topics: topics,
|
||||
auth: auther,
|
||||
visitors: make(map[string]*visitor),
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
fileCache: fileCache,
|
||||
firebaseClient: firebaseClient,
|
||||
mailer: mailer,
|
||||
topics: topics,
|
||||
auth: auther,
|
||||
visitors: make(map[string]*visitor),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -437,7 +437,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
return err
|
||||
}
|
||||
}
|
||||
if s.firebase != nil && firebase && !delayed {
|
||||
if s.firebaseClient != nil && firebase && !delayed {
|
||||
go s.sendToFirebase(v, m)
|
||||
}
|
||||
if s.mailer != nil && email != "" && !delayed {
|
||||
|
@ -463,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
}
|
||||
|
||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||
if err := s.firebase(v, m); err != nil {
|
||||
if err := s.firebaseClient.Send(v, m); err != nil {
|
||||
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
|
||||
}
|
||||
}
|
||||
|
@ -1096,20 +1096,16 @@ func (s *Server) runDelayedSender() {
|
|||
}
|
||||
|
||||
func (s *Server) runFirebaseKeepaliver() {
|
||||
if s.firebase == nil {
|
||||
if s.firebaseClient == nil {
|
||||
return
|
||||
}
|
||||
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
|
||||
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
|
||||
}
|
||||
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
|
||||
case <-time.After(s.config.FirebasePollInterval):
|
||||
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
|
||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
|
||||
}
|
||||
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
|
||||
case <-s.closeChan:
|
||||
return
|
||||
}
|
||||
|
@ -1142,7 +1138,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
|||
}
|
||||
}()
|
||||
}
|
||||
if s.firebase != nil { // Firebase subscribers may not show up in topics map
|
||||
if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
|
||||
go s.sendToFirebase(v, m)
|
||||
}
|
||||
if s.config.UpstreamBaseURL != "" {
|
||||
|
|
|
@ -3,6 +3,7 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
@ -18,33 +19,75 @@ const (
|
|||
fcmApnsBodyMessageLimit = 100
|
||||
)
|
||||
|
||||
func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subscriber, error) {
|
||||
var (
|
||||
errFirebaseQuotaExceeded = errors.New("Firebase quota exceeded")
|
||||
)
|
||||
|
||||
// firebaseClient is a generic client that formats and sends messages to Firebase.
|
||||
// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
|
||||
type firebaseClient struct {
|
||||
sender firebaseSender
|
||||
auther auth.Auther
|
||||
}
|
||||
|
||||
func newFirebaseClient(sender firebaseSender, auther auth.Auther) *firebaseClient {
|
||||
return &firebaseClient{
|
||||
sender: sender,
|
||||
auther: auther,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
||||
if err := v.FirebaseAllowed(); err != nil {
|
||||
return errFirebaseQuotaExceeded
|
||||
}
|
||||
fbm, err := toFirebaseMessage(m, c.auther)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.sender.Send(fbm)
|
||||
if err == errFirebaseQuotaExceeded {
|
||||
log.Printf("[%s] FB quota exceeded for topic %s, temporarily denying FB access to visitor", v.ip, m.Topic)
|
||||
v.FirebaseTemporarilyDeny()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// firebaseSender is an interface that represents a client that can send to Firebase Cloud Messaging.
|
||||
// In tests, this can be implemented with a mock.
|
||||
type firebaseSender interface {
|
||||
// Send sends a message to Firebase, or returns an error. It returns errFirebaseQuotaExceeded
|
||||
// if a rate limit has reached.
|
||||
Send(m *messaging.Message) error
|
||||
}
|
||||
|
||||
// firebaseSenderImpl is a firebaseSender that actually talks to Firebase
|
||||
type firebaseSenderImpl struct {
|
||||
client *messaging.Client
|
||||
}
|
||||
|
||||
func newFirebaseSender(credentialsFile string) (*firebaseSenderImpl, error) {
|
||||
fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg, err := fb.Messaging(context.Background())
|
||||
client, err := fb.Messaging(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(v *visitor, m *message) error {
|
||||
if err := v.FirebaseAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsFirebaseQuotaReached
|
||||
}
|
||||
fbm, err := toFirebaseMessage(m, auther)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = msg.Send(context.Background(), fbm)
|
||||
if err != nil && messaging.IsQuotaExceeded(err) {
|
||||
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
|
||||
v.FirebaseTemporarilyDeny()
|
||||
return errHTTPTooManyRequestsFirebaseQuotaReached
|
||||
}
|
||||
return err
|
||||
return &firebaseSenderImpl{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
|
||||
_, err := c.client.Send(context.Background(), m)
|
||||
if err != nil && messaging.IsQuotaExceeded(err) {
|
||||
return errFirebaseQuotaExceeded
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// toFirebaseMessage converts a message to a Firebase message.
|
||||
//
|
||||
// Normal messages ("message"):
|
||||
|
|
|
@ -26,6 +26,25 @@ func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error {
|
|||
return errors.New("unauthorized")
|
||||
}
|
||||
|
||||
type testFirebaseSender struct {
|
||||
allowed int
|
||||
messages []*messaging.Message
|
||||
}
|
||||
|
||||
func newTestFirebaseSender(allowed int) *testFirebaseSender {
|
||||
return &testFirebaseSender{
|
||||
allowed: allowed,
|
||||
messages: make([]*messaging.Message, 0),
|
||||
}
|
||||
}
|
||||
func (s *testFirebaseSender) Send(m *messaging.Message) error {
|
||||
if len(s.messages)+1 > s.allowed {
|
||||
return errFirebaseQuotaExceeded
|
||||
}
|
||||
s.messages = append(s.messages, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
||||
m := newKeepaliveMessage("mytopic")
|
||||
fbm, err := toFirebaseMessage(m, nil)
|
||||
|
@ -285,3 +304,22 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
|
|||
require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
|
||||
require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
|
||||
}
|
||||
|
||||
func TestToFirebaseSender_Abuse(t *testing.T) {
|
||||
sender := &testFirebaseSender{allowed: 2}
|
||||
client := newFirebaseClient(sender, &testAuther{})
|
||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
|
||||
|
||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, 1, len(sender.messages))
|
||||
|
||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, 2, len(sender.messages))
|
||||
|
||||
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, 2, len(sender.messages))
|
||||
|
||||
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
|
||||
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, 0, len(sender.messages))
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -55,6 +54,21 @@ func TestServer_PublishAndPoll(t *testing.T) {
|
|||
require.Equal(t, "my second message", lines[1]) // \n -> " "
|
||||
}
|
||||
|
||||
func TestServer_PublishWithFirebase(t *testing.T) {
|
||||
sender := newTestFirebaseSender(10)
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
|
||||
|
||||
response := request(t, s, "PUT", "/mytopic", "my first message", nil)
|
||||
msg1 := toMessage(t, response.Body.String())
|
||||
require.NotEmpty(t, msg1.ID)
|
||||
require.Equal(t, "my first message", msg1.Message)
|
||||
require.Equal(t, 1, len(sender.messages))
|
||||
require.Equal(t, "my first message", sender.messages[0].Data["message"])
|
||||
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.Aps.Alert.Body)
|
||||
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.CustomData["message"])
|
||||
}
|
||||
|
||||
func TestServer_SubscribeOpenAndKeepalive(t *testing.T) {
|
||||
c := newTestConfig(t)
|
||||
c.KeepaliveInterval = time.Second
|
||||
|
@ -461,27 +475,6 @@ func TestServer_PublishMessageInHeaderWithNewlines(t *testing.T) {
|
|||
require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n !
|
||||
}
|
||||
|
||||
func TestServer_PublishFirebase(t *testing.T) {
|
||||
// This is unfortunately not much of a test, since it merely fires the messages towards Firebase,
|
||||
// but cannot re-read them. There is no way from Go to read the messages back, or even get an error back.
|
||||
// I tried everything. I already had written the test, and it increases the code coverage, so I'll leave it ... :shrug: ...
|
||||
|
||||
c := newTestConfig(t)
|
||||
c.FirebaseKeyFile = firebaseServiceAccountFile(t) // May skip the test!
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Normal message
|
||||
response := request(t, s, "PUT", "/mytopic", "This is a message for firebase", nil)
|
||||
msg := toMessage(t, response.Body.String())
|
||||
require.NotEmpty(t, msg.ID)
|
||||
|
||||
// Keepalive message
|
||||
v := newVisitor(s.config, s.messageCache, "1.2.3.4")
|
||||
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
|
||||
|
||||
time.Sleep(500 * time.Millisecond) // Time for sends
|
||||
}
|
||||
|
||||
func TestServer_PublishInvalidTopic(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
s.mailer = &testMailer{}
|
||||
|
@ -1341,18 +1334,6 @@ func toHTTPError(t *testing.T, s string) *errHTTP {
|
|||
return &e
|
||||
}
|
||||
|
||||
func firebaseServiceAccountFile(t *testing.T) string {
|
||||
if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
|
||||
return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
|
||||
} else if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT") != "" {
|
||||
filename := filepath.Join(t.TempDir(), "firebase.json")
|
||||
require.NotNil(t, os.WriteFile(filename, []byte(os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT")), 0o600))
|
||||
return filename
|
||||
}
|
||||
t.SkipNow()
|
||||
return ""
|
||||
}
|
||||
|
||||
func basicAuth(s string) string {
|
||||
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue