diff --git a/mastodon.go b/mastodon.go index 385efcb..fadf7c7 100644 --- a/mastodon.go +++ b/mastodon.go @@ -24,6 +24,11 @@ type Config struct { AccessToken string } +type WriterResetter interface { + io.Writer + Reset() +} + // Client is a API client for mastodon. type Client struct { http.Client @@ -128,6 +133,9 @@ func (c *Client) doAPI(ctx context.Context, method string, uri string, params in } if c.JSONWriter != nil { + if resetter, ok := c.JSONWriter.(WriterResetter); ok { + resetter.Reset() + } return json.NewDecoder(io.TeeReader(resp.Body, c.JSONWriter)).Decode(&res) } else { return json.NewDecoder(resp.Body).Decode(&res) diff --git a/status_test.go b/status_test.go index a54b931..c520b82 100644 --- a/status_test.go +++ b/status_test.go @@ -39,6 +39,70 @@ func TestGetFavourites(t *testing.T) { } } +func TestGetFavouritesSavedJSONTwice(t *testing.T) { + ourJSON := `[{"content": "foo"}, {"content": "bar"}]` + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, ourJSON) + })) + defer ts.Close() + + client := NewClient(&Config{ + Server: ts.URL, + ClientID: "foo", + ClientSecret: "bar", + AccessToken: "zoo", + }) + + var buf bytes.Buffer + client.JSONWriter = &buf + + favs, err := client.GetFavourites(context.Background(), nil) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + if len(favs) != 2 { + t.Fatalf("result should be two: %d", len(favs)) + } + if favs[0].Content != "foo" { + t.Fatalf("want %q but %q", "foo", favs[0].Content) + } + if favs[1].Content != "bar" { + t.Fatalf("want %q but %q", "bar", favs[1].Content) + } + + // We get a trailing `\n` from the API which we need to trim + // off before we compare it with our literal above. + theirJSON := strings.TrimSpace(string(buf.Bytes())) + + if theirJSON != ourJSON { + t.Fatalf("want %q but %q", ourJSON, theirJSON) + } + + // Now we call the API again to see if we get the same or doubled JSON. + + favs, err = client.GetFavourites(context.Background(), nil) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + if len(favs) != 2 { + t.Fatalf("result should be two: %d", len(favs)) + } + if favs[0].Content != "foo" { + t.Fatalf("want %q but %q", "foo", favs[0].Content) + } + if favs[1].Content != "bar" { + t.Fatalf("want %q but %q", "bar", favs[1].Content) + } + + // We get a trailing `\n` from the API which we need to trim + // off before we compare it with our literal above. + theirJSON = strings.TrimSpace(string(buf.Bytes())) + + if theirJSON != ourJSON { + t.Fatalf("want %q but %q", ourJSON, theirJSON) + } +} + func TestGetFavouritesSavedJSON(t *testing.T) { ourJSON := `[{"content": "foo"}, {"content": "bar"}]` ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {