diff --git a/server/smtp_server.go b/server/smtp_server.go index 3f4b9b68..733457eb 100644 --- a/server/smtp_server.go +++ b/server/smtp_server.go @@ -34,6 +34,9 @@ type smtpBackend struct { mu sync.Mutex } +var _ smtp.Backend = (*smtpBackend)(nil) +var _ smtp.Session = (*smtpSession)(nil) + func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend { return &smtpBackend{ config: conf, @@ -41,14 +44,9 @@ func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Reques } } -func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - log.Debug("%s Incoming mail, login with user %s", logSMTPPrefix(state), username) - return &smtpSession{backend: b, state: state}, nil -} - -func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - log.Debug("%s Incoming mail, anonymous login", logSMTPPrefix(state)) - return &smtpSession{backend: b, state: state}, nil +func (b *smtpBackend) NewSession(conn *smtp.Conn) (smtp.Session, error) { + log.Debug("%s Incoming mail", logSMTPPrefix(conn)) + return &smtpSession{backend: b, conn: conn}, nil } func (b *smtpBackend) Counts() (total int64, success int64, failure int64) { @@ -60,23 +58,23 @@ func (b *smtpBackend) Counts() (total int64, success int64, failure int64) { // smtpSession is returned after EHLO. type smtpSession struct { backend *smtpBackend - state *smtp.ConnectionState + conn *smtp.Conn topic string mu sync.Mutex } -func (s *smtpSession) AuthPlain(username, password string) error { - log.Debug("%s AUTH PLAIN (with username %s)", logSMTPPrefix(s.state), username) +func (s *smtpSession) AuthPlain(username, _ string) error { + log.Debug("%s AUTH PLAIN (with username %s)", logSMTPPrefix(s.conn), username) return nil } -func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error { - log.Debug("%s MAIL FROM: %s (with options: %#v)", logSMTPPrefix(s.state), from, opts) +func (s *smtpSession) Mail(from string, opts *smtp.MailOptions) error { + log.Debug("%s MAIL FROM: %s (with options: %#v)", logSMTPPrefix(s.conn), from, opts) return nil } func (s *smtpSession) Rcpt(to string) error { - log.Debug("%s RCPT TO: %s", logSMTPPrefix(s.state), to) + log.Debug("%s RCPT TO: %s", logSMTPPrefix(s.conn), to) return s.withFailCount(func() error { conf := s.backend.config addressList, err := mail.ParseAddressList(to) @@ -114,9 +112,9 @@ func (s *smtpSession) Data(r io.Reader) error { return err } if log.IsTrace() { - log.Trace("%s DATA: %s", logSMTPPrefix(s.state), string(b)) + log.Trace("%s DATA: %s", logSMTPPrefix(s.conn), string(b)) } else if log.IsDebug() { - log.Debug("%s DATA: %d byte(s)", logSMTPPrefix(s.state), len(b)) + log.Debug("%s DATA: %d byte(s)", logSMTPPrefix(s.conn), len(b)) } msg, err := mail.ReadMessage(bytes.NewReader(b)) if err != nil { @@ -156,9 +154,9 @@ func (s *smtpSession) Data(r io.Reader) error { func (s *smtpSession) publishMessage(m *message) error { // Extract remote address (for rate limiting) - remoteAddr, _, err := net.SplitHostPort(s.state.RemoteAddr.String()) + remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String()) if err != nil { - remoteAddr = s.state.RemoteAddr.String() + remoteAddr = s.conn.Conn().RemoteAddr().String() } // Call HTTP handler with fake HTTP request @@ -198,7 +196,7 @@ func (s *smtpSession) withFailCount(fn func() error) error { if err != nil { // Almost all of these errors are parse errors, and user input errors. // We do not want to spam the log with WARN messages. - log.Debug("%s Incoming mail error: %s", logSMTPPrefix(s.state), err.Error()) + log.Debug("%s Incoming mail error: %s", logSMTPPrefix(s.conn), err.Error()) s.backend.failure++ } return err diff --git a/server/smtp_server_test.go b/server/smtp_server_test.go index c0de7079..be99d9e9 100644 --- a/server/smtp_server_test.go +++ b/server/smtp_server_test.go @@ -34,8 +34,8 @@ Content-Type: text/html; charset="UTF-8" require.Equal(t, "and one more", r.Header.Get("Title")) require.Equal(t, "what's up", readAll(t, r.Body)) }) - session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4")) - require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) + session, _ := backend.NewSession(fakeConnState(t, "1.2.3.4")) + require.Nil(t, session.Mail("phil@example.com", &smtp.MailOptions{})) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Data(strings.NewReader(email))) } @@ -303,12 +303,12 @@ func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Reques return conf, backend } -func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState { +func fakeConnState(t *testing.T, remoteAddr string) *smtp.Conn { ip, err := net.ResolveIPAddr("ip", remoteAddr) if err != nil { t.Fatal(err) } - return &smtp.ConnectionState{ + return &smtp.Conn{ Hostname: "myhostname", LocalAddr: ip, RemoteAddr: ip, diff --git a/server/util.go b/server/util.go index 269a9d59..0d21f416 100644 --- a/server/util.go +++ b/server/util.go @@ -57,8 +57,8 @@ func logHTTPPrefix(v *visitor, r *http.Request) string { return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI) } -func logSMTPPrefix(state *smtp.ConnectionState) string { - return fmt.Sprintf("%s/%s SMTP", state.Hostname, state.RemoteAddr.String()) +func logSMTPPrefix(conn *smtp.Conn) string { + return fmt.Sprintf("%s/%s SMTP", conn.Hostname(), conn.Conn().RemoteAddr().String()) } func renderHTTPRequest(r *http.Request) string {