diff --git a/pkg/messaging/mqtt/docs.go b/pkg/messaging/mqtt/docs.go new file mode 100644 index 00000000..d978eb3f --- /dev/null +++ b/pkg/messaging/mqtt/docs.go @@ -0,0 +1,11 @@ +// Copyright (c) Mainflux +// SPDX-License-Identifier: Apache-2.0 + +// Package mqtt hold the implementation of the Publisher and PubSub +// interfaces for the MQTT messaging system, the internal messaging +// broker of the Mainflux IoT platform. Due to the practical requirements +// implementation Publisher is created alongside PubSub. The reason for +// this is that Subscriber implementation of MQTT brings the burden of +// additional struct fields which are not used by Publisher. Subscriber +// is not implemented separately because PubSub can be used where Subscriber is needed. +package mqtt diff --git a/pkg/messaging/mqtt/publisher.go b/pkg/messaging/mqtt/publisher.go index 9ab7e536..525b90d2 100644 --- a/pkg/messaging/mqtt/publisher.go +++ b/pkg/messaging/mqtt/publisher.go @@ -8,7 +8,7 @@ import ( "time" mqtt "github.com/eclipse/paho.mqtt.golang" - + "github.com/gogo/protobuf/proto" "github.com/mainflux/mainflux/pkg/messaging" ) @@ -36,7 +36,14 @@ func NewPublisher(address string, timeout time.Duration) (messaging.Publisher, e } func (pub publisher) Publish(topic string, msg messaging.Message) error { - token := pub.client.Publish(topic, qos, false, msg.Payload) + if topic == "" { + return ErrEmptyTopic + } + data, err := proto.Marshal(&msg) + if err != nil { + return err + } + token := pub.client.Publish(topic, qos, false, data) if token.Error() != nil { return token.Error() } diff --git a/pkg/messaging/mqtt/pubsub.go b/pkg/messaging/mqtt/pubsub.go index a4790629..3376939b 100644 --- a/pkg/messaging/mqtt/pubsub.go +++ b/pkg/messaging/mqtt/pubsub.go @@ -21,13 +21,29 @@ const ( ) var ( - errConnect = errors.New("failed to connect to MQTT broker") - errSubscribeTimeout = errors.New("failed to subscribe due to timeout reached") - errUnsubscribeTimeout = errors.New("failed to unsubscribe due to timeout reached") - errUnsubscribeDeleteTopic = errors.New("failed to unsubscribe due to deletion of topic") - errNotSubscribed = errors.New("not subscribed") - errEmptyTopic = errors.New("empty topic") - errEmptyID = errors.New("empty ID") + // ErrConnect indicates that connection to MQTT broker failed + ErrConnect = errors.New("failed to connect to MQTT broker") + + // ErrSubscribeTimeout indicates that the subscription failed due to timeout. + ErrSubscribeTimeout = errors.New("failed to subscribe due to timeout reached") + + // ErrUnsubscribeTimeout indicates that unsubscribe failed due to timeout. + ErrUnsubscribeTimeout = errors.New("failed to unsubscribe due to timeout reached") + + // ErrUnsubscribeDeleteTopic indicates that unsubscribe failed because the topic was deleted. + ErrUnsubscribeDeleteTopic = errors.New("failed to unsubscribe due to deletion of topic") + + // ErrNotSubscribed indicates that the topic is not subscribed to. + ErrNotSubscribed = errors.New("not subscribed") + + // ErrEmptyTopic indicates the absence of topic. + ErrEmptyTopic = errors.New("empty topic") + + // ErrEmptyID indicates the absence of ID. + ErrEmptyID = errors.New("empty ID") + + // ErrFailedHandleMessage indicates that the message couldn't be handled. + ErrFailedHandleMessage = errors.New("failed to handle mainflux message") ) var _ messaging.PubSub = (*pubsub)(nil) @@ -35,12 +51,13 @@ var _ messaging.PubSub = (*pubsub)(nil) type subscription struct { client mqtt.Client topics []string + cancel func() error } type pubsub struct { publisher logger log.Logger - mu *sync.RWMutex + mu sync.RWMutex address string timeout time.Duration subscriptions map[string]subscription @@ -52,7 +69,7 @@ func NewPubSub(url, queue string, timeout time.Duration, logger log.Logger) (mes if err != nil { return nil, err } - ret := pubsub{ + ret := &pubsub{ publisher: publisher{ client: client, timeout: timeout, @@ -65,40 +82,25 @@ func NewPubSub(url, queue string, timeout time.Duration, logger log.Logger) (mes return ret, nil } -func (ps pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) error { +func (ps *pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) error { if id == "" { - return errEmptyID + return ErrEmptyID } if topic == "" { - return errEmptyTopic + return ErrEmptyTopic } ps.mu.Lock() defer ps.mu.Unlock() - // Check client ID + s, ok := ps.subscriptions[id] + // If the client exists, check if it's subscribed to the topic and unsubscribe if needed. switch ok { case true: - // Check topic - if ok = s.contains(topic); ok { - // Unlocking, so that Unsubscribe() can access ps.subscriptions - ps.mu.Unlock() - err := ps.Unsubscribe(id, topic) - ps.mu.Lock() // Lock so that deferred unlock handle it - if err != nil { + if ok := s.contains(topic); ok { + if err := s.unsubscribe(topic, ps.timeout); err != nil { return err } - if len(ps.subscriptions) == 0 { - client, err := newClient(ps.address, id, ps.timeout) - if err != nil { - return err - } - s = subscription{ - client: client, - topics: []string{topic}, - } - } } - s.topics = append(s.topics, topic) default: client, err := newClient(ps.address, id, ps.timeout) if err != nil { @@ -106,59 +108,94 @@ func (ps pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) e } s = subscription{ client: client, - topics: []string{topic}, + topics: []string{}, + cancel: handler.Cancel, } } + s.topics = append(s.topics, topic) + ps.subscriptions[id] = s token := s.client.Subscribe(topic, qos, ps.mqttHandler(handler)) if token.Error() != nil { return token.Error() } if ok := token.WaitTimeout(ps.timeout); !ok { - return errSubscribeTimeout + return ErrSubscribeTimeout } return token.Error() } -func (ps pubsub) Unsubscribe(id, topic string) error { +func (ps *pubsub) Unsubscribe(id, topic string) error { if id == "" { - return errEmptyID + return ErrEmptyID } if topic == "" { - return errEmptyTopic + return ErrEmptyTopic } ps.mu.Lock() defer ps.mu.Unlock() - // Check client ID + s, ok := ps.subscriptions[id] - switch ok { - case true: - // Check topic - if ok := s.contains(topic); !ok { - return errNotSubscribed - } - default: - return errNotSubscribed + if !ok || !s.contains(topic) { + return ErrNotSubscribed } + + if err := s.unsubscribe(topic, ps.timeout); err != nil { + return err + } + ps.subscriptions[id] = s + + if len(s.topics) == 0 { + delete(ps.subscriptions, id) + } + return nil +} + +func (s *subscription) unsubscribe(topic string, timeout time.Duration) error { + if s.cancel != nil { + if err := s.cancel(); err != nil { + return err + } + } + token := s.client.Unsubscribe(topic) if token.Error() != nil { return token.Error() } - ok = token.WaitTimeout(ps.timeout) - if !ok { - return errUnsubscribeTimeout + if ok := token.WaitTimeout(timeout); !ok { + return ErrUnsubscribeTimeout } if ok := s.delete(topic); !ok { - return errUnsubscribeDeleteTopic - } - if len(s.topics) == 0 { - delete(ps.subscriptions, id) + return ErrUnsubscribeDeleteTopic } return token.Error() } -func (ps pubsub) mqttHandler(h messaging.MessageHandler) mqtt.MessageHandler { +func newClient(address, id string, timeout time.Duration) (mqtt.Client, error) { + opts := mqtt.NewClientOptions(). + SetUsername(username). + AddBroker(address). + SetClientID(id) + client := mqtt.NewClient(opts) + token := client.Connect() + if token.Error() != nil { + return nil, token.Error() + } + + ok := token.WaitTimeout(timeout) + if !ok { + return nil, ErrConnect + } + + if token.Error() != nil { + return nil, token.Error() + } + + return client, nil +} + +func (ps *pubsub) mqttHandler(h messaging.MessageHandler) mqtt.MessageHandler { return func(c mqtt.Client, m mqtt.Message) { var msg messaging.Message if err := proto.Unmarshal(m.Payload(), &msg); err != nil { @@ -171,34 +208,14 @@ func (ps pubsub) mqttHandler(h messaging.MessageHandler) mqtt.MessageHandler { } } -func newClient(address, id string, timeout time.Duration) (mqtt.Client, error) { - opts := mqtt.NewClientOptions().SetUsername(username).AddBroker(address).SetClientID(id) - client := mqtt.NewClient(opts) - token := client.Connect() - if token.Error() != nil { - return nil, token.Error() - } - - ok := token.WaitTimeout(timeout) - if !ok { - return nil, errConnect - } - - if token.Error() != nil { - return nil, token.Error() - } - - return client, nil +// Contains checks if a topic is present. +func (s subscription) contains(topic string) bool { + return s.indexOf(topic) != -1 } -// contains checks if a topic is present -func (sub subscription) contains(topic string) bool { - return sub.indexOf(topic) != -1 -} - -// Finds the index of an item in the topics -func (sub subscription) indexOf(element string) int { - for k, v := range sub.topics { +// Finds the index of an item in the topics. +func (s subscription) indexOf(element string) int { + for k, v := range s.topics { if element == v { return k } @@ -206,15 +223,15 @@ func (sub subscription) indexOf(element string) int { return -1 } -// Deletes a topic from the slice -func (sub subscription) delete(topic string) bool { - index := sub.indexOf(topic) +// Deletes a topic from the slice. +func (s *subscription) delete(topic string) bool { + index := s.indexOf(topic) if index == -1 { return false } - topics := make([]string, len(sub.topics)-1) - copy(topics[:index], sub.topics[:index]) - copy(topics[index:], sub.topics[index+1:]) - sub.topics = topics + topics := make([]string, len(s.topics)-1) + copy(topics[:index], s.topics[:index]) + copy(topics[index:], s.topics[index+1:]) + s.topics = topics return true } diff --git a/pkg/messaging/mqtt/pubsub_test.go b/pkg/messaging/mqtt/pubsub_test.go new file mode 100644 index 00000000..bb77a227 --- /dev/null +++ b/pkg/messaging/mqtt/pubsub_test.go @@ -0,0 +1,438 @@ +// Copyright (c) Mainflux +// SPDX-License-Identifier: Apache-2.0 + +package mqtt_test + +import ( + "fmt" + "testing" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/gogo/protobuf/proto" + "github.com/mainflux/mainflux/pkg/messaging" + mqtt_pubsub "github.com/mainflux/mainflux/pkg/messaging/mqtt" + "github.com/stretchr/testify/assert" +) + +const ( + topic = "topic" + chansPrefix = "channels" + channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b" + subtopic = "engine" + tokenTimeout = 100 * time.Millisecond +) + +var ( + data = []byte("payload") +) + +func TestPublisher(t *testing.T) { + msgChan := make(chan []byte) + + // Subscribing with topic, and with subtopic, so that we can publish messages. + client, err := newClient(address, "clientID1", brokerTimeout) + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + token := client.Subscribe(topic, qos, func(c mqtt.Client, m mqtt.Message) { + msgChan <- m.Payload() + }) + if ok := token.WaitTimeout(tokenTimeout); !ok { + assert.Fail(t, fmt.Sprintf("failed to subscribe to topic %s", topic)) + } + assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error())) + + token = client.Subscribe(fmt.Sprintf("%s.%s", topic, subtopic), qos, func(c mqtt.Client, m mqtt.Message) { + msgChan <- m.Payload() + }) + if ok := token.WaitTimeout(tokenTimeout); !ok { + assert.Fail(t, fmt.Sprintf("failed to subscribe to topic %s", fmt.Sprintf("%s.%s", topic, subtopic))) + } + assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error())) + + t.Cleanup(func() { + token := client.Unsubscribe(topic, fmt.Sprintf("%s.%s", topic, subtopic)) + token.WaitTimeout(tokenTimeout) + assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error())) + + client.Disconnect(100) + }) + + // Test publish with an empty topic. + err = pubsub.Publish("", messaging.Message{Payload: data}) + assert.Equal(t, err, mqtt_pubsub.ErrEmptyTopic, fmt.Sprintf("Publish with empty topic: expected: %s, got: %s", mqtt_pubsub.ErrEmptyTopic, err)) + + cases := []struct { + desc string + channel string + subtopic string + payload []byte + }{ + { + desc: "publish message with nil payload", + payload: nil, + }, + { + desc: "publish message with string payload", + payload: data, + }, + { + desc: "publish message with channel", + payload: data, + channel: channel, + }, + { + desc: "publish message with subtopic", + payload: data, + subtopic: subtopic, + }, + { + desc: "publish message with channel and subtopic", + payload: data, + channel: channel, + subtopic: subtopic, + }, + } + for _, tc := range cases { + expectedMsg := messaging.Message{ + Publisher: "clientID11", + Channel: tc.channel, + Subtopic: tc.subtopic, + Payload: tc.payload, + } + err := pubsub.Publish(topic, expectedMsg) + assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s\n", tc.desc, err)) + + data, err := proto.Marshal(&expectedMsg) + assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err)) + + receivedMsg := <-msgChan + assert.Equal(t, data, receivedMsg, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, data, receivedMsg)) + } +} + +func TestSubscribe(t *testing.T) { + msgChan := make(chan messaging.Message) + + // Creating client to Publish messages to subscribed topic. + client, err := newClient(address, "mainflux", brokerTimeout) + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + t.Cleanup(func() { + client.Unsubscribe() + client.Disconnect(100) + }) + + cases := []struct { + desc string + topic string + clientID string + err error + handler messaging.MessageHandler + }{ + { + desc: "Subscribe to a topic with an ID", + topic: topic, + clientID: "clientid1", + err: nil, + handler: handler{false, "clientid1", msgChan}, + }, + { + desc: "Subscribe to the same topic with a different ID", + topic: topic, + clientID: "clientid2", + err: nil, + handler: handler{false, "clientid2", msgChan}, + }, + { + desc: "Subscribe to an already subscribed topic with an ID", + topic: topic, + clientID: "clientid1", + err: nil, + handler: handler{false, "clientid1", msgChan}, + }, + { + desc: "Subscribe to a topic with a subtopic with an ID", + topic: fmt.Sprintf("%s.%s", topic, subtopic), + clientID: "clientid1", + err: nil, + handler: handler{false, "clientid1", msgChan}, + }, + { + desc: "Subscribe to an already subscribed topic with a subtopic with an ID", + topic: fmt.Sprintf("%s.%s", topic, subtopic), + clientID: "clientid1", + err: nil, + handler: handler{false, "clientid1", msgChan}, + }, + { + desc: "Subscribe to an empty topic with an ID", + topic: "", + clientID: "clientid1", + err: mqtt_pubsub.ErrEmptyTopic, + handler: handler{false, "clientid1", msgChan}, + }, + { + desc: "Subscribe to a topic with empty id", + topic: topic, + clientID: "", + err: mqtt_pubsub.ErrEmptyID, + handler: handler{false, "", msgChan}, + }, + } + for _, tc := range cases { + err = pubsub.Subscribe(tc.clientID, tc.topic, tc.handler) + assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err)) + + if tc.err == nil { + expectedMsg := messaging.Message{ + Publisher: "clientID1", + Channel: channel, + Subtopic: subtopic, + Payload: data, + } + data, err := proto.Marshal(&expectedMsg) + assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err)) + + token := client.Publish(tc.topic, qos, false, data) + token.WaitTimeout(tokenTimeout) + assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error())) + + receivedMsg := <-msgChan + assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, expectedMsg, receivedMsg)) + } + } +} + +func TestPubSub(t *testing.T) { + msgChan := make(chan messaging.Message) + + cases := []struct { + desc string + topic string + clientID string + err error + handler messaging.MessageHandler + }{ + { + desc: "Subscribe to a topic with an ID", + topic: topic, + clientID: "clientid7", + err: nil, + handler: handler{false, "clientid7", msgChan}, + }, + { + desc: "Subscribe to the same topic with a different ID", + topic: topic, + clientID: "clientid8", + err: nil, + handler: handler{false, "clientid8", msgChan}, + }, + { + desc: "Subscribe to a topic with a subtopic with an ID", + topic: fmt.Sprintf("%s.%s", topic, subtopic), + clientID: "clientid7", + err: nil, + handler: handler{false, "clientid7", msgChan}, + }, + { + desc: "Subscribe to an empty topic with an ID", + topic: "", + clientID: "clientid7", + err: mqtt_pubsub.ErrEmptyTopic, + handler: handler{false, "clientid7", msgChan}, + }, + { + desc: "Subscribe to a topic with empty id", + topic: topic, + clientID: "", + err: mqtt_pubsub.ErrEmptyID, + handler: handler{false, "", msgChan}, + }, + } + for _, tc := range cases { + err := pubsub.Subscribe(tc.clientID, tc.topic, tc.handler) + assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err)) + + if tc.err == nil { + // Use pubsub to subscribe to a topic, and then publish messages to that topic. + expectedMsg := messaging.Message{ + Publisher: "clientID", + Channel: channel, + Subtopic: subtopic, + Payload: data, + } + + // Publish message, and then receive it on message channel. + err := pubsub.Publish(topic, expectedMsg) + assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s\n", tc.desc, err)) + + receivedMsg := <-msgChan + assert.Equal(t, expectedMsg, receivedMsg, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, expectedMsg, receivedMsg)) + } + } +} + +func TestUnsubscribe(t *testing.T) { + msgChan := make(chan messaging.Message) + + cases := []struct { + desc string + topic string + clientID string + err error + subscribe bool // True for subscribe and false for unsubscribe. + handler messaging.MessageHandler + }{ + { + desc: "Subscribe to a topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic), + clientID: "clientid4", + err: nil, + subscribe: true, + handler: handler{false, "clientid4", msgChan}, + }, + { + desc: "Subscribe to the same topic with a different ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic), + clientID: "clientid9", + err: nil, + subscribe: true, + handler: handler{false, "clientid9", msgChan}, + }, + { + desc: "Unsubscribe from a topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic), + clientID: "clientid4", + err: nil, + subscribe: false, + handler: handler{false, "clientid4", msgChan}, + }, + { + desc: "Unsubscribe from same topic with different ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic), + clientID: "clientid9", + err: nil, + subscribe: false, + handler: handler{false, "clientid9", msgChan}, + }, + { + desc: "Unsubscribe from a non-existent topic with an ID", + topic: "h", + clientID: "clientid4", + err: mqtt_pubsub.ErrNotSubscribed, + subscribe: false, + handler: handler{false, "clientid4", msgChan}, + }, + { + desc: "Unsubscribe from an already unsubscribed topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic), + clientID: "clientid4", + err: mqtt_pubsub.ErrNotSubscribed, + subscribe: false, + handler: handler{false, "clientid4", msgChan}, + }, + { + desc: "Subscribe to a topic with a subtopic with an ID", + topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic), + clientID: "clientidd4", + err: nil, + subscribe: true, + handler: handler{false, "clientidd4", msgChan}, + }, + { + desc: "Unsubscribe from a topic with a subtopic with an ID", + topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic), + clientID: "clientidd4", + err: nil, + subscribe: false, + handler: handler{false, "clientidd4", msgChan}, + }, + { + desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID", + topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic), + clientID: "clientid4", + err: mqtt_pubsub.ErrNotSubscribed, + subscribe: false, + handler: handler{false, "clientid4", msgChan}, + }, + { + desc: "Unsubscribe from an empty topic with an ID", + topic: "", + clientID: "clientid4", + err: mqtt_pubsub.ErrEmptyTopic, + subscribe: false, + handler: handler{false, "clientid4", msgChan}, + }, + { + desc: "Unsubscribe from a topic with empty ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic), + clientID: "", + err: mqtt_pubsub.ErrEmptyID, + subscribe: false, + handler: handler{false, "", msgChan}, + }, + { + desc: "Subscribe to a new topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"), + clientID: "clientid55", + err: nil, + subscribe: true, + handler: handler{true, "clientid5", msgChan}, + }, + { + desc: "Unsubscribe from a topic with an ID with failing handler", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"), + clientID: "clientid55", + err: mqtt_pubsub.ErrFailedHandleMessage, + subscribe: false, + handler: handler{true, "clientid5", msgChan}, + }, + { + desc: "Subscribe to a new topic with subtopic with an ID", + topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic), + clientID: "clientid55", + err: nil, + subscribe: true, + handler: handler{true, "clientid5", msgChan}, + }, + { + desc: "Unsubscribe from a topic with subtopic with an ID with failing handler", + topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic), + clientID: "clientid55", + err: mqtt_pubsub.ErrFailedHandleMessage, + subscribe: false, + handler: handler{true, "clientid5", msgChan}, + }, + } + for _, tc := range cases { + switch tc.subscribe { + case true: + err := pubsub.Subscribe(tc.clientID, tc.topic, tc.handler) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err)) + default: + err := pubsub.Unsubscribe(tc.clientID, tc.topic) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err)) + } + } +} + +type handler struct { + fail bool + publisher string + msgChan chan messaging.Message +} + +func (h handler) Handle(msg messaging.Message) error { + if msg.Publisher != h.publisher { + h.msgChan <- msg + } + return nil +} + +func (h handler) Cancel() error { + if h.fail { + return mqtt_pubsub.ErrFailedHandleMessage + } + return nil +} diff --git a/pkg/messaging/mqtt/setup_test.go b/pkg/messaging/mqtt/setup_test.go new file mode 100644 index 00000000..f1757670 --- /dev/null +++ b/pkg/messaging/mqtt/setup_test.go @@ -0,0 +1,115 @@ +// Copyright (c) Mainflux +// SPDX-License-Identifier: Apache-2.0 + +package mqtt_test + +import ( + "fmt" + "log" + "os" + "os/signal" + "syscall" + "testing" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" + mainflux_log "github.com/mainflux/mainflux/logger" + "github.com/mainflux/mainflux/pkg/messaging" + mqtt_pubsub "github.com/mainflux/mainflux/pkg/messaging/mqtt" + "github.com/ory/dockertest/v3" +) + +var ( + pubsub messaging.PubSub + logger mainflux_log.Logger + address string +) + +const ( + username = "mainflux-mqtt" + qos = 2 + port = "1883/tcp" + broker = "eclipse-mosquitto" + brokerVersion = "1.6.13" + brokerTimeout = 30 * time.Second + poolMaxWait = 120 * time.Second +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.Run(broker, brokerVersion, nil) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + handleInterrupt(m, pool, container) + + address = fmt.Sprintf("%s:%s", "localhost", container.GetPort(port)) + pool.MaxWait = poolMaxWait + + logger, err = mainflux_log.New(os.Stdout, mainflux_log.Debug.String()) + if err != nil { + log.Fatalf(err.Error()) + } + + if err := pool.Retry(func() error { + pubsub, err = mqtt_pubsub.NewPubSub(address, "mainflux", brokerTimeout, logger) + return err + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + code := m.Run() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) + + defer func() { + err = pubsub.Close() + if err != nil { + log.Fatalf(err.Error()) + } + }() +} + +func handleInterrupt(m *testing.M, pool *dockertest.Pool, container *dockertest.Resource) { + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + os.Exit(0) + }() +} + +func newClient(address, id string, timeout time.Duration) (mqtt.Client, error) { + opts := mqtt.NewClientOptions(). + SetUsername(username). + AddBroker(address). + SetClientID(id) + + client := mqtt.NewClient(opts) + token := client.Connect() + if token.Error() != nil { + return nil, token.Error() + } + + ok := token.WaitTimeout(timeout) + if !ok { + return nil, mqtt_pubsub.ErrConnect + } + + if token.Error() != nil { + return nil, token.Error() + } + + return client, nil +}