Delete expired attachments based on mod time instead of DB entry to avoid races
This commit is contained in:
		
							parent
							
								
									3e53d8a2c7
								
							
						
					
					
						commit
						10a9aca2a1
					
				
					 6 changed files with 65 additions and 43 deletions
				
			
		| 
						 | 
					@ -40,6 +40,7 @@ Thank you to [@wunter8](https://github.com/wunter8) for proactively picking up s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
* `ntfy user` commands don't work with `auth_file` but works with `auth-file` ([#344](https://github.com/binwiederhier/ntfy/issues/344), thanks to [@Histalek](https://github.com/Histalek) for reporting)
 | 
					* `ntfy user` commands don't work with `auth_file` but works with `auth-file` ([#344](https://github.com/binwiederhier/ntfy/issues/344), thanks to [@Histalek](https://github.com/Histalek) for reporting)
 | 
				
			||||||
* Ignore new draft HTTP `Priority` header  ([#351](https://github.com/binwiederhier/ntfy/issues/351), thanks to [@ksurl](https://github.com/ksurl) for reporting)
 | 
					* Ignore new draft HTTP `Priority` header  ([#351](https://github.com/binwiederhier/ntfy/issues/351), thanks to [@ksurl](https://github.com/ksurl) for reporting)
 | 
				
			||||||
 | 
					* Delete expired attachments based on mod time instead of DB entry to avoid races (no ticket) 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Documentation:**
 | 
					**Documentation:**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,16 +2,18 @@ package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"heckel.io/ntfy/util"
 | 
						"heckel.io/ntfy/util"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"path/filepath"
 | 
						"path/filepath"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	fileIDRegex      = regexp.MustCompile(`^[-_A-Za-z0-9]+$`)
 | 
						fileIDRegex      = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
 | 
				
			||||||
	errInvalidFileID = errors.New("invalid file ID")
 | 
						errInvalidFileID = errors.New("invalid file ID")
 | 
				
			||||||
	errFileExists    = errors.New("file exists")
 | 
						errFileExists    = errors.New("file exists")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -88,6 +90,25 @@ func (c *fileCache) Remove(ids ...string) error {
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Expired returns a list of file IDs for expired files
 | 
				
			||||||
 | 
					func (c *fileCache) Expired(olderThan time.Time) ([]string, error) {
 | 
				
			||||||
 | 
						entries, err := os.ReadDir(c.dir)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var ids []string
 | 
				
			||||||
 | 
						for _, e := range entries {
 | 
				
			||||||
 | 
							info, err := e.Info()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if info.ModTime().Before(olderThan) && fileIDRegex.MatchString(e.Name()) {
 | 
				
			||||||
 | 
								ids = append(ids, e.Name())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ids, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *fileCache) Size() int64 {
 | 
					func (c *fileCache) Size() int64 {
 | 
				
			||||||
	c.mu.Lock()
 | 
						c.mu.Lock()
 | 
				
			||||||
	defer c.mu.Unlock()
 | 
						defer c.mu.Unlock()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,7 @@ import (
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
| 
						 | 
					@ -16,10 +17,10 @@ var (
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestFileCache_Write_Success(t *testing.T) {
 | 
					func TestFileCache_Write_Success(t *testing.T) {
 | 
				
			||||||
	dir, c := newTestFileCache(t)
 | 
						dir, c := newTestFileCache(t)
 | 
				
			||||||
	size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999))
 | 
						size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
 | 
				
			||||||
	require.Nil(t, err)
 | 
						require.Nil(t, err)
 | 
				
			||||||
	require.Equal(t, int64(11), size)
 | 
						require.Equal(t, int64(11), size)
 | 
				
			||||||
	require.Equal(t, "normal file", readFile(t, dir+"/abc"))
 | 
						require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
 | 
				
			||||||
	require.Equal(t, int64(11), c.Size())
 | 
						require.Equal(t, int64(11), c.Size())
 | 
				
			||||||
	require.Equal(t, int64(10229), c.Remaining())
 | 
						require.Equal(t, int64(10229), c.Remaining())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -27,18 +28,18 @@ func TestFileCache_Write_Success(t *testing.T) {
 | 
				
			||||||
func TestFileCache_Write_Remove_Success(t *testing.T) {
 | 
					func TestFileCache_Write_Remove_Success(t *testing.T) {
 | 
				
			||||||
	dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
 | 
						dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
 | 
				
			||||||
	for i := 0; i < 10; i++ {     // 10x999 = 9990
 | 
						for i := 0; i < 10; i++ {     // 10x999 = 9990
 | 
				
			||||||
		size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(make([]byte, 999)))
 | 
							size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
 | 
				
			||||||
		require.Nil(t, err)
 | 
							require.Nil(t, err)
 | 
				
			||||||
		require.Equal(t, int64(999), size)
 | 
							require.Equal(t, int64(999), size)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	require.Equal(t, int64(9990), c.Size())
 | 
						require.Equal(t, int64(9990), c.Size())
 | 
				
			||||||
	require.Equal(t, int64(250), c.Remaining())
 | 
						require.Equal(t, int64(250), c.Remaining())
 | 
				
			||||||
	require.FileExists(t, dir+"/abc1")
 | 
						require.FileExists(t, dir+"/abcdefghijk1")
 | 
				
			||||||
	require.FileExists(t, dir+"/abc5")
 | 
						require.FileExists(t, dir+"/abcdefghijk5")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	require.Nil(t, c.Remove("abc1", "abc5"))
 | 
						require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
 | 
				
			||||||
	require.NoFileExists(t, dir+"/abc1")
 | 
						require.NoFileExists(t, dir+"/abcdefghijk1")
 | 
				
			||||||
	require.NoFileExists(t, dir+"/abc5")
 | 
						require.NoFileExists(t, dir+"/abcdefghijk5")
 | 
				
			||||||
	require.Equal(t, int64(7992), c.Size())
 | 
						require.Equal(t, int64(7992), c.Size())
 | 
				
			||||||
	require.Equal(t, int64(2248), c.Remaining())
 | 
						require.Equal(t, int64(2248), c.Remaining())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -46,27 +47,50 @@ func TestFileCache_Write_Remove_Success(t *testing.T) {
 | 
				
			||||||
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
 | 
					func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
 | 
				
			||||||
	dir, c := newTestFileCache(t)
 | 
						dir, c := newTestFileCache(t)
 | 
				
			||||||
	for i := 0; i < 10; i++ {
 | 
						for i := 0; i < 10; i++ {
 | 
				
			||||||
		size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(oneKilobyteArray))
 | 
							size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
 | 
				
			||||||
		require.Nil(t, err)
 | 
							require.Nil(t, err)
 | 
				
			||||||
		require.Equal(t, int64(1024), size)
 | 
							require.Equal(t, int64(1024), size)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, err := c.Write("abc11", bytes.NewReader(oneKilobyteArray))
 | 
						_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
 | 
				
			||||||
	require.Equal(t, util.ErrLimitReached, err)
 | 
						require.Equal(t, util.ErrLimitReached, err)
 | 
				
			||||||
	require.NoFileExists(t, dir+"/abc11")
 | 
						require.NoFileExists(t, dir+"/abcdefghijkX")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
 | 
					func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
 | 
				
			||||||
	dir, c := newTestFileCache(t)
 | 
						dir, c := newTestFileCache(t)
 | 
				
			||||||
	_, err := c.Write("abc", bytes.NewReader(make([]byte, 1025)))
 | 
						_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1025)))
 | 
				
			||||||
	require.Equal(t, util.ErrLimitReached, err)
 | 
						require.Equal(t, util.ErrLimitReached, err)
 | 
				
			||||||
	require.NoFileExists(t, dir+"/abc")
 | 
						require.NoFileExists(t, dir+"/abcdefghijkl")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
 | 
					func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
 | 
				
			||||||
	dir, c := newTestFileCache(t)
 | 
						dir, c := newTestFileCache(t)
 | 
				
			||||||
	_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
 | 
						_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
 | 
				
			||||||
	require.Equal(t, util.ErrLimitReached, err)
 | 
						require.Equal(t, util.ErrLimitReached, err)
 | 
				
			||||||
	require.NoFileExists(t, dir+"/abc")
 | 
						require.NoFileExists(t, dir+"/abcdefghijkl")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFileCache_RemoveExpired(t *testing.T) {
 | 
				
			||||||
 | 
						dir, c := newTestFileCache(t)
 | 
				
			||||||
 | 
						_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)))
 | 
				
			||||||
 | 
						require.Nil(t, err)
 | 
				
			||||||
 | 
						_, err = c.Write("notdeleted12", bytes.NewReader(make([]byte, 1001)))
 | 
				
			||||||
 | 
						require.Nil(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						modTime := time.Now().Add(-1 * 4 * time.Hour)
 | 
				
			||||||
 | 
						require.Nil(t, os.Chtimes(dir+"/abcdefghijkl", modTime, modTime))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						olderThan := time.Now().Add(-1 * 3 * time.Hour)
 | 
				
			||||||
 | 
						ids, err := c.Expired(olderThan)
 | 
				
			||||||
 | 
						require.Nil(t, err)
 | 
				
			||||||
 | 
						require.Equal(t, []string{"abcdefghijkl"}, ids)
 | 
				
			||||||
 | 
						require.Nil(t, c.Remove(ids...))
 | 
				
			||||||
 | 
						require.NoFileExists(t, dir+"/abcdefghijkl")
 | 
				
			||||||
 | 
						require.FileExists(t, dir+"/notdeleted12")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ids, err = c.Expired(olderThan)
 | 
				
			||||||
 | 
						require.Nil(t, err)
 | 
				
			||||||
 | 
						require.Empty(t, ids)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
 | 
					func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -85,7 +85,6 @@ const (
 | 
				
			||||||
	selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
 | 
						selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
 | 
				
			||||||
	selectTopicsQuery               = `SELECT topic FROM messages GROUP BY topic`
 | 
						selectTopicsQuery               = `SELECT topic FROM messages GROUP BY topic`
 | 
				
			||||||
	selectAttachmentsSizeQuery      = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
 | 
						selectAttachmentsSizeQuery      = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
 | 
				
			||||||
	selectAttachmentsExpiredQuery   = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Schema management queries
 | 
					// Schema management queries
 | 
				
			||||||
| 
						 | 
					@ -409,26 +408,6 @@ func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) {
 | 
				
			||||||
	return size, nil
 | 
						return size, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *messageCache) AttachmentsExpired() ([]string, error) {
 | 
					 | 
				
			||||||
	rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	defer rows.Close()
 | 
					 | 
				
			||||||
	ids := make([]string, 0)
 | 
					 | 
				
			||||||
	for rows.Next() {
 | 
					 | 
				
			||||||
		var id string
 | 
					 | 
				
			||||||
		if err := rows.Scan(&id); err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		ids = append(ids, id)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := rows.Err(); err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return ids, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func readMessages(rows *sql.Rows) ([]*message, error) {
 | 
					func readMessages(rows *sql.Rows) ([]*message, error) {
 | 
				
			||||||
	defer rows.Close()
 | 
						defer rows.Close()
 | 
				
			||||||
	messages := make([]*message, 0)
 | 
						messages := make([]*message, 0)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -344,10 +344,6 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 | 
				
			||||||
	size, err = c.AttachmentBytesUsed("5.6.7.8")
 | 
						size, err = c.AttachmentBytesUsed("5.6.7.8")
 | 
				
			||||||
	require.Nil(t, err)
 | 
						require.Nil(t, err)
 | 
				
			||||||
	require.Equal(t, int64(0), size)
 | 
						require.Equal(t, int64(0), size)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	ids, err := c.AttachmentsExpired()
 | 
					 | 
				
			||||||
	require.Nil(t, err)
 | 
					 | 
				
			||||||
	require.Equal(t, []string{"m1"}, ids)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSqliteCache_Migration_From0(t *testing.T) {
 | 
					func TestSqliteCache_Migration_From0(t *testing.T) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1116,8 +1116,9 @@ func (s *Server) updateStatsAndPrune() {
 | 
				
			||||||
	log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
 | 
						log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Delete expired attachments
 | 
						// Delete expired attachments
 | 
				
			||||||
	if s.fileCache != nil {
 | 
						if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 {
 | 
				
			||||||
		ids, err := s.messageCache.AttachmentsExpired()
 | 
							olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration)
 | 
				
			||||||
 | 
							ids, err := s.fileCache.Expired(olderThan)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Warn("Error retrieving expired attachments: %s", err.Error())
 | 
								log.Warn("Error retrieving expired attachments: %s", err.Error())
 | 
				
			||||||
		} else if len(ids) > 0 {
 | 
							} else if len(ids) > 0 {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue