From cefe276ce55cb5c29ede1c1fef5393194c1f6e44 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Sat, 8 Jan 2022 12:14:43 -0500 Subject: [PATCH] Tests for fileCache --- go.mod | 2 - go.sum | 5 --- server/file_cache.go | 9 +++-- server/file_cache_test.go | 83 +++++++++++++++++++++++++++++++++++++++ server/server.go | 14 ++++--- util/limit.go | 7 +--- util/util_test.go | 6 +-- 7 files changed, 101 insertions(+), 25 deletions(-) create mode 100644 server/file_cache_test.go diff --git a/go.mod b/go.mod index 6f620033..fad88a46 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4 // indirect github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/disintegration/imaging v1.6.2 // indirect github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 // indirect github.com/envoyproxy/go-control-plane v0.10.1 // indirect github.com/envoyproxy/protoc-gen-validate v0.6.2 // indirect @@ -39,7 +38,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect go.opencensus.io v0.23.0 // indirect - golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d // indirect golang.org/x/sys v0.0.0-20211210111614-af8b64212486 // indirect golang.org/x/text v0.3.7 // indirect diff --git a/go.sum b/go.sum index 07ff72f4..91718f40 100644 --- a/go.sum +++ b/go.sum @@ -89,8 +89,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= -github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= github.com/emersion/go-smtp v0.15.0 h1:3+hMGMGrqP/lqd7qoxZc1hTU8LY8gHV9RFGWlqSDmP8= @@ -266,9 +264,6 @@ golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EH golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= -golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/server/file_cache.go b/server/file_cache.go index 244142cf..04ae76e4 100644 --- a/server/file_cache.go +++ b/server/file_cache.go @@ -4,7 +4,6 @@ import ( "errors" "heckel.io/ntfy/util" "io" - "log" "os" "path/filepath" "regexp" @@ -14,6 +13,7 @@ import ( var ( fileIDRegex = regexp.MustCompile(`^[-_A-Za-z0-9]+$`) errInvalidFileID = errors.New("invalid file ID") + errFileExists = errors.New("file exists") ) type fileCache struct { @@ -45,12 +45,14 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i return 0, errInvalidFileID } file := filepath.Join(c.dir, id) + if _, err := os.Stat(file); err == nil { + return 0, errFileExists + } f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) if err != nil { return 0, err } defer f.Close() - log.Printf("remaining total: %d", c.Remaining()) limiters = append(limiters, util.NewLimiter(c.Remaining()), util.NewLimiter(c.fileSizeLimit)) limitWriter := util.NewLimitWriter(f, limiters...) size, err := io.Copy(limitWriter, in) @@ -66,10 +68,9 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i c.totalSizeCurrent += size c.mu.Unlock() return size, nil - } -func (c *fileCache) Remove(ids []string) error { +func (c *fileCache) Remove(ids ...string) error { var firstErr error for _, id := range ids { if err := c.removeFile(id); err != nil { diff --git a/server/file_cache_test.go b/server/file_cache_test.go new file mode 100644 index 00000000..a0a74085 --- /dev/null +++ b/server/file_cache_test.go @@ -0,0 +1,83 @@ +package server + +import ( + "bytes" + "fmt" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/util" + "os" + "strings" + "testing" +) + +var ( + oneKilobyteArray = make([]byte, 1024) +) + +func TestFileCache_Write_Success(t *testing.T) { + dir, c := newTestFileCache(t) + size, err := c.Write("abc", strings.NewReader("normal file"), util.NewLimiter(999)) + require.Nil(t, err) + require.Equal(t, int64(11), size) + require.Equal(t, "normal file", readFile(t, dir+"/abc")) + require.Equal(t, int64(11), c.Size()) + require.Equal(t, int64(10229), c.Remaining()) +} + +func TestFileCache_Write_Remove_Success(t *testing.T) { + dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024) + for i := 0; i < 10; i++ { // 10x999 = 9990 + size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(make([]byte, 999))) + require.Nil(t, err) + require.Equal(t, int64(999), size) + } + require.Equal(t, int64(9990), c.Size()) + require.Equal(t, int64(250), c.Remaining()) + require.FileExists(t, dir+"/abc1") + require.FileExists(t, dir+"/abc5") + + require.Nil(t, c.Remove("abc1", "abc5")) + require.NoFileExists(t, dir+"/abc1") + require.NoFileExists(t, dir+"/abc5") + require.Equal(t, int64(7992), c.Size()) + require.Equal(t, int64(2248), c.Remaining()) +} + +func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) { + dir, c := newTestFileCache(t) + for i := 0; i < 10; i++ { + size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(oneKilobyteArray)) + require.Nil(t, err) + require.Equal(t, int64(1024), size) + } + _, err := c.Write("abc11", bytes.NewReader(oneKilobyteArray)) + require.Equal(t, util.ErrLimitReached, err) + require.NoFileExists(t, dir+"/abc11") +} + +func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) { + dir, c := newTestFileCache(t) + _, err := c.Write("abc", bytes.NewReader(make([]byte, 1025))) + require.Equal(t, util.ErrLimitReached, err) + require.NoFileExists(t, dir+"/abc") +} + +func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) { + dir, c := newTestFileCache(t) + _, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewLimiter(1000)) + require.Equal(t, util.ErrLimitReached, err) + require.NoFileExists(t, dir+"/abc") +} + +func newTestFileCache(t *testing.T) (dir string, cache *fileCache) { + dir = t.TempDir() + cache, err := newFileCache(dir, 10*1024, 1*1024) + require.Nil(t, err) + return dir, cache +} + +func readFile(t *testing.T, f string) string { + b, err := os.ReadFile(f) + require.Nil(t, err) + return string(b) +} diff --git a/server/server.go b/server/server.go index c3ee81fa..6cd46154 100644 --- a/server/server.go +++ b/server/server.go @@ -833,13 +833,15 @@ func (s *Server) updateStatsAndPrune() { } // Delete expired attachments - ids, err := s.cache.AttachmentsExpired() - if err == nil { - if err := s.fileCache.Remove(ids); err != nil { - log.Printf("error while deleting attachments: %s", err.Error()) + if s.fileCache != nil { + ids, err := s.cache.AttachmentsExpired() + if err == nil { + if err := s.fileCache.Remove(ids...); err != nil { + log.Printf("error while deleting attachments: %s", err.Error()) + } + } else { + log.Printf("error retrieving expired attachments: %s", err.Error()) } - } else { - log.Printf("error retrieving expired attachments: %s", err.Error()) } // Prune message cache diff --git a/util/limit.go b/util/limit.go index bac3c155..f0a0c5a3 100644 --- a/util/limit.go +++ b/util/limit.go @@ -24,15 +24,12 @@ func NewLimiter(limit int64) *Limiter { } } -// Add adds n to the limiters internal value, but only if the limit has not been reached. If the limit would be +// Add adds n to the limiters internal value, but only if the limit has not been reached. If the limit was // exceeded after adding n, ErrLimitReached is returned. func (l *Limiter) Add(n int64) error { l.mu.Lock() defer l.mu.Unlock() - if l.limit == 0 { - l.value += n - return nil - } else if l.value+n <= l.limit { + if l.value+n <= l.limit { l.value += n return nil } else { diff --git a/util/util_test.go b/util/util_test.go index f60aa252..45ff3de6 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -127,7 +127,7 @@ func TestParseSize_10GSuccess(t *testing.T) { if err != nil { t.Fatal(err) } - require.Equal(t, 10*1024*1024*1024, s) + require.Equal(t, int64(10*1024*1024*1024), s) } func TestParseSize_10MUpperCaseSuccess(t *testing.T) { @@ -135,7 +135,7 @@ func TestParseSize_10MUpperCaseSuccess(t *testing.T) { if err != nil { t.Fatal(err) } - require.Equal(t, 10*1024*1024, s) + require.Equal(t, int64(10*1024*1024), s) } func TestParseSize_10kLowerCaseSuccess(t *testing.T) { @@ -143,7 +143,7 @@ func TestParseSize_10kLowerCaseSuccess(t *testing.T) { if err != nil { t.Fatal(err) } - require.Equal(t, 10*1024, s) + require.Equal(t, int64(10*1024), s) } func TestParseSize_FailureInvalid(t *testing.T) {