1
0
mirror of https://github.com/mainflux/mainflux.git synced 2025-04-28 13:48:49 +08:00

MF-1389 - Add /disconnect endpoint in Things service (#1433)

* MF-1389 - Add /disconnect endpoint in Things service

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

* Add db transaction in Postgres' Disconnect

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

* Reformat things mock and things http api

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

* Update naming of /disconnect endpoint decoder

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

* Update naming for /connect endpoint

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

Update naming based on new endpoint

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

* Update disconnect endpoint test case descriptions

Signed-off-by: Burak Sekili <buraksekili@gmail.com>
This commit is contained in:
Burak Sekili 2021-07-10 01:59:12 +03:00 committed by GitHub
parent bb072b8ad2
commit 2cfff01979
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 393 additions and 124 deletions

View File

@ -90,35 +90,38 @@ func (svc *mainfluxThings) Connect(_ context.Context, owner string, chIDs, thIDs
return nil return nil
} }
func (svc *mainfluxThings) Disconnect(_ context.Context, owner, chanID, thingID string) error { func (svc *mainfluxThings) Disconnect(_ context.Context, owner string, chIDs, thIDs []string) error {
svc.mu.Lock() svc.mu.Lock()
defer svc.mu.Unlock() defer svc.mu.Unlock()
userID, err := svc.auth.Identify(context.Background(), &mainflux.Token{Value: owner}) userID, err := svc.auth.Identify(context.Background(), &mainflux.Token{Value: owner})
if err != nil || svc.channels[chanID].Owner != userID.Email { if err != nil {
return things.ErrUnauthorizedAccess return things.ErrUnauthorizedAccess
} }
ids := svc.connections[chanID] for _, chID := range chIDs {
i := 0 if svc.channels[chID].Owner != userID.Email {
for _, t := range ids { return things.ErrUnauthorizedAccess
if t == thingID {
break
} }
i++
}
if i == len(ids) { ids := svc.connections[chID]
return things.ErrNotFound var count int
} var newConns []string
for _, thID := range thIDs {
for _, id := range ids {
if id == thID {
count++
continue
}
newConns = append(newConns, id)
}
var tmp []string if len(newConns)-len(ids) != count {
if i != len(ids)-2 { return things.ErrNotFound
tmp = ids[i+1:] }
svc.connections[chID] = newConns
}
} }
ids = append(ids[:i], tmp...)
svc.connections[chanID] = ids
return nil return nil
} }

View File

@ -214,9 +214,9 @@ func (lm *loggingMiddleware) Connect(ctx context.Context, token string, chIDs, t
return lm.svc.Connect(ctx, token, chIDs, thIDs) return lm.svc.Connect(ctx, token, chIDs, thIDs)
} }
func (lm *loggingMiddleware) Disconnect(ctx context.Context, token, chanID, thingID string) (err error) { func (lm *loggingMiddleware) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) (err error) {
defer func(begin time.Time) { defer func(begin time.Time) {
message := fmt.Sprintf("Method disconnect for token %s, channel %s and thing %s took %s to complete", token, chanID, thingID, time.Since(begin)) message := fmt.Sprintf("Method disconnect for token %s, channels %v and things %v took %s to complete", token, chIDs, thIDs, time.Since(begin))
if err != nil { if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return return
@ -224,7 +224,7 @@ func (lm *loggingMiddleware) Disconnect(ctx context.Context, token, chanID, thin
lm.logger.Info(fmt.Sprintf("%s without errors.", message)) lm.logger.Info(fmt.Sprintf("%s without errors.", message))
}(time.Now()) }(time.Now())
return lm.svc.Disconnect(ctx, token, chanID, thingID) return lm.svc.Disconnect(ctx, token, chIDs, thIDs)
} }
func (lm *loggingMiddleware) CanAccessByKey(ctx context.Context, id, key string) (thing string, err error) { func (lm *loggingMiddleware) CanAccessByKey(ctx context.Context, id, key string) (thing string, err error) {
@ -289,5 +289,5 @@ func (lm *loggingMiddleware) ListMembers(ctx context.Context, token, groupID str
lm.logger.Info(fmt.Sprintf("%s without errors.", message)) lm.logger.Info(fmt.Sprintf("%s without errors.", message))
}(time.Now()) }(time.Now())
return lm.svc.ListMembers(ctx, token, groupID, pm) return lm.svc.ListMembers(ctx, token, groupID, pm)
} }

View File

@ -156,13 +156,13 @@ func (ms *metricsMiddleware) Connect(ctx context.Context, token string, chIDs, t
return ms.svc.Connect(ctx, token, chIDs, thIDs) return ms.svc.Connect(ctx, token, chIDs, thIDs)
} }
func (ms *metricsMiddleware) Disconnect(ctx context.Context, token, chanID, thingID string) error { func (ms *metricsMiddleware) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error {
defer func(begin time.Time) { defer func(begin time.Time) {
ms.counter.With("method", "disconnect").Add(1) ms.counter.With("method", "disconnect").Add(1)
ms.latency.With("method", "disconnect").Observe(time.Since(begin).Seconds()) ms.latency.With("method", "disconnect").Observe(time.Since(begin).Seconds())
}(time.Now()) }(time.Now())
return ms.svc.Disconnect(ctx, token, chanID, thingID) return ms.svc.Disconnect(ctx, token, chIDs, thIDs)
} }
func (ms *metricsMiddleware) CanAccessByKey(ctx context.Context, id, key string) (string, error) { func (ms *metricsMiddleware) CanAccessByKey(ctx context.Context, id, key string) (string, error) {
@ -207,5 +207,5 @@ func (ms *metricsMiddleware) ListMembers(ctx context.Context, token, groupID str
ms.latency.With("method", "list_members").Observe(time.Since(begin).Seconds()) ms.latency.With("method", "list_members").Observe(time.Since(begin).Seconds())
}(time.Now()) }(time.Now())
return ms.svc.ListMembers(ctx, token, groupID, pm) return ms.svc.ListMembers(ctx, token, groupID, pm)
} }

View File

@ -444,9 +444,9 @@ func removeChannelEndpoint(svc things.Service) endpoint.Endpoint {
} }
} }
func connectEndpoint(svc things.Service) endpoint.Endpoint { func connectThingEndpoint(svc things.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) { return func(ctx context.Context, request interface{}) (interface{}, error) {
cr := request.(connectionReq) cr := request.(connectThingReq)
if err := cr.validate(); err != nil { if err := cr.validate(); err != nil {
return nil, err return nil, err
@ -456,13 +456,13 @@ func connectEndpoint(svc things.Service) endpoint.Endpoint {
return nil, err return nil, err
} }
return connectionRes{}, nil return connectThingRes{}, nil
} }
} }
func createConnectionsEndpoint(svc things.Service) endpoint.Endpoint { func connectEndpoint(svc things.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) { return func(ctx context.Context, request interface{}) (interface{}, error) {
cr := request.(createConnectionsReq) cr := request.(connectReq)
if err := cr.validate(); err != nil { if err := cr.validate(); err != nil {
return nil, err return nil, err
@ -472,23 +472,38 @@ func createConnectionsEndpoint(svc things.Service) endpoint.Endpoint {
return nil, err return nil, err
} }
return createConnectionsRes{}, nil return connectRes{}, nil
} }
} }
func disconnectEndpoint(svc things.Service) endpoint.Endpoint { func disconnectEndpoint(svc things.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) { return func(ctx context.Context, request interface{}) (interface{}, error) {
cr := request.(connectionReq) cr := request.(connectReq)
if err := cr.validate(); err != nil { if err := cr.validate(); err != nil {
return nil, err return nil, err
} }
if err := svc.Disconnect(ctx, cr.token, cr.chanID, cr.thingID); err != nil { if err := svc.Disconnect(ctx, cr.token, cr.ChannelIDs, cr.ThingIDs); err != nil {
return nil, err return nil, err
} }
return disconnectionRes{}, nil return disconnectRes{}, nil
}
}
func disconnectThingEndpoint(svc things.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(connectThingReq)
if err := req.validate(); err != nil {
return nil, err
}
if err := svc.Disconnect(ctx, req.token, []string{req.chanID}, []string{req.thingID}); err != nil {
return nil, err
}
return disconnectThingRes{}, nil
} }
} }

View File

@ -2407,6 +2407,198 @@ func TestCreateConnections(t *testing.T) {
} }
} }
func TestDisconnectList(t *testing.T) {
otherToken := "other_token"
otherEmail := "other_user@example.com"
svc := newService(map[string]string{
token: email,
otherToken: otherEmail,
})
ts := newServer(svc)
defer ts.Close()
ths, err := svc.CreateThings(context.Background(), token, thing)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err))
thIDs := []string{}
for _, th := range ths {
thIDs = append(thIDs, th.ID)
}
chs, err := svc.CreateChannels(context.Background(), token, channel)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err))
chIDs1 := []string{}
for _, ch := range chs {
chIDs1 = append(chIDs1, ch.ID)
}
chs, err = svc.CreateChannels(context.Background(), otherToken, channel)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err))
chIDs2 := []string{}
for _, ch := range chs {
chIDs2 = append(chIDs2, ch.ID)
}
err = svc.Connect(context.Background(), token, chIDs1, thIDs)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err))
cases := []struct {
desc string
channelIDs []string
thingIDs []string
auth string
contentType string
body string
status int
}{
{
desc: "disconnect existing things from existing channels",
channelIDs: chIDs1,
thingIDs: thIDs,
auth: token,
contentType: contentType,
status: http.StatusOK,
},
{
desc: "disconnect existing things from non-existent channels",
channelIDs: []string{strconv.FormatUint(wrongID, 10)},
thingIDs: thIDs,
auth: token,
contentType: contentType,
status: http.StatusNotFound,
},
{
desc: "disconnect non-existing things from existing channels",
channelIDs: chIDs1,
thingIDs: []string{strconv.FormatUint(wrongID, 10)},
auth: token,
contentType: contentType,
status: http.StatusNotFound,
},
{
desc: "disconnect existing things from channel with invalid id",
channelIDs: []string{"invalid"},
thingIDs: thIDs,
auth: token,
contentType: contentType,
status: http.StatusNotFound,
},
{
desc: "disconnect things with invalid id from existing channels",
channelIDs: chIDs1,
thingIDs: []string{"invalid"},
auth: token,
contentType: contentType,
status: http.StatusNotFound,
},
{
desc: "disconnect existing things from empty channel ids",
channelIDs: []string{""},
thingIDs: thIDs,
auth: token,
contentType: contentType,
status: http.StatusBadRequest,
},
{
desc: "disconnect empty things id from existing channels",
channelIDs: chIDs1,
thingIDs: []string{""},
auth: token,
contentType: contentType,
status: http.StatusBadRequest,
},
{
desc: "disconnect existing things from existing channels with invalid token",
channelIDs: chIDs1,
thingIDs: thIDs,
auth: wrongValue,
contentType: contentType,
status: http.StatusUnauthorized,
},
{
desc: "disconnect existing things from existing channels with empty token",
channelIDs: chIDs1,
thingIDs: thIDs,
auth: "",
contentType: contentType,
status: http.StatusUnauthorized,
},
{
desc: "disconnect things from channels of other user",
channelIDs: chIDs2,
thingIDs: thIDs,
auth: token,
contentType: contentType,
status: http.StatusNotFound,
},
{
desc: "disconnect with invalid content type",
channelIDs: chIDs2,
thingIDs: thIDs,
auth: token,
contentType: "invalid",
status: http.StatusUnsupportedMediaType,
},
{
desc: "disconnect with invalid JSON",
auth: token,
contentType: contentType,
status: http.StatusBadRequest,
body: "{",
},
{
desc: "disconnect valid thing ids from empty channel ids",
channelIDs: []string{},
thingIDs: thIDs,
auth: token,
contentType: contentType,
status: http.StatusBadRequest,
},
{
desc: "disconnect empty thing ids from valid channel ids",
channelIDs: chIDs1,
thingIDs: []string{},
auth: token,
contentType: contentType,
status: http.StatusBadRequest,
},
{
desc: "disconnect empty thing ids from empty channel ids",
channelIDs: []string{},
thingIDs: []string{},
auth: token,
contentType: contentType,
status: http.StatusBadRequest,
},
}
for _, tc := range cases {
data := struct {
ChannelIDs []string `json:"channel_ids"`
ThingIDs []string `json:"thing_ids"`
}{
tc.channelIDs,
tc.thingIDs,
}
body := toJSON(data)
if tc.body != "" {
body = tc.body
}
req := testRequest{
client: ts.Client(),
method: http.MethodDelete,
url: fmt.Sprintf("%s/disconnect", ts.URL),
contentType: tc.contentType,
token: tc.auth,
body: strings.NewReader(body),
}
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
}
}
func TestDisconnnect(t *testing.T) { func TestDisconnnect(t *testing.T) {
otherToken := "other_token" otherToken := "other_token"
otherEmail := "other_user@example.com" otherEmail := "other_user@example.com"

View File

@ -248,13 +248,13 @@ func (req listByConnectionReq) validate() error {
return nil return nil
} }
type connectionReq struct { type connectThingReq struct {
token string token string
chanID string chanID string
thingID string thingID string
} }
func (req connectionReq) validate() error { func (req connectThingReq) validate() error {
if req.token == "" { if req.token == "" {
return things.ErrUnauthorizedAccess return things.ErrUnauthorizedAccess
} }
@ -266,13 +266,13 @@ func (req connectionReq) validate() error {
return nil return nil
} }
type createConnectionsReq struct { type connectReq struct {
token string token string
ChannelIDs []string `json:"channel_ids,omitempty"` ChannelIDs []string `json:"channel_ids,omitempty"`
ThingIDs []string `json:"thing_ids,omitempty"` ThingIDs []string `json:"thing_ids,omitempty"`
} }
func (req createConnectionsReq) validate() error { func (req connectReq) validate() error {
if req.token == "" { if req.token == "" {
return things.ErrUnauthorizedAccess return things.ErrUnauthorizedAccess
} }

View File

@ -18,8 +18,10 @@ var (
_ mainflux.Response = (*channelRes)(nil) _ mainflux.Response = (*channelRes)(nil)
_ mainflux.Response = (*viewChannelRes)(nil) _ mainflux.Response = (*viewChannelRes)(nil)
_ mainflux.Response = (*channelsPageRes)(nil) _ mainflux.Response = (*channelsPageRes)(nil)
_ mainflux.Response = (*connectionRes)(nil) _ mainflux.Response = (*connectThingRes)(nil)
_ mainflux.Response = (*disconnectionRes)(nil) _ mainflux.Response = (*connectRes)(nil)
_ mainflux.Response = (*disconnectThingRes)(nil)
_ mainflux.Response = (*disconnectRes)(nil)
) )
type removeRes struct{} type removeRes struct{}
@ -213,47 +215,61 @@ func (res channelsPageRes) Empty() bool {
return false return false
} }
type connectionRes struct{} type connectThingRes struct{}
func (res connectionRes) Code() int { func (res connectThingRes) Code() int {
return http.StatusOK return http.StatusOK
} }
func (res connectionRes) Headers() map[string]string { func (res connectThingRes) Headers() map[string]string {
return map[string]string{ return map[string]string{
"Warning-Deprecated": "This endpoint will be depreciated in v1.0.0. It will be replaced with the bulk endpoint found at /connect.", "Warning-Deprecated": "This endpoint will be depreciated in v1.0.0. It will be replaced with the bulk endpoint found at /connect.",
} }
} }
func (res connectionRes) Empty() bool { func (res connectThingRes) Empty() bool {
return true return true
} }
type createConnectionsRes struct{} type connectRes struct{}
func (res createConnectionsRes) Code() int { func (res connectRes) Code() int {
return http.StatusOK return http.StatusOK
} }
func (res createConnectionsRes) Headers() map[string]string { func (res connectRes) Headers() map[string]string {
return map[string]string{} return map[string]string{}
} }
func (res createConnectionsRes) Empty() bool { func (res connectRes) Empty() bool {
return true return true
} }
type disconnectionRes struct{} type disconnectRes struct{}
func (res disconnectionRes) Code() int { func (res disconnectRes) Code() int {
return http.StatusNoContent return http.StatusOK
} }
func (res disconnectionRes) Headers() map[string]string { func (res disconnectRes) Headers() map[string]string {
return map[string]string{} return map[string]string{}
} }
func (res disconnectionRes) Empty() bool { func (res disconnectRes) Empty() bool {
return true
}
type disconnectThingRes struct{}
func (res disconnectThingRes) Code() int {
return http.StatusNoContent
}
func (res disconnectThingRes) Headers() map[string]string {
return map[string]string{}
}
func (res disconnectThingRes) Empty() bool {
return true return true
} }

View File

@ -155,23 +155,30 @@ func MakeHandler(tracer opentracing.Tracer, svc things.Service) http.Handler {
opts..., opts...,
)) ))
r.Put("/channels/:chanId/things/:thingId", kithttp.NewServer( r.Post("/connect", kithttp.NewServer(
kitot.TraceServer(tracer, "connect")(connectEndpoint(svc)), kitot.TraceServer(tracer, "connect")(connectEndpoint(svc)),
decodeConnection, decodeConnectList,
encodeResponse, encodeResponse,
opts..., opts...,
)) ))
r.Post("/connect", kithttp.NewServer( r.Delete("/disconnect", kithttp.NewServer(
kitot.TraceServer(tracer, "create_connections")(createConnectionsEndpoint(svc)), kitot.TraceServer(tracer, "disconnect")(disconnectEndpoint(svc)),
decodeCreateConnections, decodeConnectList,
encodeResponse,
opts...,
))
r.Put("/channels/:chanId/things/:thingId", kithttp.NewServer(
kitot.TraceServer(tracer, "connect_thing")(connectThingEndpoint(svc)),
decodeConnectThing,
encodeResponse, encodeResponse,
opts..., opts...,
)) ))
r.Delete("/channels/:chanId/things/:thingId", kithttp.NewServer( r.Delete("/channels/:chanId/things/:thingId", kithttp.NewServer(
kitot.TraceServer(tracer, "disconnect")(disconnectEndpoint(svc)), kitot.TraceServer(tracer, "disconnect_thing")(disconnectThingEndpoint(svc)),
decodeConnection, decodeConnectThing,
encodeResponse, encodeResponse,
opts..., opts...,
)) ))
@ -395,8 +402,8 @@ func decodeListByConnection(_ context.Context, r *http.Request) (interface{}, er
return req, nil return req, nil
} }
func decodeConnection(_ context.Context, r *http.Request) (interface{}, error) { func decodeConnectThing(_ context.Context, r *http.Request) (interface{}, error) {
req := connectionReq{ req := connectThingReq{
token: r.Header.Get("Authorization"), token: r.Header.Get("Authorization"),
chanID: bone.GetValue(r, "chanId"), chanID: bone.GetValue(r, "chanId"),
thingID: bone.GetValue(r, "thingId"), thingID: bone.GetValue(r, "thingId"),
@ -405,12 +412,12 @@ func decodeConnection(_ context.Context, r *http.Request) (interface{}, error) {
return req, nil return req, nil
} }
func decodeCreateConnections(_ context.Context, r *http.Request) (interface{}, error) { func decodeConnectList(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) { if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, errors.ErrUnsupportedContentType return nil, errors.ErrUnsupportedContentType
} }
req := createConnectionsReq{token: r.Header.Get("Authorization")} req := connectReq{token: r.Header.Get("Authorization")}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(things.ErrMalformedEntity, err) return nil, errors.Wrap(things.ErrMalformedEntity, err)
} }

View File

@ -8,7 +8,7 @@ import (
) )
// Channel represents a Mainflux "communication group". This group contains the // Channel represents a Mainflux "communication group". This group contains the
// things that can exchange messages between eachother. // things that can exchange messages between each other.
type Channel struct { type Channel struct {
ID string ID string
Owner string Owner string
@ -49,12 +49,12 @@ type ChannelRepository interface {
// by the specified user. // by the specified user.
Remove(ctx context.Context, owner, id string) error Remove(ctx context.Context, owner, id string) error
// Connect adds things to the channel's list of connected things. // Connect adds things to the channels list of connected things.
Connect(ctx context.Context, owner string, chIDs, thIDs []string) error Connect(ctx context.Context, owner string, chIDs, thIDs []string) error
// Disconnect removes thing from the channel's list of connected // Disconnect removes things from the channels list of connected
// things. // things.
Disconnect(ctx context.Context, owner, chanID, thingID string) error Disconnect(ctx context.Context, owner string, chIDs, thIDs []string) error
// HasThing determines whether the thing with the provided access key, is // HasThing determines whether the thing with the provided access key, is
// "connected" to the specified channel. If that's the case, it returns // "connected" to the specified channel. If that's the case, it returns

View File

@ -27,7 +27,7 @@ type channelRepositoryMock struct {
mu sync.Mutex mu sync.Mutex
counter uint64 counter uint64
channels map[string]things.Channel channels map[string]things.Channel
tconns chan Connection // used for syncronization with thing repo tconns chan Connection // used for synchronization with thing repo
cconns map[string]map[string]things.Channel // used to track connections cconns map[string]map[string]things.Channel // used to track connections
things things.ThingRepository things things.ThingRepository
} }
@ -216,21 +216,26 @@ func (crm *channelRepositoryMock) Connect(_ context.Context, owner string, chIDs
return nil return nil
} }
func (crm *channelRepositoryMock) Disconnect(_ context.Context, owner, chanID, thingID string) error { func (crm *channelRepositoryMock) Disconnect(_ context.Context, owner string, chIDs, thIDs []string) error {
if _, ok := crm.cconns[thingID]; !ok { for _, chID := range chIDs {
return things.ErrNotFound for _, thID := range thIDs {
if _, ok := crm.cconns[thID]; !ok {
return things.ErrNotFound
}
if _, ok := crm.cconns[thID][chID]; !ok {
return things.ErrNotFound
}
crm.tconns <- Connection{
chanID: chID,
thing: things.Thing{ID: thID, Owner: owner},
connected: false,
}
delete(crm.cconns[thID], chID)
}
} }
if _, ok := crm.cconns[thingID][chanID]; !ok {
return things.ErrNotFound
}
crm.tconns <- Connection{
chanID: chanID,
thing: things.Thing{ID: thingID, Owner: owner},
connected: false,
}
delete(crm.cconns[thingID], chanID)
return nil return nil
} }

View File

@ -314,29 +314,52 @@ func (cr channelRepository) Connect(ctx context.Context, owner string, chIDs, th
return nil return nil
} }
func (cr channelRepository) Disconnect(ctx context.Context, owner, chanID, thingID string) error { func (cr channelRepository) Disconnect(ctx context.Context, owner string, chIDs, thIDs []string) error {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(things.ErrConnect, err)
}
q := `DELETE FROM connections q := `DELETE FROM connections
WHERE channel_id = :channel AND channel_owner = :owner WHERE channel_id = :channel AND channel_owner = :owner
AND thing_id = :thing AND thing_owner = :owner` AND thing_id = :thing AND thing_owner = :owner`
conn := dbConnection{ for _, chID := range chIDs {
Channel: chanID, for _, thID := range thIDs {
Thing: thingID, dbco := dbConnection{
Owner: owner, Channel: chID,
Thing: thID,
Owner: owner,
}
res, err := tx.NamedExecContext(ctx, q, dbco)
if err != nil {
tx.Rollback()
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errFK:
return things.ErrNotFound
case errDuplicate:
return things.ErrConflict
}
}
return errors.Wrap(things.ErrDisconnect, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(things.ErrDisconnect, err)
}
if cnt == 0 {
return things.ErrNotFound
}
}
} }
res, err := cr.db.NamedExecContext(ctx, q, conn) if err = tx.Commit(); err != nil {
if err != nil { return errors.Wrap(things.ErrConnect, err)
return errors.Wrap(things.ErrDisconnect, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(things.ErrDisconnect, err)
}
if cnt == 0 {
return things.ErrNotFound
} }
return nil return nil

View File

@ -723,7 +723,7 @@ func TestDisconnect(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
err := chanRepo.Disconnect(context.Background(), tc.owner, tc.chID, tc.thID) err := chanRepo.Disconnect(context.Background(), tc.owner, []string{tc.chID}, []string{tc.thID})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
} }
} }

View File

@ -209,21 +209,25 @@ func (es eventStore) Connect(ctx context.Context, token string, chIDs, thIDs []s
return nil return nil
} }
func (es eventStore) Disconnect(ctx context.Context, token, chanID, thingID string) error { func (es eventStore) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error {
if err := es.svc.Disconnect(ctx, token, chanID, thingID); err != nil { if err := es.svc.Disconnect(ctx, token, chIDs, thIDs); err != nil {
return err return err
} }
event := disconnectThingEvent{ for _, chID := range chIDs {
chanID: chanID, for _, thID := range thIDs {
thingID: thingID, event := disconnectThingEvent{
chanID: chID,
thingID: thID,
}
record := &redis.XAddArgs{
Stream: streamID,
MaxLenApprox: streamLen,
Values: event.Encode(),
}
es.client.XAdd(ctx, record).Err()
}
} }
record := &redis.XAddArgs{
Stream: streamID,
MaxLenApprox: streamLen,
Values: event.Encode(),
}
es.client.XAdd(ctx, record).Err()
return nil return nil
} }

View File

@ -629,7 +629,7 @@ func TestDisconnectEvent(t *testing.T) {
lastID := "0" lastID := "0"
for _, tc := range cases { for _, tc := range cases {
err := svc.Disconnect(context.Background(), tc.key, tc.chanID, tc.thingID) err := svc.Disconnect(context.Background(), tc.key, []string{tc.chanID}, []string{tc.thingID})
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &r.XReadArgs{ streams := redisClient.XRead(context.Background(), &r.XReadArgs{

View File

@ -97,12 +97,12 @@ type Service interface {
// belongs to the user identified by the provided key. // belongs to the user identified by the provided key.
RemoveChannel(ctx context.Context, token, id string) error RemoveChannel(ctx context.Context, token, id string) error
// Connect adds things to the channel's list of connected things. // Connect adds things to the channels list of connected things.
Connect(ctx context.Context, token string, chIDs, thIDs []string) error Connect(ctx context.Context, token string, chIDs, thIDs []string) error
// Disconnect removes thing from the channel's list of connected // Disconnect removes things from the channels list of connected
// things. // things.
Disconnect(ctx context.Context, token, chanID, thingID string) error Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error
// CanAccessByKey determines whether the channel can be accessed using the // CanAccessByKey determines whether the channel can be accessed using the
// provided key and returns thing's id if access is allowed. // provided key and returns thing's id if access is allowed.
@ -323,17 +323,21 @@ func (ts *thingsService) Connect(ctx context.Context, token string, chIDs, thIDs
return ts.channels.Connect(ctx, res.GetEmail(), chIDs, thIDs) return ts.channels.Connect(ctx, res.GetEmail(), chIDs, thIDs)
} }
func (ts *thingsService) Disconnect(ctx context.Context, token, chanID, thingID string) error { func (ts *thingsService) Disconnect(ctx context.Context, token string, chIDs, thIDs []string) error {
res, err := ts.auth.Identify(ctx, &mainflux.Token{Value: token}) res, err := ts.auth.Identify(ctx, &mainflux.Token{Value: token})
if err != nil { if err != nil {
return errors.Wrap(ErrUnauthorizedAccess, err) return errors.Wrap(ErrUnauthorizedAccess, err)
} }
if err := ts.channelCache.Disconnect(ctx, chanID, thingID); err != nil { for _, chID := range chIDs {
return err for _, thID := range thIDs {
if err := ts.channelCache.Disconnect(ctx, chID, thID); err != nil {
return err
}
}
} }
return ts.channels.Disconnect(ctx, res.GetEmail(), chanID, thingID) return ts.channels.Disconnect(ctx, res.GetEmail(), chIDs, thIDs)
} }
func (ts *thingsService) CanAccessByKey(ctx context.Context, chanID, thingKey string) (string, error) { func (ts *thingsService) CanAccessByKey(ctx context.Context, chanID, thingKey string) (string, error) {

View File

@ -1121,7 +1121,7 @@ func TestDisconnect(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
err := svc.Disconnect(context.Background(), tc.token, tc.chanID, tc.thingID) err := svc.Disconnect(context.Background(), tc.token, []string{tc.chanID}, []string{tc.thingID})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
} }

View File

@ -30,7 +30,7 @@ var (
ErrEntityConnected = errors.New("check thing-channel connection in database error") ErrEntityConnected = errors.New("check thing-channel connection in database error")
) )
// Metadata to be used for mainflux thing or channel for customized // Metadata to be used for Mainflux thing or channel for customized
// describing of particular thing or channel. // describing of particular thing or channel.
type Metadata map[string]interface{} type Metadata map[string]interface{}

View File

@ -98,12 +98,12 @@ func (crm channelRepositoryMiddleware) Connect(ctx context.Context, owner string
return crm.repo.Connect(ctx, owner, chIDs, thIDs) return crm.repo.Connect(ctx, owner, chIDs, thIDs)
} }
func (crm channelRepositoryMiddleware) Disconnect(ctx context.Context, owner, chanID, thingID string) error { func (crm channelRepositoryMiddleware) Disconnect(ctx context.Context, owner string, chIDs, thIDs []string) error {
span := createSpan(ctx, crm.tracer, disconnectOp) span := createSpan(ctx, crm.tracer, disconnectOp)
defer span.Finish() defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span) ctx = opentracing.ContextWithSpan(ctx, span)
return crm.repo.Disconnect(ctx, owner, chanID, thingID) return crm.repo.Disconnect(ctx, owner, chIDs, thIDs)
} }
func (crm channelRepositoryMiddleware) HasThing(ctx context.Context, chanID, key string) (string, error) { func (crm channelRepositoryMiddleware) HasThing(ctx context.Context, chanID, key string) (string, error) {