Rename auth package to user; add extendToken feature

pull/584/head
binwiederhier 2022-12-25 11:41:38 -05:00
parent 3aac1b2715
commit d4c7ad4beb
14 changed files with 368 additions and 276 deletions

View File

@ -6,7 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"heckel.io/ntfy/auth" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
) )
@ -77,7 +77,7 @@ func execUserAccess(c *cli.Context) error {
} }
username := c.Args().Get(0) username := c.Args().Get(0)
if username == userEveryone { if username == userEveryone {
username = auth.Everyone username = user.Everyone
} }
topic := c.Args().Get(1) topic := c.Args().Get(1)
perms := c.Args().Get(2) perms := c.Args().Get(2)
@ -96,16 +96,16 @@ func execUserAccess(c *cli.Context) error {
return changeAccess(c, manager, username, topic, perms) return changeAccess(c, manager, username, topic, perms)
} }
func changeAccess(c *cli.Context, manager auth.Manager, username string, topic string, perms string) error { func changeAccess(c *cli.Context, manager user.Manager, username string, topic string, perms string) error {
if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) { if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) {
return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)") return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)")
} }
read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms) read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms)
write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms) write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms)
user, err := manager.User(username) u, err := manager.User(username)
if err == auth.ErrNotFound { if err == user.ErrNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} else if user.Role == auth.RoleAdmin { } else if u.Role == user.RoleAdmin {
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username) return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
} }
if err := manager.AllowAccess(username, topic, read, write); err != nil { if err := manager.AllowAccess(username, topic, read, write); err != nil {
@ -123,7 +123,7 @@ func changeAccess(c *cli.Context, manager auth.Manager, username string, topic s
return showUserAccess(c, manager, username) return showUserAccess(c, manager, username)
} }
func resetAccess(c *cli.Context, manager auth.Manager, username, topic string) error { func resetAccess(c *cli.Context, manager user.Manager, username, topic string) error {
if username == "" { if username == "" {
return resetAllAccess(c, manager) return resetAllAccess(c, manager)
} else if topic == "" { } else if topic == "" {
@ -132,7 +132,7 @@ func resetAccess(c *cli.Context, manager auth.Manager, username, topic string) e
return resetUserTopicAccess(c, manager, username, topic) return resetUserTopicAccess(c, manager, username, topic)
} }
func resetAllAccess(c *cli.Context, manager auth.Manager) error { func resetAllAccess(c *cli.Context, manager user.Manager) error {
if err := manager.ResetAccess("", ""); err != nil { if err := manager.ResetAccess("", ""); err != nil {
return err return err
} }
@ -140,7 +140,7 @@ func resetAllAccess(c *cli.Context, manager auth.Manager) error {
return nil return nil
} }
func resetUserAccess(c *cli.Context, manager auth.Manager, username string) error { func resetUserAccess(c *cli.Context, manager user.Manager, username string) error {
if err := manager.ResetAccess(username, ""); err != nil { if err := manager.ResetAccess(username, ""); err != nil {
return err return err
} }
@ -148,7 +148,7 @@ func resetUserAccess(c *cli.Context, manager auth.Manager, username string) erro
return showUserAccess(c, manager, username) return showUserAccess(c, manager, username)
} }
func resetUserTopicAccess(c *cli.Context, manager auth.Manager, username string, topic string) error { func resetUserTopicAccess(c *cli.Context, manager user.Manager, username string, topic string) error {
if err := manager.ResetAccess(username, topic); err != nil { if err := manager.ResetAccess(username, topic); err != nil {
return err return err
} }
@ -156,14 +156,14 @@ func resetUserTopicAccess(c *cli.Context, manager auth.Manager, username string,
return showUserAccess(c, manager, username) return showUserAccess(c, manager, username)
} }
func showAccess(c *cli.Context, manager auth.Manager, username string) error { func showAccess(c *cli.Context, manager user.Manager, username string) error {
if username == "" { if username == "" {
return showAllAccess(c, manager) return showAllAccess(c, manager)
} }
return showUserAccess(c, manager, username) return showUserAccess(c, manager, username)
} }
func showAllAccess(c *cli.Context, manager auth.Manager) error { func showAllAccess(c *cli.Context, manager user.Manager) error {
users, err := manager.Users() users, err := manager.Users()
if err != nil { if err != nil {
return err return err
@ -171,23 +171,23 @@ func showAllAccess(c *cli.Context, manager auth.Manager) error {
return showUsers(c, manager, users) return showUsers(c, manager, users)
} }
func showUserAccess(c *cli.Context, manager auth.Manager, username string) error { func showUserAccess(c *cli.Context, manager user.Manager, username string) error {
users, err := manager.User(username) users, err := manager.User(username)
if err == auth.ErrNotFound { if err == user.ErrNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} else if err != nil { } else if err != nil {
return err return err
} }
return showUsers(c, manager, []*auth.User{users}) return showUsers(c, manager, []*user.User{users})
} }
func showUsers(c *cli.Context, manager auth.Manager, users []*auth.User) error { func showUsers(c *cli.Context, manager user.Manager, users []*user.User) error {
for _, user := range users { for _, u := range users {
fmt.Fprintf(c.App.ErrWriter, "user %s (%s)\n", user.Name, user.Role) fmt.Fprintf(c.App.ErrWriter, "user %s (%s)\n", u.Name, u.Role)
if user.Role == auth.RoleAdmin { if u.Role == user.RoleAdmin {
fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n") fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n")
} else if len(user.Grants) > 0 { } else if len(u.Grants) > 0 {
for _, grant := range user.Grants { for _, grant := range u.Grants {
if grant.AllowRead && grant.AllowWrite { if grant.AllowRead && grant.AllowWrite {
fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s\n", grant.TopicPattern) fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s\n", grant.TopicPattern)
} else if grant.AllowRead { } else if grant.AllowRead {
@ -201,7 +201,7 @@ func showUsers(c *cli.Context, manager auth.Manager, users []*auth.User) error {
} else { } else {
fmt.Fprintf(c.App.ErrWriter, "- no topic-specific permissions\n") fmt.Fprintf(c.App.ErrWriter, "- no topic-specific permissions\n")
} }
if user.Name == auth.Everyone { if u.Name == user.Everyone {
defaultRead, defaultWrite := manager.DefaultAccess() defaultRead, defaultWrite := manager.DefaultAccess()
if defaultRead && defaultWrite { if defaultRead && defaultWrite {
fmt.Fprintln(c.App.ErrWriter, "- read-write access to all (other) topics (server config)") fmt.Fprintln(c.App.ErrWriter, "- read-write access to all (other) topics (server config)")

View File

@ -6,12 +6,12 @@ import (
"crypto/subtle" "crypto/subtle"
"errors" "errors"
"fmt" "fmt"
"heckel.io/ntfy/user"
"os" "os"
"strings" "strings"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc" "github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/auth"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
) )
@ -41,7 +41,7 @@ var cmdUser = &cli.Command{
UsageText: "ntfy user add [--role=admin|user] USERNAME\nNTFY_PASSWORD=... ntfy user add [--role=admin|user] USERNAME", UsageText: "ntfy user add [--role=admin|user] USERNAME\nNTFY_PASSWORD=... ntfy user add [--role=admin|user] USERNAME",
Action: execUserAdd, Action: execUserAdd,
Flags: []cli.Flag{ Flags: []cli.Flag{
&cli.StringFlag{Name: "role", Aliases: []string{"r"}, Value: string(auth.RoleUser), Usage: "user role"}, &cli.StringFlag{Name: "role", Aliases: []string{"r"}, Value: string(user.RoleUser), Usage: "user role"},
}, },
Description: `Add a new user to the ntfy user database. Description: `Add a new user to the ntfy user database.
@ -152,13 +152,13 @@ variable to pass the new password. This is useful if you are creating/updating u
func execUserAdd(c *cli.Context) error { func execUserAdd(c *cli.Context) error {
username := c.Args().Get(0) username := c.Args().Get(0)
role := auth.Role(c.String("role")) role := user.Role(c.String("role"))
password := os.Getenv("NTFY_PASSWORD") password := os.Getenv("NTFY_PASSWORD")
if username == "" { if username == "" {
return errors.New("username expected, type 'ntfy user add --help' for help") return errors.New("username expected, type 'ntfy user add --help' for help")
} else if username == userEveryone { } else if username == userEveryone {
return errors.New("username not allowed") return errors.New("username not allowed")
} else if !auth.AllowedRole(role) { } else if !user.AllowedRole(role) {
return errors.New("role must be either 'user' or 'admin'") return errors.New("role must be either 'user' or 'admin'")
} }
manager, err := createAuthManager(c) manager, err := createAuthManager(c)
@ -194,7 +194,7 @@ func execUserDel(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := manager.User(username); err == auth.ErrNotFound { if _, err := manager.User(username); err == user.ErrNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} }
if err := manager.RemoveUser(username); err != nil { if err := manager.RemoveUser(username); err != nil {
@ -216,7 +216,7 @@ func execUserChangePass(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := manager.User(username); err == auth.ErrNotFound { if _, err := manager.User(username); err == user.ErrNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} }
if password == "" { if password == "" {
@ -234,8 +234,8 @@ func execUserChangePass(c *cli.Context) error {
func execUserChangeRole(c *cli.Context) error { func execUserChangeRole(c *cli.Context) error {
username := c.Args().Get(0) username := c.Args().Get(0)
role := auth.Role(c.Args().Get(1)) role := user.Role(c.Args().Get(1))
if username == "" || !auth.AllowedRole(role) { if username == "" || !user.AllowedRole(role) {
return errors.New("username and new role expected, type 'ntfy user change-role --help' for help") return errors.New("username and new role expected, type 'ntfy user change-role --help' for help")
} else if username == userEveryone { } else if username == userEveryone {
return errors.New("username not allowed") return errors.New("username not allowed")
@ -244,7 +244,7 @@ func execUserChangeRole(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := manager.User(username); err == auth.ErrNotFound { if _, err := manager.User(username); err == user.ErrNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} }
if err := manager.ChangeRole(username, role); err != nil { if err := manager.ChangeRole(username, role); err != nil {
@ -266,7 +266,7 @@ func execUserList(c *cli.Context) error {
return showUsers(c, manager, users) return showUsers(c, manager, users)
} }
func createAuthManager(c *cli.Context) (auth.Manager, error) { func createAuthManager(c *cli.Context) (user.Manager, error) {
authFile := c.String("auth-file") authFile := c.String("auth-file")
authDefaultAccess := c.String("auth-default-access") authDefaultAccess := c.String("auth-default-access")
if authFile == "" { if authFile == "" {
@ -278,7 +278,7 @@ func createAuthManager(c *cli.Context) (auth.Manager, error) {
} }
authDefaultRead := authDefaultAccess == "read-write" || authDefaultAccess == "read-only" authDefaultRead := authDefaultAccess == "read-write" || authDefaultAccess == "read-only"
authDefaultWrite := authDefaultAccess == "read-write" || authDefaultAccess == "write-only" authDefaultWrite := authDefaultAccess == "read-write" || authDefaultAccess == "write-only"
return auth.NewSQLiteAuthManager(authFile, authDefaultRead, authDefaultWrite) return user.NewSQLiteAuthManager(authFile, authDefaultRead, authDefaultWrite)
} }
func readPasswordAndConfirm(c *cli.Context) (string, error) { func readPasswordAndConfirm(c *cli.Context) (string, error) {

View File

@ -54,6 +54,7 @@ var (
errHTTPBadRequestMatrixPushkeyBaseURLMismatch = &errHTTP{40020, http.StatusBadRequest, "invalid request: push key must be prefixed with base URL", "https://ntfy.sh/docs/publish/#matrix-gateway"} errHTTPBadRequestMatrixPushkeyBaseURLMismatch = &errHTTP{40020, http.StatusBadRequest, "invalid request: push key must be prefixed with base URL", "https://ntfy.sh/docs/publish/#matrix-gateway"}
errHTTPBadRequestIconURLInvalid = &errHTTP{40021, http.StatusBadRequest, "invalid request: icon URL is invalid", "https://ntfy.sh/docs/publish/#icons"} errHTTPBadRequestIconURLInvalid = &errHTTP{40021, http.StatusBadRequest, "invalid request: icon URL is invalid", "https://ntfy.sh/docs/publish/#icons"}
errHTTPBadRequestSignupNotEnabled = &errHTTP{40022, http.StatusBadRequest, "invalid request: signup not enabled", "https://ntfy.sh/docs/config"} errHTTPBadRequestSignupNotEnabled = &errHTTP{40022, http.StatusBadRequest, "invalid request: signup not enabled", "https://ntfy.sh/docs/config"}
errHTTPBadRequestNoTokenProvided = &errHTTP{40023, http.StatusBadRequest, "invalid request: no token provided", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}

View File

@ -9,6 +9,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"heckel.io/ntfy/user"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -30,17 +31,17 @@ import (
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"heckel.io/ntfy/auth"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
) )
/* /*
TODO TODO
expire tokens
auto-extend tokens from UI
use token auth in "SubscribeDialog" use token auth in "SubscribeDialog"
upload files based on user limit upload files based on user limit
database migration
publishXHR + poll should pick current user, not from userManager publishXHR + poll should pick current user, not from userManager
expire tokens
auto-refresh tokens from UI
reserve topics reserve topics
purge accounts that were not logged into in X purge accounts that were not logged into in X
sync subscription display name sync subscription display name
@ -55,7 +56,11 @@ import (
Polishing: Polishing:
aria-label for everything aria-label for everything
Tests:
- APIs
- CRUD tokens
- Expire tokens
-
*/ */
// Server is the main server, providing the UI and API for ntfy // Server is the main server, providing the UI and API for ntfy
@ -71,7 +76,7 @@ type Server struct {
visitors map[string]*visitor // ip:<ip> or user:<user> visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient firebaseClient *firebaseClient
messages int64 messages int64
auth auth.Manager userManager user.Manager
messageCache *messageCache messageCache *messageCache
fileCache *fileCache fileCache *fileCache
closeChan chan bool closeChan chan bool
@ -159,9 +164,9 @@ func New(conf *Config) (*Server, error) {
return nil, err return nil, err
} }
} }
var auther auth.Manager var auther user.Manager
if conf.AuthFile != "" { if conf.AuthFile != "" {
auther, err = auth.NewSQLiteAuthManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite) auther, err = user.NewSQLiteAuthManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -181,7 +186,7 @@ func New(conf *Config) (*Server, error) {
firebaseClient: firebaseClient, firebaseClient: firebaseClient,
smtpSender: mailer, smtpSender: mailer,
topics: topics, topics: topics,
auth: auther, userManager: auther,
visitors: make(map[string]*visitor), visitors: make(map[string]*visitor),
}, nil }, nil
} }
@ -342,11 +347,13 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.handleAccountDelete(w, r, v) return s.handleAccountDelete(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath { } else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath {
return s.handleAccountPasswordChange(w, r, v) return s.handleAccountPasswordChange(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == accountTokenPath { } else if r.Method == http.MethodPost && r.URL.Path == accountTokenPath {
return s.handleAccountTokenGet(w, r, v) return s.handleAccountTokenIssue(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath {
return s.handleAccountTokenExtend(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath { } else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath {
return s.handleAccountTokenDelete(w, r, v) return s.handleAccountTokenDelete(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountSettingsPath { } else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath {
return s.handleAccountSettingsChange(w, r, v) return s.handleAccountSettingsChange(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath { } else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath {
return s.handleAccountSubscriptionAdd(w, r, v) return s.handleAccountSubscriptionAdd(w, r, v)
@ -557,7 +564,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
} }
v.IncrMessages() v.IncrMessages()
if v.user != nil { if v.user != nil {
s.auth.EnqueueUpdateStats(v.user) s.userManager.EnqueueStats(v.user)
} }
s.mu.Lock() s.mu.Lock()
s.messages++ s.messages++
@ -1122,7 +1129,7 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
} }
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error { func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, DELETE") w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible? w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
return nil return nil
@ -1192,6 +1199,11 @@ func (s *Server) updateStatsAndPrune() {
s.mu.Unlock() s.mu.Unlock()
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Warn("Error expiring user tokens: %s", err.Error())
}
// Delete expired attachments // Delete expired attachments
if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 { if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 {
olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration) olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration)
@ -1323,7 +1335,7 @@ func (s *Server) sendDelayedMessages() error {
for _, m := range messages { for _, m := range messages {
var v *visitor var v *visitor
if m.User != "" { if m.User != "" {
user, err := s.auth.User(m.User) user, err := s.userManager.User(m.User)
if err != nil { if err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
continue continue
@ -1457,16 +1469,16 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
} }
func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc { func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
return s.autorizeTopic(next, auth.PermissionWrite) return s.autorizeTopic(next, user.PermissionWrite)
} }
func (s *Server) authorizeTopicRead(next handleFunc) handleFunc { func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
return s.autorizeTopic(next, auth.PermissionRead) return s.autorizeTopic(next, user.PermissionRead)
} }
func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc { func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error { return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.auth == nil { if s.userManager == nil {
return next(w, r, v) return next(w, r, v)
} }
topics, _, err := s.topicsFromPath(r.URL.Path) topics, _, err := s.topicsFromPath(r.URL.Path)
@ -1474,7 +1486,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc
return err return err
} }
for _, t := range topics { for _, t := range topics {
if err := s.auth.Authorize(v.user, t.ID, perm); err != nil { if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
log.Info("unauthorized: %s", err.Error()) log.Info("unauthorized: %s", err.Error())
return errHTTPForbidden return errHTTPForbidden
} }
@ -1487,7 +1499,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc
// Note that this function will always return a visitor, even if an error occurs. // Note that this function will always return a visitor, even if an error occurs.
func (s *Server) visitor(r *http.Request) (v *visitor, err error) { func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
ip := extractIPAddress(r, s.config.BehindProxy) ip := extractIPAddress(r, s.config.BehindProxy)
var user *auth.User // may stay nil if no auth header! var user *user.User // may stay nil if no auth header!
if user, err = s.authenticate(r); err != nil { if user, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error()) log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs! err = errHTTPUnauthorized // Always return visitor, even when error occurs!
@ -1505,7 +1517,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)). // query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) { func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
value := r.Header.Get("Authorization") value := r.Header.Get("Authorization")
queryParam := readQueryParam(r, "authorization", "auth") queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" { if queryParam != "" {
@ -1524,21 +1536,21 @@ func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
return s.authenticateBasicAuth(r, value) return s.authenticateBasicAuth(r, value)
} }
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) { func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
r.Header.Set("Authorization", value) r.Header.Set("Authorization", value)
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
if !ok { if !ok {
return nil, errors.New("invalid basic auth") return nil, errors.New("invalid basic auth")
} }
return s.auth.Authenticate(username, password) return s.userManager.Authenticate(username, password)
} }
func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) { func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer")) token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
return s.auth.AuthenticateToken(token) return s.userManager.AuthenticateToken(token)
} }
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor { func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
v, exists := s.visitors[visitorID] v, exists := s.visitors[visitorID]
@ -1554,6 +1566,6 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil) return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
} }
func (s *Server) visitorFromUser(user *auth.User, ip netip.Addr) *visitor { func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user) return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
} }

View File

@ -3,13 +3,13 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"heckel.io/ntfy/auth" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"net/http" "net/http"
) )
func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
admin := v.user != nil && v.user.Role == auth.RoleAdmin admin := v.user != nil && v.user.Role == user.RoleAdmin
if !admin { if !admin {
if !s.config.EnableSignup { if !s.config.EnableSignup {
return errHTTPBadRequestSignupNotEnabled return errHTTPBadRequestSignupNotEnabled
@ -26,13 +26,13 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if err := json.NewDecoder(body).Decode(&newAccount); err != nil { if err := json.NewDecoder(body).Decode(&newAccount); err != nil {
return err return err
} }
if existingUser, _ := s.auth.User(newAccount.Username); existingUser != nil { if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
return errHTTPConflictUserExists return errHTTPConflictUserExists
} }
if v.accountLimiter != nil && !v.accountLimiter.Allow() { if v.accountLimiter != nil && !v.accountLimiter.Allow() {
return errHTTPTooManyRequestsAccountCreateLimit return errHTTPTooManyRequestsAccountCreateLimit
} }
if err := s.auth.AddUser(newAccount.Username, newAccount.Password, auth.RoleUser); err != nil { // TODO this should return a User if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
return err return err
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -84,23 +84,23 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Code: v.user.Plan.Code, Code: v.user.Plan.Code,
Upgradable: v.user.Plan.Upgradable, Upgradable: v.user.Plan.Upgradable,
} }
} else if v.user.Role == auth.RoleAdmin { } else if v.user.Role == user.RoleAdmin {
response.Plan = &apiAccountPlan{ response.Plan = &apiAccountPlan{
Code: string(auth.PlanUnlimited), Code: string(user.PlanUnlimited),
Upgradable: false, Upgradable: false,
} }
} else { } else {
response.Plan = &apiAccountPlan{ response.Plan = &apiAccountPlan{
Code: string(auth.PlanDefault), Code: string(user.PlanDefault),
Upgradable: true, Upgradable: true,
} }
} }
} else { } else {
response.Username = auth.Everyone response.Username = user.Everyone
response.Role = string(auth.RoleAnonymous) response.Role = string(user.RoleAnonymous)
response.Plan = &apiAccountPlan{ response.Plan = &apiAccountPlan{
Code: string(auth.PlanNone), Code: string(user.PlanNone),
Upgradable: true, Upgradable: true,
} }
} }
@ -114,7 +114,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
if v.user == nil { if v.user == nil {
return errHTTPUnauthorized return errHTTPUnauthorized
} }
if err := s.auth.RemoveUser(v.user.Name); err != nil { if err := s.userManager.RemoveUser(v.user.Name); err != nil {
return err return err
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -136,7 +136,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
if err := json.NewDecoder(body).Decode(&newPassword); err != nil { if err := json.NewDecoder(body).Decode(&newPassword); err != nil {
return err return err
} }
if err := s.auth.ChangePassword(v.user.Name, newPassword.Password); err != nil { if err := s.userManager.ChangePassword(v.user.Name, newPassword.Password); err != nil {
return err return err
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -145,19 +145,43 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
return nil return nil
} }
func (s *Server) handleAccountTokenGet(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, r *http.Request, v *visitor) error {
// TODO rate limit // TODO rate limit
if v.user == nil { if v.user == nil {
return errHTTPUnauthorized return errHTTPUnauthorized
} }
token, err := s.auth.CreateToken(v.user) token, err := s.userManager.CreateToken(v.user)
if err != nil { if err != nil {
return err return err
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
response := &apiAccountTokenResponse{ response := &apiAccountTokenResponse{
Token: token, Token: token.Value,
Expires: token.Expires,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, r *http.Request, v *visitor) error {
// TODO rate limit
if v.user == nil {
return errHTTPUnauthorized
} else if v.user.Token == "" {
return errHTTPBadRequestNoTokenProvided
}
token, err := s.userManager.ExtendToken(v.user)
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
response := &apiAccountTokenResponse{
Token: token.Value,
Expires: token.Expires,
} }
if err := json.NewEncoder(w).Encode(response); err != nil { if err := json.NewEncoder(w).Encode(response); err != nil {
return err return err
@ -170,7 +194,7 @@ func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, r *http.Request
if v.user == nil || v.user.Token == "" { if v.user == nil || v.user.Token == "" {
return errHTTPUnauthorized return errHTTPUnauthorized
} }
if err := s.auth.RemoveToken(v.user); err != nil { if err := s.userManager.RemoveToken(v.user); err != nil {
return err return err
} }
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
@ -188,12 +212,12 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
return err return err
} }
defer r.Body.Close() defer r.Body.Close()
var newPrefs auth.UserPrefs var newPrefs user.Prefs
if err := json.NewDecoder(body).Decode(&newPrefs); err != nil { if err := json.NewDecoder(body).Decode(&newPrefs); err != nil {
return err return err
} }
if v.user.Prefs == nil { if v.user.Prefs == nil {
v.user.Prefs = &auth.UserPrefs{} v.user.Prefs = &user.Prefs{}
} }
prefs := v.user.Prefs prefs := v.user.Prefs
if newPrefs.Language != "" { if newPrefs.Language != "" {
@ -201,7 +225,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
} }
if newPrefs.Notification != nil { if newPrefs.Notification != nil {
if prefs.Notification == nil { if prefs.Notification == nil {
prefs.Notification = &auth.UserNotificationPrefs{} prefs.Notification = &user.NotificationPrefs{}
} }
if newPrefs.Notification.DeleteAfter > 0 { if newPrefs.Notification.DeleteAfter > 0 {
prefs.Notification.DeleteAfter = newPrefs.Notification.DeleteAfter prefs.Notification.DeleteAfter = newPrefs.Notification.DeleteAfter
@ -213,7 +237,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
prefs.Notification.MinPriority = newPrefs.Notification.MinPriority prefs.Notification.MinPriority = newPrefs.Notification.MinPriority
} }
} }
return s.auth.ChangeSettings(v.user) return s.userManager.ChangeSettings(v.user)
} }
func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
@ -227,12 +251,12 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
return err return err
} }
defer r.Body.Close() defer r.Body.Close()
var newSubscription auth.UserSubscription var newSubscription user.Subscription
if err := json.NewDecoder(body).Decode(&newSubscription); err != nil { if err := json.NewDecoder(body).Decode(&newSubscription); err != nil {
return err return err
} }
if v.user.Prefs == nil { if v.user.Prefs == nil {
v.user.Prefs = &auth.UserPrefs{} v.user.Prefs = &user.Prefs{}
} }
newSubscription.ID = "" // Client cannot set ID newSubscription.ID = "" // Client cannot set ID
for _, subscription := range v.user.Prefs.Subscriptions { for _, subscription := range v.user.Prefs.Subscriptions {
@ -244,7 +268,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
if newSubscription.ID == "" { if newSubscription.ID == "" {
newSubscription.ID = util.RandomString(16) newSubscription.ID = util.RandomString(16)
v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, &newSubscription) v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, &newSubscription)
if err := s.auth.ChangeSettings(v.user); err != nil { if err := s.userManager.ChangeSettings(v.user); err != nil {
return err return err
} }
} }
@ -268,7 +292,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil {
return nil return nil
} }
newSubscriptions := make([]*auth.UserSubscription, 0) newSubscriptions := make([]*user.Subscription, 0)
for _, subscription := range v.user.Prefs.Subscriptions { for _, subscription := range v.user.Prefs.Subscriptions {
if subscription.ID != subscriptionID { if subscription.ID != subscriptionID {
newSubscriptions = append(newSubscriptions, subscription) newSubscriptions = append(newSubscriptions, subscription)
@ -276,7 +300,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
} }
if len(newSubscriptions) < len(v.user.Prefs.Subscriptions) { if len(newSubscriptions) < len(v.user.Prefs.Subscriptions) {
v.user.Prefs.Subscriptions = newSubscriptions v.user.Prefs.Subscriptions = newSubscriptions
if err := s.auth.ChangeSettings(v.user); err != nil { if err := s.userManager.ChangeSettings(v.user); err != nil {
return err return err
} }
} }

View File

@ -8,8 +8,8 @@ import (
"firebase.google.com/go/v4/messaging" "firebase.google.com/go/v4/messaging"
"fmt" "fmt"
"google.golang.org/api/option" "google.golang.org/api/option"
"heckel.io/ntfy/auth"
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"strings" "strings"
) )
@ -28,10 +28,10 @@ var (
// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable. // The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
type firebaseClient struct { type firebaseClient struct {
sender firebaseSender sender firebaseSender
auther auth.Manager auther user.Manager
} }
func newFirebaseClient(sender firebaseSender, auther auth.Manager) *firebaseClient { func newFirebaseClient(sender firebaseSender, auther user.Manager) *firebaseClient {
return &firebaseClient{ return &firebaseClient{
sender: sender, sender: sender,
auther: auther, auther: auther,
@ -112,7 +112,7 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
// On Android, this will trigger the app to poll the topic and thereby displaying new messages. // On Android, this will trigger the app to poll the topic and thereby displaying new messages.
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded // - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
// to Firebase here. This is mainly for iOS to support self-hosted servers. // to Firebase here. This is mainly for iOS to support self-hosted servers.
func toFirebaseMessage(m *message, auther auth.Manager) (*messaging.Message, error) { func toFirebaseMessage(m *message, auther user.Manager) (*messaging.Message, error) {
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
var apnsConfig *messaging.APNSConfig var apnsConfig *messaging.APNSConfig
switch m.Event { switch m.Event {
@ -137,7 +137,7 @@ func toFirebaseMessage(m *message, auther auth.Manager) (*messaging.Message, err
case messageEvent: case messageEvent:
allowForward := true allowForward := true
if auther != nil { if auther != nil {
allowForward = auther.Authorize(nil, m.Topic, auth.PermissionRead) == nil allowForward = auther.Authorize(nil, m.Topic, user.PermissionRead) == nil
} }
if allowForward { if allowForward {
data = map[string]string{ data = map[string]string{

View File

@ -11,18 +11,17 @@ import (
"firebase.google.com/go/v4/messaging" "firebase.google.com/go/v4/messaging"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/auth"
) )
type testAuther struct { type testAuther struct {
Allow bool Allow bool
} }
func (t testAuther) AuthenticateUser(_, _ string) (*auth.User, error) { func (t testAuther) AuthenticateUser(_, _ string) (*user.User, error) {
return nil, errors.New("not used") return nil, errors.New("not used")
} }
func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error { func (t testAuther) Authorize(_ *user.User, _ string, _ user.Permission) error {
if t.Allow { if t.Allow {
return nil return nil
} }

View File

@ -21,7 +21,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/auth"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
) )
@ -626,8 +625,8 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
c.AuthFile = filepath.Join(t.TempDir(), "user.db") c.AuthFile = filepath.Join(t.TempDir(), "user.db")
s := newTestServer(t, c) s := newTestServer(t, c)
manager := s.auth.(auth.Manager) manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": basicAuth("phil:phil"), "Authorization": basicAuth("phil:phil"),
@ -643,8 +642,8 @@ func TestServer_Auth_Success_User(t *testing.T) {
c.AuthDefaultWrite = false c.AuthDefaultWrite = false
s := newTestServer(t, c) s := newTestServer(t, c)
manager := s.auth.(auth.Manager) manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -660,8 +659,8 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
c.AuthDefaultWrite = false c.AuthDefaultWrite = false
s := newTestServer(t, c) s := newTestServer(t, c)
manager := s.auth.(auth.Manager) manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, manager.AllowAccess("ben", "anothertopic", true, true)) require.Nil(t, manager.AllowAccess("ben", "anothertopic", true, true))
@ -683,8 +682,8 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
c.AuthDefaultWrite = false c.AuthDefaultWrite = false
s := newTestServer(t, c) s := newTestServer(t, c)
manager := s.auth.(auth.Manager) manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": basicAuth("phil:INVALID"), "Authorization": basicAuth("phil:INVALID"),
@ -699,8 +698,8 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
c.AuthDefaultWrite = false c.AuthDefaultWrite = false
s := newTestServer(t, c) s := newTestServer(t, c)
manager := s.auth.(auth.Manager) manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, manager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic! require.Nil(t, manager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -716,10 +715,10 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
c.AuthDefaultWrite = true // Open by default c.AuthDefaultWrite = true // Open by default
s := newTestServer(t, c) s := newTestServer(t, c)
manager := s.auth.(auth.Manager) manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, manager.AllowAccess(auth.Everyone, "private", false, false)) require.Nil(t, manager.AllowAccess(user.Everyone, "private", false, false))
require.Nil(t, manager.AllowAccess(auth.Everyone, "announcements", true, false)) require.Nil(t, manager.AllowAccess(user.Everyone, "announcements", true, false))
response := request(t, s, "PUT", "/mytopic", "test", nil) response := request(t, s, "PUT", "/mytopic", "test", nil)
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
@ -749,8 +748,8 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
c.AuthDefaultWrite = false c.AuthDefaultWrite = false
s := newTestServer(t, c) s := newTestServer(t, c)
manager := s.auth.(auth.Manager) manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "some pass", auth.RoleAdmin)) require.Nil(t, manager.AddUser("ben", "some pass", user.RoleAdmin))
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(basicAuth("ben:some pass")))) u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(basicAuth("ben:some pass"))))
response := request(t, s, "GET", u, "", nil) response := request(t, s, "GET", u, "", nil)

View File

@ -1,7 +1,7 @@
package server package server
import ( import (
"heckel.io/ntfy/auth" "heckel.io/ntfy/user"
"net/http" "net/http"
"net/netip" "net/netip"
"time" "time"
@ -226,7 +226,8 @@ type apiAccountCreateRequest struct {
} }
type apiAccountTokenResponse struct { type apiAccountTokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
Expires int64 `json:"expires"`
} }
type apiAccountPlan struct { type apiAccountPlan struct {
@ -252,12 +253,12 @@ type apiAccountStats struct {
} }
type apiAccountSettingsResponse struct { type apiAccountSettingsResponse struct {
Username string `json:"username"` Username string `json:"username"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"` Language string `json:"language,omitempty"`
Notification *auth.UserNotificationPrefs `json:"notification,omitempty"` Notification *user.NotificationPrefs `json:"notification,omitempty"`
Subscriptions []*auth.UserSubscription `json:"subscriptions,omitempty"` Subscriptions []*user.Subscription `json:"subscriptions,omitempty"`
Plan *apiAccountPlan `json:"plan,omitempty"` Plan *apiAccountPlan `json:"plan,omitempty"`
Limits *apiAccountLimits `json:"limits,omitempty"` Limits *apiAccountLimits `json:"limits,omitempty"`
Stats *apiAccountStats `json:"stats,omitempty"` Stats *apiAccountStats `json:"stats,omitempty"`
} }

View File

@ -2,7 +2,7 @@ package server
import ( import (
"errors" "errors"
"heckel.io/ntfy/auth" "heckel.io/ntfy/user"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@ -27,7 +27,7 @@ type visitor struct {
config *Config config *Config
messageCache *messageCache messageCache *messageCache
ip netip.Addr ip netip.Addr
user *auth.User user *user.User
messages int64 // Number of messages sent messages int64 // Number of messages sent
emails int64 // Number of emails sent emails int64 // Number of emails sent
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
@ -54,7 +54,7 @@ type visitorStats struct {
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64
} }
func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr, user *auth.User) *visitor { func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr, user *user.User) *visitor {
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
var messages, emails int64 var messages, emails int64
if user != nil { if user != nil {
@ -171,7 +171,7 @@ func (v *visitor) Stats() (*visitorStats, error) {
emails := v.emails emails := v.emails
v.mu.Unlock() v.mu.Unlock()
stats := &visitorStats{} stats := &visitorStats{}
if v.user != nil && v.user.Role == auth.RoleAdmin { if v.user != nil && v.user.Role == user.RoleAdmin {
stats.Basis = "role" stats.Basis = "role"
stats.MessagesLimit = 0 stats.MessagesLimit = 0
stats.EmailsLimit = 0 stats.EmailsLimit = 0

View File

@ -1,5 +1,5 @@
// Package auth deals with authentication and authorization against topics // Package auth deals with authentication and authorization against topics
package auth package user
import ( import (
"errors" "errors"
@ -14,10 +14,12 @@ type Manager interface {
Authenticate(username, password string) (*User, error) Authenticate(username, password string) (*User, error)
AuthenticateToken(token string) (*User, error) AuthenticateToken(token string) (*User, error)
CreateToken(user *User) (string, error) CreateToken(user *User) (*Token, error)
ExtendToken(user *User) (*Token, error)
RemoveToken(user *User) error RemoveToken(user *User) error
RemoveExpiredTokens() error
ChangeSettings(user *User) error ChangeSettings(user *User) error
EnqueueUpdateStats(user *User) EnqueueStats(user *User)
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user. // permission. The user param may be nil to signal an anonymous user.
@ -64,15 +66,20 @@ type User struct {
Token string // Only set if token was used to log in Token string // Only set if token was used to log in
Role Role Role Role
Grants []Grant Grants []Grant
Prefs *UserPrefs Prefs *Prefs
Plan *Plan Plan *Plan
Stats *Stats Stats *Stats
} }
type UserPrefs struct { type Token struct {
Language string `json:"language,omitempty"` Value string
Notification *UserNotificationPrefs `json:"notification,omitempty"` Expires int64
Subscriptions []*UserSubscription `json:"subscriptions,omitempty"` }
type Prefs struct {
Language string `json:"language,omitempty"`
Notification *NotificationPrefs `json:"notification,omitempty"`
Subscriptions []*Subscription `json:"subscriptions,omitempty"`
} }
type PlanCode string type PlanCode string
@ -92,13 +99,13 @@ type Plan struct {
AttachmentTotalSizeLimit int64 `json:"attachment_total_size_limit"` AttachmentTotalSizeLimit int64 `json:"attachment_total_size_limit"`
} }
type UserSubscription struct { type Subscription struct {
ID string `json:"id"` ID string `json:"id"`
BaseURL string `json:"base_url"` BaseURL string `json:"base_url"`
Topic string `json:"topic"` Topic string `json:"topic"`
} }
type UserNotificationPrefs struct { type NotificationPrefs struct {
Sound string `json:"sound,omitempty"` Sound string `json:"sound,omitempty"`
MinPriority int `json:"min_priority,omitempty"` MinPriority int `json:"min_priority,omitempty"`
DeleteAfter int `json:"delete_after,omitempty"` DeleteAfter int `json:"delete_after,omitempty"`

View File

@ -1,4 +1,4 @@
package auth package user
import ( import (
"database/sql" "database/sql"
@ -15,10 +15,11 @@ import (
) )
const ( const (
tokenLength = 32 tokenLength = 32
bcryptCost = 10 bcryptCost = 10
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
statsWriterInterval = 10 * time.Second userStatsQueueWriterInterval = 33 * time.Second
userTokenExpiryDuration = 72 * time.Hour
) )
// Manager-related queries // Manager-related queries
@ -106,9 +107,11 @@ const (
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?` deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteUserTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?)` deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
deleteUserTokensQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?)`
) )
// Schema management queries // Schema management queries
@ -118,20 +121,20 @@ const (
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
) )
// SQLiteAuthManager is an implementation of Manager. It stores users and access control list // SQLiteManager is an implementation of Manager. It stores users and access control list
// in a SQLite database. // in a SQLite database.
type SQLiteAuthManager struct { type SQLiteManager struct {
db *sql.DB db *sql.DB
defaultRead bool defaultRead bool
defaultWrite bool defaultWrite bool
statsQueue map[string]*Stats // Username -> Stats statsQueue map[string]*User // Username -> User, for "unimportant" user updates
mu sync.Mutex mu sync.Mutex
} }
var _ Manager = (*SQLiteAuthManager)(nil) var _ Manager = (*SQLiteManager)(nil)
// NewSQLiteAuthManager creates a new SQLiteAuthManager instance // NewSQLiteAuthManager creates a new SQLiteManager instance
func NewSQLiteAuthManager(filename string, defaultRead, defaultWrite bool) (*SQLiteAuthManager, error) { func NewSQLiteAuthManager(filename string, defaultRead, defaultWrite bool) (*SQLiteManager, error) {
db, err := sql.Open("sqlite3", filename) db, err := sql.Open("sqlite3", filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -139,20 +142,20 @@ func NewSQLiteAuthManager(filename string, defaultRead, defaultWrite bool) (*SQL
if err := setupAuthDB(db); err != nil { if err := setupAuthDB(db); err != nil {
return nil, err return nil, err
} }
manager := &SQLiteAuthManager{ manager := &SQLiteManager{
db: db, db: db,
defaultRead: defaultRead, defaultRead: defaultRead,
defaultWrite: defaultWrite, defaultWrite: defaultWrite,
statsQueue: make(map[string]*Stats), statsQueue: make(map[string]*User),
} }
go manager.statsWriter() go manager.userStatsQueueWriter()
return manager, nil return manager, nil
} }
// Authenticate checks username and password and returns a user if correct. The method // Authenticate checks username and password and returns a user if correct. The method
// returns in constant-ish time, regardless of whether the user exists or the password is // returns in constant-ish time, regardless of whether the user exists or the password is
// correct or incorrect. // correct or incorrect.
func (a *SQLiteAuthManager) Authenticate(username, password string) (*User, error) { func (a *SQLiteManager) Authenticate(username, password string) (*User, error) {
if username == Everyone { if username == Everyone {
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
} }
@ -168,7 +171,7 @@ func (a *SQLiteAuthManager) Authenticate(username, password string) (*User, erro
return user, nil return user, nil
} }
func (a *SQLiteAuthManager) AuthenticateToken(token string) (*User, error) { func (a *SQLiteManager) AuthenticateToken(token string) (*User, error) {
user, err := a.userByToken(token) user, err := a.userByToken(token)
if err != nil { if err != nil {
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
@ -177,16 +180,30 @@ func (a *SQLiteAuthManager) AuthenticateToken(token string) (*User, error) {
return user, nil return user, nil
} }
func (a *SQLiteAuthManager) CreateToken(user *User) (string, error) { func (a *SQLiteManager) CreateToken(user *User) (*Token, error) {
token := util.RandomString(tokenLength) token := util.RandomString(tokenLength)
expires := 1 // FIXME expires := time.Now().Add(userTokenExpiryDuration)
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires); err != nil { if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
return "", err return nil, err
} }
return token, nil return &Token{
Value: token,
Expires: expires.Unix(),
}, nil
} }
func (a *SQLiteAuthManager) RemoveToken(user *User) error { func (a *SQLiteManager) ExtendToken(user *User) (*Token, error) {
newExpires := time.Now().Add(userTokenExpiryDuration)
if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
return nil, err
}
return &Token{
Value: user.Token,
Expires: newExpires.Unix(),
}, nil
}
func (a *SQLiteManager) RemoveToken(user *User) error {
if user.Token == "" { if user.Token == "" {
return ErrUnauthorized return ErrUnauthorized
} }
@ -196,7 +213,14 @@ func (a *SQLiteAuthManager) RemoveToken(user *User) error {
return nil return nil
} }
func (a *SQLiteAuthManager) ChangeSettings(user *User) error { func (a *SQLiteManager) RemoveExpiredTokens() error {
if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
return err
}
return nil
}
func (a *SQLiteManager) ChangeSettings(user *User) error {
settings, err := json.Marshal(user.Prefs) settings, err := json.Marshal(user.Prefs)
if err != nil { if err != nil {
return err return err
@ -207,33 +231,40 @@ func (a *SQLiteAuthManager) ChangeSettings(user *User) error {
return nil return nil
} }
func (a *SQLiteAuthManager) EnqueueUpdateStats(user *User) { func (a *SQLiteManager) EnqueueStats(user *User) {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
a.statsQueue[user.Name] = user.Stats a.statsQueue[user.Name] = user
} }
func (a *SQLiteAuthManager) statsWriter() { func (a *SQLiteManager) userStatsQueueWriter() {
ticker := time.NewTicker(statsWriterInterval) ticker := time.NewTicker(userStatsQueueWriterInterval)
for range ticker.C { for range ticker.C {
if err := a.writeStats(); err != nil { if err := a.writeUserStatsQueue(); err != nil {
log.Warn("UserManager: Writing user stats failed: %s", err.Error()) log.Warn("UserManager: Writing user stats queue failed: %s", err.Error())
} }
} }
} }
func (a *SQLiteAuthManager) writeStats() error { func (a *SQLiteManager) writeUserStatsQueue() error {
a.mu.Lock()
if len(a.statsQueue) == 0 {
a.mu.Unlock()
log.Trace("UserManager: No user stats updates to commit")
return nil
}
statsQueue := a.statsQueue
a.statsQueue = make(map[string]*User)
a.mu.Unlock()
tx, err := a.db.Begin() tx, err := a.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
a.mu.Lock() log.Debug("UserManager: Writing user stats queue for %d user(s)", len(statsQueue))
statsQueue := a.statsQueue for username, u := range statsQueue {
a.statsQueue = make(map[string]*Stats) log.Trace("UserManager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails)
a.mu.Unlock() if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil {
for username, stats := range statsQueue {
if _, err := tx.Exec(updateUserStatsQuery, stats.Messages, stats.Emails, username); err != nil {
return err return err
} }
} }
@ -242,7 +273,7 @@ func (a *SQLiteAuthManager) writeStats() error {
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user. // permission. The user param may be nil to signal an anonymous user.
func (a *SQLiteAuthManager) Authorize(user *User, topic string, perm Permission) error { func (a *SQLiteManager) Authorize(user *User, topic string, perm Permission) error {
if user != nil && user.Role == RoleAdmin { if user != nil && user.Role == RoleAdmin {
return nil // Admin can do everything return nil // Admin can do everything
} }
@ -270,7 +301,7 @@ func (a *SQLiteAuthManager) Authorize(user *User, topic string, perm Permission)
return a.resolvePerms(read, write, perm) return a.resolvePerms(read, write, perm)
} }
func (a *SQLiteAuthManager) resolvePerms(read, write bool, perm Permission) error { func (a *SQLiteManager) resolvePerms(read, write bool, perm Permission) error {
if perm == PermissionRead && read { if perm == PermissionRead && read {
return nil return nil
} else if perm == PermissionWrite && write { } else if perm == PermissionWrite && write {
@ -281,7 +312,7 @@ func (a *SQLiteAuthManager) resolvePerms(read, write bool, perm Permission) erro
// AddUser adds a user with the given username, password and role. The password should be hashed // AddUser adds a user with the given username, password and role. The password should be hashed
// before it is stored in a persistence layer. // before it is stored in a persistence layer.
func (a *SQLiteAuthManager) AddUser(username, password string, role Role) error { func (a *SQLiteManager) AddUser(username, password string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) { if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument return ErrInvalidArgument
} }
@ -297,14 +328,14 @@ func (a *SQLiteAuthManager) AddUser(username, password string, role Role) error
// RemoveUser deletes the user with the given username. The function returns nil on success, even // RemoveUser deletes the user with the given username. The function returns nil on success, even
// if the user did not exist in the first place. // if the user did not exist in the first place.
func (a *SQLiteAuthManager) RemoveUser(username string) error { func (a *SQLiteManager) RemoveUser(username string) error {
if !AllowedUsername(username) { if !AllowedUsername(username) {
return ErrInvalidArgument return ErrInvalidArgument
} }
if _, err := a.db.Exec(deleteUserAccessQuery, username); err != nil { if _, err := a.db.Exec(deleteUserAccessQuery, username); err != nil {
return err return err
} }
if _, err := a.db.Exec(deleteUserTokenQuery, username); err != nil { if _, err := a.db.Exec(deleteUserTokensQuery, username); err != nil {
return err return err
} }
if _, err := a.db.Exec(deleteUserQuery, username); err != nil { if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
@ -314,7 +345,7 @@ func (a *SQLiteAuthManager) RemoveUser(username string) error {
} }
// Users returns a list of users. It always also returns the Everyone user ("*"). // Users returns a list of users. It always also returns the Everyone user ("*").
func (a *SQLiteAuthManager) Users() ([]*User, error) { func (a *SQLiteManager) Users() ([]*User, error) {
rows, err := a.db.Query(selectUsernamesQuery) rows, err := a.db.Query(selectUsernamesQuery)
if err != nil { if err != nil {
return nil, err return nil, err
@ -349,7 +380,7 @@ func (a *SQLiteAuthManager) Users() ([]*User, error) {
// User returns the user with the given username if it exists, or ErrNotFound otherwise. // User returns the user with the given username if it exists, or ErrNotFound otherwise.
// You may also pass Everyone to retrieve the anonymous user and its Grant list. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
func (a *SQLiteAuthManager) User(username string) (*User, error) { func (a *SQLiteManager) User(username string) (*User, error) {
if username == Everyone { if username == Everyone {
return a.everyoneUser() return a.everyoneUser()
} }
@ -360,7 +391,7 @@ func (a *SQLiteAuthManager) User(username string) (*User, error) {
return a.readUser(rows) return a.readUser(rows)
} }
func (a *SQLiteAuthManager) userByToken(token string) (*User, error) { func (a *SQLiteManager) userByToken(token string) (*User, error) {
rows, err := a.db.Query(selectUserByTokenQuery, token) rows, err := a.db.Query(selectUserByTokenQuery, token)
if err != nil { if err != nil {
return nil, err return nil, err
@ -368,7 +399,7 @@ func (a *SQLiteAuthManager) userByToken(token string) (*User, error) {
return a.readUser(rows) return a.readUser(rows)
} }
func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) { func (a *SQLiteManager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var username, hash, role string var username, hash, role string
var settings, planCode sql.NullString var settings, planCode sql.NullString
@ -397,7 +428,7 @@ func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) {
}, },
} }
if settings.Valid { if settings.Valid {
user.Prefs = &UserPrefs{} user.Prefs = &Prefs{}
if err := json.Unmarshal([]byte(settings.String), user.Prefs); err != nil { if err := json.Unmarshal([]byte(settings.String), user.Prefs); err != nil {
return nil, err return nil, err
} }
@ -415,7 +446,7 @@ func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) {
return user, nil return user, nil
} }
func (a *SQLiteAuthManager) everyoneUser() (*User, error) { func (a *SQLiteManager) everyoneUser() (*User, error) {
grants, err := a.readGrants(Everyone) grants, err := a.readGrants(Everyone)
if err != nil { if err != nil {
return nil, err return nil, err
@ -428,7 +459,7 @@ func (a *SQLiteAuthManager) everyoneUser() (*User, error) {
}, nil }, nil
} }
func (a *SQLiteAuthManager) readGrants(username string) ([]Grant, error) { func (a *SQLiteManager) readGrants(username string) ([]Grant, error) {
rows, err := a.db.Query(selectUserAccessQuery, username) rows, err := a.db.Query(selectUserAccessQuery, username)
if err != nil { if err != nil {
return nil, err return nil, err
@ -453,7 +484,7 @@ func (a *SQLiteAuthManager) readGrants(username string) ([]Grant, error) {
} }
// ChangePassword changes a user's password // ChangePassword changes a user's password
func (a *SQLiteAuthManager) ChangePassword(username, password string) error { func (a *SQLiteManager) ChangePassword(username, password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost) hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
if err != nil { if err != nil {
return err return err
@ -466,7 +497,7 @@ func (a *SQLiteAuthManager) ChangePassword(username, password string) error {
// ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin, // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
// all existing access control entries (Grant) are removed, since they are no longer needed. // all existing access control entries (Grant) are removed, since they are no longer needed.
func (a *SQLiteAuthManager) ChangeRole(username string, role Role) error { func (a *SQLiteManager) ChangeRole(username string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) { if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument return ErrInvalidArgument
} }
@ -483,7 +514,7 @@ func (a *SQLiteAuthManager) ChangeRole(username string, role Role) error {
// AllowAccess adds or updates an entry in th access control list for a specific user. It controls // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
// read/write access to a topic. The parameter topicPattern may include wildcards (*). // read/write access to a topic. The parameter topicPattern may include wildcards (*).
func (a *SQLiteAuthManager) AllowAccess(username string, topicPattern string, read bool, write bool) error { func (a *SQLiteManager) AllowAccess(username string, topicPattern string, read bool, write bool) error {
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopicPattern(topicPattern) { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument return ErrInvalidArgument
} }
@ -495,7 +526,7 @@ func (a *SQLiteAuthManager) AllowAccess(username string, topicPattern string, re
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*). // empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *SQLiteAuthManager) ResetAccess(username string, topicPattern string) error { func (a *SQLiteManager) ResetAccess(username string, topicPattern string) error {
if !AllowedUsername(username) && username != Everyone && username != "" { if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
@ -513,7 +544,7 @@ func (a *SQLiteAuthManager) ResetAccess(username string, topicPattern string) er
} }
// DefaultAccess returns the default read/write access if no access control entry matches // DefaultAccess returns the default read/write access if no access control entry matches
func (a *SQLiteAuthManager) DefaultAccess() (read bool, write bool) { func (a *SQLiteManager) DefaultAccess() (read bool, write bool) {
return a.defaultRead, a.defaultWrite return a.defaultRead, a.defaultWrite
} }

View File

@ -1,8 +1,7 @@
package auth_test package user_test
import ( import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/auth"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
@ -13,29 +12,29 @@ const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this sh
func TestSQLiteAuth_FullScenario_Default_DenyAll(t *testing.T) { func TestSQLiteAuth_FullScenario_Default_DenyAll(t *testing.T) {
a := newTestAuth(t, false, false) a := newTestAuth(t, false, false)
require.Nil(t, a.AddUser("phil", "phil", auth.RoleAdmin)) require.Nil(t, a.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, a.AddUser("ben", "ben", auth.RoleUser)) require.Nil(t, a.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", true, false))
require.Nil(t, a.AllowAccess("ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "writeme", false, true))
require.Nil(t, a.AllowAccess("ben", "everyonewrite", false, false)) // How unfair! require.Nil(t, a.AllowAccess("ben", "everyonewrite", false, false)) // How unfair!
require.Nil(t, a.AllowAccess(auth.Everyone, "announcements", true, false)) require.Nil(t, a.AllowAccess(user.Everyone, "announcements", true, false))
require.Nil(t, a.AllowAccess(auth.Everyone, "everyonewrite", true, true)) require.Nil(t, a.AllowAccess(user.Everyone, "everyonewrite", true, true))
require.Nil(t, a.AllowAccess(auth.Everyone, "up*", false, true)) // Everyone can write to /up* require.Nil(t, a.AllowAccess(user.Everyone, "up*", false, true)) // Everyone can write to /up*
phil, err := a.Authenticate("phil", "phil") phil, err := a.Authenticate("phil", "phil")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "phil", phil.Name) require.Equal(t, "phil", phil.Name)
require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$")) require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$"))
require.Equal(t, auth.RoleAdmin, phil.Role) require.Equal(t, user.RoleAdmin, phil.Role)
require.Equal(t, []auth.Grant{}, phil.Grants) require.Equal(t, []user.Grant{}, phil.Grants)
ben, err := a.Authenticate("ben", "ben") ben, err := a.Authenticate("ben", "ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "ben", ben.Name) require.Equal(t, "ben", ben.Name)
require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$")) require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$"))
require.Equal(t, auth.RoleUser, ben.Role) require.Equal(t, user.RoleUser, ben.Role)
require.Equal(t, []auth.Grant{ require.Equal(t, []user.Grant{
{"mytopic", true, true}, {"mytopic", true, true},
{"readme", true, false}, {"readme", true, false},
{"writeme", false, true}, {"writeme", false, true},
@ -44,62 +43,62 @@ func TestSQLiteAuth_FullScenario_Default_DenyAll(t *testing.T) {
notben, err := a.Authenticate("ben", "this is wrong") notben, err := a.Authenticate("ben", "this is wrong")
require.Nil(t, notben) require.Nil(t, notben)
require.Equal(t, auth.ErrUnauthenticated, err) require.Equal(t, user.ErrUnauthenticated, err)
// Admin can do everything // Admin can do everything
require.Nil(t, a.Authorize(phil, "sometopic", auth.PermissionWrite)) require.Nil(t, a.Authorize(phil, "sometopic", user.PermissionWrite))
require.Nil(t, a.Authorize(phil, "mytopic", auth.PermissionRead)) require.Nil(t, a.Authorize(phil, "mytopic", user.PermissionRead))
require.Nil(t, a.Authorize(phil, "readme", auth.PermissionWrite)) require.Nil(t, a.Authorize(phil, "readme", user.PermissionWrite))
require.Nil(t, a.Authorize(phil, "writeme", auth.PermissionWrite)) require.Nil(t, a.Authorize(phil, "writeme", user.PermissionWrite))
require.Nil(t, a.Authorize(phil, "announcements", auth.PermissionWrite)) require.Nil(t, a.Authorize(phil, "announcements", user.PermissionWrite))
require.Nil(t, a.Authorize(phil, "everyonewrite", auth.PermissionWrite)) require.Nil(t, a.Authorize(phil, "everyonewrite", user.PermissionWrite))
// User cannot do everything // User cannot do everything
require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionWrite)) require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionWrite))
require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionRead)) require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionRead))
require.Nil(t, a.Authorize(ben, "readme", auth.PermissionRead)) require.Nil(t, a.Authorize(ben, "readme", user.PermissionRead))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "readme", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "readme", user.PermissionWrite))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "writeme", auth.PermissionRead)) require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "writeme", user.PermissionRead))
require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite))
require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "everyonewrite", auth.PermissionRead)) require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "everyonewrite", user.PermissionRead))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "everyonewrite", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "everyonewrite", user.PermissionWrite))
require.Nil(t, a.Authorize(ben, "announcements", auth.PermissionRead)) require.Nil(t, a.Authorize(ben, "announcements", user.PermissionRead))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "announcements", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "announcements", user.PermissionWrite))
// Everyone else can do barely anything // Everyone else can do barely anything
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", auth.PermissionRead)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", user.PermissionRead))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", user.PermissionWrite))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "mytopic", auth.PermissionRead)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "mytopic", user.PermissionRead))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "mytopic", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "mytopic", user.PermissionWrite))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "readme", auth.PermissionRead)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "readme", user.PermissionRead))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "readme", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "readme", user.PermissionWrite))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "writeme", auth.PermissionRead)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "writeme", user.PermissionRead))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "writeme", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "writeme", user.PermissionWrite))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "announcements", auth.PermissionWrite)) require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "announcements", user.PermissionWrite))
require.Nil(t, a.Authorize(nil, "announcements", auth.PermissionRead)) require.Nil(t, a.Authorize(nil, "announcements", user.PermissionRead))
require.Nil(t, a.Authorize(nil, "everyonewrite", auth.PermissionRead)) require.Nil(t, a.Authorize(nil, "everyonewrite", user.PermissionRead))
require.Nil(t, a.Authorize(nil, "everyonewrite", auth.PermissionWrite)) require.Nil(t, a.Authorize(nil, "everyonewrite", user.PermissionWrite))
require.Nil(t, a.Authorize(nil, "up1234", auth.PermissionWrite)) // Wildcard permission require.Nil(t, a.Authorize(nil, "up1234", user.PermissionWrite)) // Wildcard permission
require.Nil(t, a.Authorize(nil, "up5678", auth.PermissionWrite)) require.Nil(t, a.Authorize(nil, "up5678", user.PermissionWrite))
} }
func TestSQLiteAuth_AddUser_Invalid(t *testing.T) { func TestSQLiteAuth_AddUser_Invalid(t *testing.T) {
a := newTestAuth(t, false, false) a := newTestAuth(t, false, false)
require.Equal(t, auth.ErrInvalidArgument, a.AddUser(" invalid ", "pass", auth.RoleAdmin)) require.Equal(t, user.ErrInvalidArgument, a.AddUser(" invalid ", "pass", user.RoleAdmin))
require.Equal(t, auth.ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role")) require.Equal(t, user.ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role"))
} }
func TestSQLiteAuth_AddUser_Timing(t *testing.T) { func TestSQLiteAuth_AddUser_Timing(t *testing.T) {
a := newTestAuth(t, false, false) a := newTestAuth(t, false, false)
start := time.Now().UnixMilli() start := time.Now().UnixMilli()
require.Nil(t, a.AddUser("user", "pass", auth.RoleAdmin)) require.Nil(t, a.AddUser("user", "pass", user.RoleAdmin))
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
} }
func TestSQLiteAuth_Authenticate_Timing(t *testing.T) { func TestSQLiteAuth_Authenticate_Timing(t *testing.T) {
a := newTestAuth(t, false, false) a := newTestAuth(t, false, false)
require.Nil(t, a.AddUser("user", "pass", auth.RoleAdmin)) require.Nil(t, a.AddUser("user", "pass", user.RoleAdmin))
// Timing a correct attempt // Timing a correct attempt
start := time.Now().UnixMilli() start := time.Now().UnixMilli()
@ -110,53 +109,53 @@ func TestSQLiteAuth_Authenticate_Timing(t *testing.T) {
// Timing an incorrect attempt // Timing an incorrect attempt
start = time.Now().UnixMilli() start = time.Now().UnixMilli()
_, err = a.Authenticate("user", "INCORRECT") _, err = a.Authenticate("user", "INCORRECT")
require.Equal(t, auth.ErrUnauthenticated, err) require.Equal(t, user.ErrUnauthenticated, err)
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
// Timing a non-existing user attempt // Timing a non-existing user attempt
start = time.Now().UnixMilli() start = time.Now().UnixMilli()
_, err = a.Authenticate("DOES-NOT-EXIST", "hithere") _, err = a.Authenticate("DOES-NOT-EXIST", "hithere")
require.Equal(t, auth.ErrUnauthenticated, err) require.Equal(t, user.ErrUnauthenticated, err)
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
} }
func TestSQLiteAuth_UserManagement(t *testing.T) { func TestSQLiteAuth_UserManagement(t *testing.T) {
a := newTestAuth(t, false, false) a := newTestAuth(t, false, false)
require.Nil(t, a.AddUser("phil", "phil", auth.RoleAdmin)) require.Nil(t, a.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, a.AddUser("ben", "ben", auth.RoleUser)) require.Nil(t, a.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", true, false))
require.Nil(t, a.AllowAccess("ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "writeme", false, true))
require.Nil(t, a.AllowAccess("ben", "everyonewrite", false, false)) // How unfair! require.Nil(t, a.AllowAccess("ben", "everyonewrite", false, false)) // How unfair!
require.Nil(t, a.AllowAccess(auth.Everyone, "announcements", true, false)) require.Nil(t, a.AllowAccess(user.Everyone, "announcements", true, false))
require.Nil(t, a.AllowAccess(auth.Everyone, "everyonewrite", true, true)) require.Nil(t, a.AllowAccess(user.Everyone, "everyonewrite", true, true))
// Query user details // Query user details
phil, err := a.User("phil") phil, err := a.User("phil")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "phil", phil.Name) require.Equal(t, "phil", phil.Name)
require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$")) require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$"))
require.Equal(t, auth.RoleAdmin, phil.Role) require.Equal(t, user.RoleAdmin, phil.Role)
require.Equal(t, []auth.Grant{}, phil.Grants) require.Equal(t, []user.Grant{}, phil.Grants)
ben, err := a.User("ben") ben, err := a.User("ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "ben", ben.Name) require.Equal(t, "ben", ben.Name)
require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$")) require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$"))
require.Equal(t, auth.RoleUser, ben.Role) require.Equal(t, user.RoleUser, ben.Role)
require.Equal(t, []auth.Grant{ require.Equal(t, []user.Grant{
{"mytopic", true, true}, {"mytopic", true, true},
{"readme", true, false}, {"readme", true, false},
{"writeme", false, true}, {"writeme", false, true},
{"everyonewrite", false, false}, {"everyonewrite", false, false},
}, ben.Grants) }, ben.Grants)
everyone, err := a.User(auth.Everyone) everyone, err := a.User(user.Everyone)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "*", everyone.Name) require.Equal(t, "*", everyone.Name)
require.Equal(t, "", everyone.Hash) require.Equal(t, "", everyone.Hash)
require.Equal(t, auth.RoleAnonymous, everyone.Role) require.Equal(t, user.RoleAnonymous, everyone.Role)
require.Equal(t, []auth.Grant{ require.Equal(t, []user.Grant{
{"announcements", true, false}, {"announcements", true, false},
{"everyonewrite", true, true}, {"everyonewrite", true, true},
}, everyone.Grants) }, everyone.Grants)
@ -165,22 +164,22 @@ func TestSQLiteAuth_UserManagement(t *testing.T) {
require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", true, false))
require.Nil(t, a.AllowAccess("ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "writeme", false, true))
require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionRead)) require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionRead))
require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionWrite)) require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionWrite))
require.Nil(t, a.Authorize(ben, "readme", auth.PermissionRead)) require.Nil(t, a.Authorize(ben, "readme", user.PermissionRead))
require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite))
// Revoke access for "ben" to "mytopic", then check again // Revoke access for "ben" to "mytopic", then check again
require.Nil(t, a.ResetAccess("ben", "mytopic")) require.Nil(t, a.ResetAccess("ben", "mytopic"))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "mytopic", auth.PermissionWrite)) // Revoked require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "mytopic", user.PermissionWrite)) // Revoked
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "mytopic", auth.PermissionRead)) // Revoked require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "mytopic", user.PermissionRead)) // Revoked
require.Nil(t, a.Authorize(ben, "readme", auth.PermissionRead)) // Unchanged require.Nil(t, a.Authorize(ben, "readme", user.PermissionRead)) // Unchanged
require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) // Unchanged require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite)) // Unchanged
// Revoke rest of the access // Revoke rest of the access
require.Nil(t, a.ResetAccess("ben", "")) require.Nil(t, a.ResetAccess("ben", ""))
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "readme", auth.PermissionRead)) // Revoked require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "readme", user.PermissionRead)) // Revoked
require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "wrtiteme", auth.PermissionWrite)) // Revoked require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "wrtiteme", user.PermissionWrite)) // Revoked
// User list // User list
users, err := a.Users() users, err := a.Users()
@ -193,7 +192,7 @@ func TestSQLiteAuth_UserManagement(t *testing.T) {
// Remove user // Remove user
require.Nil(t, a.RemoveUser("ben")) require.Nil(t, a.RemoveUser("ben"))
_, err = a.User("ben") _, err = a.User("ben")
require.Equal(t, auth.ErrNotFound, err) require.Equal(t, user.ErrNotFound, err)
users, err = a.Users() users, err = a.Users()
require.Nil(t, err) require.Nil(t, err)
@ -204,40 +203,40 @@ func TestSQLiteAuth_UserManagement(t *testing.T) {
func TestSQLiteAuth_ChangePassword(t *testing.T) { func TestSQLiteAuth_ChangePassword(t *testing.T) {
a := newTestAuth(t, false, false) a := newTestAuth(t, false, false)
require.Nil(t, a.AddUser("phil", "phil", auth.RoleAdmin)) require.Nil(t, a.AddUser("phil", "phil", user.RoleAdmin))
_, err := a.Authenticate("phil", "phil") _, err := a.Authenticate("phil", "phil")
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, a.ChangePassword("phil", "newpass")) require.Nil(t, a.ChangePassword("phil", "newpass"))
_, err = a.Authenticate("phil", "phil") _, err = a.Authenticate("phil", "phil")
require.Equal(t, auth.ErrUnauthenticated, err) require.Equal(t, user.ErrUnauthenticated, err)
_, err = a.Authenticate("phil", "newpass") _, err = a.Authenticate("phil", "newpass")
require.Nil(t, err) require.Nil(t, err)
} }
func TestSQLiteAuth_ChangeRole(t *testing.T) { func TestSQLiteAuth_ChangeRole(t *testing.T) {
a := newTestAuth(t, false, false) a := newTestAuth(t, false, false)
require.Nil(t, a.AddUser("ben", "ben", auth.RoleUser)) require.Nil(t, a.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", true, false))
ben, err := a.User("ben") ben, err := a.User("ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, auth.RoleUser, ben.Role) require.Equal(t, user.RoleUser, ben.Role)
require.Equal(t, 2, len(ben.Grants)) require.Equal(t, 2, len(ben.Grants))
require.Nil(t, a.ChangeRole("ben", auth.RoleAdmin)) require.Nil(t, a.ChangeRole("ben", user.RoleAdmin))
ben, err = a.User("ben") ben, err = a.User("ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, auth.RoleAdmin, ben.Role) require.Equal(t, user.RoleAdmin, ben.Role)
require.Equal(t, 0, len(ben.Grants)) require.Equal(t, 0, len(ben.Grants))
} }
func newTestAuth(t *testing.T, defaultRead, defaultWrite bool) *auth.SQLiteAuthManager { func newTestAuth(t *testing.T, defaultRead, defaultWrite bool) *user.SQLiteAuthManager {
filename := filepath.Join(t.TempDir(), "user.db") filename := filepath.Join(t.TempDir(), "user.db")
a, err := auth.NewSQLiteAuthManager(filename, defaultRead, defaultWrite) a, err := user.NewSQLiteAuthManager(filename, defaultRead, defaultWrite)
require.Nil(t, err) require.Nil(t, err)
return a return a
} }

View File

@ -1,14 +1,18 @@
import { import {
accountPasswordUrl,
accountSettingsUrl,
accountSubscriptionSingleUrl,
accountSubscriptionUrl,
accountTokenUrl,
accountUrl,
fetchLinesIterator, fetchLinesIterator,
maybeWithBasicAuth, maybeWithBearerAuth, maybeWithBasicAuth,
maybeWithBearerAuth,
topicShortUrl, topicShortUrl,
topicUrl, topicUrl,
topicUrlAuth, topicUrlAuth,
topicUrlJsonPoll, topicUrlJsonPoll,
topicUrlJsonPollWithSince, topicUrlJsonPollWithSince
accountSettingsUrl,
accountTokenUrl,
userStatsUrl, accountSubscriptionUrl, accountSubscriptionSingleUrl, accountUrl, accountPasswordUrl
} from "./utils"; } from "./utils";
import userManager from "./UserManager"; import userManager from "./UserManager";
@ -74,7 +78,7 @@ class Api {
xhr.setRequestHeader(key, value); xhr.setRequestHeader(key, value);
} }
xhr.upload.addEventListener("progress", onProgress); xhr.upload.addEventListener("progress", onProgress);
xhr.addEventListener('readystatechange', (ev) => { xhr.addEventListener('readystatechange', () => {
if (xhr.readyState === 4 && xhr.status >= 200 && xhr.status <= 299) { if (xhr.readyState === 4 && xhr.status >= 200 && xhr.status <= 299) {
console.log(`[Api] Publish successful (HTTP ${xhr.status})`, xhr.response); console.log(`[Api] Publish successful (HTTP ${xhr.status})`, xhr.response);
resolve(xhr.response); resolve(xhr.response);
@ -123,6 +127,7 @@ class Api {
const url = accountTokenUrl(baseUrl); const url = accountTokenUrl(baseUrl);
console.log(`[Api] Checking auth for ${url}`); console.log(`[Api] Checking auth for ${url}`);
const response = await fetch(url, { const response = await fetch(url, {
method: "POST",
headers: maybeWithBasicAuth({}, user) headers: maybeWithBasicAuth({}, user)
}); });
if (response.status === 401 || response.status === 403) { if (response.status === 401 || response.status === 403) {
@ -218,12 +223,26 @@ class Api {
} }
} }
async extendToken(baseUrl, token) {
const url = accountTokenUrl(baseUrl);
console.log(`[Api] Extending user access token ${url}`);
const response = await fetch(url, {
method: "PATCH",
headers: maybeWithBearerAuth({}, token)
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
}
async updateAccountSettings(baseUrl, token, payload) { async updateAccountSettings(baseUrl, token, payload) {
const url = accountSettingsUrl(baseUrl); const url = accountSettingsUrl(baseUrl);
const body = JSON.stringify(payload); const body = JSON.stringify(payload);
console.log(`[Api] Updating user account ${url}: ${body}`); console.log(`[Api] Updating user account ${url}: ${body}`);
const response = await fetch(url, { const response = await fetch(url, {
method: "POST", method: "PATCH",
headers: maybeWithBearerAuth({}, token), headers: maybeWithBearerAuth({}, token),
body: body body: body
}); });