diff --git a/README.md b/README.md
new file mode 100644
index 00000000..93232bc8
--- /dev/null
+++ b/README.md
@@ -0,0 +1,10 @@
+
+
+echo "mychan:long process is done" | nc -N ntfy.sh 9999
+curl -d "long process is done" ntfy.sh/mychan
+ publish on channel
+
+curl ntfy.sh/mychan
+ subscribe to channel
+
+ntfy.sh/mychan/ws
diff --git a/go.mod b/go.mod
index e3f7ce5a..ee3af69f 100644
--- a/go.mod
+++ b/go.mod
@@ -1,3 +1,5 @@
module heckel.io/notifyme
go 1.16
+
+require github.com/gorilla/websocket v1.4.2 // indirect
diff --git a/go.sum b/go.sum
index e69de29b..85efffd9 100644
--- a/go.sum
+++ b/go.sum
@@ -0,0 +1,2 @@
+github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
+github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
diff --git a/main.go b/main.go
index 973a8ac8..43af89f2 100644
--- a/main.go
+++ b/main.go
@@ -1,138 +1,13 @@
package main
import (
- "context"
- "encoding/json"
- "errors"
- "io"
+ "heckel.io/notifyme/server"
"log"
- "math/rand"
- "net/http"
- "sync"
- "time"
)
-type Message struct {
- Time int64 `json:"time"`
- Message string `json:"message"`
-}
-
-type Channel struct {
- id string
- listeners map[int]listener
- last time.Time
- ctx context.Context
- mu sync.Mutex
-}
-
-type Server struct {
- channels map[string]*Channel
- mu sync.Mutex
-}
-
-type listener func(msg *Message)
-
func main() {
- s := &Server{
- channels: make(map[string]*Channel),
- }
- go func() {
- for {
- time.Sleep(5 * time.Second)
- s.mu.Lock()
- log.Printf("channels: %d", len(s.channels))
- s.mu.Unlock()
- }
- }()
- http.HandleFunc("/", s.handle)
- if err := http.ListenAndServe(":9997", nil); err != nil {
+ s := server.New()
+ if err := s.Run(); err != nil {
log.Fatalln(err)
}
}
-
-func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
- if err := s.handleInternal(w, r); err != nil {
- w.WriteHeader(http.StatusInternalServerError)
- _, _ = io.WriteString(w, err.Error())
- }
-}
-
-func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
- if len(r.URL.Path) == 0 {
- return errors.New("invalid path")
- }
- channel := s.channel(r.URL.Path[1:])
- switch r.Method {
- case http.MethodGet:
- return s.handleGET(w, r, channel)
- case http.MethodPut:
- return s.handlePUT(w, r, channel)
- default:
- return errors.New("invalid method")
- }
-}
-
-func (s *Server) handleGET(w http.ResponseWriter, r *http.Request, ch *Channel) error {
- fl, ok := w.(http.Flusher)
- if !ok {
- return errors.New("not a flusher")
- }
- listenerID := rand.Int()
- l := func (msg *Message) {
- json.NewEncoder(w).Encode(&msg)
- fl.Flush()
- }
- ch.mu.Lock()
- ch.listeners[listenerID] = l
- ch.last = time.Now()
- ch.mu.Unlock()
- select {
- case <-ch.ctx.Done():
- case <-r.Context().Done():
- }
- ch.mu.Lock()
- delete(ch.listeners, listenerID)
- if len(ch.listeners) == 0 {
- s.mu.Lock()
- delete(s.channels, ch.id)
- s.mu.Unlock()
- }
- ch.mu.Unlock()
- return nil
-}
-
-func (s *Server) handlePUT(w http.ResponseWriter, r *http.Request, ch *Channel) error {
- ch.mu.Lock()
- defer ch.mu.Unlock()
- if len(ch.listeners) == 0 {
- return errors.New("no listeners")
- }
- defer r.Body.Close()
- ch.last = time.Now()
- msg, _ := io.ReadAll(r.Body)
- for _, l := range ch.listeners {
- l(&Message{
- Time: time.Now().UnixMilli(),
- Message: string(msg),
- })
- }
- return nil
-}
-
-func (s *Server) channel(channelID string) *Channel {
- s.mu.Lock()
- defer s.mu.Unlock()
- c, ok := s.channels[channelID]
- if !ok {
- ctx, _ := context.WithCancel(context.Background()) // FIXME
- c = &Channel{
- id: channelID,
- listeners: make(map[int]listener),
- last: time.Now(),
- ctx: ctx,
- mu: sync.Mutex{},
- }
- s.channels[channelID] = c
- }
- return c
-}
diff --git a/server/index.html b/server/index.html
new file mode 100644
index 00000000..9ab3b9b9
--- /dev/null
+++ b/server/index.html
@@ -0,0 +1,79 @@
+
+
+
+ ntfy.sh
+
+
+
+ntfy.sh
+
+Topics:
+
+
+
+
+
+
+
+
+
+
+
diff --git a/server/server.go b/server/server.go
new file mode 100644
index 00000000..3b839b82
--- /dev/null
+++ b/server/server.go
@@ -0,0 +1,185 @@
+package server
+
+import (
+ "bytes"
+ _ "embed" // required for go:embed
+ "encoding/json"
+ "errors"
+ "github.com/gorilla/websocket"
+ "io"
+ "log"
+ "net/http"
+ "regexp"
+ "strings"
+ "sync"
+ "time"
+)
+
+type Server struct {
+ topics map[string]*topic
+ mu sync.Mutex
+}
+
+type message struct {
+ Time int64 `json:"time"`
+ Message string `json:"message"`
+}
+
+const (
+ messageLimit = 1024
+)
+
+var (
+ topicRegex = regexp.MustCompile(`^/[^/]+$`)
+ wsRegex = regexp.MustCompile(`^/[^/]+/ws$`)
+ jsonRegex = regexp.MustCompile(`^/[^/]+/json$`)
+ wsUpgrader = websocket.Upgrader{
+ ReadBufferSize: messageLimit,
+ WriteBufferSize: messageLimit,
+ }
+
+ //go:embed "index.html"
+ indexSource string
+)
+
+func New() *Server {
+ return &Server{
+ topics: make(map[string]*topic),
+ }
+}
+
+func (s *Server) Run() error {
+ go func() {
+ for {
+ time.Sleep(5 * time.Second)
+ s.mu.Lock()
+ log.Printf("topics: %d", len(s.topics))
+ for _, t := range s.topics {
+ t.mu.Lock()
+ log.Printf("- %s: %d subscriber(s), %d message(s) sent, last active = %s",
+ t.id, len(t.subscribers), t.messages, t.last.String())
+ t.mu.Unlock()
+ }
+ // TODO kill dead topics
+ s.mu.Unlock()
+ }
+ }()
+ log.Printf("Listening on :9997")
+ http.HandleFunc("/", s.handle)
+ return http.ListenAndServe(":9997", nil)
+}
+
+func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
+ if err := s.handleInternal(w, r); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = io.WriteString(w, err.Error())
+ }
+}
+
+func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
+ if r.Method == http.MethodGet && r.URL.Path == "/" {
+ return s.handleHome(w, r)
+ } else if r.Method == http.MethodGet && wsRegex.MatchString(r.URL.Path) {
+ return s.handleSubscribeWS(w, r)
+ } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
+ return s.handleSubscribeHTTP(w, r)
+ } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
+ return s.handlePublishHTTP(w, r)
+ }
+ http.NotFound(w, r)
+ return nil
+}
+
+func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
+ _, err := io.WriteString(w, indexSource)
+ return err
+}
+
+func (s *Server) handlePublishHTTP(w http.ResponseWriter, r *http.Request) error {
+ t, err := s.topic(r.URL.Path[1:])
+ if err != nil {
+ return err
+ }
+ reader := io.LimitReader(r.Body, messageLimit)
+ b, err := io.ReadAll(reader)
+ if err != nil {
+ return err
+ }
+ msg := &message{
+ Time: time.Now().UnixMilli(),
+ Message: string(b),
+ }
+ return t.Publish(msg)
+}
+
+func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request) error {
+ t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/json")) // Hack
+ subscriberID := t.Subscribe(func (msg *message) error {
+ if err := json.NewEncoder(w).Encode(&msg); err != nil {
+ return err
+ }
+ if fl, ok := w.(http.Flusher); ok {
+ fl.Flush()
+ }
+ return nil
+ })
+ defer t.Unsubscribe(subscriberID)
+ select {
+ case <-t.ctx.Done():
+ case <-r.Context().Done():
+ }
+ return nil
+}
+
+func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request) error {
+ conn, err := wsUpgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return err
+ }
+ t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/ws")) // Hack
+ t.Subscribe(func (msg *message) error {
+ var buf bytes.Buffer
+ if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
+ return err
+ }
+ defer conn.Close()
+ /*conn.SetWriteDeadline(time.Now().Add(writeWait))
+ if !ok {
+ // The hub closed the channel.
+ c.conn.WriteMessage(websocket.CloseMessage, []byte{})
+ return
+ }*/
+
+ w, err := conn.NextWriter(websocket.TextMessage)
+ if err != nil {
+ return err
+ }
+ if _, err := w.Write([]byte(msg.Message)); err != nil {
+ return err
+ }
+ if err := w.Close(); err != nil {
+ return err
+ }
+ return nil
+ })
+ return nil
+}
+
+func (s *Server) createTopic(id string) *topic {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if _, ok := s.topics[id]; !ok {
+ s.topics[id] = newTopic(id)
+ }
+ return s.topics[id]
+}
+
+func (s *Server) topic(topicID string) (*topic, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ c, ok := s.topics[topicID]
+ if !ok {
+ return nil, errors.New("topic does not exist")
+ }
+ return c, nil
+}
diff --git a/server/topic.go b/server/topic.go
new file mode 100644
index 00000000..65dd7415
--- /dev/null
+++ b/server/topic.go
@@ -0,0 +1,68 @@
+package server
+
+import (
+ "context"
+ "errors"
+ "log"
+ "math/rand"
+ "sync"
+ "time"
+)
+
+type topic struct {
+ id string
+ subscribers map[int]subscriber
+ messages int
+ last time.Time
+ ctx context.Context
+ cancel context.CancelFunc
+ mu sync.Mutex
+}
+
+type subscriber func(msg *message) error
+
+func newTopic(id string) *topic {
+ ctx, cancel := context.WithCancel(context.Background())
+ return &topic{
+ id: id,
+ subscribers: make(map[int]subscriber),
+ last: time.Now(),
+ ctx: ctx,
+ cancel: cancel,
+ }
+}
+
+func (t *topic) Subscribe(s subscriber) int {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ subscriberID := rand.Int()
+ t.subscribers[subscriberID] = s
+ t.last = time.Now()
+ return subscriberID
+}
+
+func (t *topic) Unsubscribe(id int) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ delete(t.subscribers, id)
+}
+
+func (t *topic) Publish(m *message) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if len(t.subscribers) == 0 {
+ return errors.New("no subscribers")
+ }
+ t.last = time.Now()
+ t.messages++
+ for _, s := range t.subscribers {
+ if err := s(m); err != nil {
+ log.Printf("error publishing message to subscriber x")
+ }
+ }
+ return nil
+}
+
+func (t *topic) Close() {
+ t.cancel()
+}