WIP: Auth in 80 lines of code :-)
parent
aab705f4a4
commit
2181227a6e
|
@ -0,0 +1,31 @@
|
|||
package server
|
||||
|
||||
/*
|
||||
sqlite> create table user (id int auto increment, user text, password text not null);
|
||||
sqlite> create table user_topic (user_id int not null, topic text not null, allow_write int, allow_read int);
|
||||
sqlite> create table topic (topic text primary key, allow_anonymous_write int, allow_anonymous_read int);
|
||||
*/
|
||||
|
||||
const (
|
||||
permRead = 1
|
||||
permWrite = 2
|
||||
)
|
||||
|
||||
type auther interface {
|
||||
Authenticate(user, pass string) bool
|
||||
Authorize(user, topic string, perm int) bool
|
||||
}
|
||||
|
||||
type memAuther struct {
|
||||
}
|
||||
|
||||
func (m memAuther) Authenticate(user, pass string) bool {
|
||||
return user == "phil" && pass == "phil"
|
||||
}
|
||||
|
||||
func (m memAuther) Authorize(user, topic string, perm int) bool {
|
||||
if perm == permRead {
|
||||
return true
|
||||
}
|
||||
return user == "phil" && topic == "mytopic"
|
||||
}
|
|
@ -40,6 +40,7 @@ var (
|
|||
errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
|
||||
errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", ""}
|
||||
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
|
||||
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", ""}
|
||||
errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
|
|
|
@ -46,6 +46,7 @@ type Server struct {
|
|||
firebase subscriber
|
||||
mailer mailer
|
||||
messages int64
|
||||
auther auther
|
||||
cache cache
|
||||
fileCache *fileCache
|
||||
closeChan chan bool
|
||||
|
@ -57,6 +58,9 @@ type indexPage struct {
|
|||
CacheDuration time.Duration
|
||||
}
|
||||
|
||||
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
|
||||
type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
|
||||
|
||||
var (
|
||||
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
|
||||
topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
|
||||
|
@ -144,6 +148,7 @@ func New(conf *Config) (*Server, error) {
|
|||
firebase: firebaseSubscriber,
|
||||
mailer: mailer,
|
||||
topics: topics,
|
||||
auther: &memAuther{},
|
||||
visitors: make(map[string]*visitor),
|
||||
}, nil
|
||||
}
|
||||
|
@ -312,6 +317,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||
v := s.visitor(r)
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||
return s.handleHome(w, r)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
|
||||
|
@ -323,23 +329,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
|||
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
||||
return s.handleDocs(w, r)
|
||||
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
|
||||
return s.withRateLimit(w, r, s.handleFile)
|
||||
return s.limitRequests(s.handleFile)(w, r, v)
|
||||
} else if r.Method == http.MethodOptions {
|
||||
return s.handleOptions(w, r)
|
||||
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
|
||||
return s.handleTopic(w, r)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handlePublish)
|
||||
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handlePublish)
|
||||
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeJSON)
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeSSE)
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeRaw)
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeWS)
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
|
||||
}
|
||||
return errHTTPNotFound
|
||||
}
|
||||
|
@ -1094,12 +1100,45 @@ func (s *Server) sendDelayedMessages() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
|
||||
v := s.visitor(r)
|
||||
if err := v.RequestAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsLimitRequests
|
||||
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if err := v.RequestAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsLimitRequests
|
||||
}
|
||||
return next(w, r, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) authWrite(next handleFunc) handleFunc {
|
||||
return s.withAuth(next, permWrite)
|
||||
}
|
||||
|
||||
func (s *Server) authRead(next handleFunc) handleFunc {
|
||||
return s.withAuth(next, permRead)
|
||||
}
|
||||
|
||||
func (s *Server) withAuth(next handleFunc, perm int) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.auther == nil {
|
||||
return next(w, r, v)
|
||||
}
|
||||
t, err := s.topicFromPath(r.URL.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
if !s.auther.Authenticate(user, pass) {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
} else {
|
||||
user = "" // Just in case
|
||||
}
|
||||
if !s.auther.Authorize(user, t.ID, perm) {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
return next(w, r, v)
|
||||
}
|
||||
return handler(w, r, v)
|
||||
}
|
||||
|
||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
|
|
Loading…
Reference in New Issue