From fc574ac20925b9bbc19e1e082a206896f9051fcd Mon Sep 17 00:00:00 2001 From: 178inaba <178inaba@users.noreply.github.com> Date: Sun, 30 Apr 2017 22:14:02 +0900 Subject: [PATCH] Fix WebSocket close --- streaming_ws.go | 8 +++++++- streaming_ws_test.go | 9 +++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/streaming_ws.go b/streaming_ws.go index 022928c..5553a25 100644 --- a/streaming_ws.go +++ b/streaming_ws.go @@ -66,6 +66,7 @@ func (c *WSClient) streamingWS(ctx context.Context, stream, tag string) (chan Ev q := make(chan Event) go func() { + defer close(q) for { err := c.handleWS(ctx, u.String(), q) if err != nil { @@ -85,7 +86,12 @@ func (c *WSClient) handleWS(ctx context.Context, rawurl string, q chan Event) er // End. return err } - defer conn.Close() + + // Close the WebSocket when the context is canceled. + go func() { + <-ctx.Done() + conn.Close() + }() for { select { diff --git a/streaming_ws_test.go b/streaming_ws_test.go index c13a7cc..6ec5608 100644 --- a/streaming_ws_test.go +++ b/streaming_ws_test.go @@ -106,13 +106,12 @@ func wsMock(w http.ResponseWriter, r *http.Request) { func wsTest(t *testing.T, q chan Event, cancel func()) { time.AfterFunc(time.Second, func() { cancel() - close(q) }) events := []Event{} for e := range q { events = append(events, e) } - if len(events) != 4 { + if len(events) != 6 { t.Fatalf("result should be four: %d", len(events)) } if events[0].(*UpdateEvent).Status.Content != "foo" { @@ -127,6 +126,12 @@ func wsTest(t *testing.T, q chan Event, cancel func()) { if errorEvent, ok := events[3].(*ErrorEvent); !ok { t.Fatalf("should be fail: %v", errorEvent.err) } + if errorEvent, ok := events[4].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } + if errorEvent, ok := events[5].(*ErrorEvent); !ok { + t.Fatalf("should be fail: %v", errorEvent.err) + } } func TestStreamingWS(t *testing.T) {