From 5e84b57bf33e328a8c93f635333d4a90ddc77076 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Mon, 17 Apr 2017 11:10:29 +0900 Subject: [PATCH 1/2] breaking compatibility changes. take context for first arguments. --- accounts.go | 77 ++++++++++++++++++----------------- accounts_test.go | 25 ++++++------ apps.go | 3 +- apps_test.go | 3 +- cmd/mstdn/cmd_account.go | 3 +- cmd/mstdn/cmd_followers.go | 5 ++- cmd/mstdn/cmd_instance.go | 3 +- cmd/mstdn/cmd_notification.go | 3 +- cmd/mstdn/cmd_search.go | 3 +- cmd/mstdn/cmd_timeline.go | 3 +- cmd/mstdn/cmd_toot.go | 3 +- cmd/mstdn/main.go | 3 +- instance.go | 9 ++-- mastodon.go | 46 ++++++++++++++------- mastodon_test.go | 21 +++++----- notification.go | 13 +++--- status.go | 61 +++++++++++++-------------- status_test.go | 31 +++++++------- 18 files changed, 176 insertions(+), 139 deletions(-) diff --git a/accounts.go b/accounts.go index 8238791..f5d933b 100644 --- a/accounts.go +++ b/accounts.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "fmt" "net/http" "net/url" @@ -27,9 +28,9 @@ type Account struct { } // GetAccount return Account. -func (c *Client) GetAccount(id int) (*Account, error) { +func (c *Client) GetAccount(ctx context.Context, id int) (*Account, error) { var account Account - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d", id), nil, &account) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d", id), nil, &account) if err != nil { return nil, err } @@ -37,9 +38,9 @@ func (c *Client) GetAccount(id int) (*Account, error) { } // GetAccountCurrentUser return Account of current user. -func (c *Client) GetAccountCurrentUser() (*Account, error) { +func (c *Client) GetAccountCurrentUser(ctx context.Context) (*Account, error) { var account Account - err := c.doAPI(http.MethodGet, "/api/v1/accounts/verify_credentials", nil, &account) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/accounts/verify_credentials", nil, &account) if err != nil { return nil, err } @@ -59,7 +60,7 @@ type Profile struct { } // AccountUpdate updates the information of the current user. -func (c *Client) AccountUpdate(profile *Profile) (*Account, error) { +func (c *Client) AccountUpdate(ctx context.Context, profile *Profile) (*Account, error) { params := url.Values{} if profile.DisplayName != nil { params.Set("display_name", *profile.DisplayName) @@ -75,7 +76,7 @@ func (c *Client) AccountUpdate(profile *Profile) (*Account, error) { } var account Account - err := c.doAPI(http.MethodPatch, "/api/v1/accounts/update_credentials", params, &account) + err := c.doAPI(ctx, http.MethodPatch, "/api/v1/accounts/update_credentials", params, &account) if err != nil { return nil, err } @@ -83,9 +84,9 @@ func (c *Client) AccountUpdate(profile *Profile) (*Account, error) { } // GetAccountStatuses return statuses by specified accuont. -func (c *Client) GetAccountStatuses(id int64) ([]*Status, error) { +func (c *Client) GetAccountStatuses(ctx context.Context, id int64) ([]*Status, error) { var statuses []*Status - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d/statuses", id), nil, &statuses) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d/statuses", id), nil, &statuses) if err != nil { return nil, err } @@ -93,9 +94,9 @@ func (c *Client) GetAccountStatuses(id int64) ([]*Status, error) { } // GetAccountFollowers return followers list. -func (c *Client) GetAccountFollowers(id int64) ([]*Account, error) { +func (c *Client) GetAccountFollowers(ctx context.Context, id int64) ([]*Account, error) { var accounts []*Account - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d/followers", id), nil, &accounts) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d/followers", id), nil, &accounts) if err != nil { return nil, err } @@ -103,9 +104,9 @@ func (c *Client) GetAccountFollowers(id int64) ([]*Account, error) { } // GetAccountFollowing return following list. -func (c *Client) GetAccountFollowing(id int64) ([]*Account, error) { +func (c *Client) GetAccountFollowing(ctx context.Context, id int64) ([]*Account, error) { var accounts []*Account - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d/following", id), nil, &accounts) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/accounts/%d/following", id), nil, &accounts) if err != nil { return nil, err } @@ -113,9 +114,9 @@ func (c *Client) GetAccountFollowing(id int64) ([]*Account, error) { } // GetBlocks return block list. -func (c *Client) GetBlocks() ([]*Account, error) { +func (c *Client) GetBlocks(ctx context.Context) ([]*Account, error) { var accounts []*Account - err := c.doAPI(http.MethodGet, "/api/v1/blocks", nil, &accounts) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/blocks", nil, &accounts) if err != nil { return nil, err } @@ -133,9 +134,9 @@ type Relationship struct { } // AccountFollow follow the account. -func (c *Client) AccountFollow(id int64) (*Relationship, error) { +func (c *Client) AccountFollow(ctx context.Context, id int64) (*Relationship, error) { var relationship Relationship - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/follow", id), nil, &relationship) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/follow", id), nil, &relationship) if err != nil { return nil, err } @@ -143,9 +144,9 @@ func (c *Client) AccountFollow(id int64) (*Relationship, error) { } // AccountUnfollow unfollow the account. -func (c *Client) AccountUnfollow(id int64) (*Relationship, error) { +func (c *Client) AccountUnfollow(ctx context.Context, id int64) (*Relationship, error) { var relationship Relationship - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/unfollow", id), nil, &relationship) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/unfollow", id), nil, &relationship) if err != nil { return nil, err } @@ -153,9 +154,9 @@ func (c *Client) AccountUnfollow(id int64) (*Relationship, error) { } // AccountBlock block the account. -func (c *Client) AccountBlock(id int64) (*Relationship, error) { +func (c *Client) AccountBlock(ctx context.Context, id int64) (*Relationship, error) { var relationship Relationship - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/block", id), nil, &relationship) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/block", id), nil, &relationship) if err != nil { return nil, err } @@ -163,9 +164,9 @@ func (c *Client) AccountBlock(id int64) (*Relationship, error) { } // AccountUnblock unblock the account. -func (c *Client) AccountUnblock(id int64) (*Relationship, error) { +func (c *Client) AccountUnblock(ctx context.Context, id int64) (*Relationship, error) { var relationship Relationship - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/unblock", id), nil, &relationship) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/unblock", id), nil, &relationship) if err != nil { return nil, err } @@ -173,9 +174,9 @@ func (c *Client) AccountUnblock(id int64) (*Relationship, error) { } // AccountMute mute the account. -func (c *Client) AccountMute(id int64) (*Relationship, error) { +func (c *Client) AccountMute(ctx context.Context, id int64) (*Relationship, error) { var relationship Relationship - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/mute", id), nil, &relationship) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/mute", id), nil, &relationship) if err != nil { return nil, err } @@ -183,9 +184,9 @@ func (c *Client) AccountMute(id int64) (*Relationship, error) { } // AccountUnmute unmute the account. -func (c *Client) AccountUnmute(id int64) (*Relationship, error) { +func (c *Client) AccountUnmute(ctx context.Context, id int64) (*Relationship, error) { var relationship Relationship - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/unmute", id), nil, &relationship) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%d/unmute", id), nil, &relationship) if err != nil { return nil, err } @@ -193,12 +194,12 @@ func (c *Client) AccountUnmute(id int64) (*Relationship, error) { } // GetAccountRelationship return relationship for the account. -func (c *Client) GetAccountRelationship(id int64) ([]*Relationship, error) { +func (c *Client) GetAccountRelationship(ctx context.Context, id int64) ([]*Relationship, error) { params := url.Values{} params.Set("id", fmt.Sprint(id)) var relationships []*Relationship - err := c.doAPI(http.MethodGet, "/api/v1/accounts/relationship", params, &relationships) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/accounts/relationship", params, &relationships) if err != nil { return nil, err } @@ -206,13 +207,13 @@ func (c *Client) GetAccountRelationship(id int64) ([]*Relationship, error) { } // AccountsSearch search accounts by query. -func (c *Client) AccountsSearch(q string, limit int64) ([]*Account, error) { +func (c *Client) AccountsSearch(ctx context.Context, q string, limit int64) ([]*Account, error) { params := url.Values{} params.Set("q", q) params.Set("limit", fmt.Sprint(limit)) var accounts []*Account - err := c.doAPI(http.MethodGet, "/api/v1/accounts/search", params, &accounts) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/accounts/search", params, &accounts) if err != nil { return nil, err } @@ -220,12 +221,12 @@ func (c *Client) AccountsSearch(q string, limit int64) ([]*Account, error) { } // FollowRemoteUser send follow-request. -func (c *Client) FollowRemoteUser(uri string) (*Account, error) { +func (c *Client) FollowRemoteUser(ctx context.Context, uri string) (*Account, error) { params := url.Values{} params.Set("uri", uri) var account Account - err := c.doAPI(http.MethodPost, "/api/v1/follows", params, &account) + err := c.doAPI(ctx, http.MethodPost, "/api/v1/follows", params, &account) if err != nil { return nil, err } @@ -233,9 +234,9 @@ func (c *Client) FollowRemoteUser(uri string) (*Account, error) { } // GetFollowRequests return follow-requests. -func (c *Client) GetFollowRequests() ([]*Account, error) { +func (c *Client) GetFollowRequests(ctx context.Context) ([]*Account, error) { var accounts []*Account - err := c.doAPI(http.MethodGet, "/api/v1/follow_requests", nil, &accounts) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/follow_requests", nil, &accounts) if err != nil { return nil, err } @@ -243,11 +244,11 @@ func (c *Client) GetFollowRequests() ([]*Account, error) { } // FollowRequestAuthorize is authorize the follow request of user with id. -func (c *Client) FollowRequestAuthorize(id int64) error { - return c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/follow_requests/%d/authorize", id), nil, nil) +func (c *Client) FollowRequestAuthorize(ctx context.Context, id int64) error { + return c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/follow_requests/%d/authorize", id), nil, nil) } // FollowRequestReject is rejects the follow request of user with id. -func (c *Client) FollowRequestReject(id int64) error { - return c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/follow_requests/%d/reject", id), nil, nil) +func (c *Client) FollowRequestReject(ctx context.Context, id int64) error { + return c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/follow_requests/%d/reject", id), nil, nil) } diff --git a/accounts_test.go b/accounts_test.go index b891225..e5d3497 100644 --- a/accounts_test.go +++ b/accounts_test.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -20,7 +21,7 @@ func TestAccountUpdate(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - a, err := client.AccountUpdate(&Profile{ + a, err := client.AccountUpdate(context.Background(), &Profile{ DisplayName: String("display_name"), Note: String("note"), Avatar: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAUoAAADrCAYAAAA...", @@ -47,7 +48,7 @@ func TestGetBlocks(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - bl, err := client.GetBlocks() + bl, err := client.GetBlocks(context.Background()) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -79,11 +80,11 @@ func TestAccountFollow(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - rel, err := client.AccountFollow(123) + rel, err := client.AccountFollow(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - rel, err = client.AccountFollow(1234567) + rel, err = client.AccountFollow(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -112,11 +113,11 @@ func TestAccountUnfollow(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - rel, err := client.AccountUnfollow(123) + rel, err := client.AccountUnfollow(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - rel, err = client.AccountUnfollow(1234567) + rel, err = client.AccountUnfollow(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -147,11 +148,11 @@ func TestGetFollowRequests(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.GetFollowRequests() + _, err := client.GetFollowRequests(context.Background()) if err == nil { t.Fatalf("should be fail: %v", err) } - fReqs, err := client.GetFollowRequests() + fReqs, err := client.GetFollowRequests(context.Background()) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -180,11 +181,11 @@ func TestFollowRequestAuthorize(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - err := client.FollowRequestAuthorize(123) + err := client.FollowRequestAuthorize(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - err = client.FollowRequestAuthorize(1234567) + err = client.FollowRequestAuthorize(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -204,11 +205,11 @@ func TestFollowRequestReject(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - err := client.FollowRequestReject(123) + err := client.FollowRequestReject(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - err = client.FollowRequestReject(1234567) + err = client.FollowRequestReject(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } diff --git a/apps.go b/apps.go index 4f2e4a4..1dd50a5 100644 --- a/apps.go +++ b/apps.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "encoding/json" "fmt" "net/http" @@ -34,7 +35,7 @@ type Application struct { } // RegisterApp returns the mastodon application. -func RegisterApp(appConfig *AppConfig) (*Application, error) { +func RegisterApp(ctx context.Context, appConfig *AppConfig) (*Application, error) { params := url.Values{} params.Set("client_name", appConfig.ClientName) if appConfig.RedirectURIs == "" { diff --git a/apps_test.go b/apps_test.go index 64be57b..177440e 100644 --- a/apps_test.go +++ b/apps_test.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -26,7 +27,7 @@ func TestRegisterApp(t *testing.T) { })) defer ts.Close() - app, err := RegisterApp(&AppConfig{ + app, err := RegisterApp(context.Background(), &AppConfig{ Server: ts.URL, Scopes: "read write follow", }) diff --git a/cmd/mstdn/cmd_account.go b/cmd/mstdn/cmd_account.go index e7caa2e..ceff781 100644 --- a/cmd/mstdn/cmd_account.go +++ b/cmd/mstdn/cmd_account.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/mattn/go-mastodon" @@ -9,7 +10,7 @@ import ( func cmdAccount(c *cli.Context) error { client := c.App.Metadata["client"].(*mastodon.Client) - account, err := client.GetAccountCurrentUser() + account, err := client.GetAccountCurrentUser(context.Background()) if err != nil { return err } diff --git a/cmd/mstdn/cmd_followers.go b/cmd/mstdn/cmd_followers.go index aab07ef..20fc7f2 100644 --- a/cmd/mstdn/cmd_followers.go +++ b/cmd/mstdn/cmd_followers.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/mattn/go-mastodon" @@ -9,11 +10,11 @@ import ( func cmdFollowers(c *cli.Context) error { client := c.App.Metadata["client"].(*mastodon.Client) - account, err := client.GetAccountCurrentUser() + account, err := client.GetAccountCurrentUser(context.Background()) if err != nil { return err } - followers, err := client.GetAccountFollowers(account.ID) + followers, err := client.GetAccountFollowers(context.Background(), account.ID) if err != nil { return err } diff --git a/cmd/mstdn/cmd_instance.go b/cmd/mstdn/cmd_instance.go index 752eef4..b287efc 100644 --- a/cmd/mstdn/cmd_instance.go +++ b/cmd/mstdn/cmd_instance.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/mattn/go-mastodon" @@ -9,7 +10,7 @@ import ( func cmdInstance(c *cli.Context) error { client := c.App.Metadata["client"].(*mastodon.Client) - instance, err := client.GetInstance() + instance, err := client.GetInstance(context.Background()) if err != nil { return err } diff --git a/cmd/mstdn/cmd_notification.go b/cmd/mstdn/cmd_notification.go index 3b6e5a7..6936751 100644 --- a/cmd/mstdn/cmd_notification.go +++ b/cmd/mstdn/cmd_notification.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/fatih/color" @@ -10,7 +11,7 @@ import ( func cmdNotification(c *cli.Context) error { client := c.App.Metadata["client"].(*mastodon.Client) - notifications, err := client.GetNotifications() + notifications, err := client.GetNotifications(context.Background()) if err != nil { return err } diff --git a/cmd/mstdn/cmd_search.go b/cmd/mstdn/cmd_search.go index 344bbdd..5cd754b 100644 --- a/cmd/mstdn/cmd_search.go +++ b/cmd/mstdn/cmd_search.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" @@ -14,7 +15,7 @@ func cmdSearch(c *cli.Context) error { } client := c.App.Metadata["client"].(*mastodon.Client) - results, err := client.Search(argstr(c), false) + results, err := client.Search(context.Background(), argstr(c), false) if err != nil { return err } diff --git a/cmd/mstdn/cmd_timeline.go b/cmd/mstdn/cmd_timeline.go index 6bb2999..26f2142 100644 --- a/cmd/mstdn/cmd_timeline.go +++ b/cmd/mstdn/cmd_timeline.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/fatih/color" @@ -10,7 +11,7 @@ import ( func cmdTimeline(c *cli.Context) error { client := c.App.Metadata["client"].(*mastodon.Client) - timeline, err := client.GetTimelineHome() + timeline, err := client.GetTimelineHome(context.Background()) if err != nil { return err } diff --git a/cmd/mstdn/cmd_toot.go b/cmd/mstdn/cmd_toot.go index 747c35b..9a82534 100644 --- a/cmd/mstdn/cmd_toot.go +++ b/cmd/mstdn/cmd_toot.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "log" @@ -24,7 +25,7 @@ func cmdToot(c *cli.Context) error { toot = argstr(c) } client := c.App.Metadata["client"].(*mastodon.Client) - _, err := client.PostStatus(&mastodon.Toot{ + _, err := client.PostStatus(context.Background(), &mastodon.Toot{ Status: toot, }) return err diff --git a/cmd/mstdn/main.go b/cmd/mstdn/main.go index 1903c34..a3a3d15 100644 --- a/cmd/mstdn/main.go +++ b/cmd/mstdn/main.go @@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -128,7 +129,7 @@ func authenticate(client *mastodon.Client, config *mastodon.Config, file string) if err != nil { return err } - err = client.Authenticate(email, password) + err = client.Authenticate(context.Background(), email, password) if err != nil { return err } diff --git a/instance.go b/instance.go index 20be51f..b84b77b 100644 --- a/instance.go +++ b/instance.go @@ -1,6 +1,9 @@ package mastodon -import "net/http" +import ( + "context" + "net/http" +) // Instance hold information for mastodon instance. type Instance struct { @@ -11,9 +14,9 @@ type Instance struct { } // GetInstance return Instance. -func (c *Client) GetInstance() (*Instance, error) { +func (c *Client) GetInstance(ctx context.Context) (*Instance, error) { var instance Instance - err := c.doAPI(http.MethodGet, "/api/v1/instance", nil, &instance) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/instance", nil, &instance) if err != nil { return nil, err } diff --git a/mastodon.go b/mastodon.go index 19b452e..b9da000 100644 --- a/mastodon.go +++ b/mastodon.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "encoding/json" "fmt" "net/http" @@ -23,14 +24,30 @@ type Client struct { config *Config } -func (c *Client) doAPI(method string, uri string, params url.Values, res interface{}) error { +func httpDo(ctx context.Context, req *http.Request, f func(*http.Response, error) error) error { + tr := &http.Transport{} + client := &http.Client{Transport: tr} + c := make(chan error, 1) + go func() { + c <- f(client.Do(req)) + }() + select { + case <-ctx.Done(): + tr.CancelRequest(req) + <-c + return ctx.Err() + case err := <-c: + return err + } +} + +func (c *Client) doAPI(ctx context.Context, method string, uri string, params url.Values, res interface{}) error { u, err := url.Parse(c.config.Server) if err != nil { return err } u.Path = path.Join(u.Path, uri) - var resp *http.Response req, err := http.NewRequest(method, u.String(), strings.NewReader(params.Encode())) if err != nil { return err @@ -39,19 +56,20 @@ func (c *Client) doAPI(method string, uri string, params url.Values, res interfa if params != nil { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") } - resp, err = c.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad request: %v", resp.Status) - } else if res == nil { - return nil - } + return httpDo(ctx, req, func(resp *http.Response, err error) error { + if err != nil { + return err + } + defer resp.Body.Close() - return json.NewDecoder(resp.Body).Decode(&res) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad request: %v", resp.Status) + } else if res == nil { + return nil + } + return json.NewDecoder(resp.Body).Decode(&res) + }) } // NewClient return new mastodon API client. @@ -63,7 +81,7 @@ func NewClient(config *Config) *Client { } // Authenticate get access-token to the API. -func (c *Client) Authenticate(username, password string) error { +func (c *Client) Authenticate(ctx context.Context, username, password string) error { params := url.Values{} params.Set("client_id", c.config.ClientID) params.Set("client_secret", c.config.ClientSecret) diff --git a/mastodon_test.go b/mastodon_test.go index 1fbcf64..843df4c 100644 --- a/mastodon_test.go +++ b/mastodon_test.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "fmt" "io" "net/http" @@ -24,7 +25,7 @@ func TestAuthenticate(t *testing.T) { ClientID: "foo", ClientSecret: "bar", }) - err := client.Authenticate("invalid", "user") + err := client.Authenticate(context.Background(), "invalid", "user") if err == nil { t.Fatalf("should be fail: %v", err) } @@ -34,7 +35,7 @@ func TestAuthenticate(t *testing.T) { ClientID: "foo", ClientSecret: "bar", }) - err = client.Authenticate("valid", "user") + err = client.Authenticate(context.Background(), "valid", "user") if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -56,7 +57,7 @@ func TestPostStatus(t *testing.T) { ClientID: "foo", ClientSecret: "bar", }) - _, err := client.PostStatus(&Toot{ + _, err := client.PostStatus(context.Background(), &Toot{ Status: "foobar", }) if err == nil { @@ -69,7 +70,7 @@ func TestPostStatus(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err = client.PostStatus(&Toot{ + _, err = client.PostStatus(context.Background(), &Toot{ Status: "foobar", }) if err != nil { @@ -89,7 +90,7 @@ func TestGetTimelineHome(t *testing.T) { ClientID: "foo", ClientSecret: "bar", }) - _, err := client.PostStatus(&Toot{ + _, err := client.PostStatus(context.Background(), &Toot{ Status: "foobar", }) if err == nil { @@ -102,7 +103,7 @@ func TestGetTimelineHome(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - tl, err := client.GetTimelineHome() + tl, err := client.GetTimelineHome(context.Background()) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -142,11 +143,11 @@ func TestGetAccount(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - a, err := client.GetAccount(1) + a, err := client.GetAccount(context.Background(), 1) if err == nil { t.Fatalf("should not be fail: %v", err) } - a, err = client.GetAccount(1234567) + a, err = client.GetAccount(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -172,11 +173,11 @@ func TestGetAccountFollowing(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - fl, err := client.GetAccountFollowing(123) + fl, err := client.GetAccountFollowing(context.Background(), 123) if err == nil { t.Fatalf("should not be fail: %v", err) } - fl, err = client.GetAccountFollowing(1234567) + fl, err = client.GetAccountFollowing(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } diff --git a/notification.go b/notification.go index b8762c0..8fd098d 100644 --- a/notification.go +++ b/notification.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "fmt" "net/http" "time" @@ -16,9 +17,9 @@ type Notification struct { } // GetNotifications return notifications. -func (c *Client) GetNotifications() ([]*Notification, error) { +func (c *Client) GetNotifications(ctx context.Context) ([]*Notification, error) { var notifications []*Notification - err := c.doAPI(http.MethodGet, "/api/v1/notifications", nil, ¬ifications) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/notifications", nil, ¬ifications) if err != nil { return nil, err } @@ -26,9 +27,9 @@ func (c *Client) GetNotifications() ([]*Notification, error) { } // GetNotification return notifications. -func (c *Client) GetNotification(id int64) (*Notification, error) { +func (c *Client) GetNotification(ctx context.Context, id int64) (*Notification, error) { var notification Notification - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/notifications/%d", id), nil, ¬ification) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/notifications/%d", id), nil, ¬ification) if err != nil { return nil, err } @@ -36,6 +37,6 @@ func (c *Client) GetNotification(id int64) (*Notification, error) { } // ClearNotifications clear notifications. -func (c *Client) ClearNotifications() error { - return c.doAPI(http.MethodPost, "/api/v1/notifications/clear", nil, nil) +func (c *Client) ClearNotifications(ctx context.Context) error { + return c.doAPI(ctx, http.MethodPost, "/api/v1/notifications/clear", nil, nil) } diff --git a/status.go b/status.go index ae20794..104de1f 100644 --- a/status.go +++ b/status.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "fmt" "net/http" "net/url" @@ -46,9 +47,9 @@ type Card struct { } // GetFavourites return the favorite list of the current user. -func (c *Client) GetFavourites() ([]*Status, error) { +func (c *Client) GetFavourites(ctx context.Context) ([]*Status, error) { var statuses []*Status - err := c.doAPI(http.MethodGet, "/api/v1/favourites", nil, &statuses) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/favourites", nil, &statuses) if err != nil { return nil, err } @@ -56,9 +57,9 @@ func (c *Client) GetFavourites() ([]*Status, error) { } // GetStatus return status specified by id. -func (c *Client) GetStatus(id int64) (*Status, error) { +func (c *Client) GetStatus(ctx context.Context, id int64) (*Status, error) { var status Status - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d", id), nil, &status) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d", id), nil, &status) if err != nil { return nil, err } @@ -66,9 +67,9 @@ func (c *Client) GetStatus(id int64) (*Status, error) { } // GetStatusContext return status specified by id. -func (c *Client) GetStatusContext(id int64) (*Context, error) { +func (c *Client) GetStatusContext(ctx context.Context, id int64) (*Context, error) { var context Context - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/context", id), nil, &context) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/context", id), nil, &context) if err != nil { return nil, err } @@ -76,9 +77,9 @@ func (c *Client) GetStatusContext(id int64) (*Context, error) { } // GetStatusCard return status specified by id. -func (c *Client) GetStatusCard(id int64) (*Card, error) { +func (c *Client) GetStatusCard(ctx context.Context, id int64) (*Card, error) { var card Card - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/card", id), nil, &card) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/card", id), nil, &card) if err != nil { return nil, err } @@ -86,9 +87,9 @@ func (c *Client) GetStatusCard(id int64) (*Card, error) { } // GetRebloggedBy returns the account list of the user who reblogged the toot of id. -func (c *Client) GetRebloggedBy(id int64) ([]*Account, error) { +func (c *Client) GetRebloggedBy(ctx context.Context, id int64) ([]*Account, error) { var accounts []*Account - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/reblogged_by", id), nil, &accounts) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/reblogged_by", id), nil, &accounts) if err != nil { return nil, err } @@ -96,9 +97,9 @@ func (c *Client) GetRebloggedBy(id int64) ([]*Account, error) { } // GetFavouritedBy returns the account list of the user who liked the toot of id. -func (c *Client) GetFavouritedBy(id int64) ([]*Account, error) { +func (c *Client) GetFavouritedBy(ctx context.Context, id int64) ([]*Account, error) { var accounts []*Account - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/favourited_by", id), nil, &accounts) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/statuses/%d/favourited_by", id), nil, &accounts) if err != nil { return nil, err } @@ -106,9 +107,9 @@ func (c *Client) GetFavouritedBy(id int64) ([]*Account, error) { } // Reblog is reblog the toot of id and return status of reblog. -func (c *Client) Reblog(id int64) (*Status, error) { +func (c *Client) Reblog(ctx context.Context, id int64) (*Status, error) { var status Status - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/reblog", id), nil, &status) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/reblog", id), nil, &status) if err != nil { return nil, err } @@ -116,9 +117,9 @@ func (c *Client) Reblog(id int64) (*Status, error) { } // Unreblog is unreblog the toot of id and return status of the original toot. -func (c *Client) Unreblog(id int64) (*Status, error) { +func (c *Client) Unreblog(ctx context.Context, id int64) (*Status, error) { var status Status - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/unreblog", id), nil, &status) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/unreblog", id), nil, &status) if err != nil { return nil, err } @@ -126,9 +127,9 @@ func (c *Client) Unreblog(id int64) (*Status, error) { } // Favourite is favourite the toot of id and return status of the favourite toot. -func (c *Client) Favourite(id int64) (*Status, error) { +func (c *Client) Favourite(ctx context.Context, id int64) (*Status, error) { var status Status - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/favourite", id), nil, &status) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/favourite", id), nil, &status) if err != nil { return nil, err } @@ -136,9 +137,9 @@ func (c *Client) Favourite(id int64) (*Status, error) { } // Unfavourite is unfavourite the toot of id and return status of the unfavourite toot. -func (c *Client) Unfavourite(id int64) (*Status, error) { +func (c *Client) Unfavourite(ctx context.Context, id int64) (*Status, error) { var status Status - err := c.doAPI(http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/unfavourite", id), nil, &status) + err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/statuses/%d/unfavourite", id), nil, &status) if err != nil { return nil, err } @@ -146,9 +147,9 @@ func (c *Client) Unfavourite(id int64) (*Status, error) { } // GetTimelineHome return statuses from home timeline. -func (c *Client) GetTimelineHome() ([]*Status, error) { +func (c *Client) GetTimelineHome(ctx context.Context) ([]*Status, error) { var statuses []*Status - err := c.doAPI(http.MethodGet, "/api/v1/timelines/home", nil, &statuses) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/timelines/home", nil, &statuses) if err != nil { return nil, err } @@ -156,9 +157,9 @@ func (c *Client) GetTimelineHome() ([]*Status, error) { } // GetTimelineHashtag return statuses from tagged timeline. -func (c *Client) GetTimelineHashtag(tag string) ([]*Status, error) { +func (c *Client) GetTimelineHashtag(ctx context.Context, tag string) ([]*Status, error) { var statuses []*Status - err := c.doAPI(http.MethodGet, fmt.Sprintf("/api/v1/timelines/tag/%s", (&url.URL{Path: tag}).EscapedPath()), nil, &statuses) + err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/timelines/tag/%s", (&url.URL{Path: tag}).EscapedPath()), nil, &statuses) if err != nil { return nil, err } @@ -166,7 +167,7 @@ func (c *Client) GetTimelineHashtag(tag string) ([]*Status, error) { } // PostStatus post the toot. -func (c *Client) PostStatus(toot *Toot) (*Status, error) { +func (c *Client) PostStatus(ctx context.Context, toot *Toot) (*Status, error) { params := url.Values{} params.Set("status", toot.Status) if toot.InReplyToID > 0 { @@ -176,7 +177,7 @@ func (c *Client) PostStatus(toot *Toot) (*Status, error) { //params.Set("visibility", "public") var status Status - err := c.doAPI(http.MethodPost, "/api/v1/statuses", params, &status) + err := c.doAPI(ctx, http.MethodPost, "/api/v1/statuses", params, &status) if err != nil { return nil, err } @@ -184,17 +185,17 @@ func (c *Client) PostStatus(toot *Toot) (*Status, error) { } // DeleteStatus delete the toot. -func (c *Client) DeleteStatus(id int64) error { - return c.doAPI(http.MethodDelete, fmt.Sprintf("/api/v1/statuses/%d", id), nil, nil) +func (c *Client) DeleteStatus(ctx context.Context, id int64) error { + return c.doAPI(ctx, http.MethodDelete, fmt.Sprintf("/api/v1/statuses/%d", id), nil, nil) } // Search search content with query. -func (c *Client) Search(q string, resolve bool) (*Results, error) { +func (c *Client) Search(ctx context.Context, q string, resolve bool) (*Results, error) { params := url.Values{} params.Set("q", q) params.Set("resolve", fmt.Sprint(resolve)) var results Results - err := c.doAPI(http.MethodGet, "/api/v1/search", params, &results) + err := c.doAPI(ctx, http.MethodGet, "/api/v1/search", params, &results) if err != nil { return nil, err } diff --git a/status_test.go b/status_test.go index 93d03e2..115082c 100644 --- a/status_test.go +++ b/status_test.go @@ -1,6 +1,7 @@ package mastodon import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -20,7 +21,7 @@ func TestGetFavourites(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - favs, err := client.GetFavourites() + favs, err := client.GetFavourites(context.Background()) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -52,11 +53,11 @@ func TestGetStatus(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.GetStatus(123) + _, err := client.GetStatus(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - status, err := client.GetStatus(1234567) + status, err := client.GetStatus(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -82,11 +83,11 @@ func TestGetRebloggedBy(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.GetRebloggedBy(123) + _, err := client.GetRebloggedBy(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - rbs, err := client.GetRebloggedBy(1234567) + rbs, err := client.GetRebloggedBy(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -118,11 +119,11 @@ func TestGetFavouritedBy(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.GetFavouritedBy(123) + _, err := client.GetFavouritedBy(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - fbs, err := client.GetFavouritedBy(1234567) + fbs, err := client.GetFavouritedBy(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -154,11 +155,11 @@ func TestReblog(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.Reblog(123) + _, err := client.Reblog(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - status, err := client.Reblog(1234567) + status, err := client.Reblog(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -184,11 +185,11 @@ func TestUnreblog(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.Unreblog(123) + _, err := client.Unreblog(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - status, err := client.Unreblog(1234567) + status, err := client.Unreblog(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -214,11 +215,11 @@ func TestFavourite(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.Favourite(123) + _, err := client.Favourite(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - status, err := client.Favourite(1234567) + status, err := client.Favourite(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } @@ -244,11 +245,11 @@ func TestUnfavourite(t *testing.T) { ClientSecret: "bar", AccessToken: "zoo", }) - _, err := client.Unfavourite(123) + _, err := client.Unfavourite(context.Background(), 123) if err == nil { t.Fatalf("should be fail: %v", err) } - status, err := client.Unfavourite(1234567) + status, err := client.Unfavourite(context.Background(), 1234567) if err != nil { t.Fatalf("should not be fail: %v", err) } From 6fe43e545a2ff89169a1e21fc3887e449719279e Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Mon, 17 Apr 2017 12:25:20 +0900 Subject: [PATCH 2/2] Use Request.WithContext --- mastodon.go | 41 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/mastodon.go b/mastodon.go index b9da000..277e269 100644 --- a/mastodon.go +++ b/mastodon.go @@ -24,23 +24,6 @@ type Client struct { config *Config } -func httpDo(ctx context.Context, req *http.Request, f func(*http.Response, error) error) error { - tr := &http.Transport{} - client := &http.Client{Transport: tr} - c := make(chan error, 1) - go func() { - c <- f(client.Do(req)) - }() - select { - case <-ctx.Done(): - tr.CancelRequest(req) - <-c - return ctx.Err() - case err := <-c: - return err - } -} - func (c *Client) doAPI(ctx context.Context, method string, uri string, params url.Values, res interface{}) error { u, err := url.Parse(c.config.Server) if err != nil { @@ -52,24 +35,24 @@ func (c *Client) doAPI(ctx context.Context, method string, uri string, params ur if err != nil { return err } + req.WithContext(ctx) req.Header.Set("Authorization", "Bearer "+c.config.AccessToken) if params != nil { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") } - return httpDo(ctx, req, func(resp *http.Response, err error) error { - if err != nil { - return err - } - defer resp.Body.Close() + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad request: %v", resp.Status) - } else if res == nil { - return nil - } - return json.NewDecoder(resp.Body).Decode(&res) - }) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad request: %v", resp.Status) + } else if res == nil { + return nil + } + return json.NewDecoder(resp.Body).Decode(&res) } // NewClient return new mastodon API client.