Merge branch 'main' into logging

pull/284/head
Philipp Heckel 2022-05-31 23:39:11 -04:00
commit a04cf5fcb6
20 changed files with 917 additions and 715 deletions

39
.github/workflows/build.yaml vendored 100644
View File

@ -0,0 +1,39 @@
name: build
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
-
name: Install Go
uses: actions/setup-go@v2
with:
go-version: '1.18.x'
-
name: Install node
uses: actions/setup-node@v2
with:
node-version: '16'
-
name: Checkout code
uses: actions/checkout@v2
-
name: Cache Go and npm modules
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
~/go/bin
~/.npm
web/node_modules
key: ${{ runner.os }}-ntfy-${{ hashFiles('**/go.sum', '**/package.lock') }}
restore-keys: ${{ runner.os }}-ntfy-
-
name: Install dependencies
run: make build-deps-ubuntu
-
name: Build all the things
run: make build
-
name: Print build results and checksums
run: make cli-build-results

View File

@ -1,72 +0,0 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ main ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ main ]
schedule:
- cron: '21 10 * * 5'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'go', 'javascript' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
steps:
- name: Checkout repository
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v2
# Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
# If the Autobuild fails above, remove it and uncomment the following three lines.
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
# - run: |
# echo "Run, Build Application using script"
# ./location_of_script_within_repo/buildscript.sh
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2

50
.github/workflows/release.yaml vendored 100644
View File

@ -0,0 +1,50 @@
name: release
on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
jobs:
release:
runs-on: ubuntu-latest
steps:
-
name: Install Go
uses: actions/setup-go@v2
with:
go-version: '1.18.x'
-
name: Install node
uses: actions/setup-node@v2
with:
node-version: '16'
-
name: Checkout code
uses: actions/checkout@v2
-
name: Cache Go and npm modules
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
~/go/bin
~/.npm
web/node_modules
key: ${{ runner.os }}-ntfy-${{ hashFiles('**/go.sum', '**/package.lock') }}
restore-keys: ${{ runner.os }}-ntfy-
-
name: Docker login
uses: docker/login-action@v2
with:
username: ${{ github.repository_owner }}
password: ${{ secrets.DOCKER_HUB_TOKEN }}
-
name: Install dependencies
run: make build-deps-ubuntu
-
name: Build and publish
run: make release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: Print build results and checksums
run: make cli-build-results

View File

@ -4,25 +4,45 @@ jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go -
name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: '1.17.x' go-version: '1.18.x'
- name: Install node -
name: Install node
uses: actions/setup-node@v2 uses: actions/setup-node@v2
with: with:
node-version: '16' node-version: '16'
- name: Checkout code -
name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Install dependencies -
run: sudo apt update && sudo apt install -y python3-pip curl name: Cache Go and npm modules
- name: Build docs (required for tests) uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
~/go/bin
~/.npm
web/node_modules
key: ${{ runner.os }}-ntfy-${{ hashFiles('**/go.sum', '**/package.lock') }}
restore-keys: ${{ runner.os }}-ntfy-
-
name: Install dependencies
run: make build-deps-ubuntu
-
name: Build docs (required for tests)
run: make docs run: make docs
- name: Build web app (required for tests) -
name: Build web app (required for tests)
run: make web run: make web
- name: Run tests, formatting, vetting and linting -
name: Run tests, formatting, vetting and linting
run: make check run: make check
- name: Run coverage -
name: Run coverage
run: make coverage run: make coverage
- name: Upload coverage to codecov.io -
name: Upload coverage to codecov.io
run: make coverage-upload run: make coverage-upload

View File

@ -157,6 +157,7 @@ universal_binaries:
- -
id: ntfy_darwin_all id: ntfy_darwin_all
replace: true replace: true
name_template: ntfy
checksum: checksum:
name_template: 'checksums.txt' name_template: 'checksums.txt'
snapshot: snapshot:

View File

@ -79,6 +79,18 @@ build: web docs cli
update: web-deps-update cli-deps-update docs-deps-update update: web-deps-update cli-deps-update docs-deps-update
docker pull alpine docker pull alpine
# Ubuntu-specific
build-deps-ubuntu:
sudo apt update
sudo apt install -y \
curl \
gcc-aarch64-linux-gnu \
gcc-arm-linux-gnueabi \
upx \
jq
which pip3 || sudo apt install -y python3-pip
# Documentation # Documentation
docs: docs-deps docs-build docs: docs-deps docs-build
@ -114,28 +126,29 @@ web-deps:
web-deps-update: web-deps-update:
cd web && npm update cd web && npm update
# Main server/client build # Main server/client build
cli: cli-deps cli: cli-deps
goreleaser build --snapshot --rm-dist --debug goreleaser build --snapshot --rm-dist
cli-linux-amd64: cli-deps-static-sites cli-linux-amd64: cli-deps-static-sites
goreleaser build --snapshot --rm-dist --debug --id ntfy_linux_amd64 goreleaser build --snapshot --rm-dist --id ntfy_linux_amd64
cli-linux-armv6: cli-deps-static-sites cli-deps-gcc-armv6-armv7 cli-linux-armv6: cli-deps-static-sites cli-deps-gcc-armv6-armv7
goreleaser build --snapshot --rm-dist --debug --id ntfy_linux_armv6 goreleaser build --snapshot --rm-dist --id ntfy_linux_armv6
cli-linux-armv7: cli-deps-static-sites cli-deps-gcc-armv6-armv7 cli-linux-armv7: cli-deps-static-sites cli-deps-gcc-armv6-armv7
goreleaser build --snapshot --rm-dist --debug --id ntfy_linux_armv7 goreleaser build --snapshot --rm-dist --id ntfy_linux_armv7
cli-linux-arm64: cli-deps-static-sites cli-deps-gcc-arm64 cli-linux-arm64: cli-deps-static-sites cli-deps-gcc-arm64
goreleaser build --snapshot --rm-dist --debug --id ntfy_linux_arm64 goreleaser build --snapshot --rm-dist --id ntfy_linux_arm64
cli-windows-amd64: cli-deps-static-sites cli-windows-amd64: cli-deps-static-sites
goreleaser build --snapshot --rm-dist --debug --id ntfy_windows_amd64 goreleaser build --snapshot --rm-dist --id ntfy_windows_amd64
cli-darwin-all: cli-deps-static-sites cli-darwin-all: cli-deps-static-sites
goreleaser build --snapshot --rm-dist --debug --id ntfy_darwin_all goreleaser build --snapshot --rm-dist --id ntfy_darwin_all
cli-linux-server: cli-deps-static-sites cli-linux-server: cli-deps-static-sites
# This is a target to build the CLI (including the server) manually. # This is a target to build the CLI (including the server) manually.
@ -177,6 +190,7 @@ cli-deps-static-sites:
cli-deps-all: cli-deps-all:
which upx || { echo "ERROR: upx not installed. On Ubuntu, run: apt install upx"; exit 1; } which upx || { echo "ERROR: upx not installed. On Ubuntu, run: apt install upx"; exit 1; }
go install github.com/goreleaser/goreleaser@latest
cli-deps-gcc-armv6-armv7: cli-deps-gcc-armv6-armv7:
which arm-linux-gnueabi-gcc || { echo "ERROR: ARMv6/ARMv7 cross compiler not installed. On Ubuntu, run: apt install gcc-arm-linux-gnueabi"; exit 1; } which arm-linux-gnueabi-gcc || { echo "ERROR: ARMv6/ARMv7 cross compiler not installed. On Ubuntu, run: apt install gcc-arm-linux-gnueabi"; exit 1; }
@ -187,6 +201,18 @@ cli-deps-gcc-arm64:
cli-deps-update: cli-deps-update:
go get -u go get -u
go install honnef.co/go/tools/cmd/staticcheck@latest go install honnef.co/go/tools/cmd/staticcheck@latest
go install golang.org/x/lint/golint@latest
go install github.com/goreleaser/goreleaser@latest
cli-build-results:
cat dist/config.yaml
[ -f dist/artifacts.json ] && cat dist/artifacts.json | jq . || true
[ -f dist/metadata.json ] && cat dist/metadata.json | jq . || true
[ -f dist/checksums.txt ] && cat dist/checksums.txt || true
find dist -maxdepth 2 -type f \
\( -name '*.deb' -or -name '*.rpm' -or -name '*.zip' -or -name '*.tar.gz' -or -name 'ntfy' \) \
-and -not -path 'dist/goreleaserdocker*' \
-exec sha256sum {} \;
# Test/check targets # Test/check targets
@ -238,13 +264,13 @@ staticcheck: .PHONY
# Releasing targets # Releasing targets
release: clean update cli-deps release-check-tags docs web check release: clean update cli-deps release-checks docs web check
goreleaser release --rm-dist --debug goreleaser release --rm-dist
release-snapshot: clean update cli-deps docs web check release-snapshot: clean update cli-deps docs web check
goreleaser release --snapshot --skip-publish --rm-dist --debug goreleaser release --snapshot --skip-publish --rm-dist
release-check-tags: release-checks:
$(eval LATEST_TAG := $(shell git describe --abbrev=0 --tags | cut -c2-)) $(eval LATEST_TAG := $(shell git describe --abbrev=0 --tags | cut -c2-))
if ! grep -q $(LATEST_TAG) docs/install.md; then\ if ! grep -q $(LATEST_TAG) docs/install.md; then\
echo "ERROR: Must update docs/install.md with latest tag first.";\ echo "ERROR: Must update docs/install.md with latest tag first.";\
@ -254,6 +280,10 @@ release-check-tags:
echo "ERROR: Must update docs/releases.md with latest tag first.";\ echo "ERROR: Must update docs/releases.md with latest tag first.";\
exit 1;\ exit 1;\
fi fi
if [ -n "$(shell git status -s)" ]; then\
echo "ERROR: Git repository is in an unclean state.";\
exit 1;\
fi
# Installing targets # Installing targets

View File

@ -4,13 +4,39 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
<!-- <!--
## ntfy iOS app v1.1 (UNRELEASED) ## ntfy Android app v1.14.0 (UNRELEASED)
**Additional translations:**
* Italian (thanks to [@Genio2003](https://hosted.weblate.org/user/Genio2003/))
## ntfy server v1.25.0 (UNRELEASED)
**Bugs**:
* Respect Firebase "quota exceeded" response for topics, block Firebase publishing for user for 10min ([#289](https://github.com/binwiederhier/ntfy/issues/289))
**Maintenance:**
* Upgrade Firebase Admin SDK to 4.x ([#274](https://github.com/binwiederhier/ntfy/issues/274))
* CI: Build from pipeline instead of locally ([#36](https://github.com/binwiederhier/ntfy/issues/36))
**Documentation**:
* [Examples](examples.md) for [Home Assistant](https://www.home-assistant.io/) ([#282](https://github.com/binwiederhier/ntfy/pull/282), thanks to [@poblabs](https://github.com/poblabs))
-->
## ntfy iOS app v1.1
Released May 31, 2022
In this release of the iOS app, we add message priorities (mapped to iOS interruption levels), tags and emojis, In this release of the iOS app, we add message priorities (mapped to iOS interruption levels), tags and emojis,
action buttons to open websites or perform HTTP requests (in the notification and the detail view), a custom click action buttons to open websites or perform HTTP requests (in the notification and the detail view), a custom click
action when the notification is tapped, and various other fixes. action when the notification is tapped, and various other fixes.
It also adds support for self-hosted servers (albeit not supporting auth yet). The selfhosted server needs to be It also adds support for self-hosted servers (albeit not supporting auth yet). The self-hosted server needs to be
configured to forward poll requests to upstream ntfy.sh for push notifications to work (see [iOS push notifications](https://ntfy.sh/docs/config/#ios-instant-notifications) configured to forward poll requests to upstream ntfy.sh for push notifications to work (see [iOS push notifications](https://ntfy.sh/docs/config/#ios-instant-notifications)
for details). for details).
@ -29,26 +55,6 @@ for details).
* iOS UI not always updating properly ([#267](https://github.com/binwiederhier/ntfy/issues/267)) * iOS UI not always updating properly ([#267](https://github.com/binwiederhier/ntfy/issues/267))
## ntfy Android app v1.14.0 (UNRELEASED)
**Additional translations:**
* Italian (thanks to [@Genio2003](https://hosted.weblate.org/user/Genio2003/))
## ntfy server v1.25.0 (UNRELEASED)
**Maintenance:**
* Upgrade Firebase Admin SDK to 4.x ([#274](https://github.com/binwiederhier/ntfy/issues/274))
**Documentation**:
* [Examples](examples.md) for [Home Assistant](https://www.home-assistant.io/) ([#282](https://github.com/binwiederhier/ntfy/pull/282), thanks to [@poblabs](https://github.com/poblabs))
-->
## ntfy server v1.24.0 ## ntfy server v1.24.0
Released May 28, 2022 Released May 28, 2022

View File

@ -6,15 +6,16 @@ import (
// Defines default config settings (excluding limits, see below) // Defines default config settings (excluding limits, see below)
const ( const (
DefaultListenHTTP = ":80" DefaultListenHTTP = ":80"
DefaultCacheDuration = 12 * time.Hour DefaultCacheDuration = 12 * time.Hour
DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!) DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!)
DefaultManagerInterval = time.Minute DefaultManagerInterval = time.Minute
DefaultAtSenderInterval = 10 * time.Second DefaultDelayedSenderInterval = 10 * time.Second
DefaultMinDelay = 10 * time.Second DefaultMinDelay = 10 * time.Second
DefaultMaxDelay = 3 * 24 * time.Hour DefaultMaxDelay = 3 * 24 * time.Hour
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs) DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
DefaultFirebaseQuotaExceededPenaltyDuration = 10 * time.Minute // Time that over-users are locked out of Firebase if it returns "quota exceeded"
) )
// Defines all global and per-visitor limits // Defines all global and per-visitor limits
@ -66,9 +67,10 @@ type Config struct {
KeepaliveInterval time.Duration KeepaliveInterval time.Duration
ManagerInterval time.Duration ManagerInterval time.Duration
WebRootIsApp bool WebRootIsApp bool
AtSenderInterval time.Duration DelayedSenderInterval time.Duration
FirebaseKeepaliveInterval time.Duration FirebaseKeepaliveInterval time.Duration
FirebasePollInterval time.Duration FirebasePollInterval time.Duration
FirebaseQuotaExceededPenaltyDuration time.Duration
UpstreamBaseURL string UpstreamBaseURL string
SMTPSenderAddr string SMTPSenderAddr string
SMTPSenderUser string SMTPSenderUser string
@ -118,9 +120,10 @@ func NewConfig() *Config {
MessageLimit: DefaultMessageLengthLimit, MessageLimit: DefaultMessageLengthLimit,
MinDelay: DefaultMinDelay, MinDelay: DefaultMinDelay,
MaxDelay: DefaultMaxDelay, MaxDelay: DefaultMaxDelay,
AtSenderInterval: DefaultAtSenderInterval, DelayedSenderInterval: DefaultDelayedSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval, FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
FirebasePollInterval: DefaultFirebasePollInterval, FirebasePollInterval: DefaultFirebasePollInterval,
FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration,
TotalTopicLimit: DefaultTotalTopicLimit, TotalTopicLimit: DefaultTotalTopicLimit,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit, VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit, VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,

View File

@ -36,7 +36,7 @@ const (
attachment_size INT NOT NULL, attachment_size INT NOT NULL,
attachment_expires INT NOT NULL, attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL, attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL, sender TEXT NOT NULL,
encoding TEXT NOT NULL, encoding TEXT NOT NULL,
published INT NOT NULL published INT NOT NULL
); );
@ -45,37 +45,37 @@ const (
COMMIT; COMMIT;
` `
insertMessageQuery = ` insertMessageQuery = `
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published) INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1` pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE topic = ? AND mid = ?` selectRowIDFromMessageID = `SELECT id FROM messages WHERE topic = ? AND mid = ?`
selectMessagesSinceTimeQuery = ` selectMessagesSinceTimeQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
FROM messages FROM messages
WHERE topic = ? AND time >= ? AND published = 1 WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceTimeIncludeScheduledQuery = ` selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
FROM messages FROM messages
WHERE topic = ? AND time >= ? WHERE topic = ? AND time >= ?
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceIDQuery = ` selectMessagesSinceIDQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
FROM messages FROM messages
WHERE topic = ? AND id > ? AND published = 1 WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceIDIncludeScheduledQuery = ` selectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
FROM messages FROM messages
WHERE topic = ? AND (id > ? OR published = 0) WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id ORDER BY time, id
` `
selectMessagesDueQuery = ` selectMessagesDueQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
FROM messages FROM messages
WHERE time <= ? AND published = 0 WHERE time <= ? AND published = 0
ORDER BY time, id ORDER BY time, id
@ -84,13 +84,13 @@ const (
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE attachment_owner = ? AND attachment_expires >= ?` selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?` selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
) )
// Schema management queries // Schema management queries
const ( const (
currentSchemaVersion = 6 currentSchemaVersion = 7
createSchemaVersionTableQuery = ` createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion ( CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY, id INT PRIMARY KEY,
@ -173,6 +173,11 @@ const (
migrate5To6AlterMessagesTableQuery = ` migrate5To6AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT(''); ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
` `
// 6 -> 7
migrate6To7AlterMessagesTableQuery = `
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
`
) )
type messageCache struct { type messageCache struct {
@ -225,7 +230,7 @@ func (c *messageCache) AddMessage(m *message) error {
} }
published := m.Time <= time.Now().Unix() published := m.Time <= time.Now().Unix()
tags := strings.Join(m.Tags, ",") tags := strings.Join(m.Tags, ",")
var attachmentName, attachmentType, attachmentURL, attachmentOwner string var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires int64 var attachmentSize, attachmentExpires int64
if m.Attachment != nil { if m.Attachment != nil {
attachmentName = m.Attachment.Name attachmentName = m.Attachment.Name
@ -233,7 +238,6 @@ func (c *messageCache) AddMessage(m *message) error {
attachmentSize = m.Attachment.Size attachmentSize = m.Attachment.Size
attachmentExpires = m.Attachment.Expires attachmentExpires = m.Attachment.Expires
attachmentURL = m.Attachment.URL attachmentURL = m.Attachment.URL
attachmentOwner = m.Attachment.Owner
} }
var actionsStr string var actionsStr string
if len(m.Actions) > 0 { if len(m.Actions) > 0 {
@ -259,7 +263,7 @@ func (c *messageCache) AddMessage(m *message) error {
attachmentSize, attachmentSize,
attachmentExpires, attachmentExpires,
attachmentURL, attachmentURL,
attachmentOwner, m.Sender,
m.Encoding, m.Encoding,
published, published,
) )
@ -371,8 +375,8 @@ func (c *messageCache) Prune(olderThan time.Time) error {
return err return err
} }
func (c *messageCache) AttachmentBytesUsed(owner string) (int64, error) { func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeQuery, owner, time.Now().Unix()) rows, err := c.db.Query(selectAttachmentsSizeQuery, sender, time.Now().Unix())
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -415,7 +419,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
for rows.Next() { for rows.Next() {
var timestamp, attachmentSize, attachmentExpires int64 var timestamp, attachmentSize, attachmentExpires int64
var priority int var priority int
var id, topic, msg, title, tagsStr, click, actionsStr, attachmentName, attachmentType, attachmentURL, attachmentOwner, encoding string var id, topic, msg, title, tagsStr, click, actionsStr, attachmentName, attachmentType, attachmentURL, sender, encoding string
err := rows.Scan( err := rows.Scan(
&id, &id,
&timestamp, &timestamp,
@ -431,7 +435,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
&attachmentSize, &attachmentSize,
&attachmentExpires, &attachmentExpires,
&attachmentURL, &attachmentURL,
&attachmentOwner, &sender,
&encoding, &encoding,
) )
if err != nil { if err != nil {
@ -455,7 +459,6 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Size: attachmentSize, Size: attachmentSize,
Expires: attachmentExpires, Expires: attachmentExpires,
URL: attachmentURL, URL: attachmentURL,
Owner: attachmentOwner,
} }
} }
messages = append(messages, &message{ messages = append(messages, &message{
@ -470,6 +473,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Click: click, Click: click,
Actions: actions, Actions: actions,
Attachment: att, Attachment: att,
Sender: sender,
Encoding: encoding, Encoding: encoding,
}) })
} }
@ -516,6 +520,8 @@ func setupCacheDB(db *sql.DB) error {
return migrateFrom4(db) return migrateFrom4(db)
} else if schemaVersion == 5 { } else if schemaVersion == 5 {
return migrateFrom5(db) return migrateFrom5(db)
} else if schemaVersion == 6 {
return migrateFrom6(db)
} }
return fmt.Errorf("unexpected schema version found: %d", schemaVersion) return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
} }
@ -599,5 +605,16 @@ func migrateFrom5(db *sql.DB) error {
if _, err := db.Exec(updateSchemaVersion, 6); err != nil { if _, err := db.Exec(updateSchemaVersion, 6); err != nil {
return err return err
} }
return migrateFrom6(db)
}
func migrateFrom6(db *sql.DB) error {
log.Print("Migrating cache database schema: from 6 to 7")
if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 7); err != nil {
return err
}
return nil // Update this when a new version is added return nil // Update this when a new version is added
} }

View File

@ -281,39 +281,39 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() expires1 := time.Now().Add(-4 * time.Hour).Unix()
m := newDefaultMessage("mytopic", "flower for you") m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1" m.ID = "m1"
m.Sender = "1.2.3.4"
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "flower.jpg", Name: "flower.jpg",
Type: "image/jpeg", Type: "image/jpeg",
Size: 5000, Size: 5000,
Expires: expires1, Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg", URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
Owner: "1.2.3.4",
} }
require.Nil(t, c.AddMessage(m)) require.Nil(t, c.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car") m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2" m.ID = "m2"
m.Sender = "1.2.3.4"
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "car.jpg", Name: "car.jpg",
Type: "image/jpeg", Type: "image/jpeg",
Size: 10000, Size: 10000,
Expires: expires2, Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg", URL: "https://ntfy.sh/file/aCaRURL.jpg",
Owner: "1.2.3.4",
} }
require.Nil(t, c.AddMessage(m)) require.Nil(t, c.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car") m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3" m.ID = "m3"
m.Sender = "1.2.3.4"
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "another-car.jpg", Name: "another-car.jpg",
Type: "image/jpeg", Type: "image/jpeg",
Size: 20000, Size: 20000,
Expires: expires3, Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg", URL: "https://ntfy.sh/file/zakaDHFW.jpg",
Owner: "1.2.3.4",
} }
require.Nil(t, c.AddMessage(m)) require.Nil(t, c.AddMessage(m))
@ -327,7 +327,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, int64(5000), messages[0].Attachment.Size) require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires) require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Attachment.Owner) require.Equal(t, "1.2.3.4", messages[0].Sender)
require.Equal(t, "sending you a car", messages[1].Message) require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name) require.Equal(t, "car.jpg", messages[1].Attachment.Name)
@ -335,7 +335,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, int64(10000), messages[1].Attachment.Size) require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires) require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Attachment.Owner) require.Equal(t, "1.2.3.4", messages[1].Sender)
size, err := c.AttachmentBytesUsed("1.2.3.4") size, err := c.AttachmentBytesUsed("1.2.3.4")
require.Nil(t, err) require.Nil(t, err)

View File

@ -7,13 +7,11 @@ import (
"embed" "embed"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"os" "os"
"path" "path"
@ -34,22 +32,22 @@ import (
// Server is the main server, providing the UI and API for ntfy // Server is the main server, providing the UI and API for ntfy
type Server struct { type Server struct {
config *Config config *Config
httpServer *http.Server httpServer *http.Server
httpsServer *http.Server httpsServer *http.Server
unixListener net.Listener unixListener net.Listener
smtpServer *smtp.Server smtpServer *smtp.Server
smtpBackend *smtpBackend smtpBackend *smtpBackend
topics map[string]*topic topics map[string]*topic
visitors map[string]*visitor visitors map[string]*visitor
firebase subscriber firebaseClient *firebaseClient
mailer mailer mailer mailer
messages int64 messages int64
auth auth.Auther auth auth.Auther
messageCache *messageCache messageCache *messageCache
fileCache *fileCache fileCache *fileCache
closeChan chan bool closeChan chan bool
mu sync.Mutex mu sync.Mutex
} }
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@ -136,23 +134,23 @@ func New(conf *Config) (*Server, error) {
return nil, err return nil, err
} }
} }
var firebaseSubscriber subscriber var firebaseClient *firebaseClient
if conf.FirebaseKeyFile != "" { if conf.FirebaseKeyFile != "" {
var err error sender, err := newFirebaseSender(conf.FirebaseKeyFile)
firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther)
if err != nil { if err != nil {
return nil, err return nil, err
} }
firebaseClient = newFirebaseClient(sender, auther)
} }
return &Server{ return &Server{
config: conf, config: conf,
messageCache: messageCache, messageCache: messageCache,
fileCache: fileCache, fileCache: fileCache,
firebase: firebaseSubscriber, firebaseClient: firebaseClient,
mailer: mailer, mailer: mailer,
topics: topics, topics: topics,
auth: auther, auth: auther,
visitors: make(map[string]*visitor), visitors: make(map[string]*visitor),
}, nil }, nil
} }
@ -221,7 +219,7 @@ func (s *Server) Run() error {
} }
s.mu.Unlock() s.mu.Unlock()
go s.runManager() go s.runManager()
go s.runDelaySender() go s.runDelayedSender()
go s.runFirebaseKeepaliver() go s.runFirebaseKeepaliver()
return <-errChan return <-errChan
@ -439,17 +437,17 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
log.Debug("[%s] %s %s: ev=%s, body=%d bytes, delayed=%t, fb=%t, cache=%t, up=%t, email=%s", log.Debug("[%s] %s %s: ev=%s, body=%d bytes, delayed=%t, fb=%t, cache=%t, up=%t, email=%s",
v.ip, r.Method, r.URL.Path, m.Event, len(body.PeekedBytes), delayed, firebase, cache, unifiedpush, email) v.ip, r.Method, r.URL.Path, m.Event, len(body.PeekedBytes), delayed, firebase, cache, unifiedpush, email)
if !delayed { if !delayed {
if err := t.Publish(m); err != nil { if err := t.Publish(v, m); err != nil {
return err return err
} }
} }
if s.firebase != nil && firebase && !delayed { if s.firebaseClient != nil && firebase && !delayed {
go s.sendToFirebase(v, m) go s.sendToFirebase(v, m)
} }
if s.mailer != nil && email != "" && !delayed { if s.mailer != nil && email != "" && !delayed {
go s.sendEmail(v, m, email) go s.sendEmail(v, m, email)
} }
if s.config.UpstreamBaseURL != "" { if s.config.UpstreamBaseURL != "" && !delayed {
go s.forwardPollRequest(v, m) go s.forwardPollRequest(v, m)
} }
if cache { if cache {
@ -469,7 +467,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
} }
func (s *Server) sendToFirebase(v *visitor, m *message) { func (s *Server) sendToFirebase(v *visitor, m *message) {
if err := s.firebase(m); err != nil { if err := s.firebaseClient.Send(v, m); err != nil {
log.Warn("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error()) log.Warn("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
} }
} }
@ -490,7 +488,10 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
return return
} }
req.Header.Set("X-Poll-ID", m.ID) req.Header.Set("X-Poll-ID", m.ID)
response, err := http.DefaultClient.Do(req) var httpClient = &http.Client{
Timeout: time.Second * 10,
}
response, err := httpClient.Do(req)
if err != nil { if err != nil {
log.Warn("[%s] FWD - Unable to forward poll request: %v", v.ip, err.Error()) log.Warn("[%s] FWD - Unable to forward poll request: %v", v.ip, err.Error())
return return
@ -572,6 +573,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
return false, false, "", false, errHTTPBadRequestDelayTooLarge return false, false, "", false, errHTTPBadRequestDelayTooLarge
} }
m.Time = delay.Unix() m.Time = delay.Unix()
m.Sender = v.ip // Important for rate limiting
} }
actionsStr := readParam(r, "x-actions", "actions", "action") actionsStr := readParam(r, "x-actions", "actions", "action")
if actionsStr != "" { if actionsStr != "" {
@ -667,7 +669,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
m.Attachment = &attachment{} m.Attachment = &attachment{}
} }
var ext string var ext string
m.Attachment.Owner = v.ip // Important for attachment rate limiting m.Sender = v.ip // Important for attachment rate limiting
m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix() m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix()
m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name) m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext) m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
@ -735,7 +737,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return err return err
} }
var wlock sync.Mutex var wlock sync.Mutex
sub := func(msg *message) error { sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) { if !filters.Pass(msg) {
return nil return nil
} }
@ -756,7 +758,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll { if poll {
return s.sendOldMessages(topics, since, scheduled, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)
for _, t := range topics { for _, t := range topics {
@ -767,10 +769,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
topics[i].Unsubscribe(subscriberID) // Order! topics[i].Unsubscribe(subscriberID) // Order!
} }
}() }()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err return err
} }
for { for {
@ -779,7 +781,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
v.Keepalive() v.Keepalive()
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err return err
} }
} }
@ -853,7 +855,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
} }
}) })
sub := func(msg *message) error { sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) { if !filters.Pass(msg) {
return nil return nil
} }
@ -866,7 +868,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
if poll { if poll {
return s.sendOldMessages(topics, since, scheduled, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)
for _, t := range topics { for _, t := range topics {
@ -877,10 +879,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
topics[i].Unsubscribe(subscriberID) // Order! topics[i].Unsubscribe(subscriberID) // Order!
} }
}() }()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err return err
} }
err = g.Wait() err = g.Wait()
@ -904,7 +906,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
return return
} }
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error { func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
if since.IsNone() { if since.IsNone() {
return nil return nil
} }
@ -914,7 +916,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
return err return err
} }
for _, m := range messages { for _, m := range messages {
if err := sub(m); err != nil { if err := sub(v, m); err != nil {
return err return err
} }
} }
@ -1061,23 +1063,7 @@ func (s *Server) updateStatsAndPrune() {
} }
func (s *Server) runSMTPServer() error { func (s *Server) runSMTPServer() error {
sub := func(m *message) error { s.smtpBackend = newMailBackend(s.config, s.handle)
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.handle(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
s.smtpBackend = newMailBackend(s.config, sub)
s.smtpServer = smtp.NewServer(s.smtpBackend) s.smtpServer = smtp.NewServer(s.smtpBackend)
s.smtpServer.Addr = s.config.SMTPServerListen s.smtpServer.Addr = s.config.SMTPServerListen
s.smtpServer.Domain = s.config.SMTPServerDomain s.smtpServer.Domain = s.config.SMTPServerDomain
@ -1100,10 +1086,10 @@ func (s *Server) runManager() {
} }
} }
func (s *Server) runDelaySender() { func (s *Server) runDelayedSender() {
for { for {
select { select {
case <-time.After(s.config.AtSenderInterval): case <-time.After(s.config.DelayedSenderInterval):
if err := s.sendDelayedMessages(); err != nil { if err := s.sendDelayedMessages(); err != nil {
log.Warn("error sending scheduled messages: %s", err.Error()) log.Warn("error sending scheduled messages: %s", err.Error())
} }
@ -1114,19 +1100,16 @@ func (s *Server) runDelaySender() {
} }
func (s *Server) runFirebaseKeepaliver() { func (s *Server) runFirebaseKeepaliver() {
if s.firebase == nil { if s.firebaseClient == nil {
return return
} }
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
for { for {
select { select {
case <-time.After(s.config.FirebaseKeepaliveInterval): case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil { s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
log.Info("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
}
case <-time.After(s.config.FirebasePollInterval): case <-time.After(s.config.FirebasePollInterval):
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil { s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
log.Info("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
}
case <-s.closeChan: case <-s.closeChan:
return return
} }
@ -1134,27 +1117,39 @@ func (s *Server) runFirebaseKeepaliver() {
} }
func (s *Server) sendDelayedMessages() error { func (s *Server) sendDelayedMessages() error {
s.mu.Lock()
defer s.mu.Unlock()
messages, err := s.messageCache.MessagesDue() messages, err := s.messageCache.MessagesDue()
if err != nil { if err != nil {
return err return err
} }
for _, m := range messages { for _, m := range messages {
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published v := s.visitorFromIP(m.Sender)
if ok { if err := s.sendDelayedMessage(v, m); err != nil {
if err := t.Publish(m); err != nil { log.Warn("error sending delayed message: %s", err.Error())
log.Info("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error()) }
}
return nil
}
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok {
go func() {
// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
if err := t.Publish(v, m); err != nil {
log.Warn("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
} }
} }()
if s.firebase != nil { // Firebase subscribers may not show up in topics map }
if err := s.firebase(m); err != nil { if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
log.Info("unable to publish to Firebase: %v", err.Error()) go s.sendToFirebase(v, m)
} }
} if s.config.UpstreamBaseURL != "" {
if err := s.messageCache.MarkPublished(m); err != nil { go s.forwardPollRequest(v, m)
return err }
} if err := s.messageCache.MarkPublished(m); err != nil {
return err
} }
return nil return nil
} }
@ -1294,8 +1289,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// visitor creates or retrieves a rate.Limiter for the given visitor. // visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor { func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
remoteAddr := r.RemoteAddr remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr) ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
@ -1304,6 +1297,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" { if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For") ip = r.Header.Get("X-Forwarded-For")
} }
return s.visitorFromIP(ip)
}
func (s *Server) visitorFromIP(ip string) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[ip] v, exists := s.visitors[ip]
if !exists { if !exists {
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip) s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)

View File

@ -3,7 +3,9 @@ package server
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"log"
"strings" "strings"
firebase "firebase.google.com/go/v4" firebase "firebase.google.com/go/v4"
@ -17,25 +19,75 @@ const (
fcmApnsBodyMessageLimit = 100 fcmApnsBodyMessageLimit = 100
) )
func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subscriber, error) { var (
errFirebaseQuotaExceeded = errors.New("quota exceeded for Firebase messages to topic")
)
// firebaseClient is a generic client that formats and sends messages to Firebase.
// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
type firebaseClient struct {
sender firebaseSender
auther auth.Auther
}
func newFirebaseClient(sender firebaseSender, auther auth.Auther) *firebaseClient {
return &firebaseClient{
sender: sender,
auther: auther,
}
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
return errFirebaseQuotaExceeded
}
fbm, err := toFirebaseMessage(m, c.auther)
if err != nil {
return err
}
err = c.sender.Send(fbm)
if err == errFirebaseQuotaExceeded {
log.Printf("[%s] FB quota exceeded for topic %s, temporarily denying FB access to visitor", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
}
return err
}
// firebaseSender is an interface that represents a client that can send to Firebase Cloud Messaging.
// In tests, this can be implemented with a mock.
type firebaseSender interface {
// Send sends a message to Firebase, or returns an error. It returns errFirebaseQuotaExceeded
// if a rate limit has reached.
Send(m *messaging.Message) error
}
// firebaseSenderImpl is a firebaseSender that actually talks to Firebase
type firebaseSenderImpl struct {
client *messaging.Client
}
func newFirebaseSender(credentialsFile string) (*firebaseSenderImpl, error) {
fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile)) fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile))
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg, err := fb.Messaging(context.Background()) client, err := fb.Messaging(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return func(m *message) error { return &firebaseSenderImpl{
fbm, err := toFirebaseMessage(m, auther) client: client,
if err != nil {
return err
}
_, err = msg.Send(context.Background(), fbm)
return err
}, nil }, nil
} }
func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
_, err := c.client.Send(context.Background(), m)
if err != nil && messaging.IsQuotaExceeded(err) {
return errFirebaseQuotaExceeded
}
return err
}
// toFirebaseMessage converts a message to a Firebase message. // toFirebaseMessage converts a message to a Firebase message.
// //
// Normal messages ("message"): // Normal messages ("message"):

View File

@ -26,6 +26,25 @@ func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error {
return errors.New("unauthorized") return errors.New("unauthorized")
} }
type testFirebaseSender struct {
allowed int
messages []*messaging.Message
}
func newTestFirebaseSender(allowed int) *testFirebaseSender {
return &testFirebaseSender{
allowed: allowed,
messages: make([]*messaging.Message, 0),
}
}
func (s *testFirebaseSender) Send(m *messaging.Message) error {
if len(s.messages)+1 > s.allowed {
return errFirebaseQuotaExceeded
}
s.messages = append(s.messages, m)
return nil
}
func TestToFirebaseMessage_Keepalive(t *testing.T) { func TestToFirebaseMessage_Keepalive(t *testing.T) {
m := newKeepaliveMessage("mytopic") m := newKeepaliveMessage("mytopic")
fbm, err := toFirebaseMessage(m, nil) fbm, err := toFirebaseMessage(m, nil)
@ -119,7 +138,6 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
Size: 12345, Size: 12345,
Expires: 98765543, Expires: 98765543,
URL: "https://example.com/file.jpg", URL: "https://example.com/file.jpg",
Owner: "some-owner",
} }
fbm, err := toFirebaseMessage(m, &testAuther{Allow: true}) fbm, err := toFirebaseMessage(m, &testAuther{Allow: true})
require.Nil(t, err) require.Nil(t, err)
@ -286,3 +304,22 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage)) require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"]) require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
} }
func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.messages))
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.messages))
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.messages))
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 0, len(sender.messages))
}

View File

@ -9,7 +9,6 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
@ -55,6 +54,21 @@ func TestServer_PublishAndPoll(t *testing.T) {
require.Equal(t, "my second message", lines[1]) // \n -> " " require.Equal(t, "my second message", lines[1]) // \n -> " "
} }
func TestServer_PublishWithFirebase(t *testing.T) {
sender := newTestFirebaseSender(10)
s := newTestServer(t, newTestConfig(t))
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
response := request(t, s, "PUT", "/mytopic", "my first message", nil)
msg1 := toMessage(t, response.Body.String())
require.NotEmpty(t, msg1.ID)
require.Equal(t, "my first message", msg1.Message)
require.Equal(t, 1, len(sender.messages))
require.Equal(t, "my first message", sender.messages[0].Data["message"])
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.Aps.Alert.Body)
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.CustomData["message"])
}
func TestServer_SubscribeOpenAndKeepalive(t *testing.T) { func TestServer_SubscribeOpenAndKeepalive(t *testing.T) {
c := newTestConfig(t) c := newTestConfig(t)
c.KeepaliveInterval = time.Second c.KeepaliveInterval = time.Second
@ -264,7 +278,7 @@ func TestServer_PublishNoCache(t *testing.T) {
func TestServer_PublishAt(t *testing.T) { func TestServer_PublishAt(t *testing.T) {
c := newTestConfig(t) c := newTestConfig(t)
c.MinDelay = time.Second c.MinDelay = time.Second
c.AtSenderInterval = 100 * time.Millisecond c.DelayedSenderInterval = 100 * time.Millisecond
s := newTestServer(t, c) s := newTestServer(t, c)
response := request(t, s, "PUT", "/mytopic", "a message", map[string]string{ response := request(t, s, "PUT", "/mytopic", "a message", map[string]string{
@ -283,6 +297,13 @@ func TestServer_PublishAt(t *testing.T) {
messages = toMessages(t, response.Body.String()) messages = toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages)) require.Equal(t, 1, len(messages))
require.Equal(t, "a message", messages[0].Message) require.Equal(t, "a message", messages[0].Message)
require.Equal(t, "", messages[0].Sender) // Never return the sender!
messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "a message", messages[0].Message)
require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though!
} }
func TestServer_PublishAtWithCacheError(t *testing.T) { func TestServer_PublishAtWithCacheError(t *testing.T) {
@ -454,26 +475,6 @@ func TestServer_PublishMessageInHeaderWithNewlines(t *testing.T) {
require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n ! require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n !
} }
func TestServer_PublishFirebase(t *testing.T) {
// This is unfortunately not much of a test, since it merely fires the messages towards Firebase,
// but cannot re-read them. There is no way from Go to read the messages back, or even get an error back.
// I tried everything. I already had written the test, and it increases the code coverage, so I'll leave it ... :shrug: ...
c := newTestConfig(t)
c.FirebaseKeyFile = firebaseServiceAccountFile(t) // May skip the test!
s := newTestServer(t, c)
// Normal message
response := request(t, s, "PUT", "/mytopic", "This is a message for firebase", nil)
msg := toMessage(t, response.Body.String())
require.NotEmpty(t, msg.ID)
// Keepalive message
require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic)))
time.Sleep(500 * time.Millisecond) // Time for sends
}
func TestServer_PublishInvalidTopic(t *testing.T) { func TestServer_PublishInvalidTopic(t *testing.T) {
s := newTestServer(t, newTestConfig(t)) s := newTestServer(t, newTestConfig(t))
s.mailer = &testMailer{} s.mailer = &testMailer{}
@ -1018,7 +1019,7 @@ func TestServer_PublishAttachment(t *testing.T) {
require.Equal(t, int64(5000), msg.Attachment.Size) require.Equal(t, int64(5000), msg.Attachment.Size)
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.Equal(t, "", msg.Attachment.Owner) // Should never be returned require.Equal(t, "", msg.Sender) // Should never be returned
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
@ -1047,7 +1048,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
require.Equal(t, int64(21), msg.Attachment.Size) require.Equal(t, int64(21), msg.Attachment.Size)
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix()) require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.Equal(t, "", msg.Attachment.Owner) // Should never be returned require.Equal(t, "", msg.Sender) // Should never be returned
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
@ -1074,7 +1075,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
require.Equal(t, "", msg.Attachment.Type) require.Equal(t, "", msg.Attachment.Type)
require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Size)
require.Equal(t, int64(0), msg.Attachment.Expires) require.Equal(t, int64(0), msg.Attachment.Expires)
require.Equal(t, "", msg.Attachment.Owner) require.Equal(t, "", msg.Sender)
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments // Slightly unrelated cross-test: make sure we don't add an owner for external attachments
size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
@ -1095,7 +1096,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) {
require.Equal(t, "", msg.Attachment.Type) require.Equal(t, "", msg.Attachment.Type)
require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Size)
require.Equal(t, int64(0), msg.Attachment.Expires) require.Equal(t, int64(0), msg.Attachment.Expires)
require.Equal(t, "", msg.Attachment.Owner) require.Equal(t, "", msg.Sender)
} }
func TestServer_PublishAttachmentBadURL(t *testing.T) { func TestServer_PublishAttachmentBadURL(t *testing.T) {
@ -1333,18 +1334,6 @@ func toHTTPError(t *testing.T, s string) *errHTTP {
return &e return &e
} }
func firebaseServiceAccountFile(t *testing.T) string {
if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
} else if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT") != "" {
filename := filepath.Join(t.TempDir(), "firebase.json")
require.NotNil(t, os.WriteFile(filename, []byte(os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT")), 0o600))
return filename
}
t.SkipNow()
return ""
}
func basicAuth(s string) string { func basicAuth(s string) string {
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s))) return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
} }

View File

@ -3,10 +3,13 @@ package server
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"io" "io"
"mime" "mime"
"mime/multipart" "mime/multipart"
"net/http"
"net/http/httptest"
"net/mail" "net/mail"
"strings" "strings"
"sync" "sync"
@ -23,25 +26,25 @@ var (
// smtpBackend implements SMTP server methods. // smtpBackend implements SMTP server methods.
type smtpBackend struct { type smtpBackend struct {
config *Config config *Config
sub subscriber handler func(http.ResponseWriter, *http.Request)
success int64 success int64
failure int64 failure int64
mu sync.Mutex mu sync.Mutex
} }
func newMailBackend(conf *Config, sub subscriber) *smtpBackend { func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
return &smtpBackend{ return &smtpBackend{
config: conf, config: conf,
sub: sub, handler: handler,
} }
} }
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
return &smtpSession{backend: b}, nil return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
} }
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
return &smtpSession{backend: b}, nil return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
} }
func (b *smtpBackend) Counts() (success int64, failure int64) { func (b *smtpBackend) Counts() (success int64, failure int64) {
@ -52,9 +55,10 @@ func (b *smtpBackend) Counts() (success int64, failure int64) {
// smtpSession is returned after EHLO. // smtpSession is returned after EHLO.
type smtpSession struct { type smtpSession struct {
backend *smtpBackend backend *smtpBackend
topic string remoteAddr string
mu sync.Mutex topic string
mu sync.Mutex
} }
func (s *smtpSession) AuthPlain(username, password string) error { func (s *smtpSession) AuthPlain(username, password string) error {
@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error {
m.Message = m.Title // Flip them, this makes more sense m.Message = m.Title // Flip them, this makes more sense
m.Title = "" m.Title = ""
} }
if err := s.backend.sub(m); err != nil { if err := s.publishMessage(m); err != nil {
return err return err
} }
s.backend.mu.Lock() s.backend.mu.Lock()
@ -138,6 +142,25 @@ func (s *smtpSession) Data(r io.Reader) error {
}) })
} }
func (s *smtpSession) publishMessage(m *message) error {
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
req.RemoteAddr = s.remoteAddr // rate limiting!!
req.Header.Set("X-Forwarded-For", s.remoteAddr)
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.backend.handler(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
func (s *smtpSession) Reset() { func (s *smtpSession) Reset() {
s.mu.Lock() s.mu.Lock()
s.topic = "" s.topic = ""

View File

@ -3,6 +3,9 @@ package server
import ( import (
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io"
"net"
"net/http"
"strings" "strings"
"testing" "testing"
) )
@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div> <div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div>
--000000000000f3320b05d42915c9--` --000000000000f3320b05d42915c9--`
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", m.Title) require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr"><br></div> <div dir="ltr"><br></div>
--000000000000bcf4a405d429f8d4--` --000000000000bcf4a405d429f8d4--`
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "emailtest", m.Topic) require.Equal(t, "/emailtest", r.URL.Path)
require.Equal(t, "", m.Title) // We flipped message and body require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
require.Equal(t, "This email has a subject but no body", m.Message) require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8"
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", m.Title) require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "Very short mail", m.Title) require.Equal(t, "Very short mail", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8"
what's up what's up
` `
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "Three santas 🎅🎅🎅", m.Title) require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -204,7 +202,7 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
that should do it that should do it
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
expected := `you know this is a string. expected := `you know this is a string.
it's a long string. it's a long string.
it's supposed to be longer than the max message length it's supposed to be longer than the max message length
@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
...................................................................... ......................................................................
...................................................................... ......................................................................
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBB` BBBBBBBBBBBBBBBBBBBBBBBBB`
require.Equal(t, 4096, len(expected)) // Sanity check require.Equal(t, 4096, len(expected)) // Sanity check
require.Equal(t, expected, m.Message) require.Equal(t, expected, readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
return nil // Nothing.
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.Login(nil, "user", "pass") session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email))) require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
} }
func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) { func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
conf := newTestConfig(t) conf := newTestConfig(t)
conf.SMTPServerListen = ":25" conf.SMTPServerListen = ":25"
conf.SMTPServerDomain = "ntfy.sh" conf.SMTPServerDomain = "ntfy.sh"
conf.SMTPServerAddrPrefix = "ntfy-" conf.SMTPServerAddrPrefix = "ntfy-"
backend := newMailBackend(conf, sub) backend := newMailBackend(conf, handler)
return conf, backend return conf, backend
} }
func readAll(t *testing.T, rc io.ReadCloser) string {
b, err := io.ReadAll(rc)
if err != nil {
t.Fatal(err)
}
return string(b)
}
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
ip, err := net.ResolveIPAddr("ip", remoteAddr)
if err != nil {
t.Fatal(err)
}
return &smtp.ConnectionState{
Hostname: "myhostname",
LocalAddr: ip,
RemoteAddr: ip,
}
}

View File

@ -15,7 +15,7 @@ type topic struct {
} }
// subscriber is a function that is called for every new message on a topic // subscriber is a function that is called for every new message on a topic
type subscriber func(msg *message) error type subscriber func(v *visitor, msg *message) error
// newTopic creates a new topic // newTopic creates a new topic
func newTopic(id string) *topic { func newTopic(id string) *topic {
@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) {
} }
// Publish asynchronously publishes to all subscribers // Publish asynchronously publishes to all subscribers
func (t *topic) Publish(m *message) error { func (t *topic) Publish(v *visitor, m *message) error {
go func() { go func() {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
for _, s := range t.subscribers { for _, s := range t.subscribers {
if err := s(m); err != nil { if err := s(v, m); err != nil {
log.Printf("error publishing message to subscriber") log.Printf("error publishing message to subscriber")
} }
} }

View File

@ -32,6 +32,7 @@ type message struct {
Actions []*action `json:"actions,omitempty"` Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"` Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"` PollID string `json:"poll_id,omitempty"`
Sender string `json:"-"` // IP address of uploader, used for rate limiting
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
} }
@ -41,7 +42,6 @@ type attachment struct {
Size int64 `json:"size,omitempty"` Size int64 `json:"size,omitempty"`
Expires int64 `json:"expires,omitempty"` Expires int64 `json:"expires,omitempty"`
URL string `json:"url"` URL string `json:"url"`
Owner string `json:"-"` // IP address of uploader, used for rate limiting
} }
type action struct { type action struct {

View File

@ -28,6 +28,7 @@ type visitor struct {
emails *rate.Limiter emails *rate.Limiter
subscriptions util.Limiter subscriptions util.Limiter
bandwidth util.Limiter bandwidth util.Limiter
firebase time.Time // Next allowed Firebase message
seen time.Time seen time.Time
mu sync.Mutex mu sync.Mutex
} }
@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour), bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
firebase: time.Unix(0, 0),
seen: time.Now(), seen: time.Now(),
} }
} }
func (v *visitor) IP() string {
return v.ip
}
func (v *visitor) RequestAllowed() error { func (v *visitor) RequestAllowed() error {
if !v.requests.Allow() { if !v.requests.Allow() {
return errVisitorLimitReached return errVisitorLimitReached
@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error {
return nil return nil
} }
func (v *visitor) FirebaseAllowed() error {
v.mu.Lock()
defer v.mu.Unlock()
if time.Now().Before(v.firebase) {
return errVisitorLimitReached
}
return nil
}
func (v *visitor) FirebaseTemporarilyDeny() {
v.mu.Lock()
defer v.mu.Unlock()
v.firebase = time.Now().Add(v.config.FirebaseQuotaExceededPenaltyDuration)
}
func (v *visitor) EmailAllowed() error { func (v *visitor) EmailAllowed() error {
if !v.emails.Allow() { if !v.emails.Allow() {
return errVisitorLimitReached return errVisitorLimitReached

716
web/package-lock.json generated

File diff suppressed because it is too large Load Diff