diff --git a/api/openapi/users.yml b/api/openapi/users.yml index 9ca9f2df..9133de39 100644 --- a/api/openapi/users.yml +++ b/api/openapi/users.yml @@ -41,6 +41,7 @@ paths: - $ref: "#/components/parameters/Limit" - $ref: "#/components/parameters/Offset" - $ref: "#/components/parameters/Metadata" + - $ref: "#/components/parameters/Status" responses: '200': $ref: "#/components/responses/UsersPageRes" @@ -215,6 +216,46 @@ paths: description: Missing or invalid content type. '500': $ref: "#/components/responses/ServiceError" + /users/{userId}/enable: + post: + summary: Enables a user account + description: | + Enables a disabled user account for a given user ID. + tags: + - users + parameters: + - $ref: "#/components/parameters/UserId" + responses: + '200': + description: User enabled. + '400': + description: Failed due to malformed JSON. + '404': + description: Failed due to non existing user. + '401': + description: Missing or invalid access token provided. + '500': + $ref: "#/components/responses/ServiceError" + /users/{userId}/disable: + post: + summary: Disables a user account + description: | + Disables a user account for a given user ID. + tags: + - users + parameters: + - $ref: "#/components/parameters/UserId" + responses: + '200': + description: User disabled. + '400': + description: Failed due to malformed JSON. + '404': + description: Failed due to non existing user. + '401': + description: Missing or invalid access token provided. + '500': + $ref: "#/components/responses/ServiceError" /health: get: summary: Retrieves service health check info. @@ -318,7 +359,7 @@ components: type: string minimum: 0 required: false - UserID: + UserId: name: userId description: Unique user identifier. in: path @@ -353,7 +394,14 @@ components: default: 0 minimum: 0 required: false - + Status: + name: status + description: User account status. + in: query + schema: + type: string + default: enabled + required: false requestBodies: UserCreateReq: description: JSON-formatted document describing the new user to be registered diff --git a/cli/users.go b/cli/users.go index 928aea17..484faff6 100644 --- a/cli/users.go +++ b/cli/users.go @@ -58,6 +58,7 @@ var cmdUsers = []cobra.Command{ Offset: uint64(Offset), Limit: uint64(Limit), Metadata: metadata, + Status: Status, } if args[0] == "all" { l, err := sdk.Users(args[1], pageMetadata) @@ -140,6 +141,42 @@ var cmdUsers = []cobra.Command{ return } + logOK() + }, + }, + { + Use: "enable ", + Short: "Change user status to enabled", + Long: `Change user status to enabled`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsage(cmd.Use) + return + } + + if err := sdk.EnableUser(args[0], args[1]); err != nil { + logError(err) + return + } + + logOK() + }, + }, + { + Use: "disable ", + Short: "Change user status to disabled", + Long: `Change user status to disabled`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsage(cmd.Use) + return + } + + if err := sdk.DisableUser(args[0], args[1]); err != nil { + logError(err) + return + } + logOK() }, }, @@ -148,7 +185,7 @@ var cmdUsers = []cobra.Command{ // NewUsersCmd returns users command. func NewUsersCmd() *cobra.Command { cmd := cobra.Command{ - Use: "users [create | get | update | token | password]", + Use: "users [create | get | update | token | password | enable | disable]", Short: "Users management", Long: `Users management: create accounts and tokens"`, } diff --git a/cli/utils.go b/cli/utils.go index 167ddea8..50d1a041 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -22,6 +22,8 @@ var ( Email string = "" // Metadata query parameter Metadata string = "" + // Status query parameter + Status string = "" // ConfigPath config path parameter ConfigPath string = "" // RawOutput raw output mode diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 59de4f48..639f4d1a 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -186,6 +186,14 @@ func main() { "Metadata query parameter", ) + rootCmd.PersistentFlags().StringVarP( + &cli.Status, + "status", + "S", + "", + "Status query parameter", + ) + if err := rootCmd.Execute(); err != nil { log.Fatal(err) } diff --git a/internal/apiutil/errors.go b/internal/apiutil/errors.go index 24d3afce..35811dfd 100644 --- a/internal/apiutil/errors.go +++ b/internal/apiutil/errors.go @@ -30,6 +30,9 @@ var ( // ErrEmailSize indicates that email size exceeds the max. ErrEmailSize = errors.New("invalid email size") + // ErrInvalidStatus indicates an invalid user account status. + ErrInvalidStatus = errors.New("invalid user account status") + // ErrLimitSize indicates that an invalid limit. ErrLimitSize = errors.New("invalid limit size") diff --git a/pkg/sdk/go/sdk.go b/pkg/sdk/go/sdk.go index 83962ad1..bc201bd5 100644 --- a/pkg/sdk/go/sdk.go +++ b/pkg/sdk/go/sdk.go @@ -98,6 +98,7 @@ type PageMetadata struct { Name string `json:"name,omitempty"` Type string `json:"type,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` + Status string `json:"status,omitempty"` } // Group represents mainflux users group. @@ -153,6 +154,12 @@ type SDK interface { // UpdatePassword updates user password. UpdatePassword(oldPass, newPass, token string) error + // EnableUser changes the status of the user to enabled. + EnableUser(id, token string) error + + // DisableUser changes the status of the user to disabled. + DisableUser(id, token string) error + // CreateThing registers new thing and returns its id. CreateThing(thing Thing, token string) (string, error) @@ -389,6 +396,9 @@ func (pm PageMetadata) query() (string, error) { if pm.Type != "" { q.Add("type", pm.Type) } + if pm.Status != "" { + q.Add("status", pm.Status) + } if pm.Metadata != nil { md, err := json.Marshal(pm.Metadata) if err != nil { diff --git a/pkg/sdk/go/users.go b/pkg/sdk/go/users.go index 4063d877..04b0b8dd 100644 --- a/pkg/sdk/go/users.go +++ b/pkg/sdk/go/users.go @@ -189,3 +189,43 @@ func (sdk mfSDK) UpdatePassword(oldPass, newPass, token string) error { return nil } + +func (sdk mfSDK) EnableUser(id, token string) error { + url := fmt.Sprintf("%s/%s/%s/enable", sdk.usersURL, usersEndpoint, id) + + req, err := http.NewRequest(http.MethodPost, url, nil) + if err != nil { + return err + } + + resp, err := sdk.sendRequest(req, token, string(CTJSON)) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusNoContent { + return errors.Wrap(ErrFailedRemoval, errors.New(resp.Status)) + } + + return nil +} + +func (sdk mfSDK) DisableUser(id, token string) error { + url := fmt.Sprintf("%s/%s/%s/disable", sdk.usersURL, usersEndpoint, id) + + req, err := http.NewRequest(http.MethodPost, url, nil) + if err != nil { + return err + } + + resp, err := sdk.sendRequest(req, token, string(CTJSON)) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusNoContent { + return errors.Wrap(ErrFailedRemoval, errors.New(resp.Status)) + } + + return nil +} diff --git a/pkg/sdk/go/users_test.go b/pkg/sdk/go/users_test.go index 37d500e0..12ab1a8a 100644 --- a/pkg/sdk/go/users_test.go +++ b/pkg/sdk/go/users_test.go @@ -176,7 +176,7 @@ func TestUser(t *testing.T) { desc: "get non-existent user", userID: "43", token: usertoken, - err: createError(sdk.ErrFailedFetch, http.StatusUnauthorized), + err: createError(sdk.ErrFailedFetch, http.StatusNotFound), response: sdk.User{}, }, diff --git a/users/api/endpoint.go b/users/api/endpoint.go index 8fba5b8f..59ef11d6 100644 --- a/users/api/endpoint.go +++ b/users/api/endpoint.go @@ -81,7 +81,7 @@ func viewUserEndpoint(svc users.Service) endpoint.Endpoint { return nil, err } - u, err := svc.ViewUser(ctx, req.token, req.userID) + u, err := svc.ViewUser(ctx, req.token, req.id) if err != nil { return nil, err } @@ -118,7 +118,14 @@ func listUsersEndpoint(svc users.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return users.UserPage{}, err } - up, err := svc.ListUsers(ctx, req.token, req.offset, req.limit, req.email, req.metadata) + pm := users.PageMetadata{ + Offset: req.offset, + Limit: req.limit, + Email: req.email, + Status: req.status, + Metadata: req.metadata, + } + up, err := svc.ListUsers(ctx, req.token, pm) if err != nil { return users.UserPage{}, err } @@ -179,7 +186,13 @@ func listMembersEndpoint(svc users.Service) endpoint.Endpoint { return userPageRes{}, errors.Wrap(errors.ErrMalformedEntity, err) } - page, err := svc.ListMembers(ctx, req.token, req.groupID, req.offset, req.limit, req.metadata) + pm := users.PageMetadata{ + Offset: req.offset, + Limit: req.limit, + Status: req.status, + Metadata: req.metadata, + } + page, err := svc.ListMembers(ctx, req.token, req.id, pm) if err != nil { return userPageRes{}, err } @@ -188,6 +201,32 @@ func listMembersEndpoint(svc users.Service) endpoint.Endpoint { } } +func enableUserEndpoint(svc users.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(changeUserStatusReq) + if err := req.validate(); err != nil { + return nil, err + } + if err := svc.EnableUser(ctx, req.token, req.id); err != nil { + return nil, err + } + return deleteRes{}, nil + } +} + +func disableUserEndpoint(svc users.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(changeUserStatusReq) + if err := req.validate(); err != nil { + return nil, err + } + if err := svc.DisableUser(ctx, req.token, req.id); err != nil { + return nil, err + } + return deleteRes{}, nil + } +} + func buildUsersResponse(up users.UserPage) userPageRes { res := userPageRes{ pageRes: pageRes{ diff --git a/users/api/logging.go b/users/api/logging.go index 8fdd6251..718f108f 100644 --- a/users/api/logging.go +++ b/users/api/logging.go @@ -79,7 +79,7 @@ func (lm *loggingMiddleware) ViewProfile(ctx context.Context, token string) (u u return lm.svc.ViewProfile(ctx, token) } -func (lm *loggingMiddleware) ListUsers(ctx context.Context, token string, offset, limit uint64, email string, um users.Metadata) (e users.UserPage, err error) { +func (lm *loggingMiddleware) ListUsers(ctx context.Context, token string, pm users.PageMetadata) (e users.UserPage, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method list_users for token %s took %s to complete", token, time.Since(begin)) if err != nil { @@ -89,7 +89,7 @@ func (lm *loggingMiddleware) ListUsers(ctx context.Context, token string, offset lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.ListUsers(ctx, token, offset, limit, email, um) + return lm.svc.ListUsers(ctx, token, pm) } func (lm *loggingMiddleware) UpdateUser(ctx context.Context, token string, u users.User) (err error) { @@ -157,7 +157,7 @@ func (lm *loggingMiddleware) SendPasswordReset(ctx context.Context, host, email, return lm.svc.SendPasswordReset(ctx, host, email, token) } -func (lm *loggingMiddleware) ListMembers(ctx context.Context, token, groupID string, offset, limit uint64, m users.Metadata) (mp users.UserPage, err error) { +func (lm *loggingMiddleware) ListMembers(ctx context.Context, token, groupID string, pm users.PageMetadata) (mp users.UserPage, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method list_members for group %s took %s to complete", groupID, time.Since(begin)) if err != nil { @@ -167,5 +167,31 @@ func (lm *loggingMiddleware) ListMembers(ctx context.Context, token, groupID str lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.ListMembers(ctx, token, groupID, offset, limit, m) + return lm.svc.ListMembers(ctx, token, groupID, pm) +} + +func (lm *loggingMiddleware) EnableUser(ctx context.Context, token string, id string) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method enable_user for user %s took %s to complete", id, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors.", message)) + }(time.Now()) + + return lm.svc.EnableUser(ctx, token, id) +} + +func (lm *loggingMiddleware) DisableUser(ctx context.Context, token string, id string) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method disable_user for user %s took %s to complete", id, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors.", message)) + }(time.Now()) + + return lm.svc.DisableUser(ctx, token, id) } diff --git a/users/api/metrics.go b/users/api/metrics.go index 8865e68d..02d8558b 100644 --- a/users/api/metrics.go +++ b/users/api/metrics.go @@ -66,13 +66,13 @@ func (ms *metricsMiddleware) ViewProfile(ctx context.Context, token string) (use return ms.svc.ViewProfile(ctx, token) } -func (ms *metricsMiddleware) ListUsers(ctx context.Context, token string, offset, limit uint64, email string, um users.Metadata) (users.UserPage, error) { +func (ms *metricsMiddleware) ListUsers(ctx context.Context, token string, pm users.PageMetadata) (users.UserPage, error) { defer func(begin time.Time) { ms.counter.With("method", "list_users").Add(1) ms.latency.With("method", "list_users").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.ListUsers(ctx, token, offset, limit, email, um) + return ms.svc.ListUsers(ctx, token, pm) } func (ms *metricsMiddleware) UpdateUser(ctx context.Context, token string, u users.User) (err error) { @@ -120,11 +120,29 @@ func (ms *metricsMiddleware) SendPasswordReset(ctx context.Context, host, email, return ms.svc.SendPasswordReset(ctx, host, email, token) } -func (ms *metricsMiddleware) ListMembers(ctx context.Context, token, groupID string, offset, limit uint64, gm users.Metadata) (users.UserPage, error) { +func (ms *metricsMiddleware) ListMembers(ctx context.Context, token, groupID string, pm users.PageMetadata) (users.UserPage, error) { defer func(begin time.Time) { ms.counter.With("method", "list_members").Add(1) ms.latency.With("method", "list_members").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.ListMembers(ctx, token, groupID, offset, limit, gm) + return ms.svc.ListMembers(ctx, token, groupID, pm) +} + +func (ms *metricsMiddleware) EnableUser(ctx context.Context, token string, id string) (err error) { + defer func(begin time.Time) { + ms.counter.With("method", "enable_user").Add(1) + ms.latency.With("method", "enable_user").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.EnableUser(ctx, token, id) +} + +func (ms *metricsMiddleware) DisableUser(ctx context.Context, token string, id string) (err error) { + defer func(begin time.Time) { + ms.counter.With("method", "disable_user").Add(1) + ms.latency.With("method", "disable_user").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.DisableUser(ctx, token, id) } diff --git a/users/api/requests.go b/users/api/requests.go index edbee25e..e3e2fcba 100644 --- a/users/api/requests.go +++ b/users/api/requests.go @@ -31,8 +31,8 @@ func (req createUserReq) validate() error { } type viewUserReq struct { - token string - userID string + token string + id string } func (req viewUserReq) validate() error { @@ -44,6 +44,7 @@ func (req viewUserReq) validate() error { type listUsersReq struct { token string + status string offset uint64 limit uint64 email string @@ -62,6 +63,11 @@ func (req listUsersReq) validate() error { if len(req.email) > maxEmailSize { return apiutil.ErrEmailSize } + if req.status != users.AllStatusKey && + req.status != users.EnabledStatusKey && + req.status != users.DisabledStatusKey { + return apiutil.ErrInvalidStatus + } return nil } @@ -139,10 +145,11 @@ func (req passwChangeReq) validate() error { type listMemberGroupReq struct { token string + status string offset uint64 limit uint64 metadata users.Metadata - groupID string + id string } func (req listMemberGroupReq) validate() error { @@ -150,9 +157,28 @@ func (req listMemberGroupReq) validate() error { return apiutil.ErrBearerToken } - if req.groupID == "" { + if req.id == "" { + return apiutil.ErrMissingID + } + if req.status != users.AllStatusKey && + req.status != users.EnabledStatusKey && + req.status != users.DisabledStatusKey { + return apiutil.ErrInvalidStatus + } + return nil +} + +type changeUserStatusReq struct { + token string + id string +} + +func (req changeUserStatusReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { return apiutil.ErrMissingID } - return nil } diff --git a/users/api/transport.go b/users/api/transport.go index ac2550b9..d4216fb5 100644 --- a/users/api/transport.go +++ b/users/api/transport.go @@ -28,6 +28,7 @@ const ( limitKey = "limit" emailKey = "email" metadataKey = "metadata" + statusKey = "status" defOffset = 0 defLimit = 10 ) @@ -54,7 +55,7 @@ func MakeHandler(svc users.Service, tracer opentracing.Tracer, logger logger.Log opts..., )) - mux.Get("/users/:userID", kithttp.NewServer( + mux.Get("/users/:id", kithttp.NewServer( kitot.TraceServer(tracer, "view_user")(viewUserEndpoint(svc)), decodeViewUser, encodeResponse, @@ -96,7 +97,7 @@ func MakeHandler(svc users.Service, tracer opentracing.Tracer, logger logger.Log opts..., )) - mux.Get("/groups/:groupId", kithttp.NewServer( + mux.Get("/groups/:id", kithttp.NewServer( kitot.TraceServer(tracer, "list_members")(listMembersEndpoint(svc)), decodeListMembersRequest, encodeResponse, @@ -110,6 +111,20 @@ func MakeHandler(svc users.Service, tracer opentracing.Tracer, logger logger.Log opts..., )) + mux.Post("/users/:id/enable", kithttp.NewServer( + kitot.TraceServer(tracer, "enable_user")(enableUserEndpoint(svc)), + decodeChangeUserStatus, + encodeResponse, + opts..., + )) + + mux.Post("/users/:id/disable", kithttp.NewServer( + kitot.TraceServer(tracer, "disable_user")(disableUserEndpoint(svc)), + decodeChangeUserStatus, + encodeResponse, + opts..., + )) + mux.GetFunc("/health", mainflux.Health("users")) mux.Handle("/metrics", promhttp.Handler()) @@ -118,8 +133,8 @@ func MakeHandler(svc users.Service, tracer opentracing.Tracer, logger logger.Log func decodeViewUser(_ context.Context, r *http.Request) (interface{}, error) { req := viewUserReq{ - token: apiutil.ExtractBearerToken(r), - userID: bone.GetValue(r, "userID"), + token: apiutil.ExtractBearerToken(r), + id: bone.GetValue(r, "id"), } return req, nil @@ -152,8 +167,13 @@ func decodeListUsers(_ context.Context, r *http.Request) (interface{}, error) { return nil, err } + s, err := apiutil.ReadStringQuery(r, statusKey, users.EnabledStatusKey) + if err != nil { + return nil, err + } req := listUsersReq{ token: apiutil.ExtractBearerToken(r), + status: s, offset: o, limit: l, email: e, @@ -258,10 +278,15 @@ func decodeListMembersRequest(_ context.Context, r *http.Request) (interface{}, if err != nil { return nil, err } + s, err := apiutil.ReadStringQuery(r, statusKey, users.EnabledStatusKey) + if err != nil { + return nil, err + } req := listMemberGroupReq{ token: apiutil.ExtractBearerToken(r), - groupID: bone.GetValue(r, "groupId"), + status: s, + id: bone.GetValue(r, "id"), offset: o, limit: l, metadata: m, @@ -269,6 +294,15 @@ func decodeListMembersRequest(_ context.Context, r *http.Request) (interface{}, return req, nil } +func decodeChangeUserStatus(_ context.Context, r *http.Request) (interface{}, error) { + req := changeUserStatusReq{ + token: apiutil.ExtractBearerToken(r), + id: bone.GetValue(r, "id"), + } + + return req, nil +} + func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { if ar, ok := response.(mainflux.Response); ok { for k, v := range ar.Headers() { diff --git a/users/mocks/users.go b/users/mocks/users.go index b838f11a..837eb2ab 100644 --- a/users/mocks/users.go +++ b/users/mocks/users.go @@ -91,7 +91,7 @@ func (urm *userRepositoryMock) RetrieveByID(ctx context.Context, id string) (use return val, nil } -func (urm *userRepositoryMock) RetrieveAll(ctx context.Context, offset, limit uint64, ids []string, email string, um users.Metadata) (users.UserPage, error) { +func (urm *userRepositoryMock) RetrieveAll(ctx context.Context, status string, offset, limit uint64, ids []string, email string, um users.Metadata) (users.UserPage, error) { urm.mu.Lock() defer urm.mu.Unlock() @@ -110,6 +110,20 @@ func (urm *userRepositoryMock) RetrieveAll(ctx context.Context, offset, limit ui return up, nil } + if status == users.EnabledStatusKey || status == users.DisabledStatusKey { + for _, u := range sortUsers(urm.users) { + if i >= offset && i < (limit+offset) { + if status == u.Status { + up.Users = append(up.Users, u) + } + } + i++ + } + up.Offset = offset + up.Limit = limit + up.Total = uint64(i) + return up, nil + } for _, u := range sortUsers(urm.users) { if i >= offset && i < (limit+offset) { up.Users = append(up.Users, u) @@ -134,6 +148,19 @@ func (urm *userRepositoryMock) UpdatePassword(_ context.Context, token, password return nil } +func (urm *userRepositoryMock) ChangeStatus(ctx context.Context, id, status string) error { + urm.mu.Lock() + defer urm.mu.Unlock() + + user, ok := urm.usersByID[id] + if !ok { + return errors.ErrNotFound + } + user.Status = status + urm.usersByID[id] = user + urm.users[user.Email] = user + return nil +} func sortUsers(us map[string]users.User) []users.User { users := []users.User{} ids := make([]string, 0, len(us)) diff --git a/users/postgres/init.go b/users/postgres/init.go index 821c802a..0854a728 100644 --- a/users/postgres/init.go +++ b/users/postgres/init.go @@ -78,6 +78,14 @@ func migrateDB(db *sqlx.DB) error { `ALTER TABLE IF EXISTS users ADD PRIMARY KEY (id)`, }, }, + { + Id: "users_5", + Up: []string{ + `CREATE TYPE USER_STATUS AS ENUM ('enabled', 'disabled');`, + `ALTER TABLE IF EXISTS users ADD COLUMN IF NOT EXISTS + status USER_STATUS NOT NULL DEFAULT 'enabled'`, + }, + }, }, } diff --git a/users/postgres/users.go b/users/postgres/users.go index 5017bd45..808865b0 100644 --- a/users/postgres/users.go +++ b/users/postgres/users.go @@ -38,7 +38,7 @@ func NewUserRepo(db Database) users.UserRepository { } func (ur userRepository) Save(ctx context.Context, user users.User) (string, error) { - q := `INSERT INTO users (email, password, id, metadata) VALUES (:email, :password, :id, :metadata) RETURNING id` + q := `INSERT INTO users (email, password, id, metadata, status) VALUES (:email, :password, :id, :metadata, :status) RETURNING id` if user.ID == "" || user.Email == "" { return "", errors.ErrMalformedEntity } @@ -72,7 +72,7 @@ func (ur userRepository) Save(ctx context.Context, user users.User) (string, err } func (ur userRepository) Update(ctx context.Context, user users.User) error { - q := `UPDATE users SET(email, password, metadata) VALUES (:email, :password, :metadata) WHERE email = :email` + q := `UPDATE users SET(email, password, metadata, status) VALUES (:email, :password, :metadata, :status) WHERE email = :email;` dbu, err := toDBUser(user) if err != nil { @@ -87,7 +87,7 @@ func (ur userRepository) Update(ctx context.Context, user users.User) error { } func (ur userRepository) UpdateUser(ctx context.Context, user users.User) error { - q := `UPDATE users SET metadata = :metadata WHERE email = :email` + q := `UPDATE users SET metadata = :metadata WHERE email = :email AND status = 'enabled'` dbu, err := toDBUser(user) if err != nil { @@ -102,7 +102,7 @@ func (ur userRepository) UpdateUser(ctx context.Context, user users.User) error } func (ur userRepository) RetrieveByEmail(ctx context.Context, email string) (users.User, error) { - q := `SELECT id, password, metadata FROM users WHERE email = $1` + q := `SELECT id, password, metadata FROM users WHERE email = $1 AND status = 'enabled'` dbu := dbUser{ Email: email, @@ -137,7 +137,7 @@ func (ur userRepository) RetrieveByID(ctx context.Context, id string) (users.Use return toUser(dbu) } -func (ur userRepository) RetrieveAll(ctx context.Context, offset, limit uint64, userIDs []string, email string, um users.Metadata) (users.UserPage, error) { +func (ur userRepository) RetrieveAll(ctx context.Context, status string, offset, limit uint64, userIDs []string, email string, um users.Metadata) (users.UserPage, error) { eq, ep, err := createEmailQuery("", email) if err != nil { return users.UserPage{}, errors.Wrap(errors.ErrViewEntity, err) @@ -147,6 +147,10 @@ func (ur userRepository) RetrieveAll(ctx context.Context, offset, limit uint64, if err != nil { return users.UserPage{}, errors.Wrap(errors.ErrViewEntity, err) } + aq := fmt.Sprintf("status = '%s'", status) + if status == users.AllStatusKey { + aq = "" + } var query []string var emq string @@ -156,6 +160,9 @@ func (ur userRepository) RetrieveAll(ctx context.Context, offset, limit uint64, if mq != "" { query = append(query, mq) } + if aq != "" { + query = append(query, aq) + } if len(userIDs) > 0 { query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(userIDs, "','"))) @@ -213,7 +220,7 @@ func (ur userRepository) RetrieveAll(ctx context.Context, offset, limit uint64, } func (ur userRepository) UpdatePassword(ctx context.Context, email, password string) error { - q := `UPDATE users SET password = :password WHERE email = :email` + q := `UPDATE users SET password = :password WHERE status = 'enabled' AND email = :email` db := dbUser{ Email: email, @@ -227,12 +234,27 @@ func (ur userRepository) UpdatePassword(ctx context.Context, email, password str return nil } +func (ur userRepository) ChangeStatus(ctx context.Context, id, status string) error { + q := fmt.Sprintf(`UPDATE users SET status = '%s' WHERE id = :id`, status) + + dbu := dbUser{ + ID: id, + } + + if _, err := ur.db.NamedExecContext(ctx, q, dbu); err != nil { + return errors.Wrap(errors.ErrUpdateEntity, err) + } + + return nil +} + type dbUser struct { ID string `db:"id"` Email string `db:"email"` Password string `db:"password"` Metadata []byte `db:"metadata"` Groups []auth.Group `db:"groups"` + Status string `db:"status"` } func toDBUser(u users.User) (dbUser, error) { @@ -250,6 +272,7 @@ func toDBUser(u users.User) (dbUser, error) { Email: u.Email, Password: u.Password, Metadata: data, + Status: u.Status, }, nil } @@ -281,6 +304,7 @@ func toUser(dbu dbUser) (users.User, error) { Email: dbu.Email, Password: dbu.Password, Metadata: metadata, + Status: dbu.Status, }, nil } diff --git a/users/postgres/users_test.go b/users/postgres/users_test.go index c2b1cd58..5afca583 100644 --- a/users/postgres/users_test.go +++ b/users/postgres/users_test.go @@ -35,6 +35,7 @@ func TestUserSave(t *testing.T) { ID: uid, Email: email, Password: "pass", + Status: users.EnabledStatusKey, }, err: nil, }, @@ -44,9 +45,20 @@ func TestUserSave(t *testing.T) { ID: uid, Email: email, Password: "pass", + Status: users.EnabledStatusKey, }, err: errors.ErrConflict, }, + { + desc: "invalid user status", + user: users.User{ + ID: uid, + Email: email, + Password: "pass", + Status: "invalid", + }, + err: errors.ErrMalformedEntity, + }, } dbMiddleware := postgres.NewDatabase(db) @@ -71,6 +83,7 @@ func TestSingleUserRetrieval(t *testing.T) { ID: uid, Email: email, Password: "pass", + Status: users.EnabledStatusKey, } _, err = repo.Save(context.Background(), user) @@ -113,6 +126,7 @@ func TestRetrieveAll(t *testing.T) { ID: uid, Email: email, Password: "pass", + Status: users.EnabledStatusKey, } if i < metaNum { user.Metadata = meta @@ -218,7 +232,7 @@ func TestRetrieveAll(t *testing.T) { }, } for desc, tc := range cases { - page, err := userRepo.RetrieveAll(context.Background(), tc.offset, tc.limit, tc.ids, tc.email, tc.metadata) + page, err := userRepo.RetrieveAll(context.Background(), users.EnabledStatusKey, tc.offset, tc.limit, tc.ids, tc.email, tc.metadata) size := uint64(len(page.Users)) assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", desc, tc.size, size)) assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %d\n", desc, err)) diff --git a/users/service.go b/users/service.go index 115f1ad9..335d9075 100644 --- a/users/service.go +++ b/users/service.go @@ -9,6 +9,7 @@ import ( "github.com/mainflux/mainflux" "github.com/mainflux/mainflux/auth" + "github.com/mainflux/mainflux/internal/apiutil" "github.com/mainflux/mainflux/pkg/errors" ) @@ -16,6 +17,9 @@ const ( memberRelationKey = "member" authoritiesObjKey = "authorities" usersObjKey = "users" + EnabledStatusKey = "enabled" + DisabledStatusKey = "disabled" + AllStatusKey = "all" ) var ( @@ -31,6 +35,12 @@ var ( // ErrPasswordFormat indicates weak password. ErrPasswordFormat = errors.New("password does not meet the requirements") + + // ErrAlreadyEnabledUser indicates the user is already enabled. + ErrAlreadyEnabledUser = errors.New("the user is already enabled") + + // ErrAlreadyDisabledUser indicates the user is already disabled. + ErrAlreadyDisabledUser = errors.New("the user is already disabled") ) // Service specifies an API that must be fullfiled by the domain service @@ -53,7 +63,7 @@ type Service interface { ViewProfile(ctx context.Context, token string) (User, error) // ListUsers retrieves users list for a valid admin token. - ListUsers(ctx context.Context, token string, offset, limit uint64, email string, meta Metadata) (UserPage, error) + ListUsers(ctx context.Context, token string, pm PageMetadata) (UserPage, error) // UpdateUser updates the user metadata. UpdateUser(ctx context.Context, token string, user User) error @@ -73,15 +83,23 @@ type Service interface { SendPasswordReset(ctx context.Context, host, email, token string) error // ListMembers retrieves everything that is assigned to a group identified by groupID. - ListMembers(ctx context.Context, token, groupID string, offset, limit uint64, meta Metadata) (UserPage, error) + ListMembers(ctx context.Context, token, groupID string, pm PageMetadata) (UserPage, error) + + // EnableUser logically enableds the user identified with the provided ID + EnableUser(ctx context.Context, token, id string) error + + // DisableUser logically disables the user identified with the provided ID + DisableUser(ctx context.Context, token, id string) error } // PageMetadata contains page metadata that helps navigation. type PageMetadata struct { - Total uint64 - Offset uint64 - Limit uint64 - Email string + Total uint64 + Offset uint64 + Limit uint64 + Email string + Status string + Metadata Metadata } // GroupPage contains a page of groups. @@ -146,6 +164,16 @@ func (svc usersService) Register(ctx context.Context, token string, user User) ( return "", errors.Wrap(errors.ErrMalformedEntity, err) } user.Password = hash + if user.Status == "" { + user.Status = EnabledStatusKey + } + + if user.Status != AllStatusKey && + user.Status != EnabledStatusKey && + user.Status != DisabledStatusKey { + return "", apiutil.ErrInvalidStatus + } + uid, err = svc.users.Save(ctx, user) if err != nil { return "", err @@ -181,14 +209,13 @@ func (svc usersService) Login(ctx context.Context, user User) (string, error) { } func (svc usersService) ViewUser(ctx context.Context, token, id string) (User, error) { - _, err := svc.identify(ctx, token) - if err != nil { + if _, err := svc.identify(ctx, token); err != nil { return User{}, err } dbUser, err := svc.users.RetrieveByID(ctx, id) if err != nil { - return User{}, errors.Wrap(errors.ErrAuthentication, err) + return User{}, errors.Wrap(errors.ErrNotFound, err) } return User{ @@ -196,6 +223,7 @@ func (svc usersService) ViewUser(ctx context.Context, token, id string) (User, e Email: dbUser.Email, Password: "", Metadata: dbUser.Metadata, + Status: dbUser.Status, }, nil } @@ -217,7 +245,7 @@ func (svc usersService) ViewProfile(ctx context.Context, token string) (User, er }, nil } -func (svc usersService) ListUsers(ctx context.Context, token string, offset, limit uint64, email string, m Metadata) (UserPage, error) { +func (svc usersService) ListUsers(ctx context.Context, token string, pm PageMetadata) (UserPage, error) { id, err := svc.identify(ctx, token) if err != nil { return UserPage{}, err @@ -226,7 +254,7 @@ func (svc usersService) ListUsers(ctx context.Context, token string, offset, lim if err := svc.authorize(ctx, id.id, "authorities", "member"); err != nil { return UserPage{}, errors.Wrap(errors.ErrAuthentication, err) } - return svc.users.RetrieveAll(ctx, offset, limit, nil, email, m) + return svc.users.RetrieveAll(ctx, pm.Status, pm.Offset, pm.Limit, nil, pm.Email, pm.Metadata) } func (svc usersService) UpdateUser(ctx context.Context, token string, u User) error { @@ -307,12 +335,12 @@ func (svc usersService) SendPasswordReset(_ context.Context, host, email, token return svc.email.SendPasswordReset(to, host, token) } -func (svc usersService) ListMembers(ctx context.Context, token, groupID string, offset, limit uint64, m Metadata) (UserPage, error) { +func (svc usersService) ListMembers(ctx context.Context, token, groupID string, pm PageMetadata) (UserPage, error) { if _, err := svc.identify(ctx, token); err != nil { return UserPage{}, err } - userIDs, err := svc.members(ctx, token, groupID, offset, limit) + userIDs, err := svc.members(ctx, token, groupID, pm.Offset, pm.Limit) if err != nil { return UserPage{}, err } @@ -322,13 +350,46 @@ func (svc usersService) ListMembers(ctx context.Context, token, groupID string, Users: []User{}, PageMetadata: PageMetadata{ Total: 0, - Offset: offset, - Limit: limit, + Offset: pm.Offset, + Limit: pm.Limit, }, }, nil } - return svc.users.RetrieveAll(ctx, offset, limit, userIDs, "", m) + return svc.users.RetrieveAll(ctx, pm.Status, pm.Offset, pm.Limit, userIDs, pm.Email, pm.Metadata) +} + +func (svc usersService) EnableUser(ctx context.Context, token, id string) error { + if err := svc.changeStatus(ctx, token, id, EnabledStatusKey); err != nil { + return err + } + return nil +} + +func (svc usersService) DisableUser(ctx context.Context, token, id string) error { + if err := svc.changeStatus(ctx, token, id, DisabledStatusKey); err != nil { + return err + } + return nil +} + +func (svc usersService) changeStatus(ctx context.Context, token, id, status string) error { + if _, err := svc.identify(ctx, token); err != nil { + return err + } + + dbUser, err := svc.users.RetrieveByID(ctx, id) + if err != nil { + return errors.Wrap(errors.ErrNotFound, err) + } + if dbUser.Status == status { + if status == DisabledStatusKey { + return ErrAlreadyDisabledUser + } + return ErrAlreadyEnabledUser + } + + return svc.users.ChangeStatus(ctx, id, status) } // Auth helpers diff --git a/users/service_test.go b/users/service_test.go index a14167a8..ae85d5d1 100644 --- a/users/service_test.go +++ b/users/service_test.go @@ -168,7 +168,7 @@ func TestViewUser(t *testing.T) { user: users.User{}, token: token, userID: "", - err: errors.ErrAuthentication, + err: errors.ErrNotFound, }, } @@ -260,7 +260,13 @@ func TestListUsers(t *testing.T) { } for desc, tc := range cases { - page, err := svc.ListUsers(context.Background(), tc.token, tc.offset, tc.limit, tc.email, nil) + pm := users.PageMetadata{ + Offset: tc.offset, + Limit: tc.limit, + Email: tc.email, + Status: "all", + } + page, err := svc.ListUsers(context.Background(), tc.token, pm) size := uint64(len(page.Users)) assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", desc, tc.size, size)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", desc, tc.err, err)) @@ -390,3 +396,108 @@ func TestSendPasswordReset(t *testing.T) { } } + +func TestDisableUser(t *testing.T) { + enabledUser1 := users.User{Email: "user1@example.com", Password: "password"} + enabledUser2 := users.User{Email: "user2@example.com", Password: "password", Status: "enabled"} + disabledUser1 := users.User{Email: "user3@example.com", Password: "password", Status: "disabled"} + + svc := newService() + + id, err := svc.Register(context.Background(), user.Email, user) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + user.ID = id + user.Status = "enabled" + token, err := svc.Login(context.Background(), user) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + id, err = svc.Register(context.Background(), token, enabledUser1) + require.Nil(t, err, fmt.Sprintf("register enabledUser1 error: %s", err)) + enabledUser1.ID = id + enabledUser1.Status = "enabled" + + id, err = svc.Register(context.Background(), token, enabledUser2) + require.Nil(t, err, fmt.Sprintf("register enabledUser2 error: %s", err)) + enabledUser2.ID = id + enabledUser2.Status = "disabled" + + id, err = svc.Register(context.Background(), token, disabledUser1) + require.Nil(t, err, fmt.Sprintf("register disabledUser1 error: %s", err)) + disabledUser1.ID = id + disabledUser1.Status = "disabled" + + cases := []struct { + desc string + id string + token string + err error + }{ + { + desc: "disable user with wrong credentials", + id: enabledUser2.ID, + token: "", + err: errors.ErrAuthentication, + }, + { + desc: "disable existing user", + id: enabledUser2.ID, + token: token, + err: nil, + }, + { + desc: "disable disabled user", + id: enabledUser2.ID, + token: token, + err: users.ErrAlreadyDisabledUser, + }, + { + desc: "disable non-existing user", + id: "", + token: token, + err: errors.ErrNotFound, + }, + } + + for _, tc := range cases { + err := svc.DisableUser(context.Background(), tc.token, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } + + _, err = svc.Login(context.Background(), enabledUser2) + assert.True(t, errors.Contains(err, errors.ErrNotFound), fmt.Sprintf("Login disabled user: expected %s got %s\n", errors.ErrNotFound, err)) + + cases2 := map[string]struct { + status string + size uint64 + response []users.User + }{ + "list enabled users": { + status: "enabled", + size: 2, + response: []users.User{enabledUser1, user}, + }, + "list disabled users": { + status: "disabled", + size: 2, + response: []users.User{enabledUser2, disabledUser1}, + }, + "list enabled and disabled users": { + status: "all", + size: 4, + response: []users.User{enabledUser1, enabledUser2, disabledUser1, user}, + }, + } + + for desc, tc := range cases2 { + pm := users.PageMetadata{ + Offset: 0, + Limit: 100, + Status: tc.status, + } + page, err := svc.ListUsers(context.Background(), token, pm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + size := uint64(len(page.Users)) + assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", desc, tc.size, size)) + assert.ElementsMatch(t, tc.response, page.Users, fmt.Sprintf("%s: expected %s got %s\n", desc, tc.response, page.Users)) + } +} diff --git a/users/tracing/users.go b/users/tracing/users.go index e8282ca9..a4cedcdd 100644 --- a/users/tracing/users.go +++ b/users/tracing/users.go @@ -75,12 +75,20 @@ func (urm userRepositoryMiddleware) UpdatePassword(ctx context.Context, email, p return urm.repo.UpdatePassword(ctx, email, password) } -func (urm userRepositoryMiddleware) RetrieveAll(ctx context.Context, offset, limit uint64, ids []string, email string, um users.Metadata) (users.UserPage, error) { +func (urm userRepositoryMiddleware) RetrieveAll(ctx context.Context, status string, offset, limit uint64, ids []string, email string, um users.Metadata) (users.UserPage, error) { span := createSpan(ctx, urm.tracer, members) defer span.Finish() ctx = opentracing.ContextWithSpan(ctx, span) - return urm.repo.RetrieveAll(ctx, offset, limit, ids, email, um) + return urm.repo.RetrieveAll(ctx, status, offset, limit, ids, email, um) +} + +func (urm userRepositoryMiddleware) ChangeStatus(ctx context.Context, id, status string) error { + span := createSpan(ctx, urm.tracer, members) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return urm.repo.ChangeStatus(ctx, id, status) } func createSpan(ctx context.Context, tracer opentracing.Tracer, opName string) opentracing.Span { diff --git a/users/users.go b/users/users.go index fe7797e6..3264b6ff 100644 --- a/users/users.go +++ b/users/users.go @@ -39,6 +39,7 @@ type User struct { Email string Password string Metadata Metadata + Status string } // Validate returns an error if user representation is invalid. @@ -65,10 +66,13 @@ type UserRepository interface { RetrieveByID(ctx context.Context, id string) (User, error) // RetrieveAll retrieves all users for given array of userIDs. - RetrieveAll(ctx context.Context, offset, limit uint64, userIDs []string, email string, m Metadata) (UserPage, error) + RetrieveAll(ctx context.Context, status string, offset, limit uint64, userIDs []string, email string, m Metadata) (UserPage, error) // UpdatePassword updates password for user with given email UpdatePassword(ctx context.Context, email, password string) error + + // ChangeStatus changes users status to enabled or disabled + ChangeStatus(ctx context.Context, id, status string) error } func isEmail(email string) bool {