WIP: Attachments

pull/82/head
Philipp Heckel 2022-01-02 23:56:12 +01:00
parent 09515f26df
commit eb5b86ffe2
13 changed files with 444 additions and 46 deletions

1
.gitignore vendored
View File

@ -3,4 +3,5 @@ build/
.idea/ .idea/
server/docs/ server/docs/
tools/fbsend/fbsend tools/fbsend/fbsend
playground/
*.iml *.iml

View File

@ -20,6 +20,7 @@ var flagsServe = []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
@ -69,6 +70,7 @@ func execServe(c *cli.Context) error {
firebaseKeyFile := c.String("firebase-key-file") firebaseKeyFile := c.String("firebase-key-file")
cacheFile := c.String("cache-file") cacheFile := c.String("cache-file")
cacheDuration := c.Duration("cache-duration") cacheDuration := c.Duration("cache-duration")
attachmentCacheDir := c.String("attachment-cache-dir")
keepaliveInterval := c.Duration("keepalive-interval") keepaliveInterval := c.Duration("keepalive-interval")
managerInterval := c.Duration("manager-interval") managerInterval := c.Duration("manager-interval")
smtpSenderAddr := c.String("smtp-sender-addr") smtpSenderAddr := c.String("smtp-sender-addr")
@ -117,6 +119,7 @@ func execServe(c *cli.Context) error {
conf.FirebaseKeyFile = firebaseKeyFile conf.FirebaseKeyFile = firebaseKeyFile
conf.CacheFile = cacheFile conf.CacheFile = cacheFile
conf.CacheDuration = cacheDuration conf.CacheDuration = cacheDuration
conf.AttachmentCacheDir = attachmentCacheDir
conf.KeepaliveInterval = keepaliveInterval conf.KeepaliveInterval = keepaliveInterval
conf.ManagerInterval = managerInterval conf.ManagerInterval = managerInterval
conf.SMTPSenderAddr = smtpSenderAddr conf.SMTPSenderAddr = smtpSenderAddr
@ -126,7 +129,7 @@ func execServe(c *cli.Context) error {
conf.SMTPServerListen = smtpServerListen conf.SMTPServerListen = smtpServerListen
conf.SMTPServerDomain = smtpServerDomain conf.SMTPServerDomain = smtpServerDomain
conf.SMTPServerAddrPrefix = smtpServerAddrPrefix conf.SMTPServerAddrPrefix = smtpServerAddrPrefix
conf.GlobalTopicLimit = globalTopicLimit conf.TotalTopicLimit = globalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish

View File

@ -886,3 +886,4 @@ and can be passed as **HTTP headers** or **query parameters in the URL**. They a
| `X-Email` | `X-E-Mail`, `Email`, `E-Mail`, `mail`, `e` | E-mail address for [e-mail notifications](#e-mail-notifications) | | `X-Email` | `X-E-Mail`, `Email`, `E-Mail`, `mail`, `e` | E-mail address for [e-mail notifications](#e-mail-notifications) |
| `X-Cache` | `Cache` | Allows disabling [message caching](#message-caching) | | `X-Cache` | `Cache` | Allows disabling [message caching](#message-caching) |
| `X-Firebase` | `Firebase` | Allows disabling [sending to Firebase](#disable-firebase) | | `X-Firebase` | `Firebase` | Allows disabling [sending to Firebase](#disable-firebase) |
| `X-UnifiedPush` | `UnifiedPush`, `up` | XXXXXXXXXXXXXXXX |

View File

@ -14,6 +14,7 @@ const (
DefaultMinDelay = 10 * time.Second DefaultMinDelay = 10 * time.Second
DefaultMaxDelay = 3 * 24 * time.Hour DefaultMaxDelay = 3 * 24 * time.Hour
DefaultMessageLimit = 4096 DefaultMessageLimit = 4096
DefaultAttachmentSizeLimit = 5 * 1024 * 1024
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // Not too frequently to save battery DefaultFirebaseKeepaliveInterval = 3 * time.Hour // Not too frequently to save battery
) )
@ -41,6 +42,8 @@ type Config struct {
FirebaseKeyFile string FirebaseKeyFile string
CacheFile string CacheFile string
CacheDuration time.Duration CacheDuration time.Duration
AttachmentCacheDir string
AttachmentSizeLimit int64
KeepaliveInterval time.Duration KeepaliveInterval time.Duration
ManagerInterval time.Duration ManagerInterval time.Duration
AtSenderInterval time.Duration AtSenderInterval time.Duration
@ -55,7 +58,8 @@ type Config struct {
MessageLimit int MessageLimit int
MinDelay time.Duration MinDelay time.Duration
MaxDelay time.Duration MaxDelay time.Duration
GlobalTopicLimit int TotalTopicLimit int
TotalAttachmentSizeLimit int64
VisitorRequestLimitBurst int VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration VisitorRequestLimitReplenish time.Duration
VisitorEmailLimitBurst int VisitorEmailLimitBurst int
@ -75,6 +79,8 @@ func NewConfig() *Config {
FirebaseKeyFile: "", FirebaseKeyFile: "",
CacheFile: "", CacheFile: "",
CacheDuration: DefaultCacheDuration, CacheDuration: DefaultCacheDuration,
AttachmentCacheDir: "",
AttachmentSizeLimit: DefaultAttachmentSizeLimit,
KeepaliveInterval: DefaultKeepaliveInterval, KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval, ManagerInterval: DefaultManagerInterval,
MessageLimit: DefaultMessageLimit, MessageLimit: DefaultMessageLimit,
@ -82,7 +88,7 @@ func NewConfig() *Config {
MaxDelay: DefaultMaxDelay, MaxDelay: DefaultMaxDelay,
AtSenderInterval: DefaultAtSenderInterval, AtSenderInterval: DefaultAtSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval, FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
GlobalTopicLimit: DefaultGlobalTopicLimit, TotalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,

View File

@ -26,6 +26,13 @@ type message struct {
Tags []string `json:"tags,omitempty"` Tags []string `json:"tags,omitempty"`
Title string `json:"title,omitempty"` Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"` Message string `json:"message,omitempty"`
Attachment *attachment `json:"attachment,omitempty"`
}
type attachment struct {
Name string `json:"name"`
Type string `json:"type"`
URL string `json:"url"`
} }
// messageEncoder is a function that knows how to encode a message // messageEncoder is a function that knows how to encode a message

View File

@ -15,14 +15,18 @@ import (
"html/template" "html/template"
"io" "io"
"log" "log"
"mime"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"unicode/utf8"
) )
// TODO add "max messages in a topic" limit // TODO add "max messages in a topic" limit
@ -96,7 +100,8 @@ var (
staticRegex = regexp.MustCompile(`^/static/.+`) staticRegex = regexp.MustCompile(`^/static/.+`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
disallowedTopics = []string{"docs", "static"} fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
disallowedTopics = []string{"docs", "static", "file"}
templateFnMap = template.FuncMap{ templateFnMap = template.FuncMap{
"durationToHuman": util.DurationToHuman, "durationToHuman": util.DurationToHuman,
@ -132,7 +137,11 @@ var (
errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"} errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""} errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""} errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
errHTTPBadRequestAttachmentsDisallowed = &errHTTP{40011, http.StatusBadRequest, "attachments disallowed", ""}
errHTTPBadRequestAttachmentsPublishDisallowed = &errHTTP{40011, http.StatusBadRequest, "invalid message: invalid encoding or too large, and attachments are not allowed", ""}
errHTTPBadRequestMessageTooLarge = &errHTTP{40013, http.StatusBadRequest, "invalid message: too large", ""}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
) )
const ( const (
@ -163,6 +172,11 @@ func New(conf *Config) (*Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if conf.AttachmentCacheDir != "" {
if err := os.MkdirAll(conf.AttachmentCacheDir, 0700); err != nil {
return nil, err
}
}
return &Server{ return &Server{
config: conf, config: conf,
cache: cache, cache: cache,
@ -302,6 +316,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
return s.handleStatic(w, r) return s.handleStatic(w, r)
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
return s.handleDocs(w, r) return s.handleDocs(w, r)
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) {
return s.handleFile(w, r)
} else if r.Method == http.MethodOptions { } else if r.Method == http.MethodOptions {
return s.handleOptions(w, r) return s.handleOptions(w, r)
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
@ -357,17 +373,45 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error {
return nil return nil
} }
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request) error {
if s.config.AttachmentCacheDir == "" {
return errHTTPBadRequestAttachmentsDisallowed
}
matches := fileRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidFilePath
}
messageID := matches[1]
file := filepath.Join(s.config.AttachmentCacheDir, messageID)
stat, err := os.Stat(file)
if err != nil {
return errHTTPNotFound
}
w.Header().Set("Length", fmt.Sprintf("%d", stat.Size()))
f, err := os.Open(file)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(util.NewContentTypeWriter(w), f)
return err
}
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
t, err := s.topicFromPath(r.URL.Path) t, err := s.topicFromPath(r.URL.Path)
if err != nil { if err != nil {
return err return err
} }
reader := io.LimitReader(r.Body, int64(s.config.MessageLimit)) body, err := util.Peak(r.Body, s.config.MessageLimit)
b, err := io.ReadAll(reader)
if err != nil { if err != nil {
return err return err
} }
m := newDefaultMessage(t.ID, strings.TrimSpace(string(b))) m := newDefaultMessage(t.ID, "")
if !body.LimitReached && utf8.Valid(body.PeakedBytes) {
m.Message = strings.TrimSpace(string(body.PeakedBytes))
} else if err := s.writeAttachment(v, m, body); err != nil {
return err
}
cache, firebase, email, err := s.parsePublishParams(r, m) cache, firebase, email, err := s.parsePublishParams(r, m)
if err != nil { if err != nil {
return err return err
@ -478,6 +522,48 @@ func readParam(r *http.Request, names ...string) string {
return "" return ""
} }
func (s *Server) writeAttachment(v *visitor, m *message, body *util.PeakedReadCloser) error {
if s.config.AttachmentCacheDir == "" || !util.FileExists(s.config.AttachmentCacheDir) {
return errHTTPBadRequestAttachmentsPublishDisallowed
}
contentType := http.DetectContentType(body.PeakedBytes)
exts, err := mime.ExtensionsByType(contentType)
if err != nil {
return err
}
ext := ".bin"
if len(exts) > 0 {
ext = exts[0]
}
filename := fmt.Sprintf("attachment%s", ext)
file := filepath.Join(s.config.AttachmentCacheDir, m.ID)
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer f.Close()
fileSizeLimiter := util.NewLimiter(s.config.AttachmentSizeLimit)
limitWriter := util.NewLimitWriter(f, fileSizeLimiter)
if _, err := io.Copy(limitWriter, body); err != nil {
os.Remove(file)
if err == util.ErrLimitReached {
return errHTTPBadRequestMessageTooLarge
}
return err
}
if err := f.Close(); err != nil {
os.Remove(file)
return err
}
m.Message = fmt.Sprintf("You received a file: %s", filename)
m.Attachment = &attachment{
Name: filename,
Type: contentType,
URL: fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext),
}
return nil
}
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) { encoder := func(msg *message) (string, error) {
var buf bytes.Buffer var buf bytes.Buffer
@ -691,7 +777,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return nil, errHTTPBadRequestTopicDisallowed return nil, errHTTPBadRequestTopicDisallowed
} }
if _, ok := s.topics[id]; !ok { if _, ok := s.topics[id]; !ok {
if len(s.topics) >= s.config.GlobalTopicLimit { if len(s.topics) >= s.config.TotalTopicLimit {
return nil, errHTTPTooManyRequestsLimitGlobalTopics return nil, errHTTPTooManyRequestsLimitGlobalTopics
} }
s.topics[id] = newTopic(id) s.topics[id] = newTopic(id)

View File

@ -165,17 +165,8 @@ func TestServer_PublishLargeMessage(t *testing.T) {
s := newTestServer(t, newTestConfig(t)) s := newTestServer(t, newTestConfig(t))
body := strings.Repeat("this is a large message", 5000) body := strings.Repeat("this is a large message", 5000)
truncated := body[0:4096]
response := request(t, s, "PUT", "/mytopic", body, nil) response := request(t, s, "PUT", "/mytopic", body, nil)
msg := toMessage(t, response.Body.String()) require.Equal(t, 400, response.Code)
require.NotEmpty(t, msg.ID)
require.Equal(t, truncated, msg.Message)
require.Equal(t, 4096, len(msg.Message))
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
messages := toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages))
require.Equal(t, truncated, messages[0].Message)
} }
func TestServer_PublishPriority(t *testing.T) { func TestServer_PublishPriority(t *testing.T) {

View File

@ -0,0 +1,41 @@
package util
import (
"net/http"
"strings"
)
// ContentTypeWriter is an implementation of http.ResponseWriter that will detect the content type and set the
// Content-Type and (optionally) Content-Disposition headers accordingly.
//
// It will always set a Content-Type based on http.DetectContentType, but will never send the "text/html"
// content type.
type ContentTypeWriter struct {
w http.ResponseWriter
sniffed bool
}
// NewContentTypeWriter creates a new ContentTypeWriter
func NewContentTypeWriter(w http.ResponseWriter) *ContentTypeWriter {
return &ContentTypeWriter{w, false}
}
func (w *ContentTypeWriter) Write(p []byte) (n int, err error) {
if w.sniffed {
return w.w.Write(p)
}
// Detect and set Content-Type header
// Fix content types that we don't want to inline-render in the browser. In particular,
// we don't want to render HTML in the browser for security reasons.
contentType := http.DetectContentType(p)
if strings.HasPrefix(contentType, "text/html") {
contentType = strings.ReplaceAll(contentType, "text/html", "text/plain")
} else if contentType == "application/octet-stream" {
contentType = "" // Reset to let downstream http.ResponseWriter take care of it
}
if contentType != "" {
w.w.Header().Set("Content-Type", contentType)
}
w.sniffed = true
return w.w.Write(p)
}

View File

@ -0,0 +1,50 @@
package util
import (
"crypto/rand"
"github.com/stretchr/testify/require"
"net/http/httptest"
"testing"
)
func TestSniffWriter_WriteHTML(t *testing.T) {
rr := httptest.NewRecorder()
sw := NewContentTypeWriter(rr)
sw.Write([]byte("<script>alert('hi')</script>"))
require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
}
func TestSniffWriter_WriteTwoWriteCalls(t *testing.T) {
rr := httptest.NewRecorder()
sw := NewContentTypeWriter(rr)
sw.Write([]byte{0x25, 0x50, 0x44, 0x46, 0x2d, 0x11, 0x22, 0x33})
sw.Write([]byte("<script>alert('hi')</script>"))
require.Equal(t, "application/pdf", rr.Header().Get("Content-Type"))
}
func TestSniffWriter_NoSniffWriterWriteHTML(t *testing.T) {
// This test just makes sure that without the sniff-w, we would get text/html
rr := httptest.NewRecorder()
rr.Write([]byte("<script>alert('hi')</script>"))
require.Equal(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type"))
}
func TestSniffWriter_WriteHTMLSplitIntoTwoWrites(t *testing.T) {
// This test shows how splitting the HTML into two Write() calls will still yield text/plain
rr := httptest.NewRecorder()
sw := NewContentTypeWriter(rr)
sw.Write([]byte("<scr"))
sw.Write([]byte("ipt>alert('hi')</script>"))
require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
}
func TestSniffWriter_WriteUnknownMimeType(t *testing.T) {
rr := httptest.NewRecorder()
sw := NewContentTypeWriter(rr)
randomBytes := make([]byte, 199)
rand.Read(randomBytes)
sw.Write(randomBytes)
require.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type"))
}

View File

@ -2,6 +2,7 @@ package util
import ( import (
"errors" "errors"
"io"
"sync" "sync"
) )
@ -58,3 +59,43 @@ func (l *Limiter) Value() int64 {
defer l.mu.Unlock() defer l.mu.Unlock()
return l.value return l.value
} }
// Limit returns the defined limit
func (l *Limiter) Limit() int64 {
return l.limit
}
// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
// Each limiter's value is increased with every write.
type LimitWriter struct {
w io.Writer
written int64
limiters []*Limiter
mu sync.Mutex
}
// NewLimitWriter creates a new LimitWriter
func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter {
return &LimitWriter{
w: w,
limiters: limiters,
}
}
// Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
func (w *LimitWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
for i := 0; i < len(w.limiters); i++ {
if err := w.limiters[i].Add(int64(len(p))); err != nil {
for j := i - 1; j >= 0; j-- {
w.limiters[j].Sub(int64(len(p)))
}
return 0, ErrLimitReached
}
}
n, err = w.w.Write(p)
w.written += int64(n)
return
}

View File

@ -1,6 +1,7 @@
package util package util
import ( import (
"bytes"
"testing" "testing"
) )
@ -17,14 +18,68 @@ func TestLimiter_Add(t *testing.T) {
} }
} }
func TestLimiter_AddSub(t *testing.T) { func TestLimiter_AddSet(t *testing.T) {
l := NewLimiter(10) l := NewLimiter(10)
l.Add(5) l.Add(5)
if l.Value() != 5 { if l.Value() != 5 {
t.Fatalf("expected value to be %d, got %d", 5, l.Value()) t.Fatalf("expected value to be %d, got %d", 5, l.Value())
} }
l.Sub(2) l.Set(7)
if l.Value() != 3 { if l.Value() != 7 {
t.Fatalf("expected value to be %d, got %d", 3, l.Value()) t.Fatalf("expected value to be %d, got %d", 7, l.Value())
}
}
func TestLimitWriter_WriteNoLimiter(t *testing.T) {
var buf bytes.Buffer
lw := NewLimitWriter(&buf)
if _, err := lw.Write(make([]byte, 10)); err != nil {
t.Fatal(err)
}
if _, err := lw.Write(make([]byte, 1)); err != nil {
t.Fatal(err)
}
if buf.Len() != 11 {
t.Fatalf("expected buffer length to be %d, got %d", 11, buf.Len())
}
}
func TestLimitWriter_WriteOneLimiter(t *testing.T) {
var buf bytes.Buffer
l := NewLimiter(10)
lw := NewLimitWriter(&buf, l)
if _, err := lw.Write(make([]byte, 10)); err != nil {
t.Fatal(err)
}
if _, err := lw.Write(make([]byte, 1)); err != ErrLimitReached {
t.Fatalf("expected ErrLimitReached, got %#v", err)
}
if buf.Len() != 10 {
t.Fatalf("expected buffer length to be %d, got %d", 10, buf.Len())
}
if l.Value() != 10 {
t.Fatalf("expected limiter value to be %d, got %d", 10, l.Value())
}
}
func TestLimitWriter_WriteTwoLimiters(t *testing.T) {
var buf bytes.Buffer
l1 := NewLimiter(11)
l2 := NewLimiter(9)
lw := NewLimitWriter(&buf, l1, l2)
if _, err := lw.Write(make([]byte, 8)); err != nil {
t.Fatal(err)
}
if _, err := lw.Write(make([]byte, 2)); err != ErrLimitReached {
t.Fatalf("expected ErrLimitReached, got %#v", err)
}
if buf.Len() != 8 {
t.Fatalf("expected buffer length to be %d, got %d", 8, buf.Len())
}
if l1.Value() != 8 {
t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.Value())
}
if l2.Value() != 8 {
t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.Value())
} }
} }

61
util/peak.go 100644
View File

@ -0,0 +1,61 @@
package util
import (
"bytes"
"io"
"strings"
)
// PeakedReadCloser is a ReadCloser that allows peaking into a stream and buffering it in memory.
// It can be instantiated using the Peak function. After a stream has been peaked, it can still be fully
// read by reading the PeakedReadCloser. It first drained from the memory buffer, and then from the remaining
// underlying reader.
type PeakedReadCloser struct {
PeakedBytes []byte
LimitReached bool
peaked io.Reader
underlying io.ReadCloser
closed bool
}
// Peak reads the underlying ReadCloser into memory up until the limit and returns a PeakedReadCloser
func Peak(underlying io.ReadCloser, limit int) (*PeakedReadCloser, error) {
if underlying == nil {
underlying = io.NopCloser(strings.NewReader(""))
}
peaked := make([]byte, limit)
read, err := io.ReadFull(underlying, peaked)
if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
return nil, err
}
return &PeakedReadCloser{
PeakedBytes: peaked[:read],
LimitReached: read == limit,
underlying: underlying,
peaked: bytes.NewReader(peaked[:read]),
closed: false,
}, nil
}
// Read reads from the peaked bytes and then from the underlying stream
func (r *PeakedReadCloser) Read(p []byte) (n int, err error) {
if r.closed {
return 0, io.EOF
}
n, err = r.peaked.Read(p)
if err == io.EOF {
return r.underlying.Read(p)
} else if err != nil {
return 0, err
}
return
}
// Close closes the underlying stream
func (r *PeakedReadCloser) Close() error {
if r.closed {
return io.EOF
}
r.closed = true
return r.underlying.Close()
}

55
util/peak_test.go 100644
View File

@ -0,0 +1,55 @@
package util
import (
"github.com/stretchr/testify/require"
"io"
"strings"
"testing"
)
func TestPeak_LimitReached(t *testing.T) {
underlying := io.NopCloser(strings.NewReader("1234567890"))
peaked, err := Peak(underlying, 5)
if err != nil {
t.Fatal(err)
}
require.Equal(t, []byte("12345"), peaked.PeakedBytes)
require.Equal(t, true, peaked.LimitReached)
all, err := io.ReadAll(peaked)
if err != nil {
t.Fatal(err)
}
require.Equal(t, []byte("1234567890"), all)
require.Equal(t, []byte("12345"), peaked.PeakedBytes)
require.Equal(t, true, peaked.LimitReached)
}
func TestPeak_LimitNotReached(t *testing.T) {
underlying := io.NopCloser(strings.NewReader("1234567890"))
peaked, err := Peak(underlying, 15)
if err != nil {
t.Fatal(err)
}
all, err := io.ReadAll(peaked)
if err != nil {
t.Fatal(err)
}
require.Equal(t, []byte("1234567890"), all)
require.Equal(t, []byte("1234567890"), peaked.PeakedBytes)
require.Equal(t, false, peaked.LimitReached)
}
func TestPeak_Nil(t *testing.T) {
peaked, err := Peak(nil, 15)
if err != nil {
t.Fatal(err)
}
all, err := io.ReadAll(peaked)
if err != nil {
t.Fatal(err)
}
require.Equal(t, []byte(""), all)
require.Equal(t, []byte(""), peaked.PeakedBytes)
require.Equal(t, false, peaked.LimitReached)
}