Merge pull request #51 from 178inaba/streaming_ws_close

Fix close of streaming WebSocket
pull/55/head
mattn 2017-05-01 09:13:29 +09:00 committed by GitHub
commit eac43d845c
2 changed files with 28 additions and 17 deletions

View File

@ -66,6 +66,7 @@ func (c *WSClient) streamingWS(ctx context.Context, stream, tag string) (chan Ev
q := make(chan Event) q := make(chan Event)
go func() { go func() {
defer close(q)
for { for {
err := c.handleWS(ctx, u.String(), q) err := c.handleWS(ctx, u.String(), q)
if err != nil { if err != nil {
@ -85,7 +86,12 @@ func (c *WSClient) handleWS(ctx context.Context, rawurl string, q chan Event) er
// End. // End.
return err return err
} }
defer conn.Close()
// Close the WebSocket when the context is canceled.
go func() {
<-ctx.Done()
conn.Close()
}()
for { for {
select { select {

View File

@ -10,20 +10,6 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
func TestStreamingWSPublic(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(wsMock))
defer ts.Close()
client := NewClient(&Config{Server: ts.URL}).NewWSClient()
ctx, cancel := context.WithCancel(context.Background())
q, err := client.StreamingWSPublic(ctx, false)
if err != nil {
t.Fatalf("should not be fail: %v", err)
}
wsTest(t, q, cancel)
}
func TestStreamingWSUser(t *testing.T) { func TestStreamingWSUser(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(wsMock)) ts := httptest.NewServer(http.HandlerFunc(wsMock))
defer ts.Close() defer ts.Close()
@ -38,6 +24,20 @@ func TestStreamingWSUser(t *testing.T) {
wsTest(t, q, cancel) wsTest(t, q, cancel)
} }
func TestStreamingWSPublic(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(wsMock))
defer ts.Close()
client := NewClient(&Config{Server: ts.URL}).NewWSClient()
ctx, cancel := context.WithCancel(context.Background())
q, err := client.StreamingWSPublic(ctx, false)
if err != nil {
t.Fatalf("should not be fail: %v", err)
}
wsTest(t, q, cancel)
}
func TestStreamingWSHashtag(t *testing.T) { func TestStreamingWSHashtag(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(wsMock)) ts := httptest.NewServer(http.HandlerFunc(wsMock))
defer ts.Close() defer ts.Close()
@ -106,13 +106,12 @@ func wsMock(w http.ResponseWriter, r *http.Request) {
func wsTest(t *testing.T, q chan Event, cancel func()) { func wsTest(t *testing.T, q chan Event, cancel func()) {
time.AfterFunc(time.Second, func() { time.AfterFunc(time.Second, func() {
cancel() cancel()
close(q)
}) })
events := []Event{} events := []Event{}
for e := range q { for e := range q {
events = append(events, e) events = append(events, e)
} }
if len(events) != 4 { if len(events) != 6 {
t.Fatalf("result should be four: %d", len(events)) t.Fatalf("result should be four: %d", len(events))
} }
if events[0].(*UpdateEvent).Status.Content != "foo" { if events[0].(*UpdateEvent).Status.Content != "foo" {
@ -127,6 +126,12 @@ func wsTest(t *testing.T, q chan Event, cancel func()) {
if errorEvent, ok := events[3].(*ErrorEvent); !ok { if errorEvent, ok := events[3].(*ErrorEvent); !ok {
t.Fatalf("should be fail: %v", errorEvent.err) t.Fatalf("should be fail: %v", errorEvent.err)
} }
if errorEvent, ok := events[4].(*ErrorEvent); !ok {
t.Fatalf("should be fail: %v", errorEvent.err)
}
if errorEvent, ok := events[5].(*ErrorEvent); !ok {
t.Fatalf("should be fail: %v", errorEvent.err)
}
} }
func TestStreamingWS(t *testing.T) { func TestStreamingWS(t *testing.T) {