Fix previous fix
This commit is contained in:
		
							parent
							
								
									dc8932cd95
								
							
						
					
					
						commit
						f58c1e4c84
					
				
					 5 changed files with 38 additions and 43 deletions
				
			
		|  | @ -11,23 +11,25 @@ import ( | ||||||
| 	"heckel.io/ntfy/util" | 	"heckel.io/ntfy/util" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Event type constants |  | ||||||
| const ( | const ( | ||||||
| 	MessageEvent     = "message" | 	// MessageEvent identifies a message event | ||||||
| 	KeepaliveEvent   = "keepalive" | 	MessageEvent = "message" | ||||||
| 	OpenEvent        = "open" |  | ||||||
| 	PollRequestEvent = "poll_request" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	maxResponseBytes = 4096 | 	maxResponseBytes = 4096 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | var ( | ||||||
|  | 	topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| // Client is the ntfy client that can be used to publish and subscribe to ntfy topics | // Client is the ntfy client that can be used to publish and subscribe to ntfy topics | ||||||
| type Client struct { | type Client struct { | ||||||
| 	Messages      chan *Message | 	Messages      chan *Message | ||||||
|  | @ -96,7 +98,10 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess | ||||||
| // To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache, | // To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache, | ||||||
| // WithNoFirebase, and the generic WithHeader. | // WithNoFirebase, and the generic WithHeader. | ||||||
| func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) { | func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) { | ||||||
| 	topicURL := c.expandTopicURL(topic) | 	topicURL, err := c.expandTopicURL(topic) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	req, err := http.NewRequest("POST", topicURL, body) | 	req, err := http.NewRequest("POST", topicURL, body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -136,11 +141,14 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO | ||||||
| // By default, all messages will be returned, but you can change this behavior using a SubscribeOption. | // By default, all messages will be returned, but you can change this behavior using a SubscribeOption. | ||||||
| // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam. | // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam. | ||||||
| func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) { | func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) { | ||||||
|  | 	topicURL, err := c.expandTopicURL(topic) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 	messages := make([]*Message, 0) | 	messages := make([]*Message, 0) | ||||||
| 	msgChan := make(chan *Message) | 	msgChan := make(chan *Message) | ||||||
| 	errChan := make(chan error) | 	errChan := make(chan error) | ||||||
| 	topicURL := c.expandTopicURL(topic) |  | ||||||
| 	log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL)) | 	log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL)) | ||||||
| 	options = append(options, WithPoll()) | 	options = append(options, WithPoll()) | ||||||
| 	go func() { | 	go func() { | ||||||
|  | @ -169,15 +177,18 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err | ||||||
| // Example: | // Example: | ||||||
| // | // | ||||||
| //	c := client.New(client.NewConfig()) | //	c := client.New(client.NewConfig()) | ||||||
| //	subscriptionID := c.Subscribe("mytopic") | //	subscriptionID, _ := c.Subscribe("mytopic") | ||||||
| //	for m := range c.Messages { | //	for m := range c.Messages { | ||||||
| //	  fmt.Printf("New message: %s", m.Message) | //	  fmt.Printf("New message: %s", m.Message) | ||||||
| //	} | //	} | ||||||
| func (c *Client) Subscribe(topic string, options ...SubscribeOption) string { | func (c *Client) Subscribe(topic string, options ...SubscribeOption) (string, error) { | ||||||
|  | 	topicURL, err := c.expandTopicURL(topic) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
| 	c.mu.Lock() | 	c.mu.Lock() | ||||||
| 	defer c.mu.Unlock() | 	defer c.mu.Unlock() | ||||||
| 	subscriptionID := util.RandomString(10) | 	subscriptionID := util.RandomString(10) | ||||||
| 	topicURL := c.expandTopicURL(topic) |  | ||||||
| 	log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL)) | 	log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL)) | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 	c.subscriptions[subscriptionID] = &subscription{ | 	c.subscriptions[subscriptionID] = &subscription{ | ||||||
|  | @ -186,7 +197,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) string { | ||||||
| 		cancel:   cancel, | 		cancel:   cancel, | ||||||
| 	} | 	} | ||||||
| 	go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...) | 	go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...) | ||||||
| 	return subscriptionID | 	return subscriptionID, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique | // Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique | ||||||
|  | @ -202,31 +213,16 @@ func (c *Client) Unsubscribe(subscriptionID string) { | ||||||
| 	sub.cancel() | 	sub.cancel() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe. | func (c *Client) expandTopicURL(topic string) (string, error) { | ||||||
| // If there are multiple subscriptions matching the topic, all of them are unsubscribed from. |  | ||||||
| // |  | ||||||
| // A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https:// |  | ||||||
| // (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the |  | ||||||
| // config (e.g. mytopic -> https://ntfy.sh/mytopic). |  | ||||||
| func (c *Client) UnsubscribeAll(topic string) { |  | ||||||
| 	c.mu.Lock() |  | ||||||
| 	defer c.mu.Unlock() |  | ||||||
| 	topicURL := c.expandTopicURL(topic) |  | ||||||
| 	for _, sub := range c.subscriptions { |  | ||||||
| 		if sub.topicURL == topicURL { |  | ||||||
| 			delete(c.subscriptions, sub.ID) |  | ||||||
| 			sub.cancel() |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (c *Client) expandTopicURL(topic string) string { |  | ||||||
| 	if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") { | 	if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") { | ||||||
| 		return topic | 		return topic, nil | ||||||
| 	} else if strings.Contains(topic, "/") { | 	} else if strings.Contains(topic, "/") { | ||||||
| 		return fmt.Sprintf("https://%s", topic) | 		return fmt.Sprintf("https://%s", topic), nil | ||||||
| 	} | 	} | ||||||
| 	return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic) | 	if !topicRegex.MatchString(topic) { | ||||||
|  | 		return "", fmt.Errorf("invalid topic name: %s", topic) | ||||||
|  | 	} | ||||||
|  | 	return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { | func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ func TestClient_Publish_Subscribe(t *testing.T) { | ||||||
| 	defer test.StopServer(t, s, port) | 	defer test.StopServer(t, s, port) | ||||||
| 	c := client.New(newTestConfig(port)) | 	c := client.New(newTestConfig(port)) | ||||||
| 
 | 
 | ||||||
| 	subscriptionID := c.Subscribe("mytopic") | 	subscriptionID, _ := c.Subscribe("mytopic") | ||||||
| 	time.Sleep(time.Second) | 	time.Sleep(time.Second) | ||||||
| 
 | 
 | ||||||
| 	msg, err := c.Publish("mytopic", "some message") | 	msg, err := c.Publish("mytopic", "some message") | ||||||
|  |  | ||||||
|  | @ -29,7 +29,6 @@ var flagsDefault = []cli.Flag{ | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=\s]+)(?:\s*=\s*(\S+))?\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`) | 	logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=\s]+)(?:\s*=\s*(\S+))?\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`) | ||||||
| 	topicRegex            = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // New creates a new CLI application | // New creates a new CLI application | ||||||
|  |  | ||||||
|  | @ -249,10 +249,6 @@ func parseTopicMessageCommand(c *cli.Context) (topic string, message string, com | ||||||
| 	if c.String("message") != "" { | 	if c.String("message") != "" { | ||||||
| 		message = c.String("message") | 		message = c.String("message") | ||||||
| 	} | 	} | ||||||
| 	if !topicRegex.MatchString(topic) { |  | ||||||
| 		err = fmt.Errorf("topic %s contains invalid characters", topic) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -108,8 +108,6 @@ func execSubscribe(c *cli.Context) error { | ||||||
| 	// Checks | 	// Checks | ||||||
| 	if user != "" && token != "" { | 	if user != "" && token != "" { | ||||||
| 		return errors.New("cannot set both --user and --token") | 		return errors.New("cannot set both --user and --token") | ||||||
| 	} else if !topicRegex.MatchString(topic) { |  | ||||||
| 		return fmt.Errorf("topic %s contains invalid characters", topic) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !fromConfig { | 	if !fromConfig { | ||||||
|  | @ -196,7 +194,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic, | ||||||
| 			topicOptions = append(topicOptions, auth) | 			topicOptions = append(topicOptions, auth) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		subscriptionID := cl.Subscribe(s.Topic, topicOptions...) | 		subscriptionID, err := cl.Subscribe(s.Topic, topicOptions...) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
| 		if s.Command != "" { | 		if s.Command != "" { | ||||||
| 			cmds[subscriptionID] = s.Command | 			cmds[subscriptionID] = s.Command | ||||||
| 		} else if conf.DefaultCommand != "" { | 		} else if conf.DefaultCommand != "" { | ||||||
|  | @ -206,7 +207,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if topic != "" { | 	if topic != "" { | ||||||
| 		subscriptionID := cl.Subscribe(topic, options...) | 		subscriptionID, err := cl.Subscribe(topic, options...) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
| 		cmds[subscriptionID] = command | 		cmds[subscriptionID] = command | ||||||
| 	} | 	} | ||||||
| 	for m := range cl.Messages { | 	for m := range cl.Messages { | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue