diff --git a/mastodon.go b/mastodon.go index c704e10..fadf7c7 100644 --- a/mastodon.go +++ b/mastodon.go @@ -24,11 +24,17 @@ type Config struct { AccessToken string } +type WriterResetter interface { + io.Writer + Reset() +} + // Client is a API client for mastodon. type Client struct { http.Client - Config *Config - UserAgent string + Config *Config + UserAgent string + JSONWriter io.Writer } func (c *Client) doAPI(ctx context.Context, method string, uri string, params interface{}, res interface{}, pg *Pagination) error { @@ -125,7 +131,15 @@ func (c *Client) doAPI(ctx context.Context, method string, uri string, params in *pg = *pg2 } } - return json.NewDecoder(resp.Body).Decode(&res) + + if c.JSONWriter != nil { + if resetter, ok := c.JSONWriter.(WriterResetter); ok { + resetter.Reset() + } + return json.NewDecoder(io.TeeReader(resp.Body, c.JSONWriter)).Decode(&res) + } else { + return json.NewDecoder(resp.Body).Decode(&res) + } } // NewClient returns a new mastodon API client. diff --git a/status_test.go b/status_test.go index 2cac4b8..c520b82 100644 --- a/status_test.go +++ b/status_test.go @@ -1,12 +1,14 @@ package mastodon import ( + "bytes" "context" "fmt" "io/ioutil" "net/http" "net/http/httptest" "os" + "strings" "testing" ) @@ -37,6 +39,110 @@ func TestGetFavourites(t *testing.T) { } } +func TestGetFavouritesSavedJSONTwice(t *testing.T) { + ourJSON := `[{"content": "foo"}, {"content": "bar"}]` + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, ourJSON) + })) + defer ts.Close() + + client := NewClient(&Config{ + Server: ts.URL, + ClientID: "foo", + ClientSecret: "bar", + AccessToken: "zoo", + }) + + var buf bytes.Buffer + client.JSONWriter = &buf + + favs, err := client.GetFavourites(context.Background(), nil) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + if len(favs) != 2 { + t.Fatalf("result should be two: %d", len(favs)) + } + if favs[0].Content != "foo" { + t.Fatalf("want %q but %q", "foo", favs[0].Content) + } + if favs[1].Content != "bar" { + t.Fatalf("want %q but %q", "bar", favs[1].Content) + } + + // We get a trailing `\n` from the API which we need to trim + // off before we compare it with our literal above. + theirJSON := strings.TrimSpace(string(buf.Bytes())) + + if theirJSON != ourJSON { + t.Fatalf("want %q but %q", ourJSON, theirJSON) + } + + // Now we call the API again to see if we get the same or doubled JSON. + + favs, err = client.GetFavourites(context.Background(), nil) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + if len(favs) != 2 { + t.Fatalf("result should be two: %d", len(favs)) + } + if favs[0].Content != "foo" { + t.Fatalf("want %q but %q", "foo", favs[0].Content) + } + if favs[1].Content != "bar" { + t.Fatalf("want %q but %q", "bar", favs[1].Content) + } + + // We get a trailing `\n` from the API which we need to trim + // off before we compare it with our literal above. + theirJSON = strings.TrimSpace(string(buf.Bytes())) + + if theirJSON != ourJSON { + t.Fatalf("want %q but %q", ourJSON, theirJSON) + } +} + +func TestGetFavouritesSavedJSON(t *testing.T) { + ourJSON := `[{"content": "foo"}, {"content": "bar"}]` + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, ourJSON) + })) + defer ts.Close() + + client := NewClient(&Config{ + Server: ts.URL, + ClientID: "foo", + ClientSecret: "bar", + AccessToken: "zoo", + }) + + var buf bytes.Buffer + client.JSONWriter = &buf + + favs, err := client.GetFavourites(context.Background(), nil) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + if len(favs) != 2 { + t.Fatalf("result should be two: %d", len(favs)) + } + if favs[0].Content != "foo" { + t.Fatalf("want %q but %q", "foo", favs[0].Content) + } + if favs[1].Content != "bar" { + t.Fatalf("want %q but %q", "bar", favs[1].Content) + } + + // We get a trailing `\n` from the API which we need to trim + // off before we compare it with our literal above. + theirJSON := strings.TrimSpace(string(buf.Bytes())) + + if theirJSON != ourJSON { + t.Fatalf("want %q but %q", ourJSON, theirJSON) + } +} + func TestGetBookmarks(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, `[{"content": "foo"}, {"content": "bar"}]`)