From 010ec2eaf92da466640e155dfac5a9611b6a0f33 Mon Sep 17 00:00:00 2001 From: 178inaba <178inaba@users.noreply.github.com> Date: Mon, 24 Apr 2017 13:55:07 +0900 Subject: [PATCH] Add WSClient --- streaming_ws.go | 31 ++++++++++++++++++++----------- streaming_ws_test.go | 20 ++++++++++---------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/streaming_ws.go b/streaming_ws.go index e607be1..915465f 100644 --- a/streaming_ws.go +++ b/streaming_ws.go @@ -9,6 +9,15 @@ import ( "github.com/gorilla/websocket" ) +// WSClient is a WebSocket client. +type WSClient struct { + websocket.Dialer + client *Client +} + +// NewWSClient return WebSocket client. +func (c *Client) NewWSClient() *WSClient { return &WSClient{client: c} } + // Stream is a struct of data that flows in streaming. type Stream struct { Event string `json:"event"` @@ -16,39 +25,39 @@ type Stream struct { } // StreamingWSPublic return channel to read events on public using WebSocket. -func (c *Client) StreamingWSPublic(ctx context.Context) (chan Event, error) { +func (c *WSClient) StreamingWSPublic(ctx context.Context) (chan Event, error) { return c.streamingWS(ctx, "public", "") } // StreamingWSPublicLocal return channel to read events on public local using WebSocket. -func (c *Client) StreamingWSPublicLocal(ctx context.Context) (chan Event, error) { +func (c *WSClient) StreamingWSPublicLocal(ctx context.Context) (chan Event, error) { return c.streamingWS(ctx, "public:local", "") } // StreamingWSUser return channel to read events on home using WebSocket. -func (c *Client) StreamingWSUser(ctx context.Context) (chan Event, error) { +func (c *WSClient) StreamingWSUser(ctx context.Context) (chan Event, error) { return c.streamingWS(ctx, "user", "") } // StreamingWSHashtag return channel to read events on tagged timeline using WebSocket. -func (c *Client) StreamingWSHashtag(ctx context.Context, tag string) (chan Event, error) { +func (c *WSClient) StreamingWSHashtag(ctx context.Context, tag string) (chan Event, error) { return c.streamingWS(ctx, "hashtag", tag) } // StreamingWSHashtagLocal return channel to read events on tagged local timeline using WebSocket. -func (c *Client) StreamingWSHashtagLocal(ctx context.Context, tag string) (chan Event, error) { +func (c *WSClient) StreamingWSHashtagLocal(ctx context.Context, tag string) (chan Event, error) { return c.streamingWS(ctx, "hashtag:local", tag) } -func (c *Client) streamingWS(ctx context.Context, stream, tag string) (chan Event, error) { +func (c *WSClient) streamingWS(ctx context.Context, stream, tag string) (chan Event, error) { params := url.Values{} - params.Set("access_token", c.config.AccessToken) + params.Set("access_token", c.client.config.AccessToken) params.Set("stream", stream) if tag != "" { params.Set("tag", tag) } - u, err := changeWebSocketScheme(c.config.Server) + u, err := changeWebSocketScheme(c.client.config.Server) if err != nil { return nil, err } @@ -68,7 +77,7 @@ func (c *Client) streamingWS(ctx context.Context, stream, tag string) (chan Even return q, nil } -func (c *Client) handleWS(ctx context.Context, rawurl string, q chan Event) error { +func (c *WSClient) handleWS(ctx context.Context, rawurl string, q chan Event) error { conn, err := c.dialRedirect(rawurl) if err != nil { q <- &ErrorEvent{err: err} @@ -122,7 +131,7 @@ func (c *Client) handleWS(ctx context.Context, rawurl string, q chan Event) erro return nil } -func (c *Client) dialRedirect(rawurl string) (conn *websocket.Conn, err error) { +func (c *WSClient) dialRedirect(rawurl string) (conn *websocket.Conn, err error) { for { conn, rawurl, err = c.dial(rawurl) if err != nil { @@ -133,7 +142,7 @@ func (c *Client) dialRedirect(rawurl string) (conn *websocket.Conn, err error) { } } -func (c *Client) dial(rawurl string) (*websocket.Conn, string, error) { +func (c *WSClient) dial(rawurl string) (*websocket.Conn, string, error) { conn, resp, err := c.Dial(rawurl, nil) if err != nil && err != websocket.ErrBadHandshake { return nil, "", err diff --git a/streaming_ws_test.go b/streaming_ws_test.go index f041984..9d5386b 100644 --- a/streaming_ws_test.go +++ b/streaming_ws_test.go @@ -14,7 +14,7 @@ func TestStreamingWSPublic(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() - client := NewClient(&Config{Server: ts.URL}) + client := NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) q, err := client.StreamingWSPublic(ctx) if err != nil { @@ -28,7 +28,7 @@ func TestStreamingWSPublicLocal(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() - client := NewClient(&Config{Server: ts.URL}) + client := NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) q, err := client.StreamingWSPublicLocal(ctx) if err != nil { @@ -42,7 +42,7 @@ func TestStreamingWSUser(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() - client := NewClient(&Config{Server: ts.URL}) + client := NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) q, err := client.StreamingWSUser(ctx) if err != nil { @@ -56,7 +56,7 @@ func TestStreamingWSHashtag(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() - client := NewClient(&Config{Server: ts.URL}) + client := NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) q, err := client.StreamingWSHashtag(ctx, "zzz") if err != nil { @@ -70,7 +70,7 @@ func TestStreamingWSHashtagLocal(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() - client := NewClient(&Config{Server: ts.URL}) + client := NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) q, err := client.StreamingWSHashtagLocal(ctx, "zzz") if err != nil { @@ -155,13 +155,13 @@ func TestStreamingWS(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() - client := NewClient(&Config{Server: ":"}) + client := NewClient(&Config{Server: ":"}).NewWSClient() _, err := client.StreamingWSPublicLocal(context.Background()) if err == nil { t.Fatalf("should be fail: %v", err) } - client = NewClient(&Config{Server: ts.URL}) + client = NewClient(&Config{Server: ts.URL}).NewWSClient() ctx, cancel := context.WithCancel(context.Background()) cancel() q, err := client.StreamingWSPublicLocal(ctx) @@ -198,7 +198,7 @@ func TestHandleWS(t *testing.T) { defer ts.Close() q := make(chan Event) - client := NewClient(&Config{}) + client := NewClient(&Config{}).NewWSClient() go func() { e := <-q @@ -234,7 +234,7 @@ func TestHandleWS(t *testing.T) { } func TestDialRedirect(t *testing.T) { - client := NewClient(&Config{}) + client := NewClient(&Config{}).NewWSClient() _, err := client.dialRedirect(":") if err == nil { t.Fatalf("should be fail: %v", err) @@ -254,7 +254,7 @@ func TestDial(t *testing.T) { })) defer ts.Close() - client := NewClient(&Config{}) + client := NewClient(&Config{}).NewWSClient() _, _, err := client.dial(":") if err == nil { t.Fatalf("should be fail: %v", err)