diff --git a/bootstrap/mocks/things.go b/bootstrap/mocks/things.go index 893d98e8..02ff7c75 100644 --- a/bootstrap/mocks/things.go +++ b/bootstrap/mocks/things.go @@ -90,35 +90,38 @@ func (svc *mainfluxThings) Connect(_ context.Context, owner string, chIDs, thIDs return nil } -func (svc *mainfluxThings) Disconnect(_ context.Context, owner, chanID, thingID string) error { +func (svc *mainfluxThings) Disconnect(_ context.Context, owner string, chIDs, thIDs []string) error { svc.mu.Lock() defer svc.mu.Unlock() userID, err := svc.auth.Identify(context.Background(), &mainflux.Token{Value: owner}) - if err != nil || svc.channels[chanID].Owner != userID.Email { + if err != nil { return things.ErrUnauthorizedAccess } - ids := svc.connections[chanID] - i := 0 - for _, t := range ids { - if t == thingID { - break + for _, chID := range chIDs { + if svc.channels[chID].Owner != userID.Email { + return things.ErrUnauthorizedAccess } - i++ - } - if i == len(ids) { - return things.ErrNotFound - } + ids := svc.connections[chID] + var count int + var newConns []string + for _, thID := range thIDs { + for _, id := range ids { + if id == thID { + count++ + continue + } + newConns = append(newConns, id) + } - var tmp []string - if i != len(ids)-2 { - tmp = ids[i+1:] + if len(newConns)-len(ids) != count { + return things.ErrNotFound + } + svc.connections[chID] = newConns + } } - ids = append(ids[:i], tmp...) - svc.connections[chanID] = ids - return nil } diff --git a/things/api/logging.go b/things/api/logging.go index 0952e044..84acfbaf 100644 --- a/things/api/logging.go +++ b/things/api/logging.go @@ -214,9 +214,9 @@ func (lm *loggingMiddleware) Connect(ctx context.Context, token string, chIDs, t return lm.svc.Connect(ctx, token, chIDs, thIDs) } -func (lm *loggingMiddleware) Disconnect(ctx context.Context, token, chanID, thingID string) (err error) { +func (lm *loggingMiddleware) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) (err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method disconnect for token %s, channel %s and thing %s took %s to complete", token, chanID, thingID, time.Since(begin)) + message := fmt.Sprintf("Method disconnect for token %s, channels %v and things %v took %s to complete", token, chIDs, thIDs, time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return @@ -224,7 +224,7 @@ func (lm *loggingMiddleware) Disconnect(ctx context.Context, token, chanID, thin lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.Disconnect(ctx, token, chanID, thingID) + return lm.svc.Disconnect(ctx, token, chIDs, thIDs) } func (lm *loggingMiddleware) CanAccessByKey(ctx context.Context, id, key string) (thing string, err error) { @@ -289,5 +289,5 @@ func (lm *loggingMiddleware) ListMembers(ctx context.Context, token, groupID str lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.ListMembers(ctx, token, groupID, pm) + return lm.svc.ListMembers(ctx, token, groupID, pm) } diff --git a/things/api/metrics.go b/things/api/metrics.go index e1204082..0f1aabc7 100644 --- a/things/api/metrics.go +++ b/things/api/metrics.go @@ -156,13 +156,13 @@ func (ms *metricsMiddleware) Connect(ctx context.Context, token string, chIDs, t return ms.svc.Connect(ctx, token, chIDs, thIDs) } -func (ms *metricsMiddleware) Disconnect(ctx context.Context, token, chanID, thingID string) error { +func (ms *metricsMiddleware) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error { defer func(begin time.Time) { ms.counter.With("method", "disconnect").Add(1) ms.latency.With("method", "disconnect").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.Disconnect(ctx, token, chanID, thingID) + return ms.svc.Disconnect(ctx, token, chIDs, thIDs) } func (ms *metricsMiddleware) CanAccessByKey(ctx context.Context, id, key string) (string, error) { @@ -207,5 +207,5 @@ func (ms *metricsMiddleware) ListMembers(ctx context.Context, token, groupID str ms.latency.With("method", "list_members").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.ListMembers(ctx, token, groupID, pm) + return ms.svc.ListMembers(ctx, token, groupID, pm) } diff --git a/things/api/things/http/endpoint.go b/things/api/things/http/endpoint.go index 990527fa..f38bf6df 100644 --- a/things/api/things/http/endpoint.go +++ b/things/api/things/http/endpoint.go @@ -444,9 +444,9 @@ func removeChannelEndpoint(svc things.Service) endpoint.Endpoint { } } -func connectEndpoint(svc things.Service) endpoint.Endpoint { +func connectThingEndpoint(svc things.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - cr := request.(connectionReq) + cr := request.(connectThingReq) if err := cr.validate(); err != nil { return nil, err @@ -456,13 +456,13 @@ func connectEndpoint(svc things.Service) endpoint.Endpoint { return nil, err } - return connectionRes{}, nil + return connectThingRes{}, nil } } -func createConnectionsEndpoint(svc things.Service) endpoint.Endpoint { +func connectEndpoint(svc things.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - cr := request.(createConnectionsReq) + cr := request.(connectReq) if err := cr.validate(); err != nil { return nil, err @@ -472,23 +472,38 @@ func createConnectionsEndpoint(svc things.Service) endpoint.Endpoint { return nil, err } - return createConnectionsRes{}, nil + return connectRes{}, nil } } func disconnectEndpoint(svc things.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - cr := request.(connectionReq) - + cr := request.(connectReq) if err := cr.validate(); err != nil { return nil, err } - if err := svc.Disconnect(ctx, cr.token, cr.chanID, cr.thingID); err != nil { + if err := svc.Disconnect(ctx, cr.token, cr.ChannelIDs, cr.ThingIDs); err != nil { return nil, err } - return disconnectionRes{}, nil + return disconnectRes{}, nil + } +} + +func disconnectThingEndpoint(svc things.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(connectThingReq) + + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.Disconnect(ctx, req.token, []string{req.chanID}, []string{req.thingID}); err != nil { + return nil, err + } + + return disconnectThingRes{}, nil } } diff --git a/things/api/things/http/endpoint_test.go b/things/api/things/http/endpoint_test.go index 4b01a78b..2f9bf3e1 100644 --- a/things/api/things/http/endpoint_test.go +++ b/things/api/things/http/endpoint_test.go @@ -2407,6 +2407,198 @@ func TestCreateConnections(t *testing.T) { } } +func TestDisconnectList(t *testing.T) { + otherToken := "other_token" + otherEmail := "other_user@example.com" + svc := newService(map[string]string{ + token: email, + otherToken: otherEmail, + }) + ts := newServer(svc) + defer ts.Close() + + ths, err := svc.CreateThings(context.Background(), token, thing) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err)) + thIDs := []string{} + for _, th := range ths { + thIDs = append(thIDs, th.ID) + } + + chs, err := svc.CreateChannels(context.Background(), token, channel) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err)) + chIDs1 := []string{} + for _, ch := range chs { + chIDs1 = append(chIDs1, ch.ID) + } + + chs, err = svc.CreateChannels(context.Background(), otherToken, channel) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err)) + chIDs2 := []string{} + for _, ch := range chs { + chIDs2 = append(chIDs2, ch.ID) + } + + err = svc.Connect(context.Background(), token, chIDs1, thIDs) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err)) + + cases := []struct { + desc string + channelIDs []string + thingIDs []string + auth string + contentType string + body string + status int + }{ + { + desc: "disconnect existing things from existing channels", + channelIDs: chIDs1, + thingIDs: thIDs, + auth: token, + contentType: contentType, + status: http.StatusOK, + }, + { + desc: "disconnect existing things from non-existent channels", + channelIDs: []string{strconv.FormatUint(wrongID, 10)}, + thingIDs: thIDs, + auth: token, + contentType: contentType, + status: http.StatusNotFound, + }, + { + desc: "disconnect non-existing things from existing channels", + channelIDs: chIDs1, + thingIDs: []string{strconv.FormatUint(wrongID, 10)}, + auth: token, + contentType: contentType, + status: http.StatusNotFound, + }, + { + desc: "disconnect existing things from channel with invalid id", + channelIDs: []string{"invalid"}, + thingIDs: thIDs, + auth: token, + contentType: contentType, + status: http.StatusNotFound, + }, + { + desc: "disconnect things with invalid id from existing channels", + channelIDs: chIDs1, + thingIDs: []string{"invalid"}, + auth: token, + contentType: contentType, + status: http.StatusNotFound, + }, + { + desc: "disconnect existing things from empty channel ids", + channelIDs: []string{""}, + thingIDs: thIDs, + auth: token, + contentType: contentType, + status: http.StatusBadRequest, + }, + { + desc: "disconnect empty things id from existing channels", + channelIDs: chIDs1, + thingIDs: []string{""}, + auth: token, + contentType: contentType, + status: http.StatusBadRequest, + }, + { + desc: "disconnect existing things from existing channels with invalid token", + channelIDs: chIDs1, + thingIDs: thIDs, + auth: wrongValue, + contentType: contentType, + status: http.StatusUnauthorized, + }, + { + desc: "disconnect existing things from existing channels with empty token", + channelIDs: chIDs1, + thingIDs: thIDs, + auth: "", + contentType: contentType, + status: http.StatusUnauthorized, + }, + { + desc: "disconnect things from channels of other user", + channelIDs: chIDs2, + thingIDs: thIDs, + auth: token, + contentType: contentType, + status: http.StatusNotFound, + }, + { + desc: "disconnect with invalid content type", + channelIDs: chIDs2, + thingIDs: thIDs, + auth: token, + contentType: "invalid", + status: http.StatusUnsupportedMediaType, + }, + { + desc: "disconnect with invalid JSON", + auth: token, + contentType: contentType, + status: http.StatusBadRequest, + body: "{", + }, + { + desc: "disconnect valid thing ids from empty channel ids", + channelIDs: []string{}, + thingIDs: thIDs, + auth: token, + contentType: contentType, + status: http.StatusBadRequest, + }, + { + desc: "disconnect empty thing ids from valid channel ids", + channelIDs: chIDs1, + thingIDs: []string{}, + auth: token, + contentType: contentType, + status: http.StatusBadRequest, + }, + { + desc: "disconnect empty thing ids from empty channel ids", + channelIDs: []string{}, + thingIDs: []string{}, + auth: token, + contentType: contentType, + status: http.StatusBadRequest, + }, + } + + for _, tc := range cases { + data := struct { + ChannelIDs []string `json:"channel_ids"` + ThingIDs []string `json:"thing_ids"` + }{ + tc.channelIDs, + tc.thingIDs, + } + body := toJSON(data) + + if tc.body != "" { + body = tc.body + } + + req := testRequest{ + client: ts.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/disconnect", ts.URL), + contentType: tc.contentType, + token: tc.auth, + body: strings.NewReader(body), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + func TestDisconnnect(t *testing.T) { otherToken := "other_token" otherEmail := "other_user@example.com" diff --git a/things/api/things/http/requests.go b/things/api/things/http/requests.go index 44596498..745bb06e 100644 --- a/things/api/things/http/requests.go +++ b/things/api/things/http/requests.go @@ -248,13 +248,13 @@ func (req listByConnectionReq) validate() error { return nil } -type connectionReq struct { +type connectThingReq struct { token string chanID string thingID string } -func (req connectionReq) validate() error { +func (req connectThingReq) validate() error { if req.token == "" { return things.ErrUnauthorizedAccess } @@ -266,13 +266,13 @@ func (req connectionReq) validate() error { return nil } -type createConnectionsReq struct { +type connectReq struct { token string ChannelIDs []string `json:"channel_ids,omitempty"` ThingIDs []string `json:"thing_ids,omitempty"` } -func (req createConnectionsReq) validate() error { +func (req connectReq) validate() error { if req.token == "" { return things.ErrUnauthorizedAccess } diff --git a/things/api/things/http/responses.go b/things/api/things/http/responses.go index 3228dfa3..0a7309d6 100644 --- a/things/api/things/http/responses.go +++ b/things/api/things/http/responses.go @@ -18,8 +18,10 @@ var ( _ mainflux.Response = (*channelRes)(nil) _ mainflux.Response = (*viewChannelRes)(nil) _ mainflux.Response = (*channelsPageRes)(nil) - _ mainflux.Response = (*connectionRes)(nil) - _ mainflux.Response = (*disconnectionRes)(nil) + _ mainflux.Response = (*connectThingRes)(nil) + _ mainflux.Response = (*connectRes)(nil) + _ mainflux.Response = (*disconnectThingRes)(nil) + _ mainflux.Response = (*disconnectRes)(nil) ) type removeRes struct{} @@ -213,47 +215,61 @@ func (res channelsPageRes) Empty() bool { return false } -type connectionRes struct{} +type connectThingRes struct{} -func (res connectionRes) Code() int { +func (res connectThingRes) Code() int { return http.StatusOK } -func (res connectionRes) Headers() map[string]string { +func (res connectThingRes) Headers() map[string]string { return map[string]string{ "Warning-Deprecated": "This endpoint will be depreciated in v1.0.0. It will be replaced with the bulk endpoint found at /connect.", } } -func (res connectionRes) Empty() bool { +func (res connectThingRes) Empty() bool { return true } -type createConnectionsRes struct{} +type connectRes struct{} -func (res createConnectionsRes) Code() int { +func (res connectRes) Code() int { return http.StatusOK } -func (res createConnectionsRes) Headers() map[string]string { +func (res connectRes) Headers() map[string]string { return map[string]string{} } -func (res createConnectionsRes) Empty() bool { +func (res connectRes) Empty() bool { return true } -type disconnectionRes struct{} +type disconnectRes struct{} -func (res disconnectionRes) Code() int { - return http.StatusNoContent +func (res disconnectRes) Code() int { + return http.StatusOK } -func (res disconnectionRes) Headers() map[string]string { +func (res disconnectRes) Headers() map[string]string { return map[string]string{} } -func (res disconnectionRes) Empty() bool { +func (res disconnectRes) Empty() bool { + return true +} + +type disconnectThingRes struct{} + +func (res disconnectThingRes) Code() int { + return http.StatusNoContent +} + +func (res disconnectThingRes) Headers() map[string]string { + return map[string]string{} +} + +func (res disconnectThingRes) Empty() bool { return true } diff --git a/things/api/things/http/transport.go b/things/api/things/http/transport.go index a9324408..e3211ee8 100644 --- a/things/api/things/http/transport.go +++ b/things/api/things/http/transport.go @@ -155,23 +155,30 @@ func MakeHandler(tracer opentracing.Tracer, svc things.Service) http.Handler { opts..., )) - r.Put("/channels/:chanId/things/:thingId", kithttp.NewServer( + r.Post("/connect", kithttp.NewServer( kitot.TraceServer(tracer, "connect")(connectEndpoint(svc)), - decodeConnection, + decodeConnectList, encodeResponse, opts..., )) - r.Post("/connect", kithttp.NewServer( - kitot.TraceServer(tracer, "create_connections")(createConnectionsEndpoint(svc)), - decodeCreateConnections, + r.Delete("/disconnect", kithttp.NewServer( + kitot.TraceServer(tracer, "disconnect")(disconnectEndpoint(svc)), + decodeConnectList, + encodeResponse, + opts..., + )) + + r.Put("/channels/:chanId/things/:thingId", kithttp.NewServer( + kitot.TraceServer(tracer, "connect_thing")(connectThingEndpoint(svc)), + decodeConnectThing, encodeResponse, opts..., )) r.Delete("/channels/:chanId/things/:thingId", kithttp.NewServer( - kitot.TraceServer(tracer, "disconnect")(disconnectEndpoint(svc)), - decodeConnection, + kitot.TraceServer(tracer, "disconnect_thing")(disconnectThingEndpoint(svc)), + decodeConnectThing, encodeResponse, opts..., )) @@ -395,8 +402,8 @@ func decodeListByConnection(_ context.Context, r *http.Request) (interface{}, er return req, nil } -func decodeConnection(_ context.Context, r *http.Request) (interface{}, error) { - req := connectionReq{ +func decodeConnectThing(_ context.Context, r *http.Request) (interface{}, error) { + req := connectThingReq{ token: r.Header.Get("Authorization"), chanID: bone.GetValue(r, "chanId"), thingID: bone.GetValue(r, "thingId"), @@ -405,12 +412,12 @@ func decodeConnection(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } -func decodeCreateConnections(_ context.Context, r *http.Request) (interface{}, error) { +func decodeConnectList(_ context.Context, r *http.Request) (interface{}, error) { if !strings.Contains(r.Header.Get("Content-Type"), contentType) { return nil, errors.ErrUnsupportedContentType } - req := createConnectionsReq{token: r.Header.Get("Authorization")} + req := connectReq{token: r.Header.Get("Authorization")} if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, errors.Wrap(things.ErrMalformedEntity, err) } diff --git a/things/channels.go b/things/channels.go index 72e57483..a8ca01ab 100644 --- a/things/channels.go +++ b/things/channels.go @@ -8,7 +8,7 @@ import ( ) // Channel represents a Mainflux "communication group". This group contains the -// things that can exchange messages between eachother. +// things that can exchange messages between each other. type Channel struct { ID string Owner string @@ -49,12 +49,12 @@ type ChannelRepository interface { // by the specified user. Remove(ctx context.Context, owner, id string) error - // Connect adds things to the channel's list of connected things. + // Connect adds things to the channels list of connected things. Connect(ctx context.Context, owner string, chIDs, thIDs []string) error - // Disconnect removes thing from the channel's list of connected + // Disconnect removes things from the channels list of connected // things. - Disconnect(ctx context.Context, owner, chanID, thingID string) error + Disconnect(ctx context.Context, owner string, chIDs, thIDs []string) error // HasThing determines whether the thing with the provided access key, is // "connected" to the specified channel. If that's the case, it returns diff --git a/things/mocks/channels.go b/things/mocks/channels.go index 01bc52fe..f459eb5f 100644 --- a/things/mocks/channels.go +++ b/things/mocks/channels.go @@ -27,7 +27,7 @@ type channelRepositoryMock struct { mu sync.Mutex counter uint64 channels map[string]things.Channel - tconns chan Connection // used for syncronization with thing repo + tconns chan Connection // used for synchronization with thing repo cconns map[string]map[string]things.Channel // used to track connections things things.ThingRepository } @@ -216,21 +216,26 @@ func (crm *channelRepositoryMock) Connect(_ context.Context, owner string, chIDs return nil } -func (crm *channelRepositoryMock) Disconnect(_ context.Context, owner, chanID, thingID string) error { - if _, ok := crm.cconns[thingID]; !ok { - return things.ErrNotFound +func (crm *channelRepositoryMock) Disconnect(_ context.Context, owner string, chIDs, thIDs []string) error { + for _, chID := range chIDs { + for _, thID := range thIDs { + if _, ok := crm.cconns[thID]; !ok { + return things.ErrNotFound + } + + if _, ok := crm.cconns[thID][chID]; !ok { + return things.ErrNotFound + } + + crm.tconns <- Connection{ + chanID: chID, + thing: things.Thing{ID: thID, Owner: owner}, + connected: false, + } + delete(crm.cconns[thID], chID) + } } - if _, ok := crm.cconns[thingID][chanID]; !ok { - return things.ErrNotFound - } - - crm.tconns <- Connection{ - chanID: chanID, - thing: things.Thing{ID: thingID, Owner: owner}, - connected: false, - } - delete(crm.cconns[thingID], chanID) return nil } diff --git a/things/postgres/channels.go b/things/postgres/channels.go index 8c1450a6..a582a7e1 100644 --- a/things/postgres/channels.go +++ b/things/postgres/channels.go @@ -314,29 +314,52 @@ func (cr channelRepository) Connect(ctx context.Context, owner string, chIDs, th return nil } -func (cr channelRepository) Disconnect(ctx context.Context, owner, chanID, thingID string) error { +func (cr channelRepository) Disconnect(ctx context.Context, owner string, chIDs, thIDs []string) error { + tx, err := cr.db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(things.ErrConnect, err) + } + q := `DELETE FROM connections WHERE channel_id = :channel AND channel_owner = :owner AND thing_id = :thing AND thing_owner = :owner` - conn := dbConnection{ - Channel: chanID, - Thing: thingID, - Owner: owner, + for _, chID := range chIDs { + for _, thID := range thIDs { + dbco := dbConnection{ + Channel: chID, + Thing: thID, + Owner: owner, + } + + res, err := tx.NamedExecContext(ctx, q, dbco) + if err != nil { + tx.Rollback() + pqErr, ok := err.(*pq.Error) + if ok { + switch pqErr.Code.Name() { + case errFK: + return things.ErrNotFound + case errDuplicate: + return things.ErrConflict + } + } + return errors.Wrap(things.ErrDisconnect, err) + } + + cnt, err := res.RowsAffected() + if err != nil { + return errors.Wrap(things.ErrDisconnect, err) + } + + if cnt == 0 { + return things.ErrNotFound + } + } } - res, err := cr.db.NamedExecContext(ctx, q, conn) - if err != nil { - return errors.Wrap(things.ErrDisconnect, err) - } - - cnt, err := res.RowsAffected() - if err != nil { - return errors.Wrap(things.ErrDisconnect, err) - } - - if cnt == 0 { - return things.ErrNotFound + if err = tx.Commit(); err != nil { + return errors.Wrap(things.ErrConnect, err) } return nil diff --git a/things/postgres/channels_test.go b/things/postgres/channels_test.go index 0b0843b5..3fec61b2 100644 --- a/things/postgres/channels_test.go +++ b/things/postgres/channels_test.go @@ -723,7 +723,7 @@ func TestDisconnect(t *testing.T) { } for _, tc := range cases { - err := chanRepo.Disconnect(context.Background(), tc.owner, tc.chID, tc.thID) + err := chanRepo.Disconnect(context.Background(), tc.owner, []string{tc.chID}, []string{tc.thID}) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } } diff --git a/things/redis/streams.go b/things/redis/streams.go index 7f35f0c0..675146f8 100644 --- a/things/redis/streams.go +++ b/things/redis/streams.go @@ -209,21 +209,25 @@ func (es eventStore) Connect(ctx context.Context, token string, chIDs, thIDs []s return nil } -func (es eventStore) Disconnect(ctx context.Context, token, chanID, thingID string) error { - if err := es.svc.Disconnect(ctx, token, chanID, thingID); err != nil { +func (es eventStore) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error { + if err := es.svc.Disconnect(ctx, token, chIDs, thIDs); err != nil { return err } - event := disconnectThingEvent{ - chanID: chanID, - thingID: thingID, + for _, chID := range chIDs { + for _, thID := range thIDs { + event := disconnectThingEvent{ + chanID: chID, + thingID: thID, + } + record := &redis.XAddArgs{ + Stream: streamID, + MaxLenApprox: streamLen, + Values: event.Encode(), + } + es.client.XAdd(ctx, record).Err() + } } - record := &redis.XAddArgs{ - Stream: streamID, - MaxLenApprox: streamLen, - Values: event.Encode(), - } - es.client.XAdd(ctx, record).Err() return nil } diff --git a/things/redis/streams_test.go b/things/redis/streams_test.go index 98cc4287..a5d49d8e 100644 --- a/things/redis/streams_test.go +++ b/things/redis/streams_test.go @@ -629,7 +629,7 @@ func TestDisconnectEvent(t *testing.T) { lastID := "0" for _, tc := range cases { - err := svc.Disconnect(context.Background(), tc.key, tc.chanID, tc.thingID) + err := svc.Disconnect(context.Background(), tc.key, []string{tc.chanID}, []string{tc.thingID}) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) streams := redisClient.XRead(context.Background(), &r.XReadArgs{ diff --git a/things/service.go b/things/service.go index 17e3e725..ab1a6b00 100644 --- a/things/service.go +++ b/things/service.go @@ -97,12 +97,12 @@ type Service interface { // belongs to the user identified by the provided key. RemoveChannel(ctx context.Context, token, id string) error - // Connect adds things to the channel's list of connected things. + // Connect adds things to the channels list of connected things. Connect(ctx context.Context, token string, chIDs, thIDs []string) error - // Disconnect removes thing from the channel's list of connected + // Disconnect removes things from the channels list of connected // things. - Disconnect(ctx context.Context, token, chanID, thingID string) error + Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error // CanAccessByKey determines whether the channel can be accessed using the // provided key and returns thing's id if access is allowed. @@ -323,17 +323,21 @@ func (ts *thingsService) Connect(ctx context.Context, token string, chIDs, thIDs return ts.channels.Connect(ctx, res.GetEmail(), chIDs, thIDs) } -func (ts *thingsService) Disconnect(ctx context.Context, token, chanID, thingID string) error { +func (ts *thingsService) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error { res, err := ts.auth.Identify(ctx, &mainflux.Token{Value: token}) if err != nil { return errors.Wrap(ErrUnauthorizedAccess, err) } - if err := ts.channelCache.Disconnect(ctx, chanID, thingID); err != nil { - return err + for _, chID := range chIDs { + for _, thID := range thIDs { + if err := ts.channelCache.Disconnect(ctx, chID, thID); err != nil { + return err + } + } } - return ts.channels.Disconnect(ctx, res.GetEmail(), chanID, thingID) + return ts.channels.Disconnect(ctx, res.GetEmail(), chIDs, thIDs) } func (ts *thingsService) CanAccessByKey(ctx context.Context, chanID, thingKey string) (string, error) { diff --git a/things/service_test.go b/things/service_test.go index 6a988553..8a77c6db 100644 --- a/things/service_test.go +++ b/things/service_test.go @@ -1121,7 +1121,7 @@ func TestDisconnect(t *testing.T) { } for _, tc := range cases { - err := svc.Disconnect(context.Background(), tc.token, tc.chanID, tc.thingID) + err := svc.Disconnect(context.Background(), tc.token, []string{tc.chanID}, []string{tc.thingID}) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } diff --git a/things/things.go b/things/things.go index 730a9f68..66b3cc8d 100644 --- a/things/things.go +++ b/things/things.go @@ -30,7 +30,7 @@ var ( ErrEntityConnected = errors.New("check thing-channel connection in database error") ) -// Metadata to be used for mainflux thing or channel for customized +// Metadata to be used for Mainflux thing or channel for customized // describing of particular thing or channel. type Metadata map[string]interface{} diff --git a/things/tracing/channels.go b/things/tracing/channels.go index 5e3555d0..6cf4e8e6 100644 --- a/things/tracing/channels.go +++ b/things/tracing/channels.go @@ -98,12 +98,12 @@ func (crm channelRepositoryMiddleware) Connect(ctx context.Context, owner string return crm.repo.Connect(ctx, owner, chIDs, thIDs) } -func (crm channelRepositoryMiddleware) Disconnect(ctx context.Context, owner, chanID, thingID string) error { +func (crm channelRepositoryMiddleware) Disconnect(ctx context.Context, owner string, chIDs, thIDs []string) error { span := createSpan(ctx, crm.tracer, disconnectOp) defer span.Finish() ctx = opentracing.ContextWithSpan(ctx, span) - return crm.repo.Disconnect(ctx, owner, chanID, thingID) + return crm.repo.Disconnect(ctx, owner, chIDs, thIDs) } func (crm channelRepositoryMiddleware) HasThing(ctx context.Context, chanID, key string) (string, error) {