Tests, client tests WIP

pull/60/head
Philipp Heckel 2021-12-22 14:17:50 +01:00
parent 68d881291c
commit 6a7e9071b6
7 changed files with 104 additions and 30 deletions

View File

@ -0,0 +1,42 @@
package client_test
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"heckel.io/ntfy/server"
"net/http"
"testing"
"time"
)
func TestClient_Publish(t *testing.T) {
s := startTestServer(t)
defer s.Stop()
c := client.New(newTestConfig())
time.Sleep(time.Second) // FIXME Wait for port up
_, err := c.Publish("mytopic", "some message")
require.Nil(t, err)
}
func newTestConfig() *client.Config {
c := client.NewConfig()
c.DefaultHost = "http://127.0.0.1:12345"
return c
}
func startTestServer(t *testing.T) *server.Server {
conf := server.NewConfig()
conf.ListenHTTP = ":12345"
s, err := server.New(conf)
if err != nil {
t.Fatal(err)
}
go func() {
if err := s.Run(); err != nil && err != http.ErrServerClosed {
panic(err) // 'go vet' complains about 't.Fatal(err)'
}
}()
return s
}

View File

@ -1,7 +1,8 @@
package client package client_test
import ( import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -21,7 +22,7 @@ subscribe:
priority: high,urgent priority: high,urgent
`), 0600)) `), 0600))
conf, err := LoadConfig(filename) conf, err := client.LoadConfig(filename)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "http://localhost", conf.DefaultHost) require.Equal(t, "http://localhost", conf.DefaultHost)
require.Equal(t, 3, len(conf.Subscribe)) require.Equal(t, 3, len(conf.Subscribe))

View File

@ -85,7 +85,8 @@ func execServe(c *cli.Context) error {
} }
// Run server // Run server
conf := server.NewConfig(listenHTTP) conf := server.NewConfig()
conf.ListenHTTP = listenHTTP
conf.ListenHTTPS = listenHTTPS conf.ListenHTTPS = listenHTTPS
conf.KeyFile = keyFile conf.KeyFile = keyFile
conf.CertFile = certFile conf.CertFile = certFile

View File

@ -52,9 +52,9 @@ type Config struct {
} }
// NewConfig instantiates a default new server config // NewConfig instantiates a default new server config
func NewConfig(listenHTTP string) *Config { func NewConfig() *Config {
return &Config{ return &Config{
ListenHTTP: listenHTTP, ListenHTTP: DefaultListenHTTP,
ListenHTTPS: "", ListenHTTPS: "",
KeyFile: "", KeyFile: "",
CertFile: "", CertFile: "",

View File

@ -7,6 +7,7 @@ import (
) )
func TestConfig_New(t *testing.T) { func TestConfig_New(t *testing.T) {
c := server.NewConfig(":1234") c := server.NewConfig()
assert.Equal(t, ":1234", c.ListenHTTP) assert.Equal(t, ":80", c.ListenHTTP)
assert.Equal(t, server.DefaultKeepaliveInterval, c.KeepaliveInterval)
} }

View File

@ -27,13 +27,16 @@ import (
// Server is the main server, providing the UI and API for ntfy // Server is the main server, providing the UI and API for ntfy
type Server struct { type Server struct {
config *Config config *Config
topics map[string]*topic httpServer *http.Server
visitors map[string]*visitor httpsServer *http.Server
firebase subscriber topics map[string]*topic
messages int64 visitors map[string]*visitor
cache cache firebase subscriber
mu sync.Mutex messages int64
cache cache
closeChan chan bool
mu sync.Mutex
} }
// errHTTP is a generic HTTP error for any non-200 HTTP error // errHTTP is a generic HTTP error for any non-200 HTTP error
@ -198,17 +201,35 @@ func (s *Server) Run() error {
log.Printf("Listening on %s", listenStr) log.Printf("Listening on %s", listenStr)
http.HandleFunc("/", s.handle) http.HandleFunc("/", s.handle)
errChan := make(chan error) errChan := make(chan error)
s.mu.Lock()
s.closeChan = make(chan bool)
s.httpServer = &http.Server{Addr: s.config.ListenHTTP}
go func() { go func() {
errChan <- http.ListenAndServe(s.config.ListenHTTP, nil) errChan <- s.httpServer.ListenAndServe()
}() }()
if s.config.ListenHTTPS != "" { if s.config.ListenHTTPS != "" {
s.httpsServer = &http.Server{Addr: s.config.ListenHTTP}
go func() { go func() {
errChan <- http.ListenAndServeTLS(s.config.ListenHTTPS, s.config.CertFile, s.config.KeyFile, nil) errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
}() }()
} }
s.mu.Unlock()
return <-errChan return <-errChan
} }
// Stop stops HTTP (+HTTPS) server and all managers
func (s *Server) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
if s.httpServer != nil {
s.httpServer.Close()
}
if s.httpsServer != nil {
s.httpsServer.Close()
}
close(s.closeChan)
}
func (s *Server) handle(w http.ResponseWriter, r *http.Request) { func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
if err := s.handleInternal(w, r); err != nil { if err := s.handleInternal(w, r); err != nil {
if e, ok := err.(*errHTTP); ok { if e, ok := err.(*errHTTP); ok {
@ -635,21 +656,25 @@ func (s *Server) updateStatsAndPrune() {
} }
func (s *Server) runManager() { func (s *Server) runManager() {
func() { for {
ticker := time.NewTicker(s.config.ManagerInterval) select {
for { case <-time.After(s.config.ManagerInterval):
<-ticker.C
s.updateStatsAndPrune() s.updateStatsAndPrune()
case <-s.closeChan:
return
} }
}() }
} }
func (s *Server) runAtSender() { func (s *Server) runAtSender() {
ticker := time.NewTicker(s.config.AtSenderInterval)
for { for {
<-ticker.C select {
if err := s.sendDelayedMessages(); err != nil { case <-time.After(s.config.AtSenderInterval):
log.Printf("error sending scheduled messages: %s", err.Error()) if err := s.sendDelayedMessages(); err != nil {
log.Printf("error sending scheduled messages: %s", err.Error())
}
case <-s.closeChan:
return
} }
} }
} }
@ -658,14 +683,18 @@ func (s *Server) runFirebaseKeepliver() {
if s.firebase == nil { if s.firebase == nil {
return return
} }
ticker := time.NewTicker(s.config.FirebaseKeepaliveInterval)
for { for {
<-ticker.C select {
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil { case <-time.After(s.config.FirebaseKeepaliveInterval):
log.Printf("error sending Firebase keepalive message: %s", err.Error()) if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Printf("error sending Firebase keepalive message: %s", err.Error())
}
case <-s.closeChan:
return
} }
} }
} }
func (s *Server) sendDelayedMessages() error { func (s *Server) sendDelayedMessages() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()

View File

@ -488,7 +488,7 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) {
} }
func newTestConfig(t *testing.T) *Config { func newTestConfig(t *testing.T) *Config {
conf := NewConfig(":80") conf := NewConfig()
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db") conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
return conf return conf
} }