1
0
mirror of https://github.com/mainflux/mainflux.git synced 2025-04-26 13:48:53 +08:00

NOISSUE - Fix channel closing bug in WebSocket adapter (#309)

* Remove unnecessary error checks in ws adapter

Signed-off-by: Aleksandar Novakovic <aleksandar.novakovic@mainflux.com>

* Fix WebSocket adapter channel closing bug

Signed-off-by: Aleksandar Novakovic <aleksandar.novakovic@mainflux.com>
This commit is contained in:
Aleksandar Novaković 2018-05-28 12:47:20 +02:00 committed by Nikola Marčetić
parent ef3627f4ee
commit 1d75268ffa
8 changed files with 84 additions and 62 deletions

View File

@ -4,6 +4,7 @@ package ws
import (
"errors"
"sync"
"github.com/mainflux/mainflux"
broker "github.com/nats-io/go-nats"
@ -25,17 +26,44 @@ type Service interface {
mainflux.MessagePublisher
// Subscribes to channel with specified id.
Subscribe(uint64, Channel) error
Subscribe(uint64, *Channel) error
}
// Channel is used for receiving and sending messages.
type Channel struct {
Messages chan mainflux.RawMessage
Closed chan bool
closed bool
mutex sync.Mutex
}
// NewChannel instantiates empty channel.
func NewChannel() *Channel {
return &Channel{
Messages: make(chan mainflux.RawMessage),
Closed: make(chan bool),
closed: false,
mutex: sync.Mutex{},
}
}
// Send method send message over Messages channel.
func (channel *Channel) Send(msg mainflux.RawMessage) {
channel.mutex.Lock()
defer channel.mutex.Unlock()
if !channel.closed {
channel.Messages <- msg
}
}
// Close channel and stop message transfer.
func (channel Channel) Close() {
func (channel *Channel) Close() {
channel.mutex.Lock()
defer channel.mutex.Unlock()
channel.closed = true
channel.Closed <- true
close(channel.Messages)
close(channel.Closed)
}
@ -63,7 +91,7 @@ func (as *adapterService) Publish(msg mainflux.RawMessage) error {
return nil
}
func (as *adapterService) Subscribe(chanID uint64, channel Channel) error {
func (as *adapterService) Subscribe(chanID uint64, channel *Channel) error {
if err := as.pubsub.Subscribe(chanID, channel); err != nil {
return ErrFailedSubscription
}

View File

@ -18,24 +18,22 @@ const (
protocol = "ws"
)
var (
msg = mainflux.RawMessage{
Channel: chanID,
Publisher: pubID,
Protocol: protocol,
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
}
channel = ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
)
var msg = mainflux.RawMessage{
Channel: chanID,
Publisher: pubID,
Protocol: protocol,
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
}
func newService() ws.Service {
subs := map[uint64]ws.Channel{chanID: channel}
func newService(channel *ws.Channel) ws.Service {
subs := map[uint64]*ws.Channel{chanID: channel}
pubsub := mocks.NewService(subs, broker.ErrInvalidMsg)
return ws.New(pubsub)
}
func TestPublish(t *testing.T) {
svc := newService()
channel := ws.NewChannel()
svc := newService(channel)
cases := []struct {
desc string
@ -49,8 +47,8 @@ func TestPublish(t *testing.T) {
for _, tc := range cases {
// Check if message was sent.
go func(desc string, tcMsg mainflux.RawMessage) {
msg := <-channel.Messages
assert.Equal(t, tcMsg, msg, fmt.Sprintf("%s: expected %v got %v\n", desc, tcMsg, msg))
receivedMsg := <-channel.Messages
assert.Equal(t, tcMsg, receivedMsg, fmt.Sprintf("%s: expected %v got %v\n", desc, tcMsg, receivedMsg))
}(tc.desc, tc.msg)
// Check if publish succeeded.
@ -60,12 +58,13 @@ func TestPublish(t *testing.T) {
}
func TestSubscribe(t *testing.T) {
svc := newService()
channel := ws.NewChannel()
svc := newService(channel)
cases := []struct {
desc string
chanID uint64
channel ws.Channel
channel *ws.Channel
err error
}{
{"subscription to valid channel", chanID, channel, nil},
@ -78,11 +77,21 @@ func TestSubscribe(t *testing.T) {
}
}
func TestClose(t *testing.T) {
channel := ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
channel.Close()
_, closed := <-channel.Closed
_, messagesClosed := <-channel.Messages
assert.False(t, closed, "channel closed stayed open")
assert.False(t, messagesClosed, "channel messages stayed open")
func TestSend(t *testing.T) {
channel := ws.NewChannel()
go func(channel *ws.Channel) {
receivedMsg := <-channel.Messages
assert.Equal(t, msg, receivedMsg, fmt.Sprintf("send message to channel: expected %v got %v\n", msg, receivedMsg))
}(channel)
channel.Send(msg)
}
func TestClose(t *testing.T) {
channel := ws.NewChannel()
go func() {
closed := <-channel.Closed
assert.True(t, closed, "channel closed stayed open")
}()
channel.Close()
}

View File

@ -36,7 +36,7 @@ func (lm *loggingMiddleware) Publish(msg mainflux.RawMessage) (err error) {
return lm.svc.Publish(msg)
}
func (lm *loggingMiddleware) Subscribe(chanID uint64, channel ws.Channel) (err error) {
func (lm *loggingMiddleware) Subscribe(chanID uint64, channel *ws.Channel) (err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method subscribe to channel %d took %s to complete", chanID, time.Since(begin))
if err != nil {

View File

@ -36,6 +36,6 @@ func (mm *metricsMiddleware) Publish(msg mainflux.RawMessage) error {
return mm.svc.Publish(msg)
}
func (mm *metricsMiddleware) Subscribe(chanID uint64, channel ws.Channel) error {
func (mm *metricsMiddleware) Subscribe(chanID uint64, channel *ws.Channel) error {
return mm.svc.Subscribe(chanID, channel)
}

View File

@ -50,30 +50,18 @@ func MakeHandler(svc ws.Service, cc mainflux.ThingsServiceClient, l log.Logger)
func handshake(svc ws.Service) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sub, err := authorize(r)
if err == errNotFound {
logger.Warn(fmt.Sprintf("Invalid channel id: %s", err))
w.WriteHeader(http.StatusNotFound)
return
}
if err != nil {
switch err {
case errNotFound:
logger.Warn(fmt.Sprintf("Invalid channel id: %s", err))
w.WriteHeader(http.StatusNotFound)
return
case things.ErrUnauthorizedAccess:
w.WriteHeader(http.StatusForbidden)
return
default:
logger.Warn(fmt.Sprintf("Failed to authorize: %s", err))
e, ok := status.FromError(err)
if ok {
switch e.Code() {
case codes.PermissionDenied:
w.WriteHeader(http.StatusForbidden)
default:
w.WriteHeader(http.StatusServiceUnavailable)
}
return
}
w.WriteHeader(http.StatusForbidden)
w.WriteHeader(http.StatusServiceUnavailable)
return
}
}
@ -86,9 +74,7 @@ func handshake(svc ws.Service) http.HandlerFunc {
}
sub.conn = conn
// Subscribe to channel
channel := ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
sub.channel = channel
sub.channel = ws.NewChannel()
if err := svc.Subscribe(sub.chanID, sub.channel); err != nil {
logger.Warn(fmt.Sprintf("Failed to subscribe to NATS subject: %s", err))
conn.Close()
@ -122,6 +108,10 @@ func authorize(r *http.Request) (subscription, error) {
id, err := auth.CanAccess(ctx, &mainflux.AccessReq{Token: authKey, ChanID: chanID})
if err != nil {
e, ok := status.FromError(err)
if ok && e.Code() == codes.PermissionDenied {
return subscription{}, things.ErrUnauthorizedAccess
}
return subscription{}, err
}
@ -137,14 +127,14 @@ type subscription struct {
pubID uint64
chanID uint64
conn *websocket.Conn
channel ws.Channel
channel *ws.Channel
}
func (sub subscription) broadcast(svc ws.Service) {
for {
_, payload, err := sub.conn.ReadMessage()
if websocket.IsUnexpectedCloseError(err) {
sub.channel.Closed <- true
sub.channel.Close()
return
}
if err != nil {

View File

@ -26,11 +26,11 @@ const (
var (
msg = []byte(`[{"n":"current","t":-5,"v":1.2}]`)
channel = ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
channel = ws.NewChannel()
)
func newService() ws.Service {
subs := map[uint64]ws.Channel{chanID: channel}
subs := map[uint64]*ws.Channel{chanID: channel}
pubsub := mocks.NewService(subs, broker.ErrConnectionClosed)
return ws.New(pubsub)
}

View File

@ -10,13 +10,13 @@ import (
var _ ws.Service = (*mockService)(nil)
type mockService struct {
subscriptions map[uint64]ws.Channel
subscriptions map[uint64]*ws.Channel
pubError error
mutex sync.Mutex
}
// NewService returns mock message publisher.
func NewService(subs map[uint64]ws.Channel, pubError error) ws.Service {
func NewService(subs map[uint64]*ws.Channel, pubError error) ws.Service {
return &mockService{subs, pubError, sync.Mutex{}}
}
@ -30,7 +30,7 @@ func (svc *mockService) Publish(msg mainflux.RawMessage) error {
return nil
}
func (svc *mockService) Subscribe(chanID uint64, channel ws.Channel) error {
func (svc *mockService) Subscribe(chanID uint64, channel *ws.Channel) error {
svc.mutex.Lock()
defer svc.mutex.Unlock()
if _, ok := svc.subscriptions[chanID]; !ok {

View File

@ -47,7 +47,7 @@ func (pubsub *natsPubSub) Publish(msg mainflux.RawMessage) error {
return pubsub.nc.Publish(fmt.Sprintf("%s.%d", prefix, msg.Channel), data)
}
func (pubsub *natsPubSub) Subscribe(chanID uint64, channel ws.Channel) error {
func (pubsub *natsPubSub) Subscribe(chanID uint64, channel *ws.Channel) error {
var sub *broker.Subscription
sub, err := pubsub.nc.Subscribe(fmt.Sprintf("%s.%d", prefix, chanID), func(msg *broker.Msg) {
if msg == nil {
@ -59,19 +59,14 @@ func (pubsub *natsPubSub) Subscribe(chanID uint64, channel ws.Channel) error {
return
}
// Prevents sending message to closed channel
select {
case channel.Messages <- rawMsg:
case <-channel.Closed:
sub.Unsubscribe()
}
// Sends message to messages channel
channel.Send(rawMsg)
})
// Check if subscription should be closed
go func() {
<-channel.Closed
sub.Unsubscribe()
channel.Close()
}()
return err