diff --git a/streaming_ws.go b/streaming_ws.go index 022928c..5553a25 100644 --- a/streaming_ws.go +++ b/streaming_ws.go @@ -66,6 +66,7 @@ func (c *WSClient) streamingWS(ctx context.Context, stream, tag string) (chan Ev q := make(chan Event) go func() { + defer close(q) for { err := c.handleWS(ctx, u.String(), q) if err != nil { @@ -85,7 +86,12 @@ func (c *WSClient) handleWS(ctx context.Context, rawurl string, q chan Event) er // End. return err } - defer conn.Close() + + // Close the WebSocket when the context is canceled. + go func() { + <-ctx.Done() + conn.Close() + }() for { select { diff --git a/streaming_ws_test.go b/streaming_ws_test.go index c13a7cc..f52477c 100644 --- a/streaming_ws_test.go +++ b/streaming_ws_test.go @@ -10,20 +10,6 @@ import ( "github.com/gorilla/websocket" ) -func TestStreamingWSPublic(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(wsMock)) - defer ts.Close() - - client := NewClient(&Config{Server: ts.URL}).NewWSClient() - ctx, cancel := context.WithCancel(context.Background()) - q, err := client.StreamingWSPublic(ctx, false) - if err != nil { - t.Fatalf("should not be fail: %v", err) - } - - wsTest(t, q, cancel) -} - func TestStreamingWSUser(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() @@ -38,6 +24,20 @@ func TestStreamingWSUser(t *testing.T) { wsTest(t, q, cancel) } +func TestStreamingWSPublic(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(wsMock)) + defer ts.Close() + + client := NewClient(&Config{Server: ts.URL}).NewWSClient() + ctx, cancel := context.WithCancel(context.Background()) + q, err := client.StreamingWSPublic(ctx, false) + if err != nil { + t.Fatalf("should not be fail: %v", err) + } + + wsTest(t, q, cancel) +} + func TestStreamingWSHashtag(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(wsMock)) defer ts.Close() @@ -106,13 +106,12 @@ func wsMock(w http.ResponseWriter, r *http.Request) { func wsTest(t *testing.T, q chan Event, cancel func()) { time.AfterFunc(time.Second, func() { cancel() - close(q) }) events := []Event{} for e := range q { events = append(events, e) } - if len(events) != 4 { + if len(events) != 6 { t.Fatalf("result should be four: %d", len(events)) } if events[0].(*UpdateEvent).Status.Content != "foo" { @@ -127,6 +126,12 @@ func wsTest(t *testing.T, q chan Event, cancel func()) { if errorEvent, ok := events[3].(*ErrorEvent); !ok { t.Fatalf("should be fail: %v", errorEvent.err) } + if errorEvent, ok := events[4].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } + if errorEvent, ok := events[5].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } } func TestStreamingWS(t *testing.T) {