mirror of
https://github.com/mainflux/mainflux.git
synced 2025-04-26 13:48:53 +08:00
NOISSUE - Fix channel closing bug in WebSocket adapter (#309)
* Remove unnecessary error checks in ws adapter Signed-off-by: Aleksandar Novakovic <aleksandar.novakovic@mainflux.com> * Fix WebSocket adapter channel closing bug Signed-off-by: Aleksandar Novakovic <aleksandar.novakovic@mainflux.com>
This commit is contained in:
parent
ef3627f4ee
commit
1d75268ffa
@ -4,6 +4,7 @@ package ws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/mainflux/mainflux"
|
||||
broker "github.com/nats-io/go-nats"
|
||||
@ -25,17 +26,44 @@ type Service interface {
|
||||
mainflux.MessagePublisher
|
||||
|
||||
// Subscribes to channel with specified id.
|
||||
Subscribe(uint64, Channel) error
|
||||
Subscribe(uint64, *Channel) error
|
||||
}
|
||||
|
||||
// Channel is used for receiving and sending messages.
|
||||
type Channel struct {
|
||||
Messages chan mainflux.RawMessage
|
||||
Closed chan bool
|
||||
closed bool
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewChannel instantiates empty channel.
|
||||
func NewChannel() *Channel {
|
||||
return &Channel{
|
||||
Messages: make(chan mainflux.RawMessage),
|
||||
Closed: make(chan bool),
|
||||
closed: false,
|
||||
mutex: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Send method send message over Messages channel.
|
||||
func (channel *Channel) Send(msg mainflux.RawMessage) {
|
||||
channel.mutex.Lock()
|
||||
defer channel.mutex.Unlock()
|
||||
|
||||
if !channel.closed {
|
||||
channel.Messages <- msg
|
||||
}
|
||||
}
|
||||
|
||||
// Close channel and stop message transfer.
|
||||
func (channel Channel) Close() {
|
||||
func (channel *Channel) Close() {
|
||||
channel.mutex.Lock()
|
||||
defer channel.mutex.Unlock()
|
||||
|
||||
channel.closed = true
|
||||
channel.Closed <- true
|
||||
close(channel.Messages)
|
||||
close(channel.Closed)
|
||||
}
|
||||
@ -63,7 +91,7 @@ func (as *adapterService) Publish(msg mainflux.RawMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (as *adapterService) Subscribe(chanID uint64, channel Channel) error {
|
||||
func (as *adapterService) Subscribe(chanID uint64, channel *Channel) error {
|
||||
if err := as.pubsub.Subscribe(chanID, channel); err != nil {
|
||||
return ErrFailedSubscription
|
||||
}
|
||||
|
@ -18,24 +18,22 @@ const (
|
||||
protocol = "ws"
|
||||
)
|
||||
|
||||
var (
|
||||
msg = mainflux.RawMessage{
|
||||
Channel: chanID,
|
||||
Publisher: pubID,
|
||||
Protocol: protocol,
|
||||
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
|
||||
}
|
||||
channel = ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
|
||||
)
|
||||
var msg = mainflux.RawMessage{
|
||||
Channel: chanID,
|
||||
Publisher: pubID,
|
||||
Protocol: protocol,
|
||||
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
|
||||
}
|
||||
|
||||
func newService() ws.Service {
|
||||
subs := map[uint64]ws.Channel{chanID: channel}
|
||||
func newService(channel *ws.Channel) ws.Service {
|
||||
subs := map[uint64]*ws.Channel{chanID: channel}
|
||||
pubsub := mocks.NewService(subs, broker.ErrInvalidMsg)
|
||||
return ws.New(pubsub)
|
||||
}
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
svc := newService()
|
||||
channel := ws.NewChannel()
|
||||
svc := newService(channel)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@ -49,8 +47,8 @@ func TestPublish(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
// Check if message was sent.
|
||||
go func(desc string, tcMsg mainflux.RawMessage) {
|
||||
msg := <-channel.Messages
|
||||
assert.Equal(t, tcMsg, msg, fmt.Sprintf("%s: expected %v got %v\n", desc, tcMsg, msg))
|
||||
receivedMsg := <-channel.Messages
|
||||
assert.Equal(t, tcMsg, receivedMsg, fmt.Sprintf("%s: expected %v got %v\n", desc, tcMsg, receivedMsg))
|
||||
}(tc.desc, tc.msg)
|
||||
|
||||
// Check if publish succeeded.
|
||||
@ -60,12 +58,13 @@ func TestPublish(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
svc := newService()
|
||||
channel := ws.NewChannel()
|
||||
svc := newService(channel)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
chanID uint64
|
||||
channel ws.Channel
|
||||
channel *ws.Channel
|
||||
err error
|
||||
}{
|
||||
{"subscription to valid channel", chanID, channel, nil},
|
||||
@ -78,11 +77,21 @@ func TestSubscribe(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
channel := ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
|
||||
channel.Close()
|
||||
_, closed := <-channel.Closed
|
||||
_, messagesClosed := <-channel.Messages
|
||||
assert.False(t, closed, "channel closed stayed open")
|
||||
assert.False(t, messagesClosed, "channel messages stayed open")
|
||||
func TestSend(t *testing.T) {
|
||||
channel := ws.NewChannel()
|
||||
go func(channel *ws.Channel) {
|
||||
receivedMsg := <-channel.Messages
|
||||
assert.Equal(t, msg, receivedMsg, fmt.Sprintf("send message to channel: expected %v got %v\n", msg, receivedMsg))
|
||||
}(channel)
|
||||
|
||||
channel.Send(msg)
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
channel := ws.NewChannel()
|
||||
go func() {
|
||||
closed := <-channel.Closed
|
||||
assert.True(t, closed, "channel closed stayed open")
|
||||
}()
|
||||
channel.Close()
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ func (lm *loggingMiddleware) Publish(msg mainflux.RawMessage) (err error) {
|
||||
return lm.svc.Publish(msg)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Subscribe(chanID uint64, channel ws.Channel) (err error) {
|
||||
func (lm *loggingMiddleware) Subscribe(chanID uint64, channel *ws.Channel) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method subscribe to channel %d took %s to complete", chanID, time.Since(begin))
|
||||
if err != nil {
|
||||
|
@ -36,6 +36,6 @@ func (mm *metricsMiddleware) Publish(msg mainflux.RawMessage) error {
|
||||
return mm.svc.Publish(msg)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) Subscribe(chanID uint64, channel ws.Channel) error {
|
||||
func (mm *metricsMiddleware) Subscribe(chanID uint64, channel *ws.Channel) error {
|
||||
return mm.svc.Subscribe(chanID, channel)
|
||||
}
|
||||
|
@ -50,30 +50,18 @@ func MakeHandler(svc ws.Service, cc mainflux.ThingsServiceClient, l log.Logger)
|
||||
func handshake(svc ws.Service) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
sub, err := authorize(r)
|
||||
if err == errNotFound {
|
||||
logger.Warn(fmt.Sprintf("Invalid channel id: %s", err))
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
switch err {
|
||||
case errNotFound:
|
||||
logger.Warn(fmt.Sprintf("Invalid channel id: %s", err))
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
case things.ErrUnauthorizedAccess:
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
default:
|
||||
logger.Warn(fmt.Sprintf("Failed to authorize: %s", err))
|
||||
e, ok := status.FromError(err)
|
||||
if ok {
|
||||
switch e.Code() {
|
||||
case codes.PermissionDenied:
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
default:
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -86,9 +74,7 @@ func handshake(svc ws.Service) http.HandlerFunc {
|
||||
}
|
||||
sub.conn = conn
|
||||
|
||||
// Subscribe to channel
|
||||
channel := ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
|
||||
sub.channel = channel
|
||||
sub.channel = ws.NewChannel()
|
||||
if err := svc.Subscribe(sub.chanID, sub.channel); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to subscribe to NATS subject: %s", err))
|
||||
conn.Close()
|
||||
@ -122,6 +108,10 @@ func authorize(r *http.Request) (subscription, error) {
|
||||
|
||||
id, err := auth.CanAccess(ctx, &mainflux.AccessReq{Token: authKey, ChanID: chanID})
|
||||
if err != nil {
|
||||
e, ok := status.FromError(err)
|
||||
if ok && e.Code() == codes.PermissionDenied {
|
||||
return subscription{}, things.ErrUnauthorizedAccess
|
||||
}
|
||||
return subscription{}, err
|
||||
}
|
||||
|
||||
@ -137,14 +127,14 @@ type subscription struct {
|
||||
pubID uint64
|
||||
chanID uint64
|
||||
conn *websocket.Conn
|
||||
channel ws.Channel
|
||||
channel *ws.Channel
|
||||
}
|
||||
|
||||
func (sub subscription) broadcast(svc ws.Service) {
|
||||
for {
|
||||
_, payload, err := sub.conn.ReadMessage()
|
||||
if websocket.IsUnexpectedCloseError(err) {
|
||||
sub.channel.Closed <- true
|
||||
sub.channel.Close()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -26,11 +26,11 @@ const (
|
||||
|
||||
var (
|
||||
msg = []byte(`[{"n":"current","t":-5,"v":1.2}]`)
|
||||
channel = ws.Channel{make(chan mainflux.RawMessage), make(chan bool)}
|
||||
channel = ws.NewChannel()
|
||||
)
|
||||
|
||||
func newService() ws.Service {
|
||||
subs := map[uint64]ws.Channel{chanID: channel}
|
||||
subs := map[uint64]*ws.Channel{chanID: channel}
|
||||
pubsub := mocks.NewService(subs, broker.ErrConnectionClosed)
|
||||
return ws.New(pubsub)
|
||||
}
|
||||
|
@ -10,13 +10,13 @@ import (
|
||||
var _ ws.Service = (*mockService)(nil)
|
||||
|
||||
type mockService struct {
|
||||
subscriptions map[uint64]ws.Channel
|
||||
subscriptions map[uint64]*ws.Channel
|
||||
pubError error
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewService returns mock message publisher.
|
||||
func NewService(subs map[uint64]ws.Channel, pubError error) ws.Service {
|
||||
func NewService(subs map[uint64]*ws.Channel, pubError error) ws.Service {
|
||||
return &mockService{subs, pubError, sync.Mutex{}}
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ func (svc *mockService) Publish(msg mainflux.RawMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *mockService) Subscribe(chanID uint64, channel ws.Channel) error {
|
||||
func (svc *mockService) Subscribe(chanID uint64, channel *ws.Channel) error {
|
||||
svc.mutex.Lock()
|
||||
defer svc.mutex.Unlock()
|
||||
if _, ok := svc.subscriptions[chanID]; !ok {
|
||||
|
@ -47,7 +47,7 @@ func (pubsub *natsPubSub) Publish(msg mainflux.RawMessage) error {
|
||||
return pubsub.nc.Publish(fmt.Sprintf("%s.%d", prefix, msg.Channel), data)
|
||||
}
|
||||
|
||||
func (pubsub *natsPubSub) Subscribe(chanID uint64, channel ws.Channel) error {
|
||||
func (pubsub *natsPubSub) Subscribe(chanID uint64, channel *ws.Channel) error {
|
||||
var sub *broker.Subscription
|
||||
sub, err := pubsub.nc.Subscribe(fmt.Sprintf("%s.%d", prefix, chanID), func(msg *broker.Msg) {
|
||||
if msg == nil {
|
||||
@ -59,19 +59,14 @@ func (pubsub *natsPubSub) Subscribe(chanID uint64, channel ws.Channel) error {
|
||||
return
|
||||
}
|
||||
|
||||
// Prevents sending message to closed channel
|
||||
select {
|
||||
case channel.Messages <- rawMsg:
|
||||
case <-channel.Closed:
|
||||
sub.Unsubscribe()
|
||||
}
|
||||
// Sends message to messages channel
|
||||
channel.Send(rawMsg)
|
||||
})
|
||||
|
||||
// Check if subscription should be closed
|
||||
go func() {
|
||||
<-channel.Closed
|
||||
sub.Unsubscribe()
|
||||
channel.Close()
|
||||
}()
|
||||
|
||||
return err
|
||||
|
Loading…
x
Reference in New Issue
Block a user