Associate file downloads with uploader

pull/600/head
binwiederhier 2023-01-29 15:11:26 -05:00
parent 40ba143a63
commit f4c54a1643
4 changed files with 134 additions and 85 deletions

View File

@ -16,6 +16,7 @@ import (
var (
errUnexpectedMessageType = errors.New("unexpected message type")
errMessageNotFound = errors.New("message not found")
)
// Messages cache
@ -60,6 +61,11 @@ const (
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages
WHERE mid = ?
`
selectMessagesSinceTimeQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages
@ -448,6 +454,18 @@ func (c *messageCache) MessagesExpired() ([]string, error) {
return ids, nil
}
func (c *messageCache) Message(id string) (*message, error) {
rows, err := c.db.Query(selectMessagesByIDQuery, id)
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, errMessageNotFound
}
defer rows.Close()
return readMessage(rows)
}
func (c *messageCache) MarkPublished(m *message) error {
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
return err
@ -600,6 +618,19 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close()
messages := make([]*message, 0)
for rows.Next() {
m, err := readMessage(rows)
if err != nil {
return nil, err
}
messages = append(messages, m)
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func readMessage(rows *sql.Rows) (*message, error) {
var timestamp, expires, attachmentSize, attachmentExpires int64
var priority int
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
@ -651,7 +682,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
URL: attachmentURL,
}
}
messages = append(messages, &message{
return &message{
ID: id,
Time: timestamp,
Expires: expires,
@ -668,12 +699,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Sender: senderIP, // Must parse assuming database must be correct
User: user,
Encoding: encoding,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}, nil
}
func (c *messageCache) Close() error {

View File

@ -37,7 +37,8 @@ import (
/*
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Stripe payment methods
- HIGH Docs
- Large uploads for higher tiers (nginx config!)
- MEDIUM: Test new token endpoints & never-expiring token
- MEDIUM: Make sure account endpoints make sense for admins
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
@ -498,6 +499,9 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor)
return nil
}
// handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
// can associate the download bandwidth with the uploader.
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.AttachmentCacheDir == "" {
return errHTTPInternalError
@ -512,14 +516,35 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
if err != nil {
return errHTTPNotFound
}
if r.Method == http.MethodGet {
if !v.BandwidthAllowed(stat.Size()) {
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
if r.Method == http.MethodHead {
return nil
}
// Find message in database, and associate bandwidth to the uploader user
// This is an easy way to
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
// - and also uses the higher bandwidth limits of a paying user
m, err := s.messageCache.Message(messageID)
if err == errMessageNotFound {
return errHTTPNotFound
} else if err != nil {
return err
}
bandwidthVisitor := v
if s.userManager != nil && m.User != "" {
u, err := s.userManager.UserByID(m.User)
if err != nil {
return err
}
bandwidthVisitor = s.visitor(v.IP(), u)
} else if m.Sender != netip.IPv4Unspecified() {
bandwidthVisitor = s.visitor(m.Sender, nil)
}
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
return errHTTPTooManyRequestsLimitAttachmentBandwidth
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if r.Method == http.MethodGet {
// Actually send file
f, err := os.Open(file)
if err != nil {
return err
@ -528,8 +553,6 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
return err
}
return nil
}
func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
if s.config.BaseURL == "" {

View File

@ -80,6 +80,8 @@
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
# - auth-startup-queries allows you to run commands when the database is initialized, e.g. to enable
# WAL mode. This is similar to cache-startup-queries. See above for details.
#
# Debian/RPM package users:
# Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package
@ -91,6 +93,7 @@
#
# auth-file: <filename>
# auth-default-access: "read-write"
# auth-startup-queries:
# If set, the X-Forwarded-For header is used to determine the visitor IP address
# instead of the remote address of the connection.

View File

@ -1543,6 +1543,7 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
content := util.RandomString(5000) // > 4096
c := newTestConfigWithAuthFile(t)
c.VisitorAttachmentDailyBandwidthLimit = 1000 // Much lower than tier bandwidth!
s := newTestServer(t, c)
// Create tier with certain limits
@ -1566,16 +1567,12 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
msg := toMessage(t, rr.Body.String())
// Retrieve it (first time succeeds)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
rr = request(t, s, "GET", "/file/"+msg.ID, content, nil) // File downloads do not send auth headers!!
require.Equal(t, 200, rr.Code)
require.Equal(t, content, rr.Body.String())
// Retrieve it AGAIN (fails, due to bandwidth limit)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
rr = request(t, s, "GET", "/file/"+msg.ID, content, nil)
require.Equal(t, 429, rr.Code)
}