Token stuff

pull/584/head
Philipp Heckel 2022-12-03 15:20:59 -05:00
parent d3dfeeccc3
commit d499d20a9c
8 changed files with 194 additions and 64 deletions

View File

@ -6,13 +6,17 @@ import (
"regexp" "regexp"
) )
// Auther is a generic interface to implement password-based authentication and authorization // Auther is a generic interface to implement password and token based authentication and authorization
type Auther interface { type Auther interface {
// Authenticate checks username and password and returns a user if correct. The method // Authenticate checks username and password and returns a user if correct. The method
// returns in constant-ish time, regardless of whether the user exists or the password is // returns in constant-ish time, regardless of whether the user exists or the password is
// correct or incorrect. // correct or incorrect.
Authenticate(username, password string) (*User, error) Authenticate(username, password string) (*User, error)
AuthenticateToken(token string) (*User, error)
GenerateToken(user *User) (string, error)
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user. // permission. The user param may be nil to signal an anonymous user.
Authorize(user *User, topic string, perm Permission) error Authorize(user *User, topic string, perm Permission) error
@ -56,10 +60,11 @@ type Manager interface {
// User is a struct that represents a user // User is a struct that represents a user
type User struct { type User struct {
Name string Name string
Hash string // password hash (bcrypt) Hash string // password hash (bcrypt)
Role Role Role Role
Grants []Grant Grants []Grant
Language string
} }
// Grant is a struct that represents an access control entry to a topic // Grant is a struct that represents an access control entry to a topic

View File

@ -6,10 +6,12 @@ import (
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/util"
"strings" "strings"
) )
const ( const (
tokenLength = 32
bcryptCost = 10 bcryptCost = 10
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
) )
@ -67,7 +69,17 @@ const (
INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING; INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING;
COMMIT; COMMIT;
` `
selectUserQuery = `SELECT pass, role FROM user WHERE user = ?` selectUserByNameQuery = `
SELECT user, pass, role, language
FROM user
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT user, pass, role, language
FROM user
JOIN user_token on user.id = user_token.user_id
WHERE token = ?
`
selectTopicPermsQuery = ` selectTopicPermsQuery = `
SELECT read, write SELECT read, write
FROM user_access FROM user_access
@ -90,6 +102,8 @@ const (
deleteAllAccessQuery = `DELETE FROM user_access` deleteAllAccessQuery = `DELETE FROM user_access`
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?` deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
) )
// Schema management queries // Schema management queries
@ -126,7 +140,7 @@ func NewSQLiteAuth(filename string, defaultRead, defaultWrite bool) (*SQLiteAuth
}, nil }, nil
} }
// Authenticate checks username and password and returns a user if correct. The method // AuthenticateUser checks username and password and returns a user if correct. The method
// returns in constant-ish time, regardless of whether the user exists or the password is // returns in constant-ish time, regardless of whether the user exists or the password is
// correct or incorrect. // correct or incorrect.
func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) { func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
@ -145,6 +159,23 @@ func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
return user, nil return user, nil
} }
func (a *SQLiteAuth) AuthenticateToken(token string) (*User, error) {
user, err := a.userByToken(token)
if err != nil {
return nil, ErrUnauthenticated
}
return user, nil
}
func (a *SQLiteAuth) GenerateToken(user *User) (string, error) {
token := util.RandomString(tokenLength)
expires := 1 // FIXME
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires); err != nil {
return "", err
}
return token, nil
}
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user. // permission. The user param may be nil to signal an anonymous user.
func (a *SQLiteAuth) Authorize(user *User, topic string, perm Permission) error { func (a *SQLiteAuth) Authorize(user *User, topic string, perm Permission) error {
@ -255,16 +286,29 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
if username == Everyone { if username == Everyone {
return a.everyoneUser() return a.everyoneUser()
} }
rows, err := a.db.Query(selectUserQuery, username) rows, err := a.db.Query(selectUserByNameQuery, username)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return a.readUser(rows)
}
func (a *SQLiteAuth) userByToken(token string) (*User, error) {
rows, err := a.db.Query(selectUserByTokenQuery, token)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
func (a *SQLiteAuth) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var hash, role string var username, hash, role string
var language sql.NullString
if !rows.Next() { if !rows.Next() {
return nil, ErrNotFound return nil, ErrNotFound
} }
if err := rows.Scan(&hash, &role); err != nil { if err := rows.Scan(&username, &hash, &role, &language); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -274,10 +318,11 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
return nil, err return nil, err
} }
return &User{ return &User{
Name: username, Name: username,
Hash: hash, Hash: hash,
Role: Role(role), Role: Role(role),
Grants: grants, Grants: grants,
Language: language.String,
}, nil }, nil
} }

View File

@ -7,6 +7,7 @@ import (
"embed" "embed"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -320,23 +321,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
} else if r.Method == http.MethodOptions { } else if r.Method == http.MethodOptions {
return s.ensureWebEnabled(s.handleOptions)(w, r, v) return s.ensureWebEnabled(s.handleOptions)(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" { } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
return s.limitRequests(s.transformBodyJSON(s.authWrite(s.handlePublish)))(w, r, v) return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath { } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublishMatrix)))(w, r, v) return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v) return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v) return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v) return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v) return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v) return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v) return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v) return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
} else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) { } else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
return s.ensureWebEnabled(s.handleTopic)(w, r, v) return s.ensureWebEnabled(s.handleTopic)(w, r, v)
} }
@ -403,8 +404,6 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
return nil return nil
} }
var sessions = make(map[string]*auth.User) // token-> user
type tokenAuthResponse struct { type tokenAuthResponse struct {
Token string `json:"token"` Token string `json:"token"`
} }
@ -414,8 +413,10 @@ func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visit
if v.user == nil { if v.user == nil {
return errHTTPUnauthorized return errHTTPUnauthorized
} }
token := util.RandomString(32) token, err := s.auth.GenerateToken(v.user)
sessions[token] = v.user if err != nil {
return err
}
w.Header().Set("Content-Type", "text/json") w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
response := &tokenAuthResponse{ response := &tokenAuthResponse{
@ -432,35 +433,41 @@ type userSubscriptionResponse struct {
Topic string `json:"topic"` Topic string `json:"topic"`
} }
type userNotificationSettingsResponse struct {
Sound string `json:"sound"`
MinPriority string `json:"min_priority"`
DeleteAfter int `json:"delete_after"`
}
type userPlanResponse struct {
Id int `json:"id"`
Name string `json:"name"`
}
type userAccountResponse struct { type userAccountResponse struct {
Username string `json:"username"` Username string `json:"username"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"` Language string `json:"language,omitempty"`
Plan struct { Plan *userPlanResponse `json:"plan,omitempty"`
Id int `json:"id"` Notification *userNotificationSettingsResponse `json:"notification,omitempty"`
Name string `json:"name"` Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
} `json:"plan,omitempty"`
Notification struct {
Sound string `json:"sound"`
MinPriority string `json:"min_priority"`
DeleteAfter int `json:"delete_after"`
} `json:"notification,omitempty"`
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
} }
func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
w.Header().Set("Content-Type", "text/json") w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
var response *userAccountResponse response := &userAccountResponse{}
if v.user != nil { if v.user != nil {
response = &userAccountResponse{ response.Username = v.user.Name
Username: v.user.Name, response.Role = string(v.user.Role)
Role: string(v.user.Role), response.Language = v.user.Language
Language: "en_US", response.Notification = &userNotificationSettingsResponse{
Sound: "dadum",
} }
} else { } else {
response = &userAccountResponse{ response = &userAccountResponse{
Username: "anonymous", Username: auth.Everyone,
Role: string(auth.RoleAnonymous),
} }
} }
if err := json.NewEncoder(w).Encode(response); err != nil { if err := json.NewEncoder(w).Encode(response); err != nil {
@ -1453,15 +1460,15 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
} }
} }
func (s *Server) authWrite(next handleFunc) handleFunc { func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
return s.withAuth(next, auth.PermissionWrite) return s.autorizeTopic(next, auth.PermissionWrite)
} }
func (s *Server) authRead(next handleFunc) handleFunc { func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
return s.withAuth(next, auth.PermissionRead) return s.autorizeTopic(next, auth.PermissionRead)
} }
func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc { func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error { return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.auth == nil { if s.auth == nil {
return next(w, r, v) return next(w, r, v)
@ -1508,20 +1515,51 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
visitorID := fmt.Sprintf("ip:%s", ip.String()) visitorID := fmt.Sprintf("ip:%s", ip.String())
var user *auth.User // may stay nil if no auth header! var user *auth.User // may stay nil if no auth header!
username, password, ok := extractUserPass(r) if user, err = s.authenticate(r); err != nil {
if ok { log.Debug("authentication failed: %s", err.Error())
if user, err = s.auth.Authenticate(username, password); err != nil { err = errHTTPUnauthorized // Always return visitor, even when error occurs!
log.Debug("authentication failed: %s", err.Error()) }
err = errHTTPUnauthorized // Always return visitor, even when error occurs! if user != nil {
} else { visitorID = fmt.Sprintf("user:%s", user.Name)
visitorID = fmt.Sprintf("user:%s", user.Name)
}
} }
v = s.visitorFromID(visitorID, ip, user) v = s.visitorFromID(visitorID, ip, user)
v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
return v, err // Always return visitor, even when error occurs! return v, err // Always return visitor, even when error occurs!
} }
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
value := r.Header.Get("Authorization")
queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" {
a, err := base64.RawURLEncoding.DecodeString(queryParam)
if err != nil {
return nil, err
}
value = string(a)
}
if value == "" {
return nil, nil
}
if strings.HasPrefix(value, "Bearer") {
return s.authenticateBearerAuth(value)
}
return s.authenticateBasicAuth(r, value)
}
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) {
r.Header.Set("Authorization", value)
username, password, ok := r.BasicAuth()
if !ok {
return nil, errors.New("invalid basic auth")
}
return s.auth.Authenticate(username, password)
}
func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) {
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
return s.auth.AuthenticateToken(token)
}
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor { func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()

View File

@ -18,7 +18,7 @@ type testAuther struct {
Allow bool Allow bool
} }
func (t testAuther) Authenticate(_, _ string) (*auth.User, error) { func (t testAuther) AuthenticateUser(_, _ string) (*auth.User, error) {
return nil, errors.New("not used") return nil, errors.New("not used")
} }

View File

@ -1,11 +1,13 @@
import { import {
fetchLinesIterator, fetchLinesIterator,
maybeWithBasicAuth, maybeWithBasicAuth, maybeWithBearerAuth,
topicShortUrl, topicShortUrl,
topicUrl, topicUrl,
topicUrlAuth, topicUrlAuth,
topicUrlJsonPoll, topicUrlJsonPoll,
topicUrlJsonPollWithSince, userAuthUrl, topicUrlJsonPollWithSince,
userAccountUrl,
userAuthUrl,
userStatsUrl userStatsUrl
} from "./utils"; } from "./utils";
import userManager from "./UserManager"; import userManager from "./UserManager";
@ -144,6 +146,20 @@ class Api {
console.log(`[Api] Stats`, stats); console.log(`[Api] Stats`, stats);
return stats; return stats;
} }
async userAccount(baseUrl, token) {
const url = userAccountUrl(baseUrl);
console.log(`[Api] Fetching user account ${url}`);
const response = await fetch(url, {
headers: maybeWithBearerAuth({}, token)
});
if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
const account = await response.json();
console.log(`[Api] Account`, account);
return account;
}
} }
const api = new Api(); const api = new Api();

View File

@ -20,6 +20,7 @@ export const topicUrlAuth = (baseUrl, topic) => `${topicUrl(baseUrl, topic)}/aut
export const topicShortUrl = (baseUrl, topic) => shortUrl(topicUrl(baseUrl, topic)); export const topicShortUrl = (baseUrl, topic) => shortUrl(topicUrl(baseUrl, topic));
export const userStatsUrl = (baseUrl) => `${baseUrl}/user/stats`; export const userStatsUrl = (baseUrl) => `${baseUrl}/user/stats`;
export const userAuthUrl = (baseUrl) => `${baseUrl}/user/auth`; export const userAuthUrl = (baseUrl) => `${baseUrl}/user/auth`;
export const userAccountUrl = (baseUrl) => `${baseUrl}/user/account`;
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
export const expandSecureUrl = (url) => `https://${url}`; export const expandSecureUrl = (url) => `https://${url}`;
@ -95,7 +96,6 @@ export const unmatchedTags = (tags) => {
else return tags.filter(tag => !(tag in emojis)); else return tags.filter(tag => !(tag in emojis));
} }
export const maybeWithBasicAuth = (headers, user) => { export const maybeWithBasicAuth = (headers, user) => {
if (user) { if (user) {
headers['Authorization'] = `Basic ${encodeBase64(`${user.username}:${user.password}`)}`; headers['Authorization'] = `Basic ${encodeBase64(`${user.username}:${user.password}`)}`;
@ -103,6 +103,13 @@ export const maybeWithBasicAuth = (headers, user) => {
return headers; return headers;
} }
export const maybeWithBearerAuth = (headers, token) => {
if (token) {
headers['Authorization'] = `Bearer ${token}`;
}
return headers;
}
export const basicAuth = (username, password) => { export const basicAuth = (username, password) => {
return `Basic ${encodeBase64(`${username}:${password}`)}`; return `Basic ${encodeBase64(`${username}:${password}`)}`;
} }

View File

@ -25,6 +25,10 @@ import "./i18n"; // Translations!
import {Backdrop, CircularProgress} from "@mui/material"; import {Backdrop, CircularProgress} from "@mui/material";
import Home from "./Home"; import Home from "./Home";
import Login from "./Login"; import Login from "./Login";
import i18n from "i18next";
import api from "../app/Api";
import prefs from "../app/Prefs";
import session from "../app/Session";
// TODO races when two tabs are open // TODO races when two tabs are open
// TODO investigate service workers // TODO investigate service workers
@ -81,6 +85,21 @@ const Layout = () => {
useBackgroundProcesses(); useBackgroundProcesses();
useEffect(() => updateTitle(newNotificationsCount), [newNotificationsCount]); useEffect(() => updateTitle(newNotificationsCount), [newNotificationsCount]);
useEffect(() => {
(async () => {
const account = await api.userAccount("http://localhost:2586", session.token());
if (account) {
if (account.language) {
await i18n.changeLanguage(account.language);
}
if (account.notification) {
if (account.notification.sound) {
await prefs.setSound(account.notification.sound);
}
}
}
})();
});
return ( return (
<Box sx={{display: 'flex'}}> <Box sx={{display: 'flex'}}>
<CssBaseline/> <CssBaseline/>

View File

@ -32,7 +32,7 @@ const Login = () => {
email: data.get('email'), email: data.get('email'),
password: data.get('password'), password: data.get('password'),
}); });
const user ={ const user = {
username: data.get('email'), username: data.get('email'),
password: data.get('password'), password: data.get('password'),
} }