diff --git a/server/server.go b/server/server.go index 961b6ca4..2536be27 100644 --- a/server/server.go +++ b/server/server.go @@ -56,7 +56,6 @@ import ( - "mute" setting - figure out what settings are "web" or "phone" Tests: - - /access endpoints - visitor with/without user Refactor: - rename TopicsLimit -> ReservationsLimit diff --git a/server/server_account.go b/server/server_account.go index b436bb26..c9499488 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -332,23 +332,29 @@ func (s *Server) handleAccountAccessAdd(w http.ResponseWriter, r *http.Request, if !topicRegex.MatchString(req.Topic) { return errHTTPBadRequestTopicInvalid } - if v.user.Plan == nil { - return errors.New("no plan") // FIXME there should always be a plan! - } - reservations, err := s.userManager.ReservationsCount(v.user.Name) - if err != nil { - return err - } else if reservations >= v.user.Plan.TopicsLimit { - return errHTTPTooManyRequestsLimitReservations // FIXME test this - } - if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil { - return errHTTPConflictTopicReserved - } - owner, username := v.user.Name, v.user.Name everyone, err := user.ParsePermission(req.Everyone) if err != nil { return errHTTPBadRequestPermissionInvalid } + if v.user.Plan == nil { + return errors.New("no plan") // FIXME there should always be a plan! + } + if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil { + return errHTTPConflictTopicReserved + } + hasReservation, err := s.userManager.HasReservation(v.user.Name, req.Topic) + if err != nil { + return err + } + if !hasReservation { + reservations, err := s.userManager.ReservationsCount(v.user.Name) + if err != nil { + return err + } else if reservations >= v.user.Plan.TopicsLimit { + return errHTTPTooManyRequestsLimitReservations + } + } + owner, username := v.user.Name, v.user.Name if err := s.userManager.AllowAccess(owner, username, req.Topic, true, true); err != nil { return err } @@ -369,18 +375,10 @@ func (s *Server) handleAccountAccessDelete(w http.ResponseWriter, r *http.Reques if !topicRegex.MatchString(topic) { return errHTTPBadRequestTopicInvalid } - reservations, err := s.userManager.Reservations(v.user.Name) // FIXME replace with HasReservation + authorized, err := s.userManager.HasReservation(v.user.Name, topic) if err != nil { return err - } - authorized := false - for _, r := range reservations { - if r.Topic == topic { - authorized = true - break - } - } - if !authorized { + } else if !authorized { return errHTTPUnauthorized } if err := s.userManager.ResetAccess(v.user.Name, topic); err != nil { diff --git a/server/server_account_test.go b/server/server_account_test.go index 147a1297..55b051f8 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -1,6 +1,7 @@ package server import ( + "database/sql" "fmt" "github.com/stretchr/testify/require" "heckel.io/ntfy/user" @@ -11,7 +12,7 @@ import ( ) func TestAccount_Signup_Success(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.EnableSignup = true s := newTestServer(t, conf) @@ -36,7 +37,7 @@ func TestAccount_Signup_Success(t *testing.T) { } func TestAccount_Signup_UserExists(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.EnableSignup = true s := newTestServer(t, conf) @@ -49,7 +50,7 @@ func TestAccount_Signup_UserExists(t *testing.T) { } func TestAccount_Signup_LimitReached(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.EnableSignup = true s := newTestServer(t, conf) @@ -63,7 +64,7 @@ func TestAccount_Signup_LimitReached(t *testing.T) { } func TestAccount_Signup_AsUser(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.EnableSignup = true s := newTestServer(t, conf) @@ -82,7 +83,7 @@ func TestAccount_Signup_AsUser(t *testing.T) { } func TestAccount_Signup_Disabled(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.EnableSignup = false s := newTestServer(t, conf) @@ -92,7 +93,7 @@ func TestAccount_Signup_Disabled(t *testing.T) { } func TestAccount_Get_Anonymous(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.VisitorRequestLimitReplenish = 86 * time.Second conf.VisitorEmailLimitReplenish = time.Hour conf.VisitorAttachmentTotalSizeLimit = 5123 @@ -132,7 +133,7 @@ func TestAccount_Get_Anonymous(t *testing.T) { } func TestAccount_ChangeSettings(t *testing.T) { - s := newTestServer(t, newTestConfigWithUsers(t)) + s := newTestServer(t, newTestConfigWithAuthFile(t)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) user, _ := s.userManager.User("phil") token, _ := s.userManager.CreateToken(user) @@ -159,7 +160,7 @@ func TestAccount_ChangeSettings(t *testing.T) { } func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { - s := newTestServer(t, newTestConfigWithUsers(t)) + s := newTestServer(t, newTestConfigWithAuthFile(t)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{ @@ -209,7 +210,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { } func TestAccount_ChangePassword(t *testing.T) { - s := newTestServer(t, newTestConfigWithUsers(t)) + s := newTestServer(t, newTestConfigWithAuthFile(t)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, map[string]string{ @@ -229,14 +230,14 @@ func TestAccount_ChangePassword(t *testing.T) { } func TestAccount_ChangePassword_NoAccount(t *testing.T) { - s := newTestServer(t, newTestConfigWithUsers(t)) + s := newTestServer(t, newTestConfigWithAuthFile(t)) rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, nil) require.Equal(t, 401, rr.Code) } func TestAccount_ExtendToken(t *testing.T) { - s := newTestServer(t, newTestConfigWithUsers(t)) + s := newTestServer(t, newTestConfigWithAuthFile(t)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{ @@ -259,7 +260,7 @@ func TestAccount_ExtendToken(t *testing.T) { } func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) { - s := newTestServer(t, newTestConfigWithUsers(t)) + s := newTestServer(t, newTestConfigWithAuthFile(t)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{ @@ -270,7 +271,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) { } func TestAccount_DeleteToken(t *testing.T) { - s := newTestServer(t, newTestConfigWithUsers(t)) + s := newTestServer(t, newTestConfigWithAuthFile(t)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{ @@ -307,7 +308,7 @@ func TestAccount_DeleteToken(t *testing.T) { } func TestAccount_Delete_Success(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.EnableSignup = true s := newTestServer(t, conf) @@ -331,7 +332,7 @@ func TestAccount_Delete_Success(t *testing.T) { } func TestAccount_Delete_Not_Allowed(t *testing.T) { - conf := newTestConfigWithUsers(t) + conf := newTestConfigWithAuthFile(t) conf.EnableSignup = true s := newTestServer(t, conf) @@ -341,3 +342,139 @@ func TestAccount_Delete_Not_Allowed(t *testing.T) { rr = request(t, s, "DELETE", "/v1/account", "", nil) require.Equal(t, 401, rr.Code) } + +func TestAccount_Reservation_Add_User_No_Plan_Failure(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + conf.EnableSignup = true + s := newTestServer(t, conf) + + rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "POST", "/v1/account/access", `{"everyone":"deny-all"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 401, rr.Code) +} + +func TestAccount_Reservation_Add_Admin_Success(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + conf.EnableSignup = true + s := newTestServer(t, conf) + require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin)) + + rr := request(t, s, "POST", "/v1/account/access", `{"everyone":"deny-all"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "adminpass"), + }) + require.Equal(t, 200, rr.Code) +} + +func TestAccount_Reservation_Add_Remove_User_With_Plan_Success(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + conf.EnableSignup = true + s := newTestServer(t, conf) + + // Create user + rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) + require.Equal(t, 200, rr.Code) + + // Create a plan (hack!) + db, err := sql.Open("sqlite3", conf.AuthFile) + require.Nil(t, err) + + _, err = db.Exec(` + INSERT INTO plan (id, code, messages_limit, emails_limit, attachment_file_size_limit, attachment_total_size_limit, topics_limit) + VALUES (1, 'testplan', 10, 10, 10, 10, 2); + + UPDATE user SET plan_id = 1 WHERE user = 'phil'; + `) + require.Nil(t, err) + + // Reserve two topics + rr = request(t, s, "POST", "/v1/account/access", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "POST", "/v1/account/access", `{"topic": "another", "everyone":"read-only"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + // Trying to reserve a third should fail + rr = request(t, s, "POST", "/v1/account/access", `{"topic": "yet-another", "everyone":"deny-all"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 429, rr.Code) + + // Modify existing should still work + rr = request(t, s, "POST", "/v1/account/access", `{"topic": "another", "everyone":"write-only"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + // Check account result + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, 2, len(account.Reservations)) + require.Equal(t, "another", account.Reservations[0].Topic) + require.Equal(t, "write-only", account.Reservations[0].Everyone) + require.Equal(t, "mytopic", account.Reservations[1].Topic) + require.Equal(t, "deny-all", account.Reservations[1].Everyone) + + // Delete and re-check + rr = request(t, s, "DELETE", "/v1/account/access/another", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + account, _ = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, 1, len(account.Reservations)) + require.Equal(t, "mytopic", account.Reservations[0].Topic) +} + +func TestAccount_Reservation_Add_Access_By_Anonymous_Fails(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + conf.AuthDefault = user.PermissionReadWrite + conf.EnableSignup = true + s := newTestServer(t, conf) + + // Create user + rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) + require.Equal(t, 200, rr.Code) + + // Create a plan (hack!) + db, err := sql.Open("sqlite3", conf.AuthFile) + require.Nil(t, err) + + _, err = db.Exec(` + INSERT INTO plan (id, code, messages_limit, emails_limit, attachment_file_size_limit, attachment_total_size_limit, topics_limit) + VALUES (1, 'testplan', 10, 10, 10, 10, 2); + + UPDATE user SET plan_id = 1 WHERE user = 'phil'; + `) + require.Nil(t, err) + + // Reserve a topic + rr = request(t, s, "POST", "/v1/account/access", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + // Publish a message + rr = request(t, s, "POST", "/mytopic", `Howdy`, map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 200, rr.Code) + + // Publish a message (as anonymous) + rr = request(t, s, "POST", "/mytopic", `Howdy`, nil) + require.Equal(t, 403, rr.Code) +} diff --git a/server/server_test.go b/server/server_test.go index 5ce41260..7ba85571 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1440,7 +1440,7 @@ func newTestConfig(t *testing.T) *Config { return conf } -func newTestConfigWithUsers(t *testing.T) *Config { +func newTestConfigWithAuthFile(t *testing.T) *Config { conf := newTestConfig(t) conf.AuthFile = filepath.Join(t.TempDir(), "user.db") return conf diff --git a/user/manager.go b/user/manager.go index ab9928cb..29280fb3 100644 --- a/user/manager.go +++ b/user/manager.go @@ -145,6 +145,13 @@ const ( FROM user_access WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?) ` + selectUserHasReservationQuery = ` + SELECT COUNT(*) + FROM user_access + WHERE user_id = owner_user_id + AND owner_user_id = (SELECT id FROM user WHERE user = ?) + AND topic = ? + ` selectOtherAccessCountQuery = ` SELECT COUNT(*) FROM user_access @@ -604,6 +611,23 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) { return reservations, nil } +// HasReservation returns true if the given topic access is owned by the user +func (a *Manager) HasReservation(username, topic string) (bool, error) { + rows, err := a.db.Query(selectUserHasReservationQuery, username, topic) + if err != nil { + return false, err + } + defer rows.Close() + if !rows.Next() { + return false, errNoRows + } + var count int64 + if err := rows.Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + // ReservationsCount returns the number of reservations owned by this user func (a *Manager) ReservationsCount(username string) (int64, error) { rows, err := a.db.Query(selectUserReservationsCountQuery, username)