diff --git a/cmd/mstdn/cmd_stream.go b/cmd/mstdn/cmd_stream.go index 80747c4..dac8865 100644 --- a/cmd/mstdn/cmd_stream.go +++ b/cmd/mstdn/cmd_stream.go @@ -64,13 +64,13 @@ func cmdStream(c *cli.Context) error { t := c.String("type") if t == "public" { - q, err = client.StreamingPublic(ctx) + q, err = client.StreamingPublic(ctx, false) } else if t == "" || t == "public/local" { - q, err = client.StreamingPublicLocal(ctx) + q, err = client.StreamingPublic(ctx, true) } else if strings.HasPrefix(t, "user:") { - q, err = client.StreamingUser(ctx, t[5:]) + q, err = client.StreamingUser(ctx) } else if strings.HasPrefix(t, "hashtag:") { - q, err = client.StreamingHashtag(ctx, t[8:]) + q, err = client.StreamingHashtag(ctx, t[8:], false) } else { return errors.New("invalid type") } diff --git a/status.go b/status.go index cf2f036..4c28b33 100644 --- a/status.go +++ b/status.go @@ -173,9 +173,14 @@ func (c *Client) GetTimelinePublic(ctx context.Context, isLocal bool) ([]*Status } // GetTimelineHashtag return statuses from tagged timeline. -func (c *Client) GetTimelineHashtag(ctx context.Context, tag string) ([]*Status, error) { +func (c *Client) GetTimelineHashtag(ctx context.Context, tag string, isLocal bool) ([]*Status, error) { + params := url.Values{} + if isLocal { + params.Set("local", "t") + } + var statuses []*Status - err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/timelines/tag/%s", (&url.URL{Path: tag}).EscapedPath()), nil, &statuses, nil) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/timelines/tag/%s", url.PathEscape(tag)), params, &statuses, nil) if err != nil { return nil, err } diff --git a/status_test.go b/status_test.go index a511afe..bfc4a41 100644 --- a/status_test.go +++ b/status_test.go @@ -374,11 +374,11 @@ func TestGetTimelineHashtag(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.GetTimelineHashtag(context.Background(), "notfound") + _, err := client.GetTimelineHashtag(context.Background(), "notfound", false) if err == nil { t.Fatalf("should be fail: %v", err) } - tags, err := client.GetTimelineHashtag(context.Background(), "zzz") + tags, err := client.GetTimelineHashtag(context.Background(), "zzz", true) if err != nil { t.Fatalf("should not be fail: %v", err) } diff --git a/streaming.go b/streaming.go index 0bff3e3..897be4c 100644 --- a/streaming.go +++ b/streaming.go @@ -86,7 +86,15 @@ func (c *Client) streaming(ctx context.Context, p string, params url.Values) (ch if err != nil { return nil, err } - u.Path = path.Join(u.Path, "/api/v1/streaming/"+p) + u.Path = path.Join(u.Path, "/api/v1/streaming", p) + u.RawQuery = params.Encode() + + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + req.Header.Set("Authorization", "Bearer "+c.config.AccessToken) var resp *http.Response @@ -95,18 +103,9 @@ func (c *Client) streaming(ctx context.Context, p string, params url.Values) (ch defer ctx.Done() for { - var in io.Reader - if params != nil { - in = strings.NewReader(params.Encode()) - } - req, err := http.NewRequest(http.MethodGet, u.String(), in) - if err == nil { - req = req.WithContext(ctx) - req.Header.Set("Authorization", "Bearer "+c.config.AccessToken) - resp, err = c.Do(req) - if resp != nil && resp.StatusCode != http.StatusOK { - err = parseAPIError("bad request", resp) - } + resp, err = c.Do(req) + if resp != nil && resp.StatusCode != http.StatusOK { + err = parseAPIError("bad request", resp) } if err == nil { err = handleReader(ctx, q, resp.Body) @@ -129,28 +128,30 @@ func (c *Client) streaming(ctx context.Context, p string, params url.Values) (ch return q, nil } -// StreamingPublic return channel to read events on public. -func (c *Client) StreamingPublic(ctx context.Context) (chan Event, error) { - params := url.Values{} - return c.streaming(ctx, "public", params) -} - -// StreamingPublicLocal return channel to read events on public. -func (c *Client) StreamingPublicLocal(ctx context.Context) (chan Event, error) { - params := url.Values{} - return c.streaming(ctx, "public/local", params) -} - // StreamingUser return channel to read events on home. -func (c *Client) StreamingUser(ctx context.Context, user string) (chan Event, error) { - params := url.Values{} - params.Set("user", user) - return c.streaming(ctx, "user", params) +func (c *Client) StreamingUser(ctx context.Context) (chan Event, error) { + return c.streaming(ctx, "user", nil) +} + +// StreamingPublic return channel to read events on public. +func (c *Client) StreamingPublic(ctx context.Context, isLocal bool) (chan Event, error) { + p := "public" + if isLocal { + p = path.Join(p, "local") + } + + return c.streaming(ctx, p, nil) } // StreamingHashtag return channel to read events on tagged timeline. -func (c *Client) StreamingHashtag(ctx context.Context, tag string) (chan Event, error) { +func (c *Client) StreamingHashtag(ctx context.Context, tag string, isLocal bool) (chan Event, error) { params := url.Values{} params.Set("tag", tag) - return c.streaming(ctx, "hashtag", params) + + p := "hashtag" + if isLocal { + p = path.Join(p, "local") + } + + return c.streaming(ctx, p, params) } diff --git a/streaming_test.go b/streaming_test.go index 838d61d..6e6153d 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -38,7 +38,7 @@ data: {"content": "bar"} AccessToken: "zoo", }) ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingPublic(ctx) + q, err := client.StreamingPublic(ctx, false) if err != nil { t.Fatalf("should not be fail: %v", err) } diff --git a/streaming_ws.go b/streaming_ws.go index 915465f..022928c 100644 --- a/streaming_ws.go +++ b/streaming_ws.go @@ -24,29 +24,29 @@ type Stream struct { Payload interface{} `json:"payload"` } -// StreamingWSPublic return channel to read events on public using WebSocket. -func (c *WSClient) StreamingWSPublic(ctx context.Context) (chan Event, error) { - return c.streamingWS(ctx, "public", "") -} - -// StreamingWSPublicLocal return channel to read events on public local using WebSocket. -func (c *WSClient) StreamingWSPublicLocal(ctx context.Context) (chan Event, error) { - return c.streamingWS(ctx, "public:local", "") -} - // StreamingWSUser return channel to read events on home using WebSocket. func (c *WSClient) StreamingWSUser(ctx context.Context) (chan Event, error) { return c.streamingWS(ctx, "user", "") } -// StreamingWSHashtag return channel to read events on tagged timeline using WebSocket. -func (c *WSClient) StreamingWSHashtag(ctx context.Context, tag string) (chan Event, error) { - return c.streamingWS(ctx, "hashtag", tag) +// StreamingWSPublic return channel to read events on public using WebSocket. +func (c *WSClient) StreamingWSPublic(ctx context.Context, isLocal bool) (chan Event, error) { + s := "public" + if isLocal { + s += ":local" + } + + return c.streamingWS(ctx, s, "") } -// StreamingWSHashtagLocal return channel to read events on tagged local timeline using WebSocket. -func (c *WSClient) StreamingWSHashtagLocal(ctx context.Context, tag string) (chan Event, error) { - return c.streamingWS(ctx, "hashtag:local", tag) +// StreamingWSHashtag return channel to read events on tagged timeline using WebSocket. +func (c *WSClient) StreamingWSHashtag(ctx context.Context, tag string, isLocal bool) (chan Event, error) { + s := "hashtag" + if isLocal { + s += ":local" + } + + return c.streamingWS(ctx, s, tag) } func (c *WSClient) streamingWS(ctx context.Context, stream, tag string) (chan Event, error) { diff --git a/streaming_ws_test.go b/streaming_ws_test.go index 5456a5d..c13a7cc 100644 --- a/streaming_ws_test.go +++ b/streaming_ws_test.go @@ -16,21 +16,7 @@ func TestStreamingWSPublic(t *testing.T) { client := NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingWSPublic(ctx) - if err != nil { - t.Fatalf("should not be fail: %v", err) - } - - wsTest(t, q, cancel) -} - -func TestStreamingWSPublicLocal(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(wsMock)) - defer ts.Close() - - client := NewClient(&Config{Server: ts.URL}).NewWSClient() - ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingWSPublicLocal(ctx) + q, err := client.StreamingWSPublic(ctx, false) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -58,25 +44,17 @@ func TestStreamingWSHashtag(t *testing.T) { client := NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingWSHashtag(ctx, "zzz") + q, err := client.StreamingWSHashtag(ctx, "zzz", true) if err != nil { t.Fatalf("should not be fail: %v", err) } - wsTest(t, q, cancel) -} -func TestStreamingWSHashtagLocal(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(wsMock)) - defer ts.Close() - - client := NewClient(&Config{Server: ts.URL}).NewWSClient() - ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingWSHashtagLocal(ctx, "zzz") + ctx, cancel = context.WithCancel(context.Background()) + q, err = client.StreamingWSHashtag(ctx, "zzz", false) if err != nil { t.Fatalf("should not be fail: %v", err) } - wsTest(t, q, cancel) } @@ -135,7 +113,7 @@ func wsTest(t *testing.T, q chan Event, cancel func()) { events = append(events, e) } if len(events) != 4 { - t.Fatalf("result should be two: %d", len(events)) + t.Fatalf("result should be four: %d", len(events)) } if events[0].(*UpdateEvent).Status.Content != "foo" { t.Fatalf("want %q but %q", "foo", events[0].(*UpdateEvent).Status.Content) @@ -156,7 +134,7 @@ func TestStreamingWS(t *testing.T) { defer ts.Close() client := NewClient(&Config{Server: ":"}).NewWSClient() - _, err := client.StreamingWSPublicLocal(context.Background()) + _, err := client.StreamingWSPublic(context.Background(), true) if err == nil { t.Fatalf("should be fail: %v", err) } @@ -164,7 +142,7 @@ func TestStreamingWS(t *testing.T) { client = NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) cancel() - q, err := client.StreamingWSPublicLocal(ctx) + q, err := client.StreamingWSPublic(ctx, true) if err != nil { t.Fatalf("should not be fail: %v", err) }