Associate file downloads with uploader
parent
40ba143a63
commit
f4c54a1643
|
@ -16,6 +16,7 @@ import (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errUnexpectedMessageType = errors.New("unexpected message type")
|
errUnexpectedMessageType = errors.New("unexpected message type")
|
||||||
|
errMessageNotFound = errors.New("message not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Messages cache
|
// Messages cache
|
||||||
|
@ -60,7 +61,12 @@ const (
|
||||||
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||||
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||||
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
|
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
|
||||||
selectMessagesSinceTimeQuery = `
|
selectMessagesByIDQuery = `
|
||||||
|
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE mid = ?
|
||||||
|
`
|
||||||
|
selectMessagesSinceTimeQuery = `
|
||||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE topic = ? AND time >= ? AND published = 1
|
WHERE topic = ? AND time >= ? AND published = 1
|
||||||
|
@ -448,6 +454,18 @@ func (c *messageCache) MessagesExpired() ([]string, error) {
|
||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *messageCache) Message(id string) (*message, error) {
|
||||||
|
rows, err := c.db.Query(selectMessagesByIDQuery, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, errMessageNotFound
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return readMessage(rows)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *messageCache) MarkPublished(m *message) error {
|
func (c *messageCache) MarkPublished(m *message) error {
|
||||||
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
|
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
|
||||||
return err
|
return err
|
||||||
|
@ -600,75 +618,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
messages := make([]*message, 0)
|
messages := make([]*message, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var timestamp, expires, attachmentSize, attachmentExpires int64
|
m, err := readMessage(rows)
|
||||||
var priority int
|
|
||||||
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
|
|
||||||
err := rows.Scan(
|
|
||||||
&id,
|
|
||||||
×tamp,
|
|
||||||
&expires,
|
|
||||||
&topic,
|
|
||||||
&msg,
|
|
||||||
&title,
|
|
||||||
&priority,
|
|
||||||
&tagsStr,
|
|
||||||
&click,
|
|
||||||
&icon,
|
|
||||||
&actionsStr,
|
|
||||||
&attachmentName,
|
|
||||||
&attachmentType,
|
|
||||||
&attachmentSize,
|
|
||||||
&attachmentExpires,
|
|
||||||
&attachmentURL,
|
|
||||||
&sender,
|
|
||||||
&user,
|
|
||||||
&encoding,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var tags []string
|
messages = append(messages, m)
|
||||||
if tagsStr != "" {
|
|
||||||
tags = strings.Split(tagsStr, ",")
|
|
||||||
}
|
|
||||||
var actions []*action
|
|
||||||
if actionsStr != "" {
|
|
||||||
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
senderIP, err := netip.ParseAddr(sender)
|
|
||||||
if err != nil {
|
|
||||||
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
|
|
||||||
}
|
|
||||||
var att *attachment
|
|
||||||
if attachmentName != "" && attachmentURL != "" {
|
|
||||||
att = &attachment{
|
|
||||||
Name: attachmentName,
|
|
||||||
Type: attachmentType,
|
|
||||||
Size: attachmentSize,
|
|
||||||
Expires: attachmentExpires,
|
|
||||||
URL: attachmentURL,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messages = append(messages, &message{
|
|
||||||
ID: id,
|
|
||||||
Time: timestamp,
|
|
||||||
Expires: expires,
|
|
||||||
Event: messageEvent,
|
|
||||||
Topic: topic,
|
|
||||||
Message: msg,
|
|
||||||
Title: title,
|
|
||||||
Priority: priority,
|
|
||||||
Tags: tags,
|
|
||||||
Click: click,
|
|
||||||
Icon: icon,
|
|
||||||
Actions: actions,
|
|
||||||
Attachment: att,
|
|
||||||
Sender: senderIP, // Must parse assuming database must be correct
|
|
||||||
User: user,
|
|
||||||
Encoding: encoding,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -676,6 +630,78 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readMessage(rows *sql.Rows) (*message, error) {
|
||||||
|
var timestamp, expires, attachmentSize, attachmentExpires int64
|
||||||
|
var priority int
|
||||||
|
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
|
||||||
|
err := rows.Scan(
|
||||||
|
&id,
|
||||||
|
×tamp,
|
||||||
|
&expires,
|
||||||
|
&topic,
|
||||||
|
&msg,
|
||||||
|
&title,
|
||||||
|
&priority,
|
||||||
|
&tagsStr,
|
||||||
|
&click,
|
||||||
|
&icon,
|
||||||
|
&actionsStr,
|
||||||
|
&attachmentName,
|
||||||
|
&attachmentType,
|
||||||
|
&attachmentSize,
|
||||||
|
&attachmentExpires,
|
||||||
|
&attachmentURL,
|
||||||
|
&sender,
|
||||||
|
&user,
|
||||||
|
&encoding,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var tags []string
|
||||||
|
if tagsStr != "" {
|
||||||
|
tags = strings.Split(tagsStr, ",")
|
||||||
|
}
|
||||||
|
var actions []*action
|
||||||
|
if actionsStr != "" {
|
||||||
|
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderIP, err := netip.ParseAddr(sender)
|
||||||
|
if err != nil {
|
||||||
|
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
|
||||||
|
}
|
||||||
|
var att *attachment
|
||||||
|
if attachmentName != "" && attachmentURL != "" {
|
||||||
|
att = &attachment{
|
||||||
|
Name: attachmentName,
|
||||||
|
Type: attachmentType,
|
||||||
|
Size: attachmentSize,
|
||||||
|
Expires: attachmentExpires,
|
||||||
|
URL: attachmentURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &message{
|
||||||
|
ID: id,
|
||||||
|
Time: timestamp,
|
||||||
|
Expires: expires,
|
||||||
|
Event: messageEvent,
|
||||||
|
Topic: topic,
|
||||||
|
Message: msg,
|
||||||
|
Title: title,
|
||||||
|
Priority: priority,
|
||||||
|
Tags: tags,
|
||||||
|
Click: click,
|
||||||
|
Icon: icon,
|
||||||
|
Actions: actions,
|
||||||
|
Attachment: att,
|
||||||
|
Sender: senderIP, // Must parse assuming database must be correct
|
||||||
|
User: user,
|
||||||
|
Encoding: encoding,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *messageCache) Close() error {
|
func (c *messageCache) Close() error {
|
||||||
return c.db.Close()
|
return c.db.Close()
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,8 @@ import (
|
||||||
/*
|
/*
|
||||||
|
|
||||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||||
- HIGH Stripe payment methods
|
- HIGH Docs
|
||||||
|
- Large uploads for higher tiers (nginx config!)
|
||||||
- MEDIUM: Test new token endpoints & never-expiring token
|
- MEDIUM: Test new token endpoints & never-expiring token
|
||||||
- MEDIUM: Make sure account endpoints make sense for admins
|
- MEDIUM: Make sure account endpoints make sense for admins
|
||||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||||
|
@ -498,6 +499,9 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
|
||||||
|
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
|
||||||
|
// can associate the download bandwidth with the uploader.
|
||||||
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if s.config.AttachmentCacheDir == "" {
|
if s.config.AttachmentCacheDir == "" {
|
||||||
return errHTTPInternalError
|
return errHTTPInternalError
|
||||||
|
@ -512,23 +516,42 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
if r.Method == http.MethodGet {
|
|
||||||
if !v.BandwidthAllowed(stat.Size()) {
|
|
||||||
return errHTTPTooManyRequestsLimitAttachmentBandwidth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
if r.Method == http.MethodGet {
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||||
f, err := os.Open(file)
|
if r.Method == http.MethodHead {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Find message in database, and associate bandwidth to the uploader user
|
||||||
|
// This is an easy way to
|
||||||
|
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
|
||||||
|
// - and also uses the higher bandwidth limits of a paying user
|
||||||
|
m, err := s.messageCache.Message(messageID)
|
||||||
|
if err == errMessageNotFound {
|
||||||
|
return errHTTPNotFound
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bandwidthVisitor := v
|
||||||
|
if s.userManager != nil && m.User != "" {
|
||||||
|
u, err := s.userManager.UserByID(m.User)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
bandwidthVisitor = s.visitor(v.IP(), u)
|
||||||
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
|
} else if m.Sender != netip.IPv4Unspecified() {
|
||||||
|
bandwidthVisitor = s.visitor(m.Sender, nil)
|
||||||
|
}
|
||||||
|
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
|
||||||
|
return errHTTPTooManyRequestsLimitAttachmentBandwidth
|
||||||
|
}
|
||||||
|
// Actually send file
|
||||||
|
f, err := os.Open(file)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
defer f.Close()
|
||||||
|
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
||||||
|
|
|
@ -80,6 +80,8 @@
|
||||||
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
|
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
|
||||||
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
|
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
|
||||||
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
|
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
|
||||||
|
# - auth-startup-queries allows you to run commands when the database is initialized, e.g. to enable
|
||||||
|
# WAL mode. This is similar to cache-startup-queries. See above for details.
|
||||||
#
|
#
|
||||||
# Debian/RPM package users:
|
# Debian/RPM package users:
|
||||||
# Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package
|
# Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package
|
||||||
|
@ -91,6 +93,7 @@
|
||||||
#
|
#
|
||||||
# auth-file: <filename>
|
# auth-file: <filename>
|
||||||
# auth-default-access: "read-write"
|
# auth-default-access: "read-write"
|
||||||
|
# auth-startup-queries:
|
||||||
|
|
||||||
# If set, the X-Forwarded-For header is used to determine the visitor IP address
|
# If set, the X-Forwarded-For header is used to determine the visitor IP address
|
||||||
# instead of the remote address of the connection.
|
# instead of the remote address of the connection.
|
||||||
|
|
|
@ -1543,6 +1543,7 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
|
||||||
content := util.RandomString(5000) // > 4096
|
content := util.RandomString(5000) // > 4096
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.VisitorAttachmentDailyBandwidthLimit = 1000 // Much lower than tier bandwidth!
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
// Create tier with certain limits
|
// Create tier with certain limits
|
||||||
|
@ -1566,16 +1567,12 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
|
||||||
msg := toMessage(t, rr.Body.String())
|
msg := toMessage(t, rr.Body.String())
|
||||||
|
|
||||||
// Retrieve it (first time succeeds)
|
// Retrieve it (first time succeeds)
|
||||||
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
|
rr = request(t, s, "GET", "/file/"+msg.ID, content, nil) // File downloads do not send auth headers!!
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
require.Equal(t, content, rr.Body.String())
|
require.Equal(t, content, rr.Body.String())
|
||||||
|
|
||||||
// Retrieve it AGAIN (fails, due to bandwidth limit)
|
// Retrieve it AGAIN (fails, due to bandwidth limit)
|
||||||
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
|
rr = request(t, s, "GET", "/file/"+msg.ID, content, nil)
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 429, rr.Code)
|
require.Equal(t, 429, rr.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue