Add firebase support
This commit is contained in:
		
							parent
							
								
									4677e724ee
								
							
						
					
					
						commit
						b145e693a5
					
				
					 7 changed files with 293 additions and 161 deletions
				
			
		
							
								
								
									
										27
									
								
								cmd/app.go
									
										
									
									
									
								
							
							
						
						
									
										27
									
								
								cmd/app.go
									
										
									
									
									
								
							|  | @ -8,8 +8,10 @@ import ( | |||
| 	"github.com/urfave/cli/v2/altsrc" | ||||
| 	"heckel.io/ntfy/config" | ||||
| 	"heckel.io/ntfy/server" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // New creates a new CLI application | ||||
|  | @ -18,7 +20,9 @@ func New() *cli.App { | |||
| 		&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: "/etc/ntfy/config.yml", DefaultText: "/etc/ntfy/config.yml", Usage: "config file"}, | ||||
| 		altsrc.NewStringFlag(&cli.StringFlag{Name: "listen-http", Aliases: []string{"l"}, EnvVars: []string{"NTFY_LISTEN_HTTP"}, Value: config.DefaultListenHTTP, Usage: "ip:port used to as listen address"}), | ||||
| 		altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{Name: "message-buffer-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_MESSAGE_BUFFER_DURATION"}, Value: config.DefaultMessageBufferDuration, Usage: "buffer messages in memory for this time to allow `since` requests"}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}), | ||||
| 	} | ||||
| 	return &cli.App{ | ||||
| 		Name:                   "ntfy", | ||||
|  | @ -41,17 +45,27 @@ func execRun(c *cli.Context) error { | |||
| 	// Read all the options | ||||
| 	listenHTTP := c.String("listen-http") | ||||
| 	firebaseKeyFile := c.String("firebase-key-file") | ||||
| 	messageBufferDuration := c.Duration("message-buffer-duration") | ||||
| 	keepaliveInterval := c.Duration("keepalive-interval") | ||||
| 	managerInterval := c.Duration("manager-interval") | ||||
| 
 | ||||
| 	// Check values | ||||
| 	if firebaseKeyFile != "" && !fileExists(firebaseKeyFile) { | ||||
| 	if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { | ||||
| 		return errors.New("if set, FCM key file must exist") | ||||
| 	} else if keepaliveInterval < 5*time.Second { | ||||
| 		return errors.New("keepalive interval cannot be lower than five seconds") | ||||
| 	} else if managerInterval < 5*time.Second { | ||||
| 		return errors.New("manager interval cannot be lower than five seconds") | ||||
| 	} else if messageBufferDuration < managerInterval { | ||||
| 		return errors.New("message buffer duration cannot be lower than manager interval") | ||||
| 	} | ||||
| 
 | ||||
| 	// Run main bot, can be killed by signal | ||||
| 	// Run server | ||||
| 	conf := config.New(listenHTTP) | ||||
| 	conf.FirebaseKeyFile = firebaseKeyFile | ||||
| 	conf.MessageBufferDuration = messageBufferDuration | ||||
| 	conf.KeepaliveInterval = keepaliveInterval | ||||
| 	conf.ManagerInterval = managerInterval | ||||
| 	s, err := server.New(conf) | ||||
| 	if err != nil { | ||||
| 		log.Fatalln(err) | ||||
|  | @ -68,9 +82,9 @@ func execRun(c *cli.Context) error { | |||
| func initConfigFileInputSource(configFlag string, flags []cli.Flag) cli.BeforeFunc { | ||||
| 	return func(context *cli.Context) error { | ||||
| 		configFile := context.String(configFlag) | ||||
| 		if context.IsSet(configFlag) && !fileExists(configFile) { | ||||
| 		if context.IsSet(configFlag) && !util.FileExists(configFile) { | ||||
| 			return fmt.Errorf("config file %s does not exist", configFile) | ||||
| 		} else if !context.IsSet(configFlag) && !fileExists(configFile) { | ||||
| 		} else if !context.IsSet(configFlag) && !util.FileExists(configFile) { | ||||
| 			return nil | ||||
| 		} | ||||
| 		inputSource, err := altsrc.NewYamlSourceFromFile(configFile) | ||||
|  | @ -80,8 +94,3 @@ func initConfigFileInputSource(configFlag string, flags []cli.Flag) cli.BeforeFu | |||
| 		return altsrc.ApplyInputSourceValues(context, inputSource, flags) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func fileExists(filename string) bool { | ||||
| 	stat, _ := os.Stat(filename) | ||||
| 	return stat != nil | ||||
| } | ||||
|  |  | |||
|  | @ -8,9 +8,10 @@ import ( | |||
| 
 | ||||
| // Defines default config settings | ||||
| const ( | ||||
| 	DefaultListenHTTP        = ":80" | ||||
| 	DefaultKeepaliveInterval = 30 * time.Second | ||||
| 	defaultManagerInterval   = time.Minute | ||||
| 	DefaultListenHTTP            = ":80" | ||||
| 	DefaultMessageBufferDuration = 12 * time.Hour | ||||
| 	DefaultKeepaliveInterval     = 30 * time.Second | ||||
| 	DefaultManagerInterval       = time.Minute | ||||
| ) | ||||
| 
 | ||||
| // Defines the max number of requests, here: | ||||
|  | @ -22,22 +23,24 @@ var ( | |||
| 
 | ||||
| // Config is the main config struct for the application. Use New to instantiate a default config struct. | ||||
| type Config struct { | ||||
| 	ListenHTTP        string | ||||
| 	Limit             rate.Limit | ||||
| 	LimitBurst        int | ||||
| 	FirebaseKeyFile   string | ||||
| 	KeepaliveInterval time.Duration | ||||
| 	ManagerInterval   time.Duration | ||||
| 	ListenHTTP            string | ||||
| 	FirebaseKeyFile       string | ||||
| 	MessageBufferDuration time.Duration | ||||
| 	KeepaliveInterval     time.Duration | ||||
| 	ManagerInterval       time.Duration | ||||
| 	Limit                 rate.Limit | ||||
| 	LimitBurst            int | ||||
| } | ||||
| 
 | ||||
| // New instantiates a default new config | ||||
| func New(listenHTTP string) *Config { | ||||
| 	return &Config{ | ||||
| 		ListenHTTP:        listenHTTP, | ||||
| 		Limit:             defaultLimit, | ||||
| 		LimitBurst:        defaultLimitBurst, | ||||
| 		FirebaseKeyFile:   "", | ||||
| 		KeepaliveInterval: DefaultKeepaliveInterval, | ||||
| 		ManagerInterval:   defaultManagerInterval, | ||||
| 		ListenHTTP:            listenHTTP, | ||||
| 		FirebaseKeyFile:       "", | ||||
| 		MessageBufferDuration: DefaultMessageBufferDuration, | ||||
| 		KeepaliveInterval:     DefaultKeepaliveInterval, | ||||
| 		ManagerInterval:       DefaultManagerInterval, | ||||
| 		Limit:                 defaultLimit, | ||||
| 		LimitBurst:            defaultLimitBurst, | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -38,14 +38,31 @@ | |||
|     </p> | ||||
|     <p id="error"></p> | ||||
| 
 | ||||
|     <h2>Publishing messages</h2> | ||||
|     <p> | ||||
|         Publishing messages can be done via PUT or POST using. Topics are created on the fly by subscribing or publishing to them. | ||||
|         Because there is no sign-up, <b>the topic is essentially a password</b>, so pick something that's not easily guessable. | ||||
|     </p> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Here's an example showing how to publish a message using <tt>curl</tt>: | ||||
|     </p> | ||||
|     <code> | ||||
|         curl -d "long process is done" ntfy.sh/mytopic | ||||
|     </code> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Here's an example in JS with <tt>fetch()</tt> (see <a href="https://github.com/binwiederhier/ntfy/tree/main/examples">full example</a>): | ||||
|     </p> | ||||
|     <code> | ||||
|         fetch('https://ntfy.sh/mytopic', {<br/> | ||||
|           method: 'POST', // PUT works too<br/> | ||||
|           body: 'Hello from the other side.'<br/> | ||||
|         }) | ||||
|     </code> | ||||
| 
 | ||||
|     <h2>Subscribe to a topic</h2> | ||||
|     <p> | ||||
|         Topics are created on the fly by subscribing to them. You can create and subscribe to a topic either in this web UI, or in | ||||
|         your own app by subscribing to an <a href="https://developer.mozilla.org/en-US/docs/Web/API/EventSource">EventSource</a>, | ||||
|         a JSON feed, or raw feed. | ||||
|     </p> | ||||
|     <p> | ||||
|         Because there is no sign-up, <b>the topic is essentially a password</b>, so pick something that's not easily guessable. | ||||
|         You can create and subscribe to a topic either in this web UI, or in your own app by subscribing to an | ||||
|         <a href="https://developer.mozilla.org/en-US/docs/Web/API/EventSource">EventSource</a>, a JSON feed, or raw feed. | ||||
|     </p> | ||||
| 
 | ||||
|     <h3>Subscribe via web</h3> | ||||
|  | @ -66,7 +83,7 @@ | |||
| 
 | ||||
|     <h3>Subscribe via your app, or via the CLI</h3> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Using <a href="https://developer.mozilla.org/en-US/docs/Web/API/EventSource">EventSource</a>, you can consume | ||||
|         Using <a href="https://developer.mozilla.org/en-US/docs/Web/API/EventSource">EventSource</a> in JS, you can consume | ||||
|         notifications like this (see <a href="https://github.com/binwiederhier/ntfy/tree/main/examples">full example</a>): | ||||
|     </p> | ||||
|     <code> | ||||
|  | @ -76,30 +93,29 @@ | |||
|         }; | ||||
|     </code> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Or you can use <tt>curl</tt> or any other HTTP library. Here's an example for the <tt>/json</tt> endpoint, | ||||
|         which prints one JSON message per line (keepalive and open messages have an "event" field): | ||||
|     </p> | ||||
|     <code> | ||||
|         $ curl -s ntfy.sh/mytopic/json<br/> | ||||
|         {"time":1635359841,"event":"open"}<br/> | ||||
|         {"time":1635359844,"message":"This is a notification"}<br/> | ||||
|         {"time":1635359851,"event":"keepalive"} | ||||
|     </code> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Using the <tt>/sse</tt> endpoint (SSE, server-sent events stream): | ||||
|         You can also use the same <tt>/sse</tt> endpoint via <tt>curl</tt> or any other HTTP library: | ||||
|     </p> | ||||
|     <code> | ||||
|         $ curl -s ntfy.sh/mytopic/sse<br/> | ||||
|         event: open<br/> | ||||
|         data: {"time":1635359796,"event":"open"}<br/><br/> | ||||
|         data: {"id":"weSj9RtNkj","time":1635528898,"event":"open","topic":"mytopic"}<br/><br/> | ||||
| 
 | ||||
|         data: {"time":1635359803,"message":"This is a notification"}<br/><br/> | ||||
|         data: {"id":"p0M5y6gcCY","time":1635528909,"event":"message","topic":"mytopic","message":"Hi!"}<br/><br/> | ||||
| 
 | ||||
|         event: keepalive<br/> | ||||
|         data: {"time":1635359806,"event":"keepalive"} | ||||
|         data: {"id":"VNxNIg5fpt","time":1635528928,"event":"keepalive","topic":"test"} | ||||
|     </code> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Using the <tt>/raw</tt> endpoint (empty lines are keepalive messages): | ||||
|         To consume JSON instead, use the <tt>/json</tt> endpoint, which prints one message per line: | ||||
|     </p> | ||||
|     <code> | ||||
|         $ curl -s ntfy.sh/mytopic/json<br/> | ||||
|         {"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"}<br/> | ||||
|         {"id":"hwQ2YpKdmg","time":1635528741,"event":"message","topic":"mytopic","message":"Hi!"}<br/> | ||||
|         {"id":"DGUDShMCsc","time":1635528787,"event":"keepalive","topic":"mytopic"} | ||||
|     </code> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Or use the <tt>/raw</tt> endpoint if you need something super simple (empty lines are keepalive messages): | ||||
|     </p> | ||||
|     <code> | ||||
|         $ curl -s ntfy.sh/mytopic/raw<br/> | ||||
|  | @ -107,27 +123,25 @@ | |||
|         This is a notification | ||||
|     </code> | ||||
| 
 | ||||
|     <h2>Publishing messages</h2> | ||||
|     <h3>Message buffering and polling</h3> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Publishing messages can be done via PUT or POST using. Here's an example using <tt>curl</tt>: | ||||
|         Messages are buffered in memory for a few hours to account for network interruptions of subscribers. | ||||
|         You can read back what you missed by using the <tt>since=...</tt> query parameter. It takes either a | ||||
|         duration (e.g. <tt>10m</tt> or <tt>30s</tt>) or a Unix timestamp (e.g. <tt>1635528757</tt>): | ||||
|     </p> | ||||
|     <code> | ||||
|         curl -d "long process is done" ntfy.sh/mytopic | ||||
|         $ curl -s "ntfy.sh/mytopic/json?since=10m"<br/> | ||||
|         # Same output as above, but includes messages from up to 10 minutes ago | ||||
|     </code> | ||||
|     <p class="smallMarginBottom"> | ||||
|         Here's an example in JS with <tt>fetch()</tt> (see <a href="https://github.com/binwiederhier/ntfy/tree/main/examples">full example</a>): | ||||
|         You can also just poll for messages if you don't like the long-standing connection using the <tt>poll=1</tt> | ||||
|         query parameter. The connection will end after all available messages have been read. This parameter has to be | ||||
|         combined with <tt>since=</tt>. | ||||
|     </p> | ||||
|     <code> | ||||
|         fetch('https://ntfy.sh/mytopic', {<br/> | ||||
|           method: 'POST', // PUT works too<br/> | ||||
|           body: 'Hello from the other side.'<br/> | ||||
|         }) | ||||
|         $ curl -s "ntfy.sh/mytopic/json?poll=1&since=10m"<br/> | ||||
|         # Returns messages from up to 10 minutes ago and ends the connection | ||||
|     </code> | ||||
|     <p> | ||||
|         Messages published to a non-existing topic or a topic without subscribers will not be delivered later. | ||||
|         There is (currently) no buffering of any kind. If you're not listening, the message won't be delivered. | ||||
|     </p> | ||||
| 
 | ||||
|     <h2>FAQ</h2> | ||||
|     <p> | ||||
|         <b>Isn't this like ...?</b><br/> | ||||
|  |  | |||
|  | @ -1,18 +1,27 @@ | |||
| package server | ||||
| 
 | ||||
| import "time" | ||||
| import ( | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // List of possible events | ||||
| const ( | ||||
| 	openEvent      = "open" | ||||
| 	keepaliveEvent = "keepalive" | ||||
| 	messageEvent = "message" | ||||
| 	messageEvent   = "message" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	messageIDLength = 10 | ||||
| ) | ||||
| 
 | ||||
| // message represents a message published to a topic | ||||
| type message struct { | ||||
| 	Time    int64  `json:"time"`            // Unix time in seconds | ||||
| 	Event   string `json:"event,omitempty"` // One of the above | ||||
| 	ID      string `json:"id"`    // Random message ID | ||||
| 	Time    int64  `json:"time"`  // Unix time in seconds | ||||
| 	Event   string `json:"event"` // One of the above | ||||
| 	Topic   string `json:"topic"` | ||||
| 	Message string `json:"message,omitempty"` | ||||
| } | ||||
| 
 | ||||
|  | @ -20,25 +29,27 @@ type message struct { | |||
| type messageEncoder func(msg *message) (string, error) | ||||
| 
 | ||||
| // newMessage creates a new message with the current timestamp | ||||
| func newMessage(event string, msg string) *message { | ||||
| func newMessage(event, topic, msg string) *message { | ||||
| 	return &message{ | ||||
| 		ID:      util.RandomString(messageIDLength), | ||||
| 		Time:    time.Now().Unix(), | ||||
| 		Event:   event, | ||||
| 		Topic:   topic, | ||||
| 		Message: msg, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // newOpenMessage is a convenience method to create an open message | ||||
| func newOpenMessage() *message { | ||||
| 	return newMessage(openEvent, "") | ||||
| func newOpenMessage(topic string) *message { | ||||
| 	return newMessage(openEvent, topic, "") | ||||
| } | ||||
| 
 | ||||
| // newKeepaliveMessage is a convenience method to create a keepalive message | ||||
| func newKeepaliveMessage() *message { | ||||
| 	return newMessage(keepaliveEvent, "") | ||||
| func newKeepaliveMessage(topic string) *message { | ||||
| 	return newMessage(keepaliveEvent, topic, "") | ||||
| } | ||||
| 
 | ||||
| // newDefaultMessage is a convenience method to create a notification message | ||||
| func newDefaultMessage(msg string) *message { | ||||
| 	return newMessage(messageEvent, msg) | ||||
| func newDefaultMessage(topic, msg string) *message { | ||||
| 	return newMessage(messageEvent, topic, msg) | ||||
| } | ||||
|  |  | |||
							
								
								
									
										209
									
								
								server/server.go
									
										
									
									
									
								
							
							
						
						
									
										209
									
								
								server/server.go
									
										
									
									
									
								
							|  | @ -17,17 +17,23 @@ import ( | |||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // TODO add "max connections open" limit | ||||
| // TODO add "max messages in a topic" limit | ||||
| // TODO add "max topics" limit | ||||
| 
 | ||||
| // Server is the main server | ||||
| type Server struct { | ||||
| 	config   *config.Config | ||||
| 	topics   map[string]*topic | ||||
| 	visitors map[string]*visitor | ||||
| 	firebase *messaging.Client | ||||
| 	firebase subscriber | ||||
| 	messages int64 | ||||
| 	mu       sync.Mutex | ||||
| } | ||||
| 
 | ||||
|  | @ -53,10 +59,11 @@ const ( | |||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	topicRegex  = regexp.MustCompile(`^/[^/]+$`) | ||||
| 	jsonRegex   = regexp.MustCompile(`^/[^/]+/json$`) | ||||
| 	sseRegex    = regexp.MustCompile(`^/[^/]+/sse$`) | ||||
| 	rawRegex    = regexp.MustCompile(`^/[^/]+/raw$`) | ||||
| 	topicRegex = regexp.MustCompile(`^/[^/]+$`) | ||||
| 	jsonRegex  = regexp.MustCompile(`^/[^/]+/json$`) | ||||
| 	sseRegex   = regexp.MustCompile(`^/[^/]+/sse$`) | ||||
| 	rawRegex   = regexp.MustCompile(`^/[^/]+/raw$`) | ||||
| 
 | ||||
| 	staticRegex = regexp.MustCompile(`^/static/.+`) | ||||
| 
 | ||||
| 	//go:embed "index.html" | ||||
|  | @ -65,30 +72,57 @@ var ( | |||
| 	//go:embed static | ||||
| 	webStaticFs embed.FS | ||||
| 
 | ||||
| 	errHTTPBadRequest      = &errHTTP{http.StatusBadRequest, http.StatusText(http.StatusBadRequest)} | ||||
| 	errHTTPNotFound        = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)} | ||||
| 	errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)} | ||||
| ) | ||||
| 
 | ||||
| func New(conf *config.Config) (*Server, error) { | ||||
| 	var fcm *messaging.Client | ||||
| 	var firebaseSubscriber subscriber | ||||
| 	if conf.FirebaseKeyFile != "" { | ||||
| 		fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(conf.FirebaseKeyFile)) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		fcm, err = fb.Messaging(context.Background()) | ||||
| 		var err error | ||||
| 		firebaseSubscriber, err = createFirebaseSubscriber(conf) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	return &Server{ | ||||
| 		config:   conf, | ||||
| 		firebase: fcm, | ||||
| 		firebase: firebaseSubscriber, | ||||
| 		topics:   make(map[string]*topic), | ||||
| 		visitors: make(map[string]*visitor), | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func createFirebaseSubscriber(conf *config.Config) (subscriber, error) { | ||||
| 	fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(conf.FirebaseKeyFile)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	msg, err := fb.Messaging(context.Background()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return func(m *message) error { | ||||
| 		_, err := msg.Send(context.Background(), &messaging.Message{ | ||||
| 			Data: map[string]string{ | ||||
| 				"id":      m.ID, | ||||
| 				"time":    fmt.Sprintf("%d", m.Time), | ||||
| 				"event": m.Event, | ||||
| 				"topic":   m.Topic, | ||||
| 				"message": m.Message, | ||||
| 			}, | ||||
| 			Notification: &messaging.Notification{ | ||||
| 				Title:    m.Topic, // FIXME convert to ntfy.sh/$topic instead | ||||
| 				Body:     m.Message, | ||||
| 				ImageURL: "", | ||||
| 			}, | ||||
| 			Topic: m.Topic, | ||||
| 		}) | ||||
| 		return err | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) Run() error { | ||||
| 	go func() { | ||||
| 		ticker := time.NewTicker(s.config.ManagerInterval) | ||||
|  | @ -106,28 +140,6 @@ func (s *Server) listenAndServe() error { | |||
| 	return http.ListenAndServe(s.config.ListenHTTP, nil) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) updateStatsAndExpire() { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 
 | ||||
| 	// Expire visitors from rate visitors map | ||||
| 	for ip, v := range s.visitors { | ||||
| 		if time.Since(v.seen) > visitorExpungeAfter { | ||||
| 			delete(s.visitors, ip) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Print stats | ||||
| 	var subscribers, messages int | ||||
| 	for _, t := range s.topics { | ||||
| 		subs, msgs := t.Stats() | ||||
| 		subscribers += subs | ||||
| 		messages += msgs | ||||
| 	} | ||||
| 	log.Printf("Stats: %d topic(s), %d subscriber(s), %d message(s) sent, %d visitor(s)", | ||||
| 		len(s.topics), subscribers, messages, len(s.visitors)) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handle(w http.ResponseWriter, r *http.Request) { | ||||
| 	if err := s.handleInternal(w, r); err != nil { | ||||
| 		if e, ok := err.(*errHTTP); ok { | ||||
|  | @ -147,14 +159,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { | |||
| 		return s.handleHome(w, r) | ||||
| 	} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { | ||||
| 		return s.handleStatic(w, r) | ||||
| 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { | ||||
| 		return s.handlePublish(w, r) | ||||
| 	} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { | ||||
| 		return s.handleSubscribeJSON(w, r) | ||||
| 	} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) { | ||||
| 		return s.handleSubscribeSSE(w, r) | ||||
| 	} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) { | ||||
| 		return s.handleSubscribeRaw(w, r) | ||||
| 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { | ||||
| 		return s.handlePublishHTTP(w, r) | ||||
| 	} else if r.Method == http.MethodOptions { | ||||
| 		return s.handleOptions(w, r) | ||||
| 	} | ||||
|  | @ -166,42 +178,28 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error { | |||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handlePublishHTTP(w http.ResponseWriter, r *http.Request) error { | ||||
| 	t, err := s.topic(r.URL.Path[1:]) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { | ||||
| 	http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request) error { | ||||
| 	t := s.createTopic(r.URL.Path[1:]) | ||||
| 	reader := io.LimitReader(r.Body, messageLimit) | ||||
| 	b, err := io.ReadAll(reader) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err := t.Publish(newDefaultMessage(string(b))); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err := s.maybePublishFirebase(t.id, string(b)); err != nil { | ||||
| 	if err := t.Publish(newDefaultMessage(t.id, string(b))); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 	s.mu.Lock() | ||||
| 	s.messages++ | ||||
| 	s.mu.Unlock() | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) maybePublishFirebase(topic, message string) error { | ||||
| 	_, err := s.firebase.Send(context.Background(), &messaging.Message{ | ||||
| 		Data: map[string]string{ | ||||
| 			"topic":   topic, | ||||
| 			"message": message, | ||||
| 		}, | ||||
| 		Notification: &messaging.Notification{ | ||||
| 			Title:    "ntfy.sh/" + topic, | ||||
| 			Body:     message, | ||||
| 			ImageURL: "", | ||||
| 		}, | ||||
| 		Topic: topic, | ||||
| 	}) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) error { | ||||
| 	encoder := func(msg *message) (string, error) { | ||||
| 		var buf bytes.Buffer | ||||
|  | @ -239,6 +237,11 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request) erro | |||
| 
 | ||||
| func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format string, contentType string, encoder messageEncoder) error { | ||||
| 	t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack | ||||
| 	since, err := parseSince(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	poll := r.URL.Query().Has("poll") | ||||
| 	sub := func(msg *message) error { | ||||
| 		m, err := encoder(msg) | ||||
| 		if err != nil { | ||||
|  | @ -252,11 +255,17 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format | |||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
| 	subscriberID := t.Subscribe(sub) | ||||
| 	defer s.unsubscribe(t, subscriberID) | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 	w.Header().Set("Content-Type", contentType) | ||||
| 	if err := sub(newOpenMessage()); err != nil { // Send out open message | ||||
| 	if poll { | ||||
| 		return sendOldMessages(t, since, sub) | ||||
| 	} | ||||
| 	subscriberID := t.Subscribe(sub) | ||||
| 	defer t.Unsubscribe(subscriberID) | ||||
| 	if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message | ||||
| 		return err | ||||
| 	} | ||||
| 	if err := sendOldMessages(t, since, sub); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	for { | ||||
|  | @ -266,49 +275,85 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format | |||
| 		case <-r.Context().Done(): | ||||
| 			return nil | ||||
| 		case <-time.After(s.config.KeepaliveInterval): | ||||
| 			if err := sub(newKeepaliveMessage()); err != nil { // Send keepalive message | ||||
| 			if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func sendOldMessages(t *topic, since time.Time, sub subscriber) error { | ||||
| 	if since.IsZero() { | ||||
| 		return nil | ||||
| 	} | ||||
| 	for _, m := range t.Messages(since) { | ||||
| 		if err := sub(m); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func parseSince(r *http.Request) (time.Time, error) { | ||||
| 	if !r.URL.Query().Has("since") { | ||||
| 		return time.Time{}, nil | ||||
| 	} | ||||
| 	if since, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil { | ||||
| 		return time.Unix(since, 0), nil | ||||
| 	} | ||||
| 	if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil { | ||||
| 		return time.Now().Add(-1 * d), nil | ||||
| 	} | ||||
| 	return time.Time{}, errHTTPBadRequest | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error { | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
| 	w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST") | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { | ||||
| 	http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) createTopic(id string) *topic { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	if _, ok := s.topics[id]; !ok { | ||||
| 		s.topics[id] = newTopic(id) | ||||
| 		if s.firebase != nil { | ||||
| 			s.topics[id].Subscribe(s.firebase) | ||||
| 		} | ||||
| 	} | ||||
| 	return s.topics[id] | ||||
| } | ||||
| 
 | ||||
| func (s *Server) topic(topicID string) (*topic, error) { | ||||
| func (s *Server) updateStatsAndExpire() { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	c, ok := s.topics[topicID] | ||||
| 	if !ok { | ||||
| 		return nil, errHTTPNotFound | ||||
| 	} | ||||
| 	return c, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) unsubscribe(t *topic, subscriberID int) { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	if subscribers := t.Unsubscribe(subscriberID); subscribers == 0 { | ||||
| 		delete(s.topics, t.id) | ||||
| 	// Expire visitors from rate visitors map | ||||
| 	for ip, v := range s.visitors { | ||||
| 		if time.Since(v.seen) > visitorExpungeAfter { | ||||
| 			delete(s.visitors, ip) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Prune old messages, remove topics without subscribers | ||||
| 	for _, t := range s.topics { | ||||
| 		t.Prune(s.config.MessageBufferDuration) | ||||
| 		subs, msgs := t.Stats() | ||||
| 		if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { | ||||
| 			delete(s.topics, t.id) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Print stats | ||||
| 	var subscribers, messages int | ||||
| 	for _, t := range s.topics { | ||||
| 		subs, msgs := t.Stats() | ||||
| 		subscribers += subs | ||||
| 		messages += msgs | ||||
| 	} | ||||
| 	log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)", | ||||
| 		s.messages, len(s.topics), subscribers, messages, len(s.visitors)) | ||||
| } | ||||
| 
 | ||||
| // visitor creates or retrieves a rate.Limiter for the given visitor. | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"sync" | ||||
|  | @ -14,7 +13,7 @@ import ( | |||
| type topic struct { | ||||
| 	id          string | ||||
| 	subscribers map[int]subscriber | ||||
| 	messages    int | ||||
| 	messages    []*message | ||||
| 	last        time.Time | ||||
| 	ctx         context.Context | ||||
| 	cancel      context.CancelFunc | ||||
|  | @ -45,21 +44,17 @@ func (t *topic) Subscribe(s subscriber) int { | |||
| 	return subscriberID | ||||
| } | ||||
| 
 | ||||
| func (t *topic) Unsubscribe(id int) int { | ||||
| func (t *topic) Unsubscribe(id int) { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	delete(t.subscribers, id) | ||||
| 	return len(t.subscribers) | ||||
| } | ||||
| 
 | ||||
| func (t *topic) Publish(m *message) error { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	if len(t.subscribers) == 0 { | ||||
| 		return errors.New("no subscribers") | ||||
| 	} | ||||
| 	t.last = time.Now() | ||||
| 	t.messages++ | ||||
| 	t.messages = append(t.messages, m) | ||||
| 	for _, s := range t.subscribers { | ||||
| 		if err := s(m); err != nil { | ||||
| 			log.Printf("error publishing message to subscriber") | ||||
|  | @ -68,10 +63,36 @@ func (t *topic) Publish(m *message) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (t *topic) Messages(since time.Time) []*message { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	messages := make([]*message, 0) // copy! | ||||
| 	for _, m := range t.messages { | ||||
| 		msgTime := time.Unix(m.Time, 0) | ||||
| 		if msgTime == since || msgTime.After(since) { | ||||
| 			messages = append(messages, m) | ||||
| 		} | ||||
| 	} | ||||
| 	return messages | ||||
| } | ||||
| 
 | ||||
| func (t *topic) Prune(keep time.Duration) { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	for i, m := range t.messages { | ||||
| 		msgTime := time.Unix(m.Time, 0) | ||||
| 		if time.Since(msgTime) < keep { | ||||
| 			t.messages = t.messages[i:] | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	t.messages = make([]*message, 0) | ||||
| } | ||||
| 
 | ||||
| func (t *topic) Stats() (subscribers int, messages int) { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	return len(t.subscribers), t.messages | ||||
| 	return len(t.subscribers), len(t.messages) | ||||
| } | ||||
| 
 | ||||
| func (t *topic) Close() { | ||||
|  |  | |||
							
								
								
									
										29
									
								
								util/util.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								util/util.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,29 @@ | |||
| package util | ||||
| 
 | ||||
| import ( | ||||
| 	"math/rand" | ||||
| 	"os" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	randomStringCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	random = rand.New(rand.NewSource(time.Now().UnixNano())) | ||||
| ) | ||||
| 
 | ||||
| func FileExists(filename string) bool { | ||||
| 	stat, _ := os.Stat(filename) | ||||
| 	return stat != nil | ||||
| } | ||||
| 
 | ||||
| // RandomString returns a random string with a given length | ||||
| func RandomString(length int) string { | ||||
| 	b := make([]byte, length) | ||||
| 	for i := range b { | ||||
| 		b[i] = randomStringCharset[random.Intn(len(randomStringCharset))] | ||||
| 	} | ||||
| 	return string(b) | ||||
| } | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue