From db1a1fec0c5dbae8376d0f6b5047200578fce869 Mon Sep 17 00:00:00 2001
From: binwiederhier <pheckel@datto.com>
Date: Fri, 17 Feb 2023 09:07:57 -0500
Subject: [PATCH] Custom HTTP response writer

---
 server/server.go |  1 +
 server/util.go   | 57 ++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 58 insertions(+)

diff --git a/server/server.go b/server/server.go
index 4588614b..b9e7b17f 100644
--- a/server/server.go
+++ b/server/server.go
@@ -291,6 +291,7 @@ func (s *Server) closeDatabases() {
 
 // handle is the main entry point for all HTTP requests
 func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
+	w = newHTTPResponseWriter(w)     // Avoid logging "superfluous response.WriteHeader call" warning
 	v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
 	if err != nil {
 		s.handleError(w, r, v, err)
diff --git a/server/util.go b/server/util.go
index 1141e5d5..99c5e1bb 100644
--- a/server/util.go
+++ b/server/util.go
@@ -1,11 +1,14 @@
 package server
 
 import (
+	"bufio"
 	"heckel.io/ntfy/util"
 	"io"
+	"net"
 	"net/http"
 	"net/netip"
 	"strings"
+	"sync"
 )
 
 func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@@ -85,3 +88,57 @@ func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T,
 	}
 	return obj, nil
 }
+
+type httpResponseWriter struct {
+	w             http.ResponseWriter
+	headerWritten bool
+	mu            sync.Mutex
+}
+
+type httpResponseWriterWithHijacker struct {
+	httpResponseWriter
+}
+
+var _ http.ResponseWriter = (*httpResponseWriter)(nil)
+var _ http.Flusher = (*httpResponseWriter)(nil)
+var _ http.Hijacker = (*httpResponseWriterWithHijacker)(nil)
+
+func newHTTPResponseWriter(w http.ResponseWriter) http.ResponseWriter {
+	if _, ok := w.(http.Hijacker); ok {
+		return &httpResponseWriterWithHijacker{httpResponseWriter: httpResponseWriter{w: w}}
+	}
+	return &httpResponseWriter{w: w}
+}
+
+func (w *httpResponseWriter) Header() http.Header {
+	return w.w.Header()
+}
+
+func (w *httpResponseWriter) Write(bytes []byte) (int, error) {
+	w.mu.Lock()
+	w.headerWritten = true
+	w.mu.Unlock()
+	return w.w.Write(bytes)
+}
+
+func (w *httpResponseWriter) WriteHeader(statusCode int) {
+	w.mu.Lock()
+	if w.headerWritten {
+		w.mu.Unlock()
+		return
+	}
+	w.headerWritten = true
+	w.mu.Unlock()
+	w.w.WriteHeader(statusCode)
+}
+
+func (w *httpResponseWriter) Flush() {
+	if f, ok := w.w.(http.Flusher); ok {
+		f.Flush()
+	}
+}
+
+func (w *httpResponseWriterWithHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	h, _ := w.w.(http.Hijacker)
+	return h.Hijack()
+}