diff --git a/authn/api/grpc/endpoint_test.go b/authn/api/grpc/endpoint_test.go index 1bcee2d0..39f87ef4 100644 --- a/authn/api/grpc/endpoint_test.go +++ b/authn/api/grpc/endpoint_test.go @@ -58,42 +58,50 @@ func TestIssue(t *testing.T) { id string kind uint32 err error + code codes.Code }{ { desc: "issue for user with valid token", id: email, kind: authn.UserKey, err: nil, + code: codes.OK, }, { desc: "issue recovery key", id: email, kind: authn.RecoveryKey, err: nil, + code: codes.OK, }, { desc: "issue API key", id: userKey.Secret, kind: authn.APIKey, err: nil, + code: codes.OK, }, { desc: "issue for invalid key type", id: email, kind: 32, err: status.Error(codes.InvalidArgument, "received invalid token request"), + code: codes.InvalidArgument, }, { desc: "issue for user that exist", id: "", kind: authn.APIKey, err: status.Error(codes.Unauthenticated, "unauthorized access"), + code: codes.Unauthenticated, }, } for _, tc := range cases { _, err := client.Issue(context.Background(), &mainflux.IssueReq{Issuer: tc.id, Type: tc.kind}) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err)) + e, ok := status.FromError(err) + assert.True(t, ok, "gRPC status can't be extracted from the error") + assert.Equal(t, tc.code, e.Code(), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.code, e.Code())) } } @@ -116,36 +124,43 @@ func TestIdentify(t *testing.T) { token string id string err error + code codes.Code }{ { desc: "identify user with recovery token", token: recoveryKey.Secret, id: email, err: nil, + code: codes.OK, }, { desc: "identify user with API token", token: apiKey.Secret, id: email, err: nil, + code: codes.OK, }, { desc: "identify user with invalid user token", token: "invalid", id: "", err: status.Error(codes.Unauthenticated, "unauthorized access"), + code: codes.Unauthenticated, }, { desc: "identify user that doesn't exist", token: "", id: "", err: status.Error(codes.InvalidArgument, "received invalid token request"), + code: codes.InvalidArgument, }, } for _, tc := range cases { id, err := client.Identify(context.Background(), &mainflux.Token{Value: tc.token}) assert.Equal(t, tc.id, id.GetValue(), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.id, id.GetValue())) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err)) + e, ok := status.FromError(err) + assert.True(t, ok, "gRPC status can't be extracted from the error") + assert.Equal(t, tc.code, e.Code(), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.code, e.Code())) } } diff --git a/authn/api/grpc/server.go b/authn/api/grpc/server.go index 54ee5374..efa4e6fe 100644 --- a/authn/api/grpc/server.go +++ b/authn/api/grpc/server.go @@ -8,6 +8,7 @@ import ( kitgrpc "github.com/go-kit/kit/transport/grpc" mainflux "github.com/mainflux/mainflux" "github.com/mainflux/mainflux/authn" + "github.com/mainflux/mainflux/errors" opentracing "github.com/opentracing/opentracing-go" "golang.org/x/net/context" "google.golang.org/grpc/codes" @@ -74,12 +75,14 @@ func encodeIdentifyResponse(_ context.Context, grpcRes interface{}) (interface{} } func encodeError(err error) error { - switch err { - case nil: + switch { + case errors.Contains(err, nil): return nil - case authn.ErrMalformedEntity: + case errors.Contains(err, authn.ErrMalformedEntity): return status.Error(codes.InvalidArgument, "received invalid token request") - case authn.ErrUnauthorizedAccess, authn.ErrKeyExpired: + case errors.Contains(err, authn.ErrUnauthorizedAccess): + return status.Error(codes.Unauthenticated, err.Error()) + case errors.Contains(err, authn.ErrKeyExpired): return status.Error(codes.Unauthenticated, err.Error()) default: return status.Error(codes.Internal, "internal server error") diff --git a/authn/api/http/responses.go b/authn/api/http/responses.go index 24709439..76ff4007 100644 --- a/authn/api/http/responses.go +++ b/authn/api/http/responses.go @@ -48,3 +48,7 @@ func (res revokeKeyRes) Headers() map[string]string { func (res revokeKeyRes) Empty() bool { return true } + +type errorRes struct { + Err string `json:"error"` +} diff --git a/authn/api/http/transport.go b/authn/api/http/transport.go index dee6f944..5141a403 100644 --- a/authn/api/http/transport.go +++ b/authn/api/http/transport.go @@ -6,7 +6,6 @@ package http import ( "context" "encoding/json" - "errors" "io" "net/http" "strings" @@ -16,6 +15,7 @@ import ( "github.com/go-zoo/bone" "github.com/mainflux/mainflux" "github.com/mainflux/mainflux/authn" + "github.com/mainflux/mainflux/errors" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -67,7 +67,7 @@ func decodeIssue(_ context.Context, r *http.Request) (interface{}, error) { issuer: r.Header.Get("Authorization"), } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err + return nil, errors.Wrap(authn.ErrMalformedEntity, err) } return req, nil @@ -100,28 +100,28 @@ func encodeResponse(_ context.Context, w http.ResponseWriter, response interface } func encodeError(_ context.Context, err error, w http.ResponseWriter) { - w.Header().Set("Content-Type", contentType) - - switch err { - case authn.ErrMalformedEntity: + switch { + case errors.Contains(err, authn.ErrMalformedEntity): w.WriteHeader(http.StatusBadRequest) - case authn.ErrUnauthorizedAccess: + case errors.Contains(err, authn.ErrUnauthorizedAccess): w.WriteHeader(http.StatusForbidden) - case authn.ErrNotFound: + case errors.Contains(err, authn.ErrNotFound): w.WriteHeader(http.StatusNotFound) - case authn.ErrConflict: + case errors.Contains(err, authn.ErrConflict): w.WriteHeader(http.StatusConflict) - case io.EOF, io.ErrUnexpectedEOF: + case errors.Contains(err, io.EOF): w.WriteHeader(http.StatusBadRequest) - case errUnsupportedContentType: + case errors.Contains(err, io.ErrUnexpectedEOF): + w.WriteHeader(http.StatusBadRequest) + case errors.Contains(err, errUnsupportedContentType): w.WriteHeader(http.StatusUnsupportedMediaType) default: - switch err.(type) { - case *json.SyntaxError: - w.WriteHeader(http.StatusBadRequest) - case *json.UnmarshalTypeError: - w.WriteHeader(http.StatusBadRequest) - default: + w.WriteHeader(http.StatusInternalServerError) + } + errorVal, ok := err.(errors.Error) + if ok { + if err := json.NewEncoder(w).Encode(errorRes{Err: errorVal.Msg()}); err != nil { + w.Header().Set("Content-Type", contentType) w.WriteHeader(http.StatusInternalServerError) } } diff --git a/authn/jwt/token_test.go b/authn/jwt/token_test.go index 3e07fd87..57565f05 100644 --- a/authn/jwt/token_test.go +++ b/authn/jwt/token_test.go @@ -10,6 +10,7 @@ import ( "github.com/mainflux/mainflux/authn" "github.com/mainflux/mainflux/authn/jwt" + "github.com/mainflux/mainflux/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -48,7 +49,7 @@ func TestIssue(t *testing.T) { for _, tc := range cases { _, err := tokenizer.Issue(tc.key) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) } } @@ -104,7 +105,7 @@ func TestParse(t *testing.T) { for _, tc := range cases { key, err := tokenizer.Parse(tc.token) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key)) } } diff --git a/authn/jwt/tokenizer.go b/authn/jwt/tokenizer.go index 15d92692..41bef4bf 100644 --- a/authn/jwt/tokenizer.go +++ b/authn/jwt/tokenizer.go @@ -8,6 +8,7 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/mainflux/mainflux/authn" + "github.com/mainflux/mainflux/errors" ) type claims struct { @@ -68,9 +69,9 @@ func (svc tokenizer) Parse(token string) (authn.Key, error) { if c.Type != nil && *c.Type == authn.APIKey { return c.toKey(), nil } - return authn.Key{}, authn.ErrKeyExpired + return authn.Key{}, errors.Wrap(authn.ErrKeyExpired, err) } - return authn.Key{}, authn.ErrUnauthorizedAccess + return authn.Key{}, errors.Wrap(authn.ErrUnauthorizedAccess, err) } return c.toKey(), nil diff --git a/authn/postgres/key.go b/authn/postgres/key.go index 6c121745..988f04d3 100644 --- a/authn/postgres/key.go +++ b/authn/postgres/key.go @@ -7,8 +7,14 @@ import ( "github.com/lib/pq" "github.com/mainflux/mainflux/authn" + "github.com/mainflux/mainflux/errors" ) +var ( + errSave = errors.New("failed to save key in database") + errRetrieve = errors.New("failed to retrieve key from database") + errDelete = errors.New("failed to delete key from database") +) var _ authn.KeyRepository = (*repo)(nil) const ( @@ -37,11 +43,11 @@ func (kr repo) Save(ctx context.Context, key authn.Key) (string, error) { pqErr, ok := err.(*pq.Error) if ok { if pqErr.Code.Name() == errDuplicate { - return "", authn.ErrConflict + return "", errors.Wrap(authn.ErrConflict, pqErr) } } - return "", err + return "", errors.Wrap(errSave, err) } return dbKey.ID, nil @@ -53,10 +59,10 @@ func (kr repo) Retrieve(ctx context.Context, issuer, id string) (authn.Key, erro if err := kr.db.QueryRowxContext(ctx, q, issuer, id).StructScan(&key); err != nil { pqErr, ok := err.(*pq.Error) if err == sql.ErrNoRows || ok && errInvalid == pqErr.Code.Name() { - return authn.Key{}, authn.ErrNotFound + return authn.Key{}, errors.Wrap(authn.ErrNotFound, err) } - return authn.Key{}, err + return authn.Key{}, errors.Wrap(errRetrieve, err) } return toKey(key), nil @@ -69,7 +75,7 @@ func (kr repo) Remove(ctx context.Context, issuer, id string) error { Issuer: issuer, } if _, err := kr.db.NamedExecContext(ctx, q, key); err != nil { - return err + return errors.Wrap(errDelete, err) } return nil diff --git a/authn/postgres/key_test.go b/authn/postgres/key_test.go index 8a405063..e0d8d0b7 100644 --- a/authn/postgres/key_test.go +++ b/authn/postgres/key_test.go @@ -12,6 +12,7 @@ import ( "github.com/mainflux/mainflux/authn" "github.com/mainflux/mainflux/authn/postgres" "github.com/mainflux/mainflux/authn/uuid" + "github.com/mainflux/mainflux/errors" "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" ) @@ -53,7 +54,7 @@ func TestKeySave(t *testing.T) { for _, tc := range cases { _, err := repo.Save(context.Background(), tc.key) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } } @@ -101,7 +102,7 @@ func TestKeyRetrieve(t *testing.T) { for _, tc := range cases { _, err := repo.Retrieve(context.Background(), tc.issuer, tc.id) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } } @@ -143,6 +144,6 @@ func TestKeyRemove(t *testing.T) { for _, tc := range cases { err := repo.Remove(context.Background(), tc.issuer, tc.id) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } } diff --git a/authn/service.go b/authn/service.go index 9121e60b..5aa3f567 100644 --- a/authn/service.go +++ b/authn/service.go @@ -5,8 +5,9 @@ package authn import ( "context" - "errors" "time" + + "github.com/mainflux/mainflux/errors" ) const ( @@ -28,6 +29,12 @@ var ( // ErrConflict indicates that entity already exists. ErrConflict = errors.New("entity already exists") + + errIssueUser = errors.New("failed to issue new user key") + errIssueTmp = errors.New("failed to issue new temporary key") + errRevoke = errors.New("failed to remove key") + errRetrieve = errors.New("failed to retrieve key data") + errIdentify = errors.New("failed to validate token") ) // Service specifies an API that must be fullfiled by the domain service @@ -84,16 +91,18 @@ func (svc service) Issue(ctx context.Context, issuer string, key Key) (Key, erro func (svc service) Revoke(ctx context.Context, issuer, id string) error { email, err := svc.login(issuer) if err != nil { - return err + return errors.Wrap(errRevoke, err) } - - return svc.keys.Remove(ctx, email, id) + if err := svc.keys.Remove(ctx, email, id); err != nil { + return errors.Wrap(errRevoke, err) + } + return nil } func (svc service) Retrieve(ctx context.Context, issuer, id string) (Key, error) { email, err := svc.login(issuer) if err != nil { - return Key{}, err + return Key{}, errors.Wrap(errRetrieve, err) } return svc.keys.Retrieve(ctx, email, id) @@ -102,7 +111,7 @@ func (svc service) Retrieve(ctx context.Context, issuer, id string) (Key, error) func (svc service) Identify(ctx context.Context, token string) (string, error) { c, err := svc.tokenizer.Parse(token) if err != nil { - return "", err + return "", errors.Wrap(errIdentify, err) } switch c.Type { @@ -133,7 +142,7 @@ func (svc service) tmpKey(issuer string, duration time.Duration, key Key) (Key, key.ExpiresAt = key.IssuedAt.Add(duration) val, err := svc.tokenizer.Issue(key) if err != nil { - return Key{}, err + return Key{}, errors.Wrap(errIssueTmp, err) } key.Secret = val @@ -143,24 +152,24 @@ func (svc service) tmpKey(issuer string, duration time.Duration, key Key) (Key, func (svc service) userKey(ctx context.Context, issuer string, key Key) (Key, error) { email, err := svc.login(issuer) if err != nil { - return Key{}, err + return Key{}, errors.Wrap(errIssueUser, err) } key.Issuer = email id, err := svc.idp.ID() if err != nil { - return Key{}, err + return Key{}, errors.Wrap(errIssueUser, err) } key.ID = id value, err := svc.tokenizer.Issue(key) if err != nil { - return Key{}, err + return Key{}, errors.Wrap(errIssueUser, err) } key.Secret = value if _, err := svc.keys.Save(ctx, key); err != nil { - return Key{}, err + return Key{}, errors.Wrap(errIssueUser, err) } return key, nil diff --git a/authn/service_test.go b/authn/service_test.go index b8ff31c6..31c03c40 100644 --- a/authn/service_test.go +++ b/authn/service_test.go @@ -12,6 +12,7 @@ import ( "github.com/mainflux/mainflux/authn" "github.com/mainflux/mainflux/authn/jwt" "github.com/mainflux/mainflux/authn/mocks" + "github.com/mainflux/mainflux/errors" "github.com/stretchr/testify/assert" ) @@ -102,7 +103,7 @@ func TestIssue(t *testing.T) { for _, tc := range cases { _, err := svc.Issue(context.Background(), tc.issuer, tc.key) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) } } func TestRevoke(t *testing.T) { @@ -144,7 +145,7 @@ func TestRevoke(t *testing.T) { for _, tc := range cases { err := svc.Revoke(context.Background(), tc.issuer, tc.id) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) } } func TestRetrieve(t *testing.T) { @@ -205,7 +206,7 @@ func TestRetrieve(t *testing.T) { for _, tc := range cases { _, err := svc.Retrieve(context.Background(), tc.issuer, tc.id) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) } } func TestIdentify(t *testing.T) { @@ -272,7 +273,7 @@ func TestIdentify(t *testing.T) { for _, tc := range cases { id, err := svc.Identify(context.Background(), tc.key) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.id, id, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.id, id)) } } diff --git a/authn/uuid/idp.go b/authn/uuid/idp.go index 785f69cb..f3e38d28 100644 --- a/authn/uuid/idp.go +++ b/authn/uuid/idp.go @@ -7,8 +7,12 @@ package uuid import ( "github.com/gofrs/uuid" "github.com/mainflux/mainflux/authn" + "github.com/mainflux/mainflux/errors" ) +// errGeneratingID indicates error in generating UUID +var errGeneratingID = errors.New("failed to generate uuid") + var _ authn.IdentityProvider = (*uuidIdentityProvider)(nil) type uuidIdentityProvider struct{} @@ -21,7 +25,7 @@ func New() authn.IdentityProvider { func (idp *uuidIdentityProvider) ID() (string, error) { id, err := uuid.NewV4() if err != nil { - return "", err + return "", errors.Wrap(errGeneratingID, err) } return id.String(), nil