diff --git a/cmd/mstdn/cmd_stream.go b/cmd/mstdn/cmd_stream.go index dac8865..f2ca59c 100644 --- a/cmd/mstdn/cmd_stream.go +++ b/cmd/mstdn/cmd_stream.go @@ -80,7 +80,6 @@ func cmdStream(c *cli.Context) error { go func() { <-sc cancel() - close(q) }() c.App.Metadata["signal"] = sc diff --git a/status.go b/status.go index 4c28b33..dbbed03 100644 --- a/status.go +++ b/status.go @@ -187,6 +187,23 @@ func (c *Client) GetTimelineHashtag(ctx context.Context, tag string, isLocal boo return statuses, nil } +// GetTimelineMedia return statuses from media timeline. +// NOTE: This is an experimental feature of pawoo.net. +func (c *Client) GetTimelineMedia(ctx context.Context, isLocal bool) ([]*Status, error) { + params := url.Values{} + params.Set("media", "t") + if isLocal { + params.Set("local", "t") + } + + var statuses []*Status + err := c.doAPI(ctx, http.MethodGet, "/api/v1/timelines/public", params, &statuses, nil) + if err != nil { + return nil, err + } + return statuses, nil +} + // PostStatus post the toot. func (c *Client) PostStatus(ctx context.Context, toot *Toot) (*Status, error) { params := url.Values{} diff --git a/status_test.go b/status_test.go index bfc4a41..bc4eed6 100644 --- a/status_test.go +++ b/status_test.go @@ -393,6 +393,42 @@ func TestGetTimelineHashtag(t *testing.T) { } } +func TestGetTimelineMedia(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("local") == "" { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + fmt.Fprintln(w, `[{"content": "zzz"},{"content": "yyy"}]`) + return + })) + defer ts.Close() + + client := NewClient(&Config{ + Server: ts.URL, + ClientID: "foo", + ClientSecret: "bar", + AccessToken: "zoo", + }) + _, err := client.GetTimelineMedia(context.Background(), false) + if err == nil { + t.Fatalf("should be fail: %v", err) + } + tags, err := client.GetTimelineMedia(context.Background(), true) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + if len(tags) != 2 { + t.Fatalf("should have %q entries but %q", "2", len(tags)) + } + if tags[0].Content != "zzz" { + t.Fatalf("want %q but %q", "zzz", tags[0].Content) + } + if tags[1].Content != "yyy" { + t.Fatalf("want %q but %q", "zzz", tags[1].Content) + } +} + func TestDeleteStatus(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/v1/statuses/1234567" { diff --git a/streaming.go b/streaming.go index 897be4c..060c9e7 100644 --- a/streaming.go +++ b/streaming.go @@ -8,8 +8,8 @@ import ( "net/http" "net/url" "path" + "strconv" "strings" - "time" ) // UpdateEvent is struct for passing status event to app. @@ -42,8 +42,8 @@ type Event interface { event() } -func handleReader(ctx context.Context, q chan Event, r io.Reader) error { - name := "" +func handleReader(q chan Event, r io.Reader) error { + var name string s := bufio.NewScanner(r) for s.Scan() { line := s.Text() @@ -55,30 +55,33 @@ func handleReader(ctx context.Context, q chan Event, r io.Reader) error { case "event": name = strings.TrimSpace(token[1]) case "data": + var err error switch name { case "update": var status Status - err := json.Unmarshal([]byte(token[1]), &status) + err = json.Unmarshal([]byte(token[1]), &status) if err == nil { q <- &UpdateEvent{&status} } case "notification": var notification Notification - err := json.Unmarshal([]byte(token[1]), ¬ification) + err = json.Unmarshal([]byte(token[1]), ¬ification) if err == nil { q <- &NotificationEvent{¬ification} } case "delete": var id int64 - err := json.Unmarshal([]byte(token[1]), &id) + id, err = strconv.ParseInt(strings.TrimSpace(token[1]), 10, 64) if err == nil { q <- &DeleteEvent{id} } } - default: + if err != nil { + q <- &ErrorEvent{err} + } } } - return ctx.Err() + return s.Err() } func (c *Client) streaming(ctx context.Context, p string, params url.Values) (chan Event, error) { @@ -96,38 +99,42 @@ func (c *Client) streaming(ctx context.Context, p string, params url.Values) (ch req = req.WithContext(ctx) req.Header.Set("Authorization", "Bearer "+c.config.AccessToken) - var resp *http.Response - - q := make(chan Event, 10) + q := make(chan Event) go func() { - defer ctx.Done() - + defer close(q) for { - resp, err = c.Do(req) - if resp != nil && resp.StatusCode != http.StatusOK { - err = parseAPIError("bad request", resp) + select { + case <-ctx.Done(): + q <- &ErrorEvent{ctx.Err()} + return + default: } - if err == nil { - err = handleReader(ctx, q, resp.Body) - if err == nil { - break - } - } else { - q <- &ErrorEvent{err} - } - resp.Body.Close() - time.Sleep(3 * time.Second) - } - }() - go func() { - <-ctx.Done() - if resp != nil && resp.Body != nil { - resp.Body.Close() + + c.doStreaming(req, q) } }() return q, nil } +func (c *Client) doStreaming(req *http.Request, q chan Event) { + resp, err := c.Do(req) + if err != nil { + q <- &ErrorEvent{err} + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + q <- &ErrorEvent{parseAPIError("bad request", resp)} + return + } + + err = handleReader(q, resp.Body) + if err != nil { + q <- &ErrorEvent{err} + } +} + // StreamingUser return channel to read events on home. func (c *Client) StreamingUser(ctx context.Context) (chan Event, error) { return c.streaming(ctx, "user", nil) diff --git a/streaming_test.go b/streaming_test.go index 6e6153d..9f5b715 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -5,13 +5,207 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" ) -func TestStreamingPublic(t *testing.T) { +func TestHandleReader(t *testing.T) { + q := make(chan Event) + r := strings.NewReader(` +event: update +data: {content: error} +event: update +data: {"content": "foo"} +event: notification +data: {"type": "mention"} +event: delete +data: 1234567 +:thump + `) + go func() { + defer close(q) + err := handleReader(q, r) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + }() + var passUpdate, passNotification, passDelete, passError bool + for e := range q { + switch event := e.(type) { + case *UpdateEvent: + passUpdate = true + if event.Status.Content != "foo" { + t.Fatalf("want %q but %q", "foo", event.Status.Content) + } + case *NotificationEvent: + passNotification = true + if event.Notification.Type != "mention" { + t.Fatalf("want %q but %q", "mention", event.Notification.Type) + } + case *DeleteEvent: + passDelete = true + if event.ID != 1234567 { + t.Fatalf("want %d but %d", 1234567, event.ID) + } + case *ErrorEvent: + passError = true + if event.err == nil { + t.Fatalf("should be fail: %v", event.err) + } + } + } + if !passUpdate || !passNotification || !passDelete || !passError { + t.Fatalf("have not passed through somewhere: "+ + "update %t, notification %t, delete %t, error %t", + passUpdate, passNotification, passDelete, passError) + } +} + +func TestStreaming(t *testing.T) { + var isEnd bool + canErr := true ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/api/v1/streaming/public" { + if isEnd { + return + } else if canErr { + canErr = false + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + f := w.(http.Flusher) + fmt.Fprintln(w, ` +event: update +data: {"content": "foo"} + `) + f.Flush() + isEnd = true + })) + defer ts.Close() + + c := NewClient(&Config{Server: ":"}) + _, err := c.streaming(context.Background(), "", nil) + if err == nil { + t.Fatalf("should be fail: %v", err) + } + + c = NewClient(&Config{Server: ts.URL}) + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(time.Second, cancel) + q, err := c.streaming(ctx, "", nil) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + var cnt int + var passError, passUpdate bool + for e := range q { + switch event := e.(type) { + case *ErrorEvent: + passError = true + if event.err == nil { + t.Fatalf("should be fail: %v", event.err) + } + case *UpdateEvent: + cnt++ + passUpdate = true + if event.Status.Content != "foo" { + t.Fatalf("want %q but %q", "foo", event.Status.Content) + } + } + } + if cnt != 1 { + t.Fatalf("result should be one: %d", cnt) + } + if !passError || !passUpdate { + t.Fatalf("have not passed through somewhere: error %t, update %t", passError, passUpdate) + } +} + +func TestDoStreaming(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() + time.Sleep(time.Second) + })) + defer ts.Close() + + c := NewClient(&Config{Server: ts.URL}) + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(time.Millisecond, cancel) + req = req.WithContext(ctx) + + q := make(chan Event) + go func() { + defer close(q) + c.doStreaming(req, q) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + }() + var passError bool + for e := range q { + if event, ok := e.(*ErrorEvent); ok { + passError = true + if event.err == nil { + t.Fatalf("should be fail: %v", event.err) + } + } + } + if !passError { + t.Fatalf("have not passed through: error %t", passError) + } +} + +func TestStreamingUser(t *testing.T) { + var isEnd bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isEnd { + return + } else if r.URL.Path != "/api/v1/streaming/user" { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + f, _ := w.(http.Flusher) + fmt.Fprintln(w, ` +event: update +data: {"content": "foo"} + `) + f.Flush() + isEnd = true + })) + defer ts.Close() + + c := NewClient(&Config{Server: ts.URL}) + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(time.Second, cancel) + q, err := c.StreamingUser(ctx) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + events := []Event{} + for e := range q { + if _, ok := e.(*ErrorEvent); !ok { + events = append(events, e) + } + } + if len(events) != 1 { + t.Fatalf("result should be one: %d", len(events)) + } + if events[0].(*UpdateEvent).Status.Content != "foo" { + t.Fatalf("want %q but %q", "foo", events[0].(*UpdateEvent).Status.Content) + } +} + +func TestStreamingPublic(t *testing.T) { + var isEnd bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isEnd { + return + } else if r.URL.Path != "/api/v1/streaming/public/local" { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } @@ -27,7 +221,7 @@ event: update data: {"content": "bar"} `) f.Flush() - return + isEnd = true })) defer ts.Close() @@ -38,17 +232,16 @@ data: {"content": "bar"} AccessToken: "zoo", }) ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingPublic(ctx, false) + q, err := client.StreamingPublic(ctx, true) if err != nil { t.Fatalf("should not be fail: %v", err) } - time.AfterFunc(3*time.Second, func() { - cancel() - close(q) - }) + time.AfterFunc(time.Second, cancel) events := []Event{} for e := range q { - events = append(events, e) + if _, ok := e.(*ErrorEvent); !ok { + events = append(events, e) + } } if len(events) != 2 { t.Fatalf("result should be two: %d", len(events)) @@ -60,3 +253,43 @@ data: {"content": "bar"} t.Fatalf("want %q but %q", "bar", events[1].(*UpdateEvent).Status.Content) } } + +func TestStreamingHashtag(t *testing.T) { + var isEnd bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isEnd { + return + } else if r.URL.Path != "/api/v1/streaming/hashtag/local" { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + f, _ := w.(http.Flusher) + fmt.Fprintln(w, ` +event: update +data: {"content": "foo"} + `) + f.Flush() + isEnd = true + })) + defer ts.Close() + + client := NewClient(&Config{Server: ts.URL}) + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(time.Second, cancel) + q, err := client.StreamingHashtag(ctx, "hashtag", true) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + events := []Event{} + for e := range q { + if _, ok := e.(*ErrorEvent); !ok { + events = append(events, e) + } + } + if len(events) != 1 { + t.Fatalf("result should be one: %d", len(events)) + } + if events[0].(*UpdateEvent).Status.Content != "foo" { + t.Fatalf("want %q but %q", "foo", events[0].(*UpdateEvent).Status.Content) + } +} diff --git a/streaming_ws.go b/streaming_ws.go index 022928c..5553a25 100644 --- a/streaming_ws.go +++ b/streaming_ws.go @@ -66,6 +66,7 @@ func (c *WSClient) streamingWS(ctx context.Context, stream, tag string) (chan Ev q := make(chan Event) go func() { + defer close(q) for { err := c.handleWS(ctx, u.String(), q) if err != nil { @@ -85,7 +86,12 @@ func (c *WSClient) handleWS(ctx context.Context, rawurl string, q chan Event) er // End. return err } - defer conn.Close() + + // Close the WebSocket when the context is canceled. + go func() { + <-ctx.Done() + conn.Close() + }() for { select { diff --git a/streaming_ws_test.go b/streaming_ws_test.go index c13a7cc..f52477c 100644 --- a/streaming_ws_test.go +++ b/streaming_ws_test.go @@ -10,20 +10,6 @@ import ( "github.com/gorilla/websocket" ) -func TestStreamingWSPublic(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(wsMock)) - defer ts.Close() - - client := NewClient(&Config{Server: ts.URL}).NewWSClient() - ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingWSPublic(ctx, false) - if err != nil { - t.Fatalf("should not be fail: %v", err) - } - - wsTest(t, q, cancel) -} - func TestStreamingWSUser(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() @@ -38,6 +24,20 @@ func TestStreamingWSUser(t *testing.T) { wsTest(t, q, cancel) } +func TestStreamingWSPublic(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(wsMock)) + defer ts.Close() + + client := NewClient(&Config{Server: ts.URL}).NewWSClient() + ctx, cancel := context.WithCancel(context.Background()) + q, err := client.StreamingWSPublic(ctx, false) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + + wsTest(t, q, cancel) +} + func TestStreamingWSHashtag(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() @@ -106,13 +106,12 @@ func wsMock(w http.ResponseWriter, r *http.Request) { func wsTest(t *testing.T, q chan Event, cancel func()) { time.AfterFunc(time.Second, func() { cancel() - close(q) }) events := []Event{} for e := range q { events = append(events, e) } - if len(events) != 4 { + if len(events) != 6 { t.Fatalf("result should be four: %d", len(events)) } if events[0].(*UpdateEvent).Status.Content != "foo" { @@ -127,6 +126,12 @@ func wsTest(t *testing.T, q chan Event, cancel func()) { if errorEvent, ok := events[3].(*ErrorEvent); !ok { t.Fatalf("should be fail: %v", errorEvent.err) } + if errorEvent, ok := events[4].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } + if errorEvent, ok := events[5].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } } func TestStreamingWS(t *testing.T) {