Add /auth endpoint and tests
parent
89957e7058
commit
e61a0c2f78
|
@ -69,6 +69,7 @@ var (
|
||||||
ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
|
ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
|
||||||
rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
|
rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
|
||||||
wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
|
wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
|
||||||
|
authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
|
||||||
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
|
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
|
||||||
|
|
||||||
staticRegex = regexp.MustCompile(`^/static/.+`)
|
staticRegex = regexp.MustCompile(`^/static/.+`)
|
||||||
|
@ -331,7 +332,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||||
} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
|
} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
|
||||||
return s.handleExample(w, r)
|
return s.handleExample(w, r)
|
||||||
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
|
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
|
||||||
return s.handleEmpty(w, r)
|
return s.handleEmpty(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
||||||
return s.handleStatic(w, r)
|
return s.handleStatic(w, r)
|
||||||
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
||||||
|
@ -354,6 +355,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||||
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
|
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
||||||
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
|
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
|
||||||
|
} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
|
||||||
|
return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v)
|
||||||
}
|
}
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
|
@ -376,10 +379,17 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error {
|
||||||
return s.handleHome(w, r)
|
return s.handleHome(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request) error {
|
func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||||
|
_, err := io.WriteString(w, `{"success":true}`+"\n")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handleExample(w http.ResponseWriter, _ *http.Request) error {
|
func (s *Server) handleExample(w http.ResponseWriter, _ *http.Request) error {
|
||||||
_, err := io.WriteString(w, exampleSource)
|
_, err := io.WriteString(w, exampleSource)
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/auth"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -524,6 +525,104 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) {
|
||||||
require.Equal(t, keepaliveEvent, messages[2].Event)
|
require.Equal(t, keepaliveEvent, messages[2].Event)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_Auth_Success_Admin(t *testing.T) {
|
||||||
|
c := newTestConfig(t)
|
||||||
|
c.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
manager := s.auth.(auth.Manager)
|
||||||
|
require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin))
|
||||||
|
|
||||||
|
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||||
|
"Authorization": basicAuth("phil:phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Auth_Success_User(t *testing.T) {
|
||||||
|
c := newTestConfig(t)
|
||||||
|
c.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||||
|
c.AuthDefaultRead = false
|
||||||
|
c.AuthDefaultWrite = false
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
manager := s.auth.(auth.Manager)
|
||||||
|
require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser))
|
||||||
|
require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true)) // Not mytopic!
|
||||||
|
|
||||||
|
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||||
|
"Authorization": basicAuth("ben:ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
|
||||||
|
c := newTestConfig(t)
|
||||||
|
c.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||||
|
c.AuthDefaultRead = false
|
||||||
|
c.AuthDefaultWrite = false
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
manager := s.auth.(auth.Manager)
|
||||||
|
require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin))
|
||||||
|
|
||||||
|
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||||
|
"Authorization": basicAuth("phil:INVALID"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 401, response.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
|
||||||
|
c := newTestConfig(t)
|
||||||
|
c.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||||
|
c.AuthDefaultRead = false
|
||||||
|
c.AuthDefaultWrite = false
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
manager := s.auth.(auth.Manager)
|
||||||
|
require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser))
|
||||||
|
require.Nil(t, manager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic!
|
||||||
|
|
||||||
|
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||||
|
"Authorization": basicAuth("ben:ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 403, response.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
|
||||||
|
c := newTestConfig(t)
|
||||||
|
c.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||||
|
c.AuthDefaultRead = true // Open by default
|
||||||
|
c.AuthDefaultWrite = true // Open by default
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
manager := s.auth.(auth.Manager)
|
||||||
|
require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin))
|
||||||
|
require.Nil(t, manager.AllowAccess(auth.Everyone, "private", false, false))
|
||||||
|
require.Nil(t, manager.AllowAccess(auth.Everyone, "announcements", true, false))
|
||||||
|
|
||||||
|
response := request(t, s, "PUT", "/mytopic", "test", nil)
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
|
||||||
|
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
|
||||||
|
response = request(t, s, "PUT", "/announcements", "test", nil)
|
||||||
|
require.Equal(t, 403, response.Code) // Cannot write as anonymous
|
||||||
|
|
||||||
|
response = request(t, s, "PUT", "/announcements", "test", map[string]string{
|
||||||
|
"Authorization": basicAuth("phil:phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
|
||||||
|
response = request(t, s, "GET", "/announcements/json?poll=1", "", nil)
|
||||||
|
require.Equal(t, 200, response.Code) // Anonymous read allowed
|
||||||
|
|
||||||
|
response = request(t, s, "GET", "/private/json?poll=1", "", nil)
|
||||||
|
require.Equal(t, 403, response.Code) // Anonymous read not allowed
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
func TestServer_Curl_Publish_Poll(t *testing.T) {
|
func TestServer_Curl_Publish_Poll(t *testing.T) {
|
||||||
s, port := test.StartServer(t)
|
s, port := test.StartServer(t)
|
||||||
|
@ -988,3 +1087,7 @@ func firebaseServiceAccountFile(t *testing.T) string {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func basicAuth(s string) string {
|
||||||
|
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue