From 4b9d40464c6216922ef28f639787890c2e3e2846 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 2 Jan 2023 21:12:42 -0500 Subject: [PATCH] Replace read/write flags with Permission --- cmd/access.go | 14 ++--- cmd/access_test.go | 2 +- cmd/serve.go | 12 ++-- cmd/user.go | 10 ++-- cmd/user_test.go | 20 +------ server/config.go | 7 +-- server/errors.go | 5 +- server/server.go | 2 +- server/server_account.go | 24 +++----- server/server_test.go | 18 ++---- user/manager.go | 41 ++++++------- user/manager_test.go | 121 ++++++++++++++++++++++++--------------- user/types.go | 70 ++++++++++++++++++---- 13 files changed, 194 insertions(+), 152 deletions(-) diff --git a/cmd/access.go b/cmd/access.go index 88f666c9..c304acd5 100644 --- a/cmd/access.go +++ b/cmd/access.go @@ -192,11 +192,11 @@ func showUsers(c *cli.Context, manager *user.Manager, users []*user.User) error fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n") } else if len(grants) > 0 { for _, grant := range grants { - if grant.AllowRead && grant.AllowWrite { + if grant.Allow.IsReadWrite() { fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s\n", grant.TopicPattern) - } else if grant.AllowRead { + } else if grant.Allow.IsRead() { fmt.Fprintf(c.App.ErrWriter, "- read-only access to topic %s\n", grant.TopicPattern) - } else if grant.AllowWrite { + } else if grant.Allow.IsWrite() { fmt.Fprintf(c.App.ErrWriter, "- write-only access to topic %s\n", grant.TopicPattern) } else { fmt.Fprintf(c.App.ErrWriter, "- no access to topic %s\n", grant.TopicPattern) @@ -206,12 +206,12 @@ func showUsers(c *cli.Context, manager *user.Manager, users []*user.User) error fmt.Fprintf(c.App.ErrWriter, "- no topic-specific permissions\n") } if u.Name == user.Everyone { - defaultRead, defaultWrite := manager.DefaultAccess() - if defaultRead && defaultWrite { + access := manager.DefaultAccess() + if access.IsReadWrite() { fmt.Fprintln(c.App.ErrWriter, "- read-write access to all (other) topics (server config)") - } else if defaultRead { + } else if access.IsRead() { fmt.Fprintln(c.App.ErrWriter, "- read-only access to all (other) topics (server config)") - } else if defaultWrite { + } else if access.IsWrite() { fmt.Fprintln(c.App.ErrWriter, "- write-only access to all (other) topics (server config)") } else { fmt.Fprintln(c.App.ErrWriter, "- no access to any (other) topics (server config)") diff --git a/cmd/access_test.go b/cmd/access_test.go index 67159d43..6e3c5ba3 100644 --- a/cmd/access_test.go +++ b/cmd/access_test.go @@ -81,7 +81,7 @@ func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error { "ntfy", "access", "--auth-file=" + conf.AuthFile, - "--auth-default-access=" + confToDefaultAccess(conf), + "--auth-default-access=" + conf.AuthDefault.String(), } return app.Run(append(userArgs, args...)) } diff --git a/cmd/serve.go b/cmd/serve.go index e82c3bf7..196b8918 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,6 +5,7 @@ package cmd import ( "errors" "fmt" + "heckel.io/ntfy/user" "io/fs" "math" "net" @@ -171,8 +172,6 @@ func execServe(c *cli.Context) error { return errors.New("if set, base-url must start with http:// or https://") } else if baseURL != "" && strings.HasSuffix(baseURL, "/") { return errors.New("if set, base-url must not end with a slash (/)") - } else if !util.Contains([]string{"read-write", "read-only", "write-only", "deny-all"}, authDefaultAccess) { - return errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'") } else if !util.Contains([]string{"app", "home", "disable"}, webRoot) { return errors.New("if set, web-root must be 'home' or 'app'") } else if upstreamBaseURL != "" && !strings.HasPrefix(upstreamBaseURL, "http://") && !strings.HasPrefix(upstreamBaseURL, "https://") { @@ -189,8 +188,10 @@ func execServe(c *cli.Context) error { enableWeb := webRoot != "disable" // Default auth permissions - authDefaultRead := authDefaultAccess == "read-write" || authDefaultAccess == "read-only" - authDefaultWrite := authDefaultAccess == "read-write" || authDefaultAccess == "write-only" + authDefault, err := user.ParsePermission(authDefaultAccess) + if err != nil { + return errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'") + } // Special case: Unset default if listenHTTP == "-" { @@ -244,8 +245,7 @@ func execServe(c *cli.Context) error { conf.CacheBatchSize = cacheBatchSize conf.CacheBatchTimeout = cacheBatchTimeout conf.AuthFile = authFile - conf.AuthDefaultRead = authDefaultRead - conf.AuthDefaultWrite = authDefaultWrite + conf.AuthDefault = authDefault conf.AttachmentCacheDir = attachmentCacheDir conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit conf.AttachmentFileSizeLimit = attachmentFileSizeLimit diff --git a/cmd/user.go b/cmd/user.go index 3562e305..ee50de62 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -273,12 +273,12 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { return nil, errors.New("option auth-file not set; auth is unconfigured for this server") } else if !util.FileExists(authFile) { return nil, errors.New("auth-file does not exist; please start the server at least once to create it") - } else if !util.Contains([]string{"read-write", "read-only", "write-only", "deny-all"}, authDefaultAccess) { - return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only' or 'deny-all'") } - authDefaultRead := authDefaultAccess == "read-write" || authDefaultAccess == "read-only" - authDefaultWrite := authDefaultAccess == "read-write" || authDefaultAccess == "write-only" - return user.NewManager(authFile, authDefaultRead, authDefaultWrite) + authDefault, err := user.ParsePermission(authDefaultAccess) + if err != nil { + return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'") + } + return user.NewManager(authFile, authDefault) } func readPasswordAndConfirm(c *cli.Context) (string, error) { diff --git a/cmd/user_test.go b/cmd/user_test.go index 666cb422..c8bcd6a6 100644 --- a/cmd/user_test.go +++ b/cmd/user_test.go @@ -5,6 +5,7 @@ import ( "github.com/urfave/cli/v2" "heckel.io/ntfy/server" "heckel.io/ntfy/test" + "heckel.io/ntfy/user" "path/filepath" "testing" ) @@ -114,8 +115,7 @@ func TestCLI_User_Delete(t *testing.T) { func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) { conf = server.NewConfig() conf.AuthFile = filepath.Join(t.TempDir(), "user.db") - conf.AuthDefaultRead = false - conf.AuthDefaultWrite = false + conf.AuthDefault = user.PermissionDenyAll s, port = test.StartServerWithConfig(t, conf) return } @@ -125,21 +125,7 @@ func runUserCommand(app *cli.App, conf *server.Config, args ...string) error { "ntfy", "user", "--auth-file=" + conf.AuthFile, - "--auth-default-access=" + confToDefaultAccess(conf), + "--auth-default-access=" + conf.AuthDefault.String(), } return app.Run(append(userArgs, args...)) } - -func confToDefaultAccess(conf *server.Config) string { - var defaultAccess string - if conf.AuthDefaultRead && conf.AuthDefaultWrite { - defaultAccess = "read-write" - } else if conf.AuthDefaultRead && !conf.AuthDefaultWrite { - defaultAccess = "read-only" - } else if !conf.AuthDefaultRead && conf.AuthDefaultWrite { - defaultAccess = "write-only" - } else if !conf.AuthDefaultRead && !conf.AuthDefaultWrite { - defaultAccess = "deny-all" - } - return defaultAccess -} diff --git a/server/config.go b/server/config.go index abc53a50..c3f93560 100644 --- a/server/config.go +++ b/server/config.go @@ -1,6 +1,7 @@ package server import ( + "heckel.io/ntfy/user" "io/fs" "net/netip" "time" @@ -66,8 +67,7 @@ type Config struct { CacheBatchSize int CacheBatchTimeout time.Duration AuthFile string - AuthDefaultRead bool - AuthDefaultWrite bool + AuthDefault user.Permission AttachmentCacheDir string AttachmentTotalSizeLimit int64 AttachmentFileSizeLimit int64 @@ -127,8 +127,7 @@ func NewConfig() *Config { CacheBatchSize: 0, CacheBatchTimeout: 0, AuthFile: "", - AuthDefaultRead: true, - AuthDefaultWrite: true, + AuthDefault: user.NewPermission(true, true), AttachmentCacheDir: "", AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit, AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit, diff --git a/server/errors.go b/server/errors.go index a1d8bcb8..e136287d 100644 --- a/server/errors.go +++ b/server/errors.go @@ -41,8 +41,8 @@ var ( errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"} errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"} - errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: topic invalid", ""} - errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""} + errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid request: topic invalid", ""} + errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid request: topic name is disallowed", ""} errHTTPBadRequestMessageNotUTF8 = &errHTTP{40011, http.StatusBadRequest, "invalid message: message must be UTF-8 encoded", ""} errHTTPBadRequestAttachmentURLInvalid = &errHTTP{40013, http.StatusBadRequest, "invalid request: attachment URL is invalid", "https://ntfy.sh/docs/publish/#attachments"} errHTTPBadRequestAttachmentsDisallowed = &errHTTP{40014, http.StatusBadRequest, "invalid request: attachments not allowed", "https://ntfy.sh/docs/config/#attachments"} @@ -56,6 +56,7 @@ var ( errHTTPBadRequestSignupNotEnabled = &errHTTP{40022, http.StatusBadRequest, "invalid request: signup not enabled", "https://ntfy.sh/docs/config"} errHTTPBadRequestNoTokenProvided = &errHTTP{40023, http.StatusBadRequest, "invalid request: no token provided", ""} errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""} + errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} diff --git a/server/server.go b/server/server.go index 032879b6..b4a79f0e 100644 --- a/server/server.go +++ b/server/server.go @@ -165,7 +165,7 @@ func New(conf *Config) (*Server, error) { } var userManager *user.Manager if conf.AuthFile != "" { - userManager, err = user.NewManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite) + userManager, err = user.NewManager(conf.AuthFile, conf.AuthDefault) if err != nil { return nil, err } diff --git a/server/server_account.go b/server/server_account.go index c6d226be..ce8a2921 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -101,19 +101,9 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis if len(reservations) > 0 { response.Reservations = make([]*apiAccountReservation, 0) for _, r := range reservations { - var everyone string - if r.AllowEveryoneRead && r.AllowEveryoneWrite { - everyone = "read-write" - } else if r.AllowEveryoneRead && !r.AllowEveryoneWrite { - everyone = "read-only" - } else if !r.AllowEveryoneRead && r.AllowEveryoneWrite { - everyone = "write-only" - } else { - everyone = "deny-all" - } response.Reservations = append(response.Reservations, &apiAccountReservation{ - Topic: r.TopicPattern, - Everyone: everyone, + Topic: r.Topic, + Everyone: r.Everyone.String(), }) } } @@ -345,12 +335,14 @@ func (s *Server) handleAccountAccessAdd(w http.ResponseWriter, r *http.Request, return errHTTPConflictTopicReserved } owner, username := v.user.Name, v.user.Name - everyoneRead := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, req.Everyone) - everyoneWrite := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, req.Everyone) + everyone, err := user.ParsePermission(req.Everyone) + if err != nil { + return errHTTPBadRequestPermissionInvalid + } if err := s.userManager.AllowAccess(owner, username, req.Topic, true, true); err != nil { return err } - if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyoneRead, everyoneWrite); err != nil { + if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil { return err } w.Header().Set("Content-Type", "application/json") @@ -373,7 +365,7 @@ func (s *Server) handleAccountAccessDelete(w http.ResponseWriter, r *http.Reques } authorized := false for _, r := range reservations { - if r.TopicPattern == topic { + if r.Topic == topic { authorized = true break } diff --git a/server/server_test.go b/server/server_test.go index bf731b9c..5ce41260 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -638,8 +638,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) { func TestServer_Auth_Success_User(t *testing.T) { c := newTestConfig(t) c.AuthFile = filepath.Join(t.TempDir(), "user.db") - c.AuthDefaultRead = false - c.AuthDefaultWrite = false + c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) @@ -654,8 +653,7 @@ func TestServer_Auth_Success_User(t *testing.T) { func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) { c := newTestConfig(t) c.AuthFile = filepath.Join(t.TempDir(), "user.db") - c.AuthDefaultRead = false - c.AuthDefaultWrite = false + c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) @@ -676,8 +674,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) { func TestServer_Auth_Fail_InvalidPass(t *testing.T) { c := newTestConfig(t) c.AuthFile = filepath.Join(t.TempDir(), "user.db") - c.AuthDefaultRead = false - c.AuthDefaultWrite = false + c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin)) @@ -691,8 +688,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) { func TestServer_Auth_Fail_Unauthorized(t *testing.T) { c := newTestConfig(t) c.AuthFile = filepath.Join(t.TempDir(), "user.db") - c.AuthDefaultRead = false - c.AuthDefaultWrite = false + c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) @@ -707,8 +703,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) { func TestServer_Auth_Fail_CannotPublish(t *testing.T) { c := newTestConfig(t) c.AuthFile = filepath.Join(t.TempDir(), "user.db") - c.AuthDefaultRead = true // Open by default - c.AuthDefaultWrite = true // Open by default + c.AuthDefault = user.PermissionReadWrite // Open by default s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin)) @@ -739,8 +734,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) { func TestServer_Auth_ViaQuery(t *testing.T) { c := newTestConfig(t) c.AuthFile = filepath.Join(t.TempDir(), "user.db") - c.AuthDefaultRead = false - c.AuthDefaultWrite = false + c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin)) diff --git a/user/manager.go b/user/manager.go index 41d42ec0..3dd77a92 100644 --- a/user/manager.go +++ b/user/manager.go @@ -182,8 +182,7 @@ const ( // in a SQLite database. type Manager struct { db *sql.DB - defaultRead bool // Default read permission if no ACL matches - defaultWrite bool // Default write permission if no ACL matches + defaultAccess Permission // Default permission if no ACL matches statsQueue map[string]*User // Username -> User, for "unimportant" user updates tokenExpiryInterval time.Duration // Duration after which tokens expire, and by which tokens are extended mu sync.Mutex @@ -192,12 +191,12 @@ type Manager struct { var _ Auther = (*Manager)(nil) // NewManager creates a new Manager instance -func NewManager(filename string, defaultRead, defaultWrite bool) (*Manager, error) { - return newManager(filename, defaultRead, defaultWrite, userTokenExpiryDuration, userStatsQueueWriterInterval) +func NewManager(filename string, defaultAccess Permission) (*Manager, error) { + return newManager(filename, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) } // NewManager creates a new Manager instance -func newManager(filename string, defaultRead, defaultWrite bool, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) { +func newManager(filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -207,8 +206,7 @@ func newManager(filename string, defaultRead, defaultWrite bool, tokenExpiryDura } manager := &Manager{ db: db, - defaultRead: defaultRead, - defaultWrite: defaultWrite, + defaultAccess: defaultAccess, statsQueue: make(map[string]*User), tokenExpiryInterval: tokenExpiryDuration, } @@ -368,7 +366,7 @@ func (a *Manager) Authorize(user *User, topic string, perm Permission) error { } defer rows.Close() if !rows.Next() { - return a.resolvePerms(a.defaultRead, a.defaultWrite, perm) + return a.resolvePerms(a.defaultAccess, perm) } var read, write bool if err := rows.Scan(&read, &write); err != nil { @@ -376,13 +374,13 @@ func (a *Manager) Authorize(user *User, topic string, perm Permission) error { } else if err := rows.Err(); err != nil { return err } - return a.resolvePerms(read, write, perm) + return a.resolvePerms(NewPermission(read, write), perm) } -func (a *Manager) resolvePerms(read, write bool, perm Permission) error { - if perm == PermissionRead && read { +func (a *Manager) resolvePerms(base, perm Permission) error { + if perm == PermissionRead && base.IsRead() { return nil - } else if perm == PermissionWrite && write { + } else if perm == PermissionWrite && base.IsWrite() { return nil } return ErrUnauthorized @@ -534,8 +532,7 @@ func (a *Manager) Grants(username string) ([]Grant, error) { } grants = append(grants, Grant{ TopicPattern: fromSQLWildcard(topic), - AllowRead: read, - AllowWrite: write, + Allow: NewPermission(read, write), }) } return grants, nil @@ -551,19 +548,17 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) { reservations := make([]Reservation, 0) for rows.Next() { var topic string - var read, write bool + var ownerRead, ownerWrite bool var everyoneRead, everyoneWrite sql.NullBool - if err := rows.Scan(&topic, &read, &write, &everyoneRead, &everyoneWrite); err != nil { + if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } reservations = append(reservations, Reservation{ - TopicPattern: topic, - AllowRead: read, - AllowWrite: write, - AllowEveryoneRead: everyoneRead.Bool, // false if null - AllowEveryoneWrite: everyoneWrite.Bool, // false if null + Topic: topic, + Owner: NewPermission(ownerRead, ownerWrite), + Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null }) } return reservations, nil @@ -659,8 +654,8 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error { } // DefaultAccess returns the default read/write access if no access control entry matches -func (a *Manager) DefaultAccess() (read bool, write bool) { - return a.defaultRead, a.defaultWrite +func (a *Manager) DefaultAccess() Permission { + return a.defaultAccess } func toSQLWildcard(s string) string { diff --git a/user/manager_test.go b/user/manager_test.go index 9c2153bd..88453722 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -12,7 +12,7 @@ import ( const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources func TestManager_FullScenario_Default_DenyAll(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin)) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) @@ -28,19 +28,25 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) { require.Equal(t, "phil", phil.Name) require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$")) require.Equal(t, RoleAdmin, phil.Role) - require.Equal(t, []Grant{}, phil.Grants) + + philGrants, err := a.Grants("phil") + require.Nil(t, err) + require.Equal(t, []Grant{}, philGrants) ben, err := a.Authenticate("ben", "ben") require.Nil(t, err) require.Equal(t, "ben", ben.Name) require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$")) require.Equal(t, RoleUser, ben.Role) + + benGrants, err := a.Grants("ben") + require.Nil(t, err) require.Equal(t, []Grant{ - {"mytopic", true, true, false}, - {"writeme", false, true, false}, - {"readme", true, false, false}, - {"everyonewrite", false, false, false}, - }, ben.Grants) + {"mytopic", PermissionReadWrite}, + {"writeme", PermissionWrite}, + {"readme", PermissionRead}, + {"everyonewrite", PermissionDenyAll}, + }, benGrants) notben, err := a.Authenticate("ben", "this is wrong") require.Nil(t, notben) @@ -85,20 +91,20 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) { } func TestManager_AddUser_Invalid(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin)) require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role")) } func TestManager_AddUser_Timing(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) start := time.Now().UnixMilli() require.Nil(t, a.AddUser("user", "pass", RoleAdmin)) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) } func TestManager_Authenticate_Timing(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("user", "pass", RoleAdmin)) // Timing a correct attempt @@ -121,7 +127,7 @@ func TestManager_Authenticate_Timing(t *testing.T) { } func TestManager_UserManagement(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin)) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) @@ -137,29 +143,38 @@ func TestManager_UserManagement(t *testing.T) { require.Equal(t, "phil", phil.Name) require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$")) require.Equal(t, RoleAdmin, phil.Role) - require.Equal(t, []Grant{}, phil.Grants) + + philGrants, err := a.Grants("phil") + require.Nil(t, err) + require.Equal(t, []Grant{}, philGrants) ben, err := a.User("ben") require.Nil(t, err) require.Equal(t, "ben", ben.Name) require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$")) require.Equal(t, RoleUser, ben.Role) + + benGrants, err := a.Grants("ben") + require.Nil(t, err) require.Equal(t, []Grant{ - {"mytopic", true, true, false}, - {"writeme", false, true, false}, - {"readme", true, false, false}, - {"everyonewrite", false, false, false}, - }, ben.Grants) + {"mytopic", PermissionReadWrite}, + {"writeme", PermissionWrite}, + {"readme", PermissionRead}, + {"everyonewrite", PermissionDenyAll}, + }, benGrants) everyone, err := a.User(Everyone) require.Nil(t, err) require.Equal(t, "*", everyone.Name) require.Equal(t, "", everyone.Hash) require.Equal(t, RoleAnonymous, everyone.Role) + + everyoneGrants, err := a.Grants(Everyone) + require.Nil(t, err) require.Equal(t, []Grant{ - {"everyonewrite", true, true, false}, - {"announcements", true, false, false}, - }, everyone.Grants) + {"everyonewrite", PermissionReadWrite}, + {"announcements", PermissionRead}, + }, everyoneGrants) // Ben: Before revoking require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) // Overwrite! @@ -203,7 +218,7 @@ func TestManager_UserManagement(t *testing.T) { } func TestManager_ChangePassword(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin)) _, err := a.Authenticate("phil", "phil") @@ -217,7 +232,7 @@ func TestManager_ChangePassword(t *testing.T) { } func TestManager_ChangeRole(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) @@ -225,18 +240,24 @@ func TestManager_ChangeRole(t *testing.T) { ben, err := a.User("ben") require.Nil(t, err) require.Equal(t, RoleUser, ben.Role) - require.Equal(t, 2, len(ben.Grants)) + + benGrants, err := a.Grants("ben") + require.Nil(t, err) + require.Equal(t, 2, len(benGrants)) require.Nil(t, a.ChangeRole("ben", RoleAdmin)) ben, err = a.User("ben") require.Nil(t, err) require.Equal(t, RoleAdmin, ben.Role) - require.Equal(t, 0, len(ben.Grants)) + + benGrants, err = a.Grants("ben") + require.Nil(t, err) + require.Equal(t, 0, len(benGrants)) } func TestManager_Token_Valid(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) u, err := a.User("ben") @@ -261,7 +282,7 @@ func TestManager_Token_Valid(t *testing.T) { } func TestManager_Token_Invalid(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length @@ -274,7 +295,7 @@ func TestManager_Token_Invalid(t *testing.T) { } func TestManager_Token_Expire(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) u, err := a.User("ben") @@ -322,7 +343,7 @@ func TestManager_Token_Expire(t *testing.T) { } func TestManager_Token_Extend(t *testing.T) { - a := newTestManager(t, false, false) + a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) // Try to extend token for user without token @@ -349,7 +370,7 @@ func TestManager_Token_Extend(t *testing.T) { } func TestManager_EnqueueStats(t *testing.T) { - a, err := newManager(filepath.Join(t.TempDir(), "db"), true, true, time.Hour, 1500*time.Millisecond) + a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) require.Nil(t, err) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) @@ -379,7 +400,7 @@ func TestManager_EnqueueStats(t *testing.T) { } func TestManager_ChangeSettings(t *testing.T) { - a, err := newManager(filepath.Join(t.TempDir(), "db"), true, true, time.Hour, 1500*time.Millisecond) + a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) require.Nil(t, err) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) @@ -461,7 +482,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) { require.Nil(t, err) // Create manager to trigger migration - a := newTestManagerFromFile(t, filename, false, false, userTokenExpiryDuration, userStatsQueueWriterInterval) + a := newTestManagerFromFile(t, filename, PermissionDenyAll, userTokenExpiryDuration, userStatsQueueWriterInterval) checkSchemaVersion(t, a.db) users, err := a.Users() @@ -469,26 +490,32 @@ func TestSqliteCache_Migration_From1(t *testing.T) { require.Equal(t, 3, len(users)) phil, ben, everyone := users[0], users[1], users[2] + philGrants, err := a.Grants("phil") + require.Nil(t, err) + + benGrants, err := a.Grants("ben") + require.Nil(t, err) + + everyoneGrants, err := a.Grants(Everyone) + require.Nil(t, err) + require.Equal(t, "phil", phil.Name) require.Equal(t, RoleAdmin, phil.Role) - require.Equal(t, 0, len(phil.Grants)) + require.Equal(t, 0, len(philGrants)) require.Equal(t, "ben", ben.Name) require.Equal(t, RoleUser, ben.Role) - require.Equal(t, 2, len(ben.Grants)) - require.Equal(t, "stats", ben.Grants[0].TopicPattern) - require.Equal(t, true, ben.Grants[0].AllowRead) - require.Equal(t, true, ben.Grants[0].AllowWrite) - require.Equal(t, "secret", ben.Grants[1].TopicPattern) - require.Equal(t, true, ben.Grants[1].AllowRead) - require.Equal(t, false, ben.Grants[1].AllowWrite) + require.Equal(t, 2, len(benGrants)) + require.Equal(t, "stats", benGrants[0].TopicPattern) + require.Equal(t, PermissionReadWrite, benGrants[0].Allow) + require.Equal(t, "secret", benGrants[1].TopicPattern) + require.Equal(t, PermissionRead, benGrants[1].Allow) require.Equal(t, Everyone, everyone.Name) require.Equal(t, RoleAnonymous, everyone.Role) - require.Equal(t, 1, len(everyone.Grants)) - require.Equal(t, "stats", everyone.Grants[0].TopicPattern) - require.Equal(t, true, everyone.Grants[0].AllowRead) - require.Equal(t, false, everyone.Grants[0].AllowWrite) + require.Equal(t, 1, len(everyoneGrants)) + require.Equal(t, "stats", everyoneGrants[0].TopicPattern) + require.Equal(t, PermissionRead, everyoneGrants[0].Allow) } func checkSchemaVersion(t *testing.T, db *sql.DB) { @@ -502,12 +529,12 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) { require.Nil(t, rows.Close()) } -func newTestManager(t *testing.T, defaultRead, defaultWrite bool) *Manager { - return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultRead, defaultWrite, userTokenExpiryDuration, userStatsQueueWriterInterval) +func newTestManager(t *testing.T, defaultAccess Permission) *Manager { + return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) } -func newTestManagerFromFile(t *testing.T, filename string, defaultRead, defaultWrite bool, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager { - a, err := newManager(filename, defaultRead, defaultWrite, tokenExpiryDuration, statsWriterInterval) +func newTestManagerFromFile(t *testing.T, filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager { + a, err := newManager(filename, defaultAccess, tokenExpiryDuration, statsWriterInterval) require.Nil(t, err) return a } diff --git a/user/types.go b/user/types.go index 03e67acd..ac5f8380 100644 --- a/user/types.go +++ b/user/types.go @@ -85,31 +85,79 @@ type Stats struct { Emails int64 } -// Grant is a struct that represents an access control entry to a topic +// Grant is a struct that represents an access control entry to a topic by a user type Grant struct { TopicPattern string // May include wildcard (*) - AllowRead bool - AllowWrite bool + Allow Permission } // Reservation is a struct that represents the ownership over a topic by a user type Reservation struct { - TopicPattern string - AllowRead bool - AllowWrite bool - AllowEveryoneRead bool - AllowEveryoneWrite bool + Topic string + Owner Permission + Everyone Permission } // Permission represents a read or write permission to a topic -type Permission int +type Permission uint8 // Permissions to a topic const ( - PermissionRead = Permission(1) - PermissionWrite = Permission(2) + PermissionDenyAll Permission = iota + PermissionRead + PermissionWrite + PermissionReadWrite // 3! ) +func NewPermission(read, write bool) Permission { + p := uint8(0) + if read { + p |= uint8(PermissionRead) + } + if write { + p |= uint8(PermissionWrite) + } + return Permission(p) +} + +func ParsePermission(s string) (Permission, error) { + switch s { + case "read-write", "rw": + return NewPermission(true, true), nil + case "read-only", "read", "ro": + return NewPermission(true, false), nil + case "write-only", "write", "wo": + return NewPermission(false, true), nil + case "deny-all", "deny", "none": + return NewPermission(false, false), nil + default: + return NewPermission(false, false), errors.New("invalid permission") + } +} + +func (p Permission) IsRead() bool { + return p&PermissionRead != 0 +} + +func (p Permission) IsWrite() bool { + return p&PermissionWrite != 0 +} + +func (p Permission) IsReadWrite() bool { + return p.IsRead() && p.IsWrite() +} + +func (p Permission) String() string { + if p.IsReadWrite() { + return "read-write" + } else if p.IsRead() { + return "read-only" + } else if p.IsWrite() { + return "write-only" + } + return "deny-all" +} + // Role represents a user's role, either admin or regular user type Role string