diff --git a/streaming_ws.go b/streaming_ws.go index 022928c..5553a25 100644 --- a/streaming_ws.go +++ b/streaming_ws.go @@ -66,6 +66,7 @@ func (c *WSClient) streamingWS(ctx context.Context, stream, tag string) (chan Ev q := make(chan Event) go func() { + defer close(q) for { err := c.handleWS(ctx, u.String(), q) if err != nil { @@ -85,7 +86,12 @@ func (c *WSClient) handleWS(ctx context.Context, rawurl string, q chan Event) er // End. return err } - defer conn.Close() + + // Close the WebSocket when the context is canceled. + go func() { + <-ctx.Done() + conn.Close() + }() for { select { diff --git a/streaming_ws_test.go b/streaming_ws_test.go index c13a7cc..6ec5608 100644 --- a/streaming_ws_test.go +++ b/streaming_ws_test.go @@ -106,13 +106,12 @@ func wsMock(w http.ResponseWriter, r *http.Request) { func wsTest(t *testing.T, q chan Event, cancel func()) { time.AfterFunc(time.Second, func() { cancel() - close(q) }) events := []Event{} for e := range q { events = append(events, e) } - if len(events) != 4 { + if len(events) != 6 { t.Fatalf("result should be four: %d", len(events)) } if events[0].(*UpdateEvent).Status.Content != "foo" { @@ -127,6 +126,12 @@ func wsTest(t *testing.T, q chan Event, cancel func()) { if errorEvent, ok := events[3].(*ErrorEvent); !ok { t.Fatalf("should be fail: %v", errorEvent.err) } + if errorEvent, ok := events[4].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } + if errorEvent, ok := events[5].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } } func TestStreamingWS(t *testing.T) {