From 1d75268ffad7ece4a94d5beedf5be05aa1ffe10b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Novakovi=C4=87?= Date: Mon, 28 May 2018 12:47:20 +0200 Subject: [PATCH] NOISSUE - Fix channel closing bug in WebSocket adapter (#309) * Remove unnecessary error checks in ws adapter Signed-off-by: Aleksandar Novakovic * Fix WebSocket adapter channel closing bug Signed-off-by: Aleksandar Novakovic --- ws/adapter.go | 34 ++++++++++++++++++++++--- ws/adapter_test.go | 55 +++++++++++++++++++++++----------------- ws/api/logging.go | 2 +- ws/api/metrics.go | 2 +- ws/api/transport.go | 32 ++++++++--------------- ws/api/transport_test.go | 4 +-- ws/mocks/messages.go | 6 ++--- ws/nats/publisher.go | 11 +++----- 8 files changed, 84 insertions(+), 62 deletions(-) diff --git a/ws/adapter.go b/ws/adapter.go index 60d015c2..ef0b5699 100644 --- a/ws/adapter.go +++ b/ws/adapter.go @@ -4,6 +4,7 @@ package ws import ( "errors" + "sync" "github.com/mainflux/mainflux" broker "github.com/nats-io/go-nats" @@ -25,17 +26,44 @@ type Service interface { mainflux.MessagePublisher // Subscribes to channel with specified id. - Subscribe(uint64, Channel) error + Subscribe(uint64, *Channel) error } // Channel is used for receiving and sending messages. type Channel struct { Messages chan mainflux.RawMessage Closed chan bool + closed bool + mutex sync.Mutex +} + +// NewChannel instantiates empty channel. +func NewChannel() *Channel { + return &Channel{ + Messages: make(chan mainflux.RawMessage), + Closed: make(chan bool), + closed: false, + mutex: sync.Mutex{}, + } +} + +// Send method send message over Messages channel. +func (channel *Channel) Send(msg mainflux.RawMessage) { + channel.mutex.Lock() + defer channel.mutex.Unlock() + + if !channel.closed { + channel.Messages <- msg + } } // Close channel and stop message transfer. -func (channel Channel) Close() { +func (channel *Channel) Close() { + channel.mutex.Lock() + defer channel.mutex.Unlock() + + channel.closed = true + channel.Closed <- true close(channel.Messages) close(channel.Closed) } @@ -63,7 +91,7 @@ func (as *adapterService) Publish(msg mainflux.RawMessage) error { return nil } -func (as *adapterService) Subscribe(chanID uint64, channel Channel) error { +func (as *adapterService) Subscribe(chanID uint64, channel *Channel) error { if err := as.pubsub.Subscribe(chanID, channel); err != nil { return ErrFailedSubscription } diff --git a/ws/adapter_test.go b/ws/adapter_test.go index 1f5de148..3b52cdd5 100644 --- a/ws/adapter_test.go +++ b/ws/adapter_test.go @@ -18,24 +18,22 @@ const ( protocol = "ws" ) -var ( - msg = mainflux.RawMessage{ - Channel: chanID, - Publisher: pubID, - Protocol: protocol, - Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`), - } - channel = ws.Channel{make(chan mainflux.RawMessage), make(chan bool)} -) +var msg = mainflux.RawMessage{ + Channel: chanID, + Publisher: pubID, + Protocol: protocol, + Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`), +} -func newService() ws.Service { - subs := map[uint64]ws.Channel{chanID: channel} +func newService(channel *ws.Channel) ws.Service { + subs := map[uint64]*ws.Channel{chanID: channel} pubsub := mocks.NewService(subs, broker.ErrInvalidMsg) return ws.New(pubsub) } func TestPublish(t *testing.T) { - svc := newService() + channel := ws.NewChannel() + svc := newService(channel) cases := []struct { desc string @@ -49,8 +47,8 @@ func TestPublish(t *testing.T) { for _, tc := range cases { // Check if message was sent. go func(desc string, tcMsg mainflux.RawMessage) { - msg := <-channel.Messages - assert.Equal(t, tcMsg, msg, fmt.Sprintf("%s: expected %v got %v\n", desc, tcMsg, msg)) + receivedMsg := <-channel.Messages + assert.Equal(t, tcMsg, receivedMsg, fmt.Sprintf("%s: expected %v got %v\n", desc, tcMsg, receivedMsg)) }(tc.desc, tc.msg) // Check if publish succeeded. @@ -60,12 +58,13 @@ func TestPublish(t *testing.T) { } func TestSubscribe(t *testing.T) { - svc := newService() + channel := ws.NewChannel() + svc := newService(channel) cases := []struct { desc string chanID uint64 - channel ws.Channel + channel *ws.Channel err error }{ {"subscription to valid channel", chanID, channel, nil}, @@ -78,11 +77,21 @@ func TestSubscribe(t *testing.T) { } } -func TestClose(t *testing.T) { - channel := ws.Channel{make(chan mainflux.RawMessage), make(chan bool)} - channel.Close() - _, closed := <-channel.Closed - _, messagesClosed := <-channel.Messages - assert.False(t, closed, "channel closed stayed open") - assert.False(t, messagesClosed, "channel messages stayed open") +func TestSend(t *testing.T) { + channel := ws.NewChannel() + go func(channel *ws.Channel) { + receivedMsg := <-channel.Messages + assert.Equal(t, msg, receivedMsg, fmt.Sprintf("send message to channel: expected %v got %v\n", msg, receivedMsg)) + }(channel) + + channel.Send(msg) +} + +func TestClose(t *testing.T) { + channel := ws.NewChannel() + go func() { + closed := <-channel.Closed + assert.True(t, closed, "channel closed stayed open") + }() + channel.Close() } diff --git a/ws/api/logging.go b/ws/api/logging.go index abea81dc..d56557f0 100644 --- a/ws/api/logging.go +++ b/ws/api/logging.go @@ -36,7 +36,7 @@ func (lm *loggingMiddleware) Publish(msg mainflux.RawMessage) (err error) { return lm.svc.Publish(msg) } -func (lm *loggingMiddleware) Subscribe(chanID uint64, channel ws.Channel) (err error) { +func (lm *loggingMiddleware) Subscribe(chanID uint64, channel *ws.Channel) (err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method subscribe to channel %d took %s to complete", chanID, time.Since(begin)) if err != nil { diff --git a/ws/api/metrics.go b/ws/api/metrics.go index 4c02d089..8269bb69 100644 --- a/ws/api/metrics.go +++ b/ws/api/metrics.go @@ -36,6 +36,6 @@ func (mm *metricsMiddleware) Publish(msg mainflux.RawMessage) error { return mm.svc.Publish(msg) } -func (mm *metricsMiddleware) Subscribe(chanID uint64, channel ws.Channel) error { +func (mm *metricsMiddleware) Subscribe(chanID uint64, channel *ws.Channel) error { return mm.svc.Subscribe(chanID, channel) } diff --git a/ws/api/transport.go b/ws/api/transport.go index fd2db598..85451c2a 100644 --- a/ws/api/transport.go +++ b/ws/api/transport.go @@ -50,30 +50,18 @@ func MakeHandler(svc ws.Service, cc mainflux.ThingsServiceClient, l log.Logger) func handshake(svc ws.Service) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { sub, err := authorize(r) - if err == errNotFound { - logger.Warn(fmt.Sprintf("Invalid channel id: %s", err)) - w.WriteHeader(http.StatusNotFound) - return - } if err != nil { switch err { case errNotFound: logger.Warn(fmt.Sprintf("Invalid channel id: %s", err)) w.WriteHeader(http.StatusNotFound) return + case things.ErrUnauthorizedAccess: + w.WriteHeader(http.StatusForbidden) + return default: logger.Warn(fmt.Sprintf("Failed to authorize: %s", err)) - e, ok := status.FromError(err) - if ok { - switch e.Code() { - case codes.PermissionDenied: - w.WriteHeader(http.StatusForbidden) - default: - w.WriteHeader(http.StatusServiceUnavailable) - } - return - } - w.WriteHeader(http.StatusForbidden) + w.WriteHeader(http.StatusServiceUnavailable) return } } @@ -86,9 +74,7 @@ func handshake(svc ws.Service) http.HandlerFunc { } sub.conn = conn - // Subscribe to channel - channel := ws.Channel{make(chan mainflux.RawMessage), make(chan bool)} - sub.channel = channel + sub.channel = ws.NewChannel() if err := svc.Subscribe(sub.chanID, sub.channel); err != nil { logger.Warn(fmt.Sprintf("Failed to subscribe to NATS subject: %s", err)) conn.Close() @@ -122,6 +108,10 @@ func authorize(r *http.Request) (subscription, error) { id, err := auth.CanAccess(ctx, &mainflux.AccessReq{Token: authKey, ChanID: chanID}) if err != nil { + e, ok := status.FromError(err) + if ok && e.Code() == codes.PermissionDenied { + return subscription{}, things.ErrUnauthorizedAccess + } return subscription{}, err } @@ -137,14 +127,14 @@ type subscription struct { pubID uint64 chanID uint64 conn *websocket.Conn - channel ws.Channel + channel *ws.Channel } func (sub subscription) broadcast(svc ws.Service) { for { _, payload, err := sub.conn.ReadMessage() if websocket.IsUnexpectedCloseError(err) { - sub.channel.Closed <- true + sub.channel.Close() return } if err != nil { diff --git a/ws/api/transport_test.go b/ws/api/transport_test.go index 4dfd7535..c66034ac 100644 --- a/ws/api/transport_test.go +++ b/ws/api/transport_test.go @@ -26,11 +26,11 @@ const ( var ( msg = []byte(`[{"n":"current","t":-5,"v":1.2}]`) - channel = ws.Channel{make(chan mainflux.RawMessage), make(chan bool)} + channel = ws.NewChannel() ) func newService() ws.Service { - subs := map[uint64]ws.Channel{chanID: channel} + subs := map[uint64]*ws.Channel{chanID: channel} pubsub := mocks.NewService(subs, broker.ErrConnectionClosed) return ws.New(pubsub) } diff --git a/ws/mocks/messages.go b/ws/mocks/messages.go index 1d600718..960ff6c0 100644 --- a/ws/mocks/messages.go +++ b/ws/mocks/messages.go @@ -10,13 +10,13 @@ import ( var _ ws.Service = (*mockService)(nil) type mockService struct { - subscriptions map[uint64]ws.Channel + subscriptions map[uint64]*ws.Channel pubError error mutex sync.Mutex } // NewService returns mock message publisher. -func NewService(subs map[uint64]ws.Channel, pubError error) ws.Service { +func NewService(subs map[uint64]*ws.Channel, pubError error) ws.Service { return &mockService{subs, pubError, sync.Mutex{}} } @@ -30,7 +30,7 @@ func (svc *mockService) Publish(msg mainflux.RawMessage) error { return nil } -func (svc *mockService) Subscribe(chanID uint64, channel ws.Channel) error { +func (svc *mockService) Subscribe(chanID uint64, channel *ws.Channel) error { svc.mutex.Lock() defer svc.mutex.Unlock() if _, ok := svc.subscriptions[chanID]; !ok { diff --git a/ws/nats/publisher.go b/ws/nats/publisher.go index da4f46ce..81d11087 100644 --- a/ws/nats/publisher.go +++ b/ws/nats/publisher.go @@ -47,7 +47,7 @@ func (pubsub *natsPubSub) Publish(msg mainflux.RawMessage) error { return pubsub.nc.Publish(fmt.Sprintf("%s.%d", prefix, msg.Channel), data) } -func (pubsub *natsPubSub) Subscribe(chanID uint64, channel ws.Channel) error { +func (pubsub *natsPubSub) Subscribe(chanID uint64, channel *ws.Channel) error { var sub *broker.Subscription sub, err := pubsub.nc.Subscribe(fmt.Sprintf("%s.%d", prefix, chanID), func(msg *broker.Msg) { if msg == nil { @@ -59,19 +59,14 @@ func (pubsub *natsPubSub) Subscribe(chanID uint64, channel ws.Channel) error { return } - // Prevents sending message to closed channel - select { - case channel.Messages <- rawMsg: - case <-channel.Closed: - sub.Unsubscribe() - } + // Sends message to messages channel + channel.Send(rawMsg) }) // Check if subscription should be closed go func() { <-channel.Closed sub.Unsubscribe() - channel.Close() }() return err