From 64e71edf9500965a65fe12a3dffc7eadca5a8032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Novakovi=C4=87?= Date: Mon, 21 May 2018 23:12:21 +0200 Subject: [PATCH] NOISSUE - Add mutex to WebSocket service mock (#294) --- ws/api/transport_test.go | 3 +-- ws/mocks/messages.go | 13 ++++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ws/api/transport_test.go b/ws/api/transport_test.go index 2ffdf84d..4dfd7535 100644 --- a/ws/api/transport_test.go +++ b/ws/api/transport_test.go @@ -60,8 +60,7 @@ func handshake(tsURL string, chanID uint64, token string, addHeader bool) (*webs header.Add("Authorization", token) } url := makeURL(tsURL, chanID, token, addHeader) - conn, resp, err := websocket.DefaultDialer.Dial(url, header) - return conn, resp, err + return websocket.DefaultDialer.Dial(url, header) } func TestHandshake(t *testing.T) { diff --git a/ws/mocks/messages.go b/ws/mocks/messages.go index e014d1a0..41044f3a 100644 --- a/ws/mocks/messages.go +++ b/ws/mocks/messages.go @@ -1,6 +1,8 @@ package mocks import ( + "sync" + "github.com/mainflux/mainflux" "github.com/mainflux/mainflux/ws" ) @@ -10,25 +12,30 @@ var _ ws.Service = (*mockService)(nil) type mockService struct { subscriptions map[uint64]ws.Channel pubError error + mutex sync.Mutex } // NewService returns mock message publisher. func NewService(subs map[uint64]ws.Channel, pubError error) ws.Service { - return mockService{subs, pubError} + return &mockService{subs, pubError, sync.Mutex{}} } -func (svc mockService) Publish(msg mainflux.RawMessage) error { +func (svc *mockService) Publish(msg mainflux.RawMessage) error { if len(msg.Payload) == 0 { return svc.pubError } + svc.mutex.Lock() svc.subscriptions[msg.Channel].Messages <- msg + svc.mutex.Unlock() return nil } -func (svc mockService) Subscribe(chanID uint64, channel ws.Channel) error { +func (svc *mockService) Subscribe(chanID uint64, channel ws.Channel) error { if _, ok := svc.subscriptions[chanID]; !ok { return ws.ErrFailedSubscription } + svc.mutex.Lock() svc.subscriptions[chanID] = channel + svc.mutex.Unlock() return nil }