Delete expired attachments based on mod time instead of DB entry to avoid races
This commit is contained in:
		
							parent
							
								
									3e53d8a2c7
								
							
						
					
					
						commit
						10a9aca2a1
					
				
					 6 changed files with 65 additions and 43 deletions
				
			
		|  | @ -40,6 +40,7 @@ Thank you to [@wunter8](https://github.com/wunter8) for proactively picking up s | ||||||
| 
 | 
 | ||||||
| * `ntfy user` commands don't work with `auth_file` but works with `auth-file` ([#344](https://github.com/binwiederhier/ntfy/issues/344), thanks to [@Histalek](https://github.com/Histalek) for reporting) | * `ntfy user` commands don't work with `auth_file` but works with `auth-file` ([#344](https://github.com/binwiederhier/ntfy/issues/344), thanks to [@Histalek](https://github.com/Histalek) for reporting) | ||||||
| * Ignore new draft HTTP `Priority` header  ([#351](https://github.com/binwiederhier/ntfy/issues/351), thanks to [@ksurl](https://github.com/ksurl) for reporting) | * Ignore new draft HTTP `Priority` header  ([#351](https://github.com/binwiederhier/ntfy/issues/351), thanks to [@ksurl](https://github.com/ksurl) for reporting) | ||||||
|  | * Delete expired attachments based on mod time instead of DB entry to avoid races (no ticket)  | ||||||
| 
 | 
 | ||||||
| **Documentation:** | **Documentation:** | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2,16 +2,18 @@ package server | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"heckel.io/ntfy/util" | 	"heckel.io/ntfy/util" | ||||||
| 	"io" | 	"io" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"sync" | 	"sync" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	fileIDRegex      = regexp.MustCompile(`^[-_A-Za-z0-9]+$`) | 	fileIDRegex      = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength)) | ||||||
| 	errInvalidFileID = errors.New("invalid file ID") | 	errInvalidFileID = errors.New("invalid file ID") | ||||||
| 	errFileExists    = errors.New("file exists") | 	errFileExists    = errors.New("file exists") | ||||||
| ) | ) | ||||||
|  | @ -88,6 +90,25 @@ func (c *fileCache) Remove(ids ...string) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Expired returns a list of file IDs for expired files | ||||||
|  | func (c *fileCache) Expired(olderThan time.Time) ([]string, error) { | ||||||
|  | 	entries, err := os.ReadDir(c.dir) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	var ids []string | ||||||
|  | 	for _, e := range entries { | ||||||
|  | 		info, err := e.Info() | ||||||
|  | 		if err != nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if info.ModTime().Before(olderThan) && fileIDRegex.MatchString(e.Name()) { | ||||||
|  | 			ids = append(ids, e.Name()) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return ids, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (c *fileCache) Size() int64 { | func (c *fileCache) Size() int64 { | ||||||
| 	c.mu.Lock() | 	c.mu.Lock() | ||||||
| 	defer c.mu.Unlock() | 	defer c.mu.Unlock() | ||||||
|  |  | ||||||
|  | @ -8,6 +8,7 @@ import ( | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
|  | @ -16,10 +17,10 @@ var ( | ||||||
| 
 | 
 | ||||||
| func TestFileCache_Write_Success(t *testing.T) { | func TestFileCache_Write_Success(t *testing.T) { | ||||||
| 	dir, c := newTestFileCache(t) | 	dir, c := newTestFileCache(t) | ||||||
| 	size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999)) | 	size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) | ||||||
| 	require.Nil(t, err) | 	require.Nil(t, err) | ||||||
| 	require.Equal(t, int64(11), size) | 	require.Equal(t, int64(11), size) | ||||||
| 	require.Equal(t, "normal file", readFile(t, dir+"/abc")) | 	require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) | ||||||
| 	require.Equal(t, int64(11), c.Size()) | 	require.Equal(t, int64(11), c.Size()) | ||||||
| 	require.Equal(t, int64(10229), c.Remaining()) | 	require.Equal(t, int64(10229), c.Remaining()) | ||||||
| } | } | ||||||
|  | @ -27,18 +28,18 @@ func TestFileCache_Write_Success(t *testing.T) { | ||||||
| func TestFileCache_Write_Remove_Success(t *testing.T) { | func TestFileCache_Write_Remove_Success(t *testing.T) { | ||||||
| 	dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024) | 	dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024) | ||||||
| 	for i := 0; i < 10; i++ {     // 10x999 = 9990 | 	for i := 0; i < 10; i++ {     // 10x999 = 9990 | ||||||
| 		size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(make([]byte, 999))) | 		size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) | ||||||
| 		require.Nil(t, err) | 		require.Nil(t, err) | ||||||
| 		require.Equal(t, int64(999), size) | 		require.Equal(t, int64(999), size) | ||||||
| 	} | 	} | ||||||
| 	require.Equal(t, int64(9990), c.Size()) | 	require.Equal(t, int64(9990), c.Size()) | ||||||
| 	require.Equal(t, int64(250), c.Remaining()) | 	require.Equal(t, int64(250), c.Remaining()) | ||||||
| 	require.FileExists(t, dir+"/abc1") | 	require.FileExists(t, dir+"/abcdefghijk1") | ||||||
| 	require.FileExists(t, dir+"/abc5") | 	require.FileExists(t, dir+"/abcdefghijk5") | ||||||
| 
 | 
 | ||||||
| 	require.Nil(t, c.Remove("abc1", "abc5")) | 	require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5")) | ||||||
| 	require.NoFileExists(t, dir+"/abc1") | 	require.NoFileExists(t, dir+"/abcdefghijk1") | ||||||
| 	require.NoFileExists(t, dir+"/abc5") | 	require.NoFileExists(t, dir+"/abcdefghijk5") | ||||||
| 	require.Equal(t, int64(7992), c.Size()) | 	require.Equal(t, int64(7992), c.Size()) | ||||||
| 	require.Equal(t, int64(2248), c.Remaining()) | 	require.Equal(t, int64(2248), c.Remaining()) | ||||||
| } | } | ||||||
|  | @ -46,27 +47,50 @@ func TestFileCache_Write_Remove_Success(t *testing.T) { | ||||||
| func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) { | func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) { | ||||||
| 	dir, c := newTestFileCache(t) | 	dir, c := newTestFileCache(t) | ||||||
| 	for i := 0; i < 10; i++ { | 	for i := 0; i < 10; i++ { | ||||||
| 		size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(oneKilobyteArray)) | 		size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) | ||||||
| 		require.Nil(t, err) | 		require.Nil(t, err) | ||||||
| 		require.Equal(t, int64(1024), size) | 		require.Equal(t, int64(1024), size) | ||||||
| 	} | 	} | ||||||
| 	_, err := c.Write("abc11", bytes.NewReader(oneKilobyteArray)) | 	_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) | ||||||
| 	require.Equal(t, util.ErrLimitReached, err) | 	require.Equal(t, util.ErrLimitReached, err) | ||||||
| 	require.NoFileExists(t, dir+"/abc11") | 	require.NoFileExists(t, dir+"/abcdefghijkX") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) { | func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) { | ||||||
| 	dir, c := newTestFileCache(t) | 	dir, c := newTestFileCache(t) | ||||||
| 	_, err := c.Write("abc", bytes.NewReader(make([]byte, 1025))) | 	_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1025))) | ||||||
| 	require.Equal(t, util.ErrLimitReached, err) | 	require.Equal(t, util.ErrLimitReached, err) | ||||||
| 	require.NoFileExists(t, dir+"/abc") | 	require.NoFileExists(t, dir+"/abcdefghijkl") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) { | func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) { | ||||||
| 	dir, c := newTestFileCache(t) | 	dir, c := newTestFileCache(t) | ||||||
| 	_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) | 	_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) | ||||||
| 	require.Equal(t, util.ErrLimitReached, err) | 	require.Equal(t, util.ErrLimitReached, err) | ||||||
| 	require.NoFileExists(t, dir+"/abc") | 	require.NoFileExists(t, dir+"/abcdefghijkl") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestFileCache_RemoveExpired(t *testing.T) { | ||||||
|  | 	dir, c := newTestFileCache(t) | ||||||
|  | 	_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001))) | ||||||
|  | 	require.Nil(t, err) | ||||||
|  | 	_, err = c.Write("notdeleted12", bytes.NewReader(make([]byte, 1001))) | ||||||
|  | 	require.Nil(t, err) | ||||||
|  | 
 | ||||||
|  | 	modTime := time.Now().Add(-1 * 4 * time.Hour) | ||||||
|  | 	require.Nil(t, os.Chtimes(dir+"/abcdefghijkl", modTime, modTime)) | ||||||
|  | 
 | ||||||
|  | 	olderThan := time.Now().Add(-1 * 3 * time.Hour) | ||||||
|  | 	ids, err := c.Expired(olderThan) | ||||||
|  | 	require.Nil(t, err) | ||||||
|  | 	require.Equal(t, []string{"abcdefghijkl"}, ids) | ||||||
|  | 	require.Nil(t, c.Remove(ids...)) | ||||||
|  | 	require.NoFileExists(t, dir+"/abcdefghijkl") | ||||||
|  | 	require.FileExists(t, dir+"/notdeleted12") | ||||||
|  | 
 | ||||||
|  | 	ids, err = c.Expired(olderThan) | ||||||
|  | 	require.Nil(t, err) | ||||||
|  | 	require.Empty(t, ids) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newTestFileCache(t *testing.T) (dir string, cache *fileCache) { | func newTestFileCache(t *testing.T) (dir string, cache *fileCache) { | ||||||
|  |  | ||||||
|  | @ -85,7 +85,6 @@ const ( | ||||||
| 	selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` | 	selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` | ||||||
| 	selectTopicsQuery               = `SELECT topic FROM messages GROUP BY topic` | 	selectTopicsQuery               = `SELECT topic FROM messages GROUP BY topic` | ||||||
| 	selectAttachmentsSizeQuery      = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` | 	selectAttachmentsSizeQuery      = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` | ||||||
| 	selectAttachmentsExpiredQuery   = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?` |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Schema management queries | // Schema management queries | ||||||
|  | @ -409,26 +408,6 @@ func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) { | ||||||
| 	return size, nil | 	return size, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *messageCache) AttachmentsExpired() ([]string, error) { |  | ||||||
| 	rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	defer rows.Close() |  | ||||||
| 	ids := make([]string, 0) |  | ||||||
| 	for rows.Next() { |  | ||||||
| 		var id string |  | ||||||
| 		if err := rows.Scan(&id); err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		ids = append(ids, id) |  | ||||||
| 	} |  | ||||||
| 	if err := rows.Err(); err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	return ids, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func readMessages(rows *sql.Rows) ([]*message, error) { | func readMessages(rows *sql.Rows) ([]*message, error) { | ||||||
| 	defer rows.Close() | 	defer rows.Close() | ||||||
| 	messages := make([]*message, 0) | 	messages := make([]*message, 0) | ||||||
|  |  | ||||||
|  | @ -344,10 +344,6 @@ func testCacheAttachments(t *testing.T, c *messageCache) { | ||||||
| 	size, err = c.AttachmentBytesUsed("5.6.7.8") | 	size, err = c.AttachmentBytesUsed("5.6.7.8") | ||||||
| 	require.Nil(t, err) | 	require.Nil(t, err) | ||||||
| 	require.Equal(t, int64(0), size) | 	require.Equal(t, int64(0), size) | ||||||
| 
 |  | ||||||
| 	ids, err := c.AttachmentsExpired() |  | ||||||
| 	require.Nil(t, err) |  | ||||||
| 	require.Equal(t, []string{"m1"}, ids) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSqliteCache_Migration_From0(t *testing.T) { | func TestSqliteCache_Migration_From0(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -1116,8 +1116,9 @@ func (s *Server) updateStatsAndPrune() { | ||||||
| 	log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) | 	log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) | ||||||
| 
 | 
 | ||||||
| 	// Delete expired attachments | 	// Delete expired attachments | ||||||
| 	if s.fileCache != nil { | 	if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 { | ||||||
| 		ids, err := s.messageCache.AttachmentsExpired() | 		olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration) | ||||||
|  | 		ids, err := s.fileCache.Expired(olderThan) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Warn("Error retrieving expired attachments: %s", err.Error()) | 			log.Warn("Error retrieving expired attachments: %s", err.Error()) | ||||||
| 		} else if len(ids) > 0 { | 		} else if len(ids) > 0 { | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue