Speed up tests, hopefully fix races
parent
b77920bb4b
commit
000bf27c87
4
Makefile
4
Makefile
|
@ -232,11 +232,11 @@ test: .PHONY
|
||||||
go test -v $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
go test -v $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||||
|
|
||||||
race: .PHONY
|
race: .PHONY
|
||||||
go test -race $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
go test -v -race $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||||
|
|
||||||
coverage:
|
coverage:
|
||||||
mkdir -p build/coverage
|
mkdir -p build/coverage
|
||||||
go test -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
go test -v -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||||
go tool cover -func build/coverage/coverage.txt
|
go tool cover -func build/coverage/coverage.txt
|
||||||
|
|
||||||
coverage-html:
|
coverage-html:
|
||||||
|
|
|
@ -330,7 +330,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
||||||
}
|
}
|
||||||
return user.NewManager(authFile, authStartupQueries, authDefault)
|
return user.NewManager(authFile, authStartupQueries, authDefault, user.DefaultUserPasswordBcryptCost, user.DefaultUserStatsQueueWriterInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
||||||
|
|
|
@ -76,6 +76,7 @@ type Config struct {
|
||||||
AuthFile string
|
AuthFile string
|
||||||
AuthStartupQueries string
|
AuthStartupQueries string
|
||||||
AuthDefault user.Permission
|
AuthDefault user.Permission
|
||||||
|
AuthBcryptCost int
|
||||||
AttachmentCacheDir string
|
AttachmentCacheDir string
|
||||||
AttachmentTotalSizeLimit int64
|
AttachmentTotalSizeLimit int64
|
||||||
AttachmentFileSizeLimit int64
|
AttachmentFileSizeLimit int64
|
||||||
|
@ -143,6 +144,7 @@ func NewConfig() *Config {
|
||||||
AuthFile: "",
|
AuthFile: "",
|
||||||
AuthStartupQueries: "",
|
AuthStartupQueries: "",
|
||||||
AuthDefault: user.NewPermission(true, true),
|
AuthDefault: user.NewPermission(true, true),
|
||||||
|
AuthBcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||||
AttachmentCacheDir: "",
|
AttachmentCacheDir: "",
|
||||||
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
||||||
AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit,
|
AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit,
|
||||||
|
|
|
@ -676,6 +676,10 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *messageCache) Close() error {
|
||||||
|
return c.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func setupDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
func setupDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||||
// Run startup queries
|
// Run startup queries
|
||||||
if startupQueries != "" {
|
if startupQueries != "" {
|
||||||
|
|
|
@ -171,7 +171,7 @@ func New(conf *Config) (*Server, error) {
|
||||||
}
|
}
|
||||||
var userManager *user.Manager
|
var userManager *user.Manager
|
||||||
if conf.AuthFile != "" {
|
if conf.AuthFile != "" {
|
||||||
userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault)
|
userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, user.DefaultUserStatsQueueWriterInterval)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -296,9 +296,17 @@ func (s *Server) Stop() {
|
||||||
if s.smtpServer != nil {
|
if s.smtpServer != nil {
|
||||||
s.smtpServer.Close()
|
s.smtpServer.Close()
|
||||||
}
|
}
|
||||||
|
s.closeDatabases()
|
||||||
close(s.closeChan)
|
close(s.closeChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeDatabases() {
|
||||||
|
if s.userManager != nil {
|
||||||
|
s.userManager.Close()
|
||||||
|
}
|
||||||
|
s.messageCache.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||||
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
|
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -1567,8 +1575,9 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
u := v.User()
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
|
if err := s.userManager.Authorize(u, t.ID, perm); err != nil {
|
||||||
log.Info("unauthorized: %s", err.Error())
|
log.Info("unauthorized: %s", err.Error())
|
||||||
return errHTTPForbidden
|
return errHTTPForbidden
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ func TestAccount_Signup_Success(t *testing.T) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
@ -41,6 +42,7 @@ func TestAccount_Signup_UserExists(t *testing.T) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
@ -54,6 +56,7 @@ func TestAccount_Signup_LimitReached(t *testing.T) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
|
rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
|
||||||
|
@ -68,15 +71,18 @@ func TestAccount_Signup_AsUser(t *testing.T) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
log.Info("1")
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||||
|
log.Info("2")
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||||
|
log.Info("3")
|
||||||
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
|
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
log.Info("4")
|
||||||
rr = request(t, s, "POST", "/v1/account", `{"username":"marian", "password":"marian"}`, map[string]string{
|
rr = request(t, s, "POST", "/v1/account", `{"username":"marian", "password":"marian"}`, map[string]string{
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
})
|
})
|
||||||
|
@ -87,6 +93,7 @@ func TestAccount_Signup_Disabled(t *testing.T) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t)
|
||||||
conf.EnableSignup = false
|
conf.EnableSignup = false
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
||||||
require.Equal(t, 400, rr.Code)
|
require.Equal(t, 400, rr.Code)
|
||||||
|
@ -115,6 +122,7 @@ func TestAccount_Get_Anonymous(t *testing.T) {
|
||||||
conf.AttachmentFileSizeLimit = 512
|
conf.AttachmentFileSizeLimit = 512
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
s.smtpSender = &testMailer{}
|
s.smtpSender = &testMailer{}
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
rr := request(t, s, "GET", "/v1/account", "", nil)
|
rr := request(t, s, "GET", "/v1/account", "", nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
@ -149,6 +157,8 @@ func TestAccount_Get_Anonymous(t *testing.T) {
|
||||||
|
|
||||||
func TestAccount_ChangeSettings(t *testing.T) {
|
func TestAccount_ChangeSettings(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
u, _ := s.userManager.User("phil")
|
u, _ := s.userManager.User("phil")
|
||||||
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0))
|
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0))
|
||||||
|
@ -176,6 +186,8 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
||||||
|
|
||||||
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
|
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
|
||||||
|
@ -226,6 +238,8 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
||||||
|
|
||||||
func TestAccount_ChangePassword(t *testing.T) {
|
func TestAccount_ChangePassword(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
|
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
|
||||||
|
@ -246,6 +260,7 @@ func TestAccount_ChangePassword(t *testing.T) {
|
||||||
|
|
||||||
func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, nil)
|
rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, nil)
|
||||||
require.Equal(t, 401, rr.Code)
|
require.Equal(t, 401, rr.Code)
|
||||||
|
@ -253,6 +268,8 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
||||||
|
|
||||||
func TestAccount_ExtendToken(t *testing.T) {
|
func TestAccount_ExtendToken(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
@ -276,6 +293,8 @@ func TestAccount_ExtendToken(t *testing.T) {
|
||||||
|
|
||||||
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
|
|
||||||
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
|
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
|
||||||
|
@ -287,6 +306,8 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
||||||
|
|
||||||
func TestAccount_DeleteToken(t *testing.T) {
|
func TestAccount_DeleteToken(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
@ -295,7 +316,6 @@ func TestAccount_DeleteToken(t *testing.T) {
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
token, err := util.UnmarshalJSON[apiAccountTokenResponse](io.NopCloser(rr.Body))
|
token, err := util.UnmarshalJSON[apiAccountTokenResponse](io.NopCloser(rr.Body))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
log.Info("token = %#v", token)
|
|
||||||
require.True(t, token.Expires > time.Now().Add(71*time.Hour).Unix())
|
require.True(t, token.Expires > time.Now().Add(71*time.Hour).Unix())
|
||||||
|
|
||||||
// Delete token failure (using basic auth)
|
// Delete token failure (using basic auth)
|
||||||
|
@ -522,6 +542,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
conf.AuthDefault = user.PermissionReadWrite
|
conf.AuthDefault = user.PermissionReadWrite
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create user with tier
|
// Create user with tier
|
||||||
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
||||||
|
@ -544,6 +565,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
require.Equal(t, "open", messages[0].Event)
|
require.Equal(t, "open", messages[0].Event)
|
||||||
require.Equal(t, "message before reservation", messages[1].Message)
|
require.Equal(t, "message before reservation", messages[1].Message)
|
||||||
anonCh <- true
|
anonCh <- true
|
||||||
|
log.Info("Anonymous subscription ended")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Subscribe with user
|
// Subscribe with user
|
||||||
|
@ -558,13 +580,14 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
require.Equal(t, "message before reservation", messages[1].Message)
|
require.Equal(t, "message before reservation", messages[1].Message)
|
||||||
require.Equal(t, "message after reservation", messages[2].Message)
|
require.Equal(t, "message after reservation", messages[2].Message)
|
||||||
userCh <- true
|
userCh <- true
|
||||||
|
log.Info("User subscription ended")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Publish message (before reservation)
|
// Publish message (before reservation)
|
||||||
time.Sleep(700 * time.Millisecond) // Wait for subscribers
|
time.Sleep(time.Second) // Wait for subscribers
|
||||||
rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
|
rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
time.Sleep(700 * time.Millisecond) // Wait for subscribers to receive message
|
time.Sleep(time.Second) // Wait for subscribers to receive message
|
||||||
|
|
||||||
// Reserve a topic
|
// Reserve a topic
|
||||||
rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
|
rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
@ -893,7 +894,7 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t)
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
var err error
|
var err error
|
||||||
s.userManager, err = user.NewManagerWithStatsInterval(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, 100*time.Millisecond)
|
s.userManager, err = user.NewManager(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, c.AuthBcryptCost, 100*time.Millisecond)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Create user, and update it with some message and email stats
|
// Create user, and update it with some message and email stats
|
||||||
|
@ -1794,6 +1795,7 @@ func newTestConfig(t *testing.T) *Config {
|
||||||
conf := NewConfig()
|
conf := NewConfig()
|
||||||
conf.BaseURL = "http://127.0.0.1:12345"
|
conf.BaseURL = "http://127.0.0.1:12345"
|
||||||
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
|
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
|
||||||
|
conf.CacheStartupQueries = "pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;"
|
||||||
conf.AttachmentCacheDir = t.TempDir()
|
conf.AttachmentCacheDir = t.TempDir()
|
||||||
return conf
|
return conf
|
||||||
}
|
}
|
||||||
|
@ -1801,6 +1803,8 @@ func newTestConfig(t *testing.T) *Config {
|
||||||
func newTestConfigWithAuthFile(t *testing.T) *Config {
|
func newTestConfigWithAuthFile(t *testing.T) *Config {
|
||||||
conf := newTestConfig(t)
|
conf := newTestConfig(t)
|
||||||
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||||
|
conf.AuthStartupQueries = "pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;"
|
||||||
|
conf.AuthBcryptCost = bcrypt.MinCost // This speeds up tests a lot
|
||||||
return conf
|
return conf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,15 +22,18 @@ const (
|
||||||
syncTopicLength = 16
|
syncTopicLength = 16
|
||||||
userIDPrefix = "u_"
|
userIDPrefix = "u_"
|
||||||
userIDLength = 12
|
userIDLength = 12
|
||||||
userPasswordBcryptCost = 10
|
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
|
||||||
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
|
|
||||||
userStatsQueueWriterInterval = 33 * time.Second
|
|
||||||
userHardDeleteAfterDuration = 7 * 24 * time.Hour
|
userHardDeleteAfterDuration = 7 * 24 * time.Hour
|
||||||
tokenPrefix = "tk_"
|
tokenPrefix = "tk_"
|
||||||
tokenLength = 32
|
tokenLength = 32
|
||||||
tokenMaxCount = 10 // Only keep this many tokens in the table per user
|
tokenMaxCount = 10 // Only keep this many tokens in the table per user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultUserStatsQueueWriterInterval = 33 * time.Second
|
||||||
|
DefaultUserPasswordBcryptCost = 10
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNoTokenProvided = errors.New("no token provided")
|
errNoTokenProvided = errors.New("no token provided")
|
||||||
errTopicOwnedByOthers = errors.New("topic owned by others")
|
errTopicOwnedByOthers = errors.New("topic owned by others")
|
||||||
|
@ -296,18 +299,14 @@ type Manager struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
defaultAccess Permission // Default permission if no ACL matches
|
defaultAccess Permission // Default permission if no ACL matches
|
||||||
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
||||||
|
bcryptCost int // Makes testing easier
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Auther = (*Manager)(nil)
|
var _ Auther = (*Manager)(nil)
|
||||||
|
|
||||||
// NewManager creates a new Manager instance
|
// NewManager creates a new Manager instance
|
||||||
func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
|
func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) (*Manager, error) {
|
||||||
return NewManagerWithStatsInterval(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManagerWithStatsInterval creates a new Manager instance
|
|
||||||
func NewManagerWithStatsInterval(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
|
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -322,6 +321,7 @@ func NewManagerWithStatsInterval(filename, startupQueries string, defaultAccess
|
||||||
db: db,
|
db: db,
|
||||||
defaultAccess: defaultAccess,
|
defaultAccess: defaultAccess,
|
||||||
statsQueue: make(map[string]*Stats),
|
statsQueue: make(map[string]*Stats),
|
||||||
|
bcryptCost: bcryptCost,
|
||||||
}
|
}
|
||||||
go manager.userStatsQueueWriter(statsWriterInterval)
|
go manager.userStatsQueueWriter(statsWriterInterval)
|
||||||
return manager, nil
|
return manager, nil
|
||||||
|
@ -615,7 +615,7 @@ func (a *Manager) AddUser(username, password string, role Role) error {
|
||||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -871,7 +871,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||||
|
|
||||||
// ChangePassword changes a user's password
|
// ChangePassword changes a user's password
|
||||||
func (a *Manager) ChangePassword(username, password string) error {
|
func (a *Manager) ChangePassword(username, password string) error {
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1144,6 +1144,10 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Manager) Close() error {
|
||||||
|
return a.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func toSQLWildcard(s string) string {
|
func toSQLWildcard(s string) string {
|
||||||
return strings.ReplaceAll(s, "*", "%")
|
return strings.ReplaceAll(s, "*", "%")
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package user
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -13,7 +14,7 @@ import (
|
||||||
const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources
|
const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources
|
||||||
|
|
||||||
func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
|
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||||
|
@ -98,14 +99,14 @@ func TestManager_AddUser_Invalid(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_AddUser_Timing(t *testing.T) {
|
func TestManager_AddUser_Timing(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
||||||
start := time.Now().UnixMilli()
|
start := time.Now().UnixMilli()
|
||||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
||||||
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
|
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Authenticate_Timing(t *testing.T) {
|
func TestManager_Authenticate_Timing(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
||||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
||||||
|
|
||||||
// Timing a correct attempt
|
// Timing a correct attempt
|
||||||
|
@ -192,7 +193,7 @@ func TestManager_UserManagement(t *testing.T) {
|
||||||
phil, err := a.User("phil")
|
phil, err := a.User("phil")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "phil", phil.Name)
|
require.Equal(t, "phil", phil.Name)
|
||||||
require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$"))
|
require.True(t, strings.HasPrefix(phil.Hash, "$2a$04$")) // Min cost for testing
|
||||||
require.Equal(t, RoleAdmin, phil.Role)
|
require.Equal(t, RoleAdmin, phil.Role)
|
||||||
|
|
||||||
philGrants, err := a.Grants("phil")
|
philGrants, err := a.Grants("phil")
|
||||||
|
@ -202,7 +203,7 @@ func TestManager_UserManagement(t *testing.T) {
|
||||||
ben, err := a.User("ben")
|
ben, err := a.User("ben")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "ben", ben.Name)
|
require.Equal(t, "ben", ben.Name)
|
||||||
require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$"))
|
require.True(t, strings.HasPrefix(ben.Hash, "$2a$04$")) // Min cost for testing
|
||||||
require.Equal(t, RoleUser, ben.Role)
|
require.Equal(t, RoleUser, ben.Role)
|
||||||
|
|
||||||
benGrants, err := a.Grants("ben")
|
benGrants, err := a.Grants("ben")
|
||||||
|
@ -551,7 +552,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_EnqueueStats(t *testing.T) {
|
func TestManager_EnqueueStats(t *testing.T) {
|
||||||
a, err := NewManagerWithStatsInterval(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
a, err := NewManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, bcrypt.MinCost, 1500*time.Millisecond)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
|
||||||
|
@ -581,7 +582,7 @@ func TestManager_EnqueueStats(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangeSettings(t *testing.T) {
|
func TestManager_ChangeSettings(t *testing.T) {
|
||||||
a, err := NewManagerWithStatsInterval(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
a, err := NewManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, bcrypt.MinCost, 1500*time.Millisecond)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
|
||||||
|
@ -665,7 +666,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Create manager to trigger migration
|
// Create manager to trigger migration
|
||||||
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, userStatsQueueWriterInterval)
|
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
||||||
checkSchemaVersion(t, a.db)
|
checkSchemaVersion(t, a.db)
|
||||||
|
|
||||||
users, err := a.Users()
|
users, err := a.Users()
|
||||||
|
@ -720,11 +721,11 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
|
func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
|
||||||
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, userStatsQueueWriterInterval)
|
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager {
|
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager {
|
||||||
a, err := NewManagerWithStatsInterval(filename, startupQueries, defaultAccess, statsWriterInterval)
|
a, err := NewManager(filename, startupQueries, defaultAccess, bcryptCost, statsWriterInterval)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue