Associate file downloads with uploader

pull/600/head
binwiederhier 2023-01-29 15:11:26 -05:00
parent 40ba143a63
commit f4c54a1643
4 changed files with 134 additions and 85 deletions

View File

@ -16,6 +16,7 @@ import (
var ( var (
errUnexpectedMessageType = errors.New("unexpected message type") errUnexpectedMessageType = errors.New("unexpected message type")
errMessageNotFound = errors.New("message not found")
) )
// Messages cache // Messages cache
@ -60,6 +61,11 @@ const (
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?` deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?` updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages
WHERE mid = ?
`
selectMessagesSinceTimeQuery = ` selectMessagesSinceTimeQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
@ -448,6 +454,18 @@ func (c *messageCache) MessagesExpired() ([]string, error) {
return ids, nil return ids, nil
} }
func (c *messageCache) Message(id string) (*message, error) {
rows, err := c.db.Query(selectMessagesByIDQuery, id)
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, errMessageNotFound
}
defer rows.Close()
return readMessage(rows)
}
func (c *messageCache) MarkPublished(m *message) error { func (c *messageCache) MarkPublished(m *message) error {
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID) _, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
return err return err
@ -600,6 +618,19 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close() defer rows.Close()
messages := make([]*message, 0) messages := make([]*message, 0)
for rows.Next() { for rows.Next() {
m, err := readMessage(rows)
if err != nil {
return nil, err
}
messages = append(messages, m)
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func readMessage(rows *sql.Rows) (*message, error) {
var timestamp, expires, attachmentSize, attachmentExpires int64 var timestamp, expires, attachmentSize, attachmentExpires int64
var priority int var priority int
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
@ -651,7 +682,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
URL: attachmentURL, URL: attachmentURL,
} }
} }
messages = append(messages, &message{ return &message{
ID: id, ID: id,
Time: timestamp, Time: timestamp,
Expires: expires, Expires: expires,
@ -668,12 +699,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Sender: senderIP, // Must parse assuming database must be correct Sender: senderIP, // Must parse assuming database must be correct
User: user, User: user,
Encoding: encoding, Encoding: encoding,
}) }, nil
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
} }
func (c *messageCache) Close() error { func (c *messageCache) Close() error {

View File

@ -37,7 +37,8 @@ import (
/* /*
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Stripe payment methods - HIGH Docs
- Large uploads for higher tiers (nginx config!)
- MEDIUM: Test new token endpoints & never-expiring token - MEDIUM: Test new token endpoints & never-expiring token
- MEDIUM: Make sure account endpoints make sense for admins - MEDIUM: Make sure account endpoints make sense for admins
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
@ -498,6 +499,9 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor)
return nil return nil
} }
// handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
// can associate the download bandwidth with the uploader.
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.AttachmentCacheDir == "" { if s.config.AttachmentCacheDir == "" {
return errHTTPInternalError return errHTTPInternalError
@ -512,14 +516,35 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
if err != nil { if err != nil {
return errHTTPNotFound return errHTTPNotFound
} }
if r.Method == http.MethodGet { w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if !v.BandwidthAllowed(stat.Size()) { w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
if r.Method == http.MethodHead {
return nil
}
// Find message in database, and associate bandwidth to the uploader user
// This is an easy way to
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
// - and also uses the higher bandwidth limits of a paying user
m, err := s.messageCache.Message(messageID)
if err == errMessageNotFound {
return errHTTPNotFound
} else if err != nil {
return err
}
bandwidthVisitor := v
if s.userManager != nil && m.User != "" {
u, err := s.userManager.UserByID(m.User)
if err != nil {
return err
}
bandwidthVisitor = s.visitor(v.IP(), u)
} else if m.Sender != netip.IPv4Unspecified() {
bandwidthVisitor = s.visitor(m.Sender, nil)
}
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
return errHTTPTooManyRequestsLimitAttachmentBandwidth return errHTTPTooManyRequestsLimitAttachmentBandwidth
} }
} // Actually send file
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if r.Method == http.MethodGet {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
return err return err
@ -527,8 +552,6 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
defer f.Close() defer f.Close()
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f) _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
return err return err
}
return nil
} }
func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {

View File

@ -80,6 +80,8 @@
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist # - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be # - auth-default-access defines the default/fallback access if no access control entry is found; it can be
# set to "read-write" (default), "read-only", "write-only" or "deny-all". # set to "read-write" (default), "read-only", "write-only" or "deny-all".
# - auth-startup-queries allows you to run commands when the database is initialized, e.g. to enable
# WAL mode. This is similar to cache-startup-queries. See above for details.
# #
# Debian/RPM package users: # Debian/RPM package users:
# Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package # Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package
@ -91,6 +93,7 @@
# #
# auth-file: <filename> # auth-file: <filename>
# auth-default-access: "read-write" # auth-default-access: "read-write"
# auth-startup-queries:
# If set, the X-Forwarded-For header is used to determine the visitor IP address # If set, the X-Forwarded-For header is used to determine the visitor IP address
# instead of the remote address of the connection. # instead of the remote address of the connection.

View File

@ -1543,6 +1543,7 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
content := util.RandomString(5000) // > 4096 content := util.RandomString(5000) // > 4096
c := newTestConfigWithAuthFile(t) c := newTestConfigWithAuthFile(t)
c.VisitorAttachmentDailyBandwidthLimit = 1000 // Much lower than tier bandwidth!
s := newTestServer(t, c) s := newTestServer(t, c)
// Create tier with certain limits // Create tier with certain limits
@ -1566,16 +1567,12 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
msg := toMessage(t, rr.Body.String()) msg := toMessage(t, rr.Body.String())
// Retrieve it (first time succeeds) // Retrieve it (first time succeeds)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{ rr = request(t, s, "GET", "/file/"+msg.ID, content, nil) // File downloads do not send auth headers!!
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
require.Equal(t, content, rr.Body.String()) require.Equal(t, content, rr.Body.String())
// Retrieve it AGAIN (fails, due to bandwidth limit) // Retrieve it AGAIN (fails, due to bandwidth limit)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{ rr = request(t, s, "GET", "/file/"+msg.ID, content, nil)
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code) require.Equal(t, 429, rr.Code)
} }