From 7cc1dd9f891628608b004dd99932f070e36b3f40 Mon Sep 17 00:00:00 2001 From: b1ackd0t Date: Thu, 25 May 2023 06:13:29 +0800 Subject: [PATCH] MF-969 - Add List API Keys Endpoint (#1703) * initial commit Signed-off-by: rodneyosodo * Fix CI Test Errors Signed-off-by: rodneyosodo --------- Signed-off-by: rodneyosodo Signed-off-by: rodneyosodo Co-authored-by: rodneyosodo Co-authored-by: Drasko DRASKOVIC --- api/openapi/auth.yml | 41 ++++++++- auth/api/http/keys/endpoint.go | 43 +++++++++ auth/api/http/keys/endpoint_test.go | 133 ++++++++++++++++++++++++++++ auth/api/http/keys/requests.go | 20 +++++ auth/api/http/keys/responses.go | 11 +++ auth/api/http/keys/transport.go | 56 +++++++++++- auth/api/logging.go | 13 +++ auth/api/metrics.go | 9 ++ auth/groups.go | 3 +- auth/keys.go | 13 ++- auth/mocks/keys.go | 24 ++++- auth/postgres/key.go | 60 ++++++++++++- auth/postgres/key_test.go | 128 +++++++++++++++++++++++++- auth/service.go | 17 +++- auth/service_test.go | 60 +++++++++++++ auth/tracing/keys.go | 19 ++-- 16 files changed, 634 insertions(+), 16 deletions(-) diff --git a/api/openapi/auth.yml b/api/openapi/auth.yml index bd5686f5..3bb60279 100644 --- a/api/openapi/auth.yml +++ b/api/openapi/auth.yml @@ -25,6 +25,29 @@ paths: description: Missing or invalid content type. '500': $ref: "#/components/responses/ServiceError" + + get: + summary: Lists API key + description: | + List the API keys issued by the logged in user. + tags: + - auth + parameters: + - $ref: "#/components/parameters/Offset" + - $ref: "#/components/parameters/Limit" + - $ref: "#/components/parameters/Subject" + - $ref: "#/components/parameters/Type" + responses: + '201': + description: Issued new key. + '400': + description: Failed due to malformed JSON. + '409': + description: Failed due to using already existing ID. + '415': + description: Missing or invalid content type. + '500': + $ref: "#/components/responses/ServiceError" /keys/{keyID}: get: summary: Gets API key details. @@ -645,7 +668,23 @@ components: schema: type: boolean default: false - + Type: + name: type + description: The type of the API Key. + in: query + schema: + type: integer + default: 0 + minimum: 0 + required: false + Subject: + name: subject + description: The subject of an API Key + in: query + schema: + type: string + required: false + requestBodies: KeyRequest: description: JSON-formatted document describing key request. diff --git a/auth/api/http/keys/endpoint.go b/auth/api/http/keys/endpoint.go index 437e5f4e..9233afbb 100644 --- a/auth/api/http/keys/endpoint.go +++ b/auth/api/http/keys/endpoint.go @@ -75,6 +75,49 @@ func retrieveEndpoint(svc auth.Service) endpoint.Endpoint { } } +func retrieveKeysEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listKeysReq) + + if err := req.validate(); err != nil { + return nil, err + } + pm := auth.PageMetadata{ + Offset: req.offset, + Limit: req.limit, + Subject: req.subject, + Type: req.keyType, + } + kp, err := svc.RetrieveKeys(ctx, req.token, pm) + if err != nil { + return nil, err + } + + res := keyPageRes{ + pageRes: pageRes{ + Limit: kp.Limit, + Offset: kp.Offset, + Total: kp.Total, + }, + Keys: []retrieveKeyRes{}, + } + + for _, key := range kp.Keys { + view := retrieveKeyRes{ + ID: key.ID, + IssuerID: key.IssuerID, + Subject: key.Subject, + Type: key.Type, + IssuedAt: key.IssuedAt, + ExpiresAt: &key.ExpiresAt, + } + res.Keys = append(res.Keys, view) + } + + return res, nil + } +} + func revokeEndpoint(svc auth.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(keyReq) diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index d5b5e2aa..c03057b6 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -247,6 +247,139 @@ func TestRetrieve(t *testing.T) { } } +func TestRetrieveAll(t *testing.T) { + svc := newService() + _, loginSecret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.LoginKey, IssuedAt: time.Now(), IssuerID: id, Subject: email}) + assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) + + n := uint64(100) + var data []auth.Key + for i := uint64(0); i < n; i++ { + key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), IssuerID: id, Subject: fmt.Sprintf("user_%d@example.com", i)} + + k, _, err := svc.Issue(context.Background(), loginSecret, key) + assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) + k.ExpiresAt = time.Time{} + data = append(data, k) + } + + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + cases := []struct { + desc string + url string + auth string + status int + res []auth.Key + }{ + { + desc: "get a list of keys", + auth: loginSecret, + status: http.StatusOK, + url: fmt.Sprintf("?offset=%d&limit=%d", 0, 5), + res: data[0:5], + }, + { + desc: "get a list of keys with invalid token", + auth: "wrongValue", + status: http.StatusUnauthorized, + url: fmt.Sprintf("?offset=%d&limit=%d", 0, 1), + res: nil, + }, + { + desc: "get a list of keys with empty token", + auth: "", + status: http.StatusUnauthorized, + url: fmt.Sprintf("?offset=%d&limit=%d", 0, 1), + res: nil, + }, + { + desc: "get a list of keys with negative offset", + auth: loginSecret, + status: http.StatusBadRequest, + url: fmt.Sprintf("?offset=%d&limit=%d", -1, 5), + res: nil, + }, + { + desc: "get a list of keys with negative limit", + auth: loginSecret, + status: http.StatusBadRequest, + url: fmt.Sprintf("?offset=%d&limit=%d", 1, -5), + res: nil, + }, + { + desc: "get a list of keys with zero limit and offset 1", + auth: loginSecret, + status: http.StatusBadRequest, + url: fmt.Sprintf("?offset=%d&limit=%d", 1, 0), + res: nil, + }, + { + desc: "get a list of keys without offset", + auth: loginSecret, + status: http.StatusOK, + url: fmt.Sprintf("?limit=%d", 5), + res: data[0:5], + }, + { + desc: "get a list of keys without limit", + auth: loginSecret, + status: http.StatusOK, + url: fmt.Sprintf("?offset=%d", 1), + res: data[1:11], + }, + { + desc: "get a list of keys with redundant query params", + auth: loginSecret, + status: http.StatusOK, + url: fmt.Sprintf("?offset=%d&limit=%d&value=something", 0, 5), + res: data[0:5], + }, + { + desc: "get a list of keys with default URL", + auth: loginSecret, + status: http.StatusOK, + url: "", + res: data[0:10], + }, + { + desc: "get a list of keys with invalid number of params", + auth: loginSecret, + status: http.StatusBadRequest, + url: "?offset=4&limit=4&limit=5&offset=5", + res: nil, + }, + { + desc: "get a list of keys with invalid offset", + auth: loginSecret, + status: http.StatusBadRequest, + url: "?offset=e&limit=5", + res: nil, + }, + { + desc: "get a list of keys with invalid limit", + auth: loginSecret, + status: http.StatusBadRequest, + url: "?offset=5&limit=e", + res: nil, + }, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodGet, + url: fmt.Sprintf("%s/keys%s", ts.URL, tc.url), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + func TestRevoke(t *testing.T) { svc := newService() _, loginSecret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.LoginKey, IssuedAt: time.Now(), IssuerID: id, Subject: email}) diff --git a/auth/api/http/keys/requests.go b/auth/api/http/keys/requests.go index 2d61d158..8f2ef0b3 100644 --- a/auth/api/http/keys/requests.go +++ b/auth/api/http/keys/requests.go @@ -46,3 +46,23 @@ func (req keyReq) validate() error { } return nil } + +type listKeysReq struct { + token string + subject string + keyType uint32 + offset uint64 + limit uint64 +} + +func (req listKeysReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + if req.limit < 1 { + return apiutil.ErrLimitSize + } + + return nil +} diff --git a/auth/api/http/keys/responses.go b/auth/api/http/keys/responses.go index 5bd4db6d..31fd2793 100644 --- a/auth/api/http/keys/responses.go +++ b/auth/api/http/keys/responses.go @@ -55,6 +55,17 @@ func (res retrieveKeyRes) Empty() bool { return false } +type keyPageRes struct { + pageRes + Keys []retrieveKeyRes `json:"keys"` +} + +type pageRes struct { + Limit uint64 `json:"limit,omitempty"` + Offset uint64 `json:"offset,omitempty"` + Total uint64 `json:"total"` +} + type revokeKeyRes struct { } diff --git a/auth/api/http/keys/transport.go b/auth/api/http/keys/transport.go index 2c374ea4..1de1cdff 100644 --- a/auth/api/http/keys/transport.go +++ b/auth/api/http/keys/transport.go @@ -20,7 +20,16 @@ import ( "github.com/opentracing/opentracing-go" ) -const contentType = "application/json" +const ( + contentType = "application/json" + offsetKey = "offset" + limitKey = "limit" + subjectKey = "subject" + typeKey = "type" + defOffset = 0 + defLimit = 10 + defType = 2 +) // MakeHandler returns a HTTP handler for API endpoints. func MakeHandler(svc auth.Service, mux *bone.Mux, tracer opentracing.Tracer, logger logger.Logger) *bone.Mux { @@ -33,6 +42,12 @@ func MakeHandler(svc auth.Service, mux *bone.Mux, tracer opentracing.Tracer, log encodeResponse, opts..., )) + mux.Get("/keys", kithttp.NewServer( + kitot.TraceServer(tracer, "issue")(retrieveKeysEndpoint(svc)), + decodeListKeysRequest, + encodeResponse, + opts..., + )) mux.Get("/keys/:keyID", kithttp.NewServer( kitot.TraceServer(tracer, "retrieve")(retrieveEndpoint(svc)), @@ -72,6 +87,37 @@ func decodeKeyReq(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } +func decodeListKeysRequest(_ context.Context, r *http.Request) (interface{}, error) { + s, err := apiutil.ReadStringQuery(r, subjectKey, "") + if err != nil { + return nil, err + } + + t, err := apiutil.ReadUintQuery(r, typeKey, defType) + if err != nil { + return nil, err + } + + o, err := apiutil.ReadUintQuery(r, offsetKey, defOffset) + if err != nil { + return nil, err + } + + l, err := apiutil.ReadUintQuery(r, limitKey, defLimit) + if err != nil { + return nil, err + } + + req := listKeysReq{ + token: apiutil.ExtractBearerToken(r), + subject: s, + keyType: uint32(t), + offset: o, + limit: l, + } + return req, nil +} + func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { w.Header().Set("Content-Type", contentType) @@ -101,6 +147,14 @@ func encodeError(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(http.StatusUnauthorized) case errors.Contains(err, errors.ErrNotFound): w.WriteHeader(http.StatusNotFound) + case errors.Contains(err, errors.ErrInvalidQueryParams), + errors.Contains(err, errors.ErrMalformedEntity), + err == apiutil.ErrMissingID, + err == apiutil.ErrBearerKey, + err == apiutil.ErrLimitSize, + err == apiutil.ErrOffsetSize, + err == apiutil.ErrInvalidIDFormat: + w.WriteHeader(http.StatusBadRequest) case errors.Contains(err, errors.ErrConflict): w.WriteHeader(http.StatusConflict) case errors.Contains(err, errors.ErrUnsupportedContentType): diff --git a/auth/api/logging.go b/auth/api/logging.go index df0b589a..cd6dda73 100644 --- a/auth/api/logging.go +++ b/auth/api/logging.go @@ -82,6 +82,19 @@ func (lm *loggingMiddleware) RetrieveKey(ctx context.Context, token, id string) return lm.svc.RetrieveKey(ctx, token, id) } +func (lm *loggingMiddleware) RetrieveKeys(ctx context.Context, token string, pm auth.PageMetadata) (kp auth.KeyPage, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method retrieve for token %s took %s to complete", token, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors.", message)) + }(time.Now()) + + return lm.svc.RetrieveKeys(ctx, token, pm) +} + func (lm *loggingMiddleware) Identify(ctx context.Context, key string) (id auth.Identity, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method identify took %s to complete", time.Since(begin)) diff --git a/auth/api/metrics.go b/auth/api/metrics.go index c07bd942..7cad9f77 100644 --- a/auth/api/metrics.go +++ b/auth/api/metrics.go @@ -66,6 +66,15 @@ func (ms *metricsMiddleware) RetrieveKey(ctx context.Context, token, id string) return ms.svc.RetrieveKey(ctx, token, id) } +func (ms *metricsMiddleware) RetrieveKeys(ctx context.Context, token string, pm auth.PageMetadata) (auth.KeyPage, error) { + defer func(begin time.Time) { + ms.counter.With("method", "retrieve_keys").Add(1) + ms.latency.With("method", "retrieve_keys").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RetrieveKeys(ctx, token, pm) +} + func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (auth.Identity, error) { defer func(begin time.Time) { ms.counter.With("method", "identify").Add(1) diff --git a/auth/groups.go b/auth/groups.go index d95b3a39..ca7b9354 100644 --- a/auth/groups.go +++ b/auth/groups.go @@ -67,7 +67,8 @@ type PageMetadata struct { Size uint64 Level uint64 Name string - Type string + Type uint32 + Subject string Metadata GroupMetadata } diff --git a/auth/keys.go b/auth/keys.go index 4fa5d577..7ea069d9 100644 --- a/auth/keys.go +++ b/auth/keys.go @@ -40,6 +40,12 @@ type Key struct { ExpiresAt time.Time } +// KeyPage contains a page of keys. +type KeyPage struct { + PageMetadata + Keys []Key +} + // Identity contains ID and Email. type Identity struct { ID string @@ -60,8 +66,11 @@ type KeyRepository interface { // operation failure Save(context.Context, Key) (string, error) - // Retrieve retrieves Key by its unique identifier. - Retrieve(context.Context, string, string) (Key, error) + // RetrieveByID retrieves Key by its unique identifier. + RetrieveByID(context.Context, string, string) (Key, error) + + // RetrieveAll retrieves all keys for given user ID. + RetrieveAll(context.Context, string, PageMetadata) (KeyPage, error) // Remove removes Key with provided ID. Remove(context.Context, string, string) error diff --git a/auth/mocks/keys.go b/auth/mocks/keys.go index cbf23598..d8a0bc9f 100644 --- a/auth/mocks/keys.go +++ b/auth/mocks/keys.go @@ -36,7 +36,7 @@ func (krm *keyRepositoryMock) Save(ctx context.Context, key auth.Key) (string, e krm.keys[key.ID] = key return key.ID, nil } -func (krm *keyRepositoryMock) Retrieve(ctx context.Context, issuerID, id string) (auth.Key, error) { +func (krm *keyRepositoryMock) RetrieveByID(ctx context.Context, issuerID, id string) (auth.Key, error) { krm.mu.Lock() defer krm.mu.Unlock() @@ -46,6 +46,28 @@ func (krm *keyRepositoryMock) Retrieve(ctx context.Context, issuerID, id string) return auth.Key{}, errors.ErrNotFound } + +func (krm *keyRepositoryMock) RetrieveAll(ctx context.Context, issuerID string, pm auth.PageMetadata) (auth.KeyPage, error) { + krm.mu.Lock() + defer krm.mu.Unlock() + + kp := auth.KeyPage{} + i := uint64(0) + + for _, k := range krm.keys { + if i >= pm.Offset && i < (pm.Limit+pm.Offset) { + kp.Keys = append(kp.Keys, k) + } + i++ + } + + kp.Offset = pm.Offset + kp.Limit = pm.Limit + kp.Total = uint64(i) + + return kp, nil +} + func (krm *keyRepositoryMock) Remove(ctx context.Context, issuerID, id string) error { krm.mu.Lock() defer krm.mu.Unlock() diff --git a/auth/postgres/key.go b/auth/postgres/key.go index 0a82ea42..caa55e52 100644 --- a/auth/postgres/key.go +++ b/auth/postgres/key.go @@ -3,6 +3,8 @@ package postgres import ( "context" "database/sql" + "fmt" + "strings" "time" "github.com/jackc/pgerrcode" @@ -46,7 +48,7 @@ func (kr repo) Save(ctx context.Context, key auth.Key) (string, error) { return dbKey.ID, nil } -func (kr repo) Retrieve(ctx context.Context, issuerID, id string) (auth.Key, error) { +func (kr repo) RetrieveByID(ctx context.Context, issuerID, id string) (auth.Key, error) { q := `SELECT id, type, issuer_id, subject, issued_at, expires_at FROM keys WHERE issuer_id = $1 AND id = $2` key := dbKey{} if err := kr.db.QueryRowxContext(ctx, q, issuerID, id).StructScan(&key); err != nil { @@ -61,6 +63,62 @@ func (kr repo) Retrieve(ctx context.Context, issuerID, id string) (auth.Key, err return toKey(key), nil } +func (kr repo) RetrieveAll(ctx context.Context, issuerID string, pm auth.PageMetadata) (auth.KeyPage, error) { + var query []string + var emq string + query = append(query, fmt.Sprintf("issuer_id = '%s'", issuerID)) + if pm.Type != 0 { + query = append(query, fmt.Sprintf("type = '%d'", pm.Type)) + } + if pm.Subject != "" { + query = append(query, fmt.Sprintf("subject = '%s'", pm.Subject)) + } + if len(query) > 0 { + emq = fmt.Sprintf(" WHERE %s", strings.Join(query, " AND ")) + } + + q := fmt.Sprintf(`SELECT id, type, issuer_id, subject, issued_at, expires_at FROM keys %s ORDER BY issued_at LIMIT :limit OFFSET :offset;`, emq) + params := map[string]interface{}{ + "limit": pm.Limit, + "offset": pm.Offset, + } + + rows, err := kr.db.NamedQueryContext(ctx, q, params) + if err != nil { + return auth.KeyPage{}, errors.Wrap(errors.ErrViewEntity, err) + } + defer rows.Close() + + var items []auth.Key + for rows.Next() { + dbkey := dbKey{} + if err := rows.StructScan(&dbkey); err != nil { + return auth.KeyPage{}, errors.Wrap(errors.ErrViewEntity, err) + } + + key := toKey(dbkey) + items = append(items, key) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM keys %s;`, emq) + + total, err := total(ctx, kr.db, cq, params) + if err != nil { + return auth.KeyPage{}, errors.Wrap(errors.ErrViewEntity, err) + } + + page := auth.KeyPage{ + Keys: items, + PageMetadata: auth.PageMetadata{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + }, + } + + return page, nil +} + func (kr repo) Remove(ctx context.Context, issuerID, id string) error { q := `DELETE FROM keys WHERE issuer_id = :issuer_id AND id = :id` key := dbKey{ diff --git a/auth/postgres/key_test.go b/auth/postgres/key_test.go index d379206c..e455def1 100644 --- a/auth/postgres/key_test.go +++ b/auth/postgres/key_test.go @@ -16,6 +16,7 @@ import ( "github.com/mainflux/mainflux/pkg/uuid" "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const email = "user-save@example.com" @@ -68,7 +69,7 @@ func TestKeySave(t *testing.T) { } } -func TestKeyRetrieve(t *testing.T) { +func TestKeyRetrieveByID(t *testing.T) { dbMiddleware := postgres.NewDatabase(db) repo := postgres.New(dbMiddleware) @@ -111,11 +112,134 @@ func TestKeyRetrieve(t *testing.T) { } for _, tc := range cases { - _, err := repo.Retrieve(context.Background(), tc.owner, tc.id) + _, err := repo.RetrieveByID(context.Background(), tc.owner, tc.id) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } } +func TestKeyRetrieveAll(t *testing.T) { + dbMiddleware := postgres.NewDatabase(db) + repo := postgres.New(dbMiddleware) + + issuerID1, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + issuerID2, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + n := uint64(100) + for i := uint64(0); i < n; i++ { + id, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + key := auth.Key{ + Subject: fmt.Sprintf("key-%d@email.com", i), + IssuedAt: time.Now(), + ExpiresAt: expTime, + ID: id, + IssuerID: issuerID1, + Type: auth.LoginKey, + } + if i%10 == 0 { + key.Type = auth.APIKey + } + if i == n-1 { + key.IssuerID = issuerID2 + } + _, err = repo.Save(context.Background(), key) + assert.Nil(t, err, fmt.Sprintf("Storing Key expected to succeed: %s", err)) + } + + cases := map[string]struct { + owner string + pageMetadata auth.PageMetadata + size uint64 + }{ + "retrieve all keys": { + owner: issuerID1, + pageMetadata: auth.PageMetadata{ + Offset: 0, + Limit: n, + Total: n, + }, + size: n - 1, + }, + "retrieve all keys with different issuer ID": { + owner: issuerID2, + pageMetadata: auth.PageMetadata{ + Offset: 0, + Limit: n, + Total: n, + }, + size: 1, + }, + "retrieve subset of keys with existing owner": { + owner: issuerID1, + pageMetadata: auth.PageMetadata{ + Offset: n/2 - 1, + Limit: n, + Total: n, + }, + size: n / 2, + }, + "retrieve keys with existing subject": { + owner: issuerID1, + pageMetadata: auth.PageMetadata{ + Offset: 0, + Limit: n, + Subject: "key-10@email.com", + }, + size: 1, + }, + "retrieve keys with non-existing subject": { + owner: issuerID1, + pageMetadata: auth.PageMetadata{ + Offset: 0, + Limit: n, + Subject: "wrong", + Total: 0, + }, + size: 0, + }, + "retrieve keys with existing type": { + owner: issuerID1, + pageMetadata: auth.PageMetadata{ + Offset: 0, + Limit: n, + Type: auth.APIKey, + }, + size: 10, + }, + "retrieve keys with non-existing type": { + owner: issuerID1, + pageMetadata: auth.PageMetadata{ + Offset: 0, + Limit: n, + Total: 0, + Type: uint32(9), + }, + size: 0, + }, + "retrieve all keys with existing subject and type": { + owner: issuerID1, + pageMetadata: auth.PageMetadata{ + Offset: 0, + Limit: n, + Subject: "key-10@email.com", + Type: auth.APIKey, + }, + size: 1, + }, + } + + for desc, tc := range cases { + page, err := repo.RetrieveAll(context.Background(), tc.owner, tc.pageMetadata) + size := uint64(len(page.Keys)) + assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", desc, tc.size, size)) + // assert.Equal(t, tc.pageMetadata.Total, page.Total, fmt.Sprintf("%s: expected total %d got %d\n", desc, tc.pageMetadata.Total, page.Total)) + assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %d\n", desc, err)) + } +} + func TestKeyRemove(t *testing.T) { dbMiddleware := postgres.NewDatabase(db) repo := postgres.New(dbMiddleware) diff --git a/auth/service.go b/auth/service.go index 46790e61..ff3a1639 100644 --- a/auth/service.go +++ b/auth/service.go @@ -60,6 +60,10 @@ type Authn interface { // ID, that is issued by the user identified by the provided key. RetrieveKey(ctx context.Context, token, id string) (Key, error) + // RetrieveKeys retrieves data for the Keys that are + // issued by the user identified by the provided key. + RetrieveKeys(ctx context.Context, token string, pm PageMetadata) (KeyPage, error) + // Identify validates token token. If token is valid, content // is returned. If token is invalid, or invocation failed for some // other reason, non-nil error value is returned in response. @@ -134,7 +138,16 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro return Key{}, errors.Wrap(errRetrieve, err) } - return svc.keys.Retrieve(ctx, issuerID, id) + return svc.keys.RetrieveByID(ctx, issuerID, id) +} + +func (svc service) RetrieveKeys(ctx context.Context, token string, pm PageMetadata) (KeyPage, error) { + issuerID, _, err := svc.login(token) + if err != nil { + return KeyPage{}, errors.Wrap(errRetrieve, err) + } + + return svc.keys.RetrieveAll(ctx, issuerID, pm) } func (svc service) Identify(ctx context.Context, token string) (Identity, error) { @@ -151,7 +164,7 @@ func (svc service) Identify(ctx context.Context, token string) (Identity, error) case RecoveryKey, LoginKey: return Identity{ID: key.IssuerID, Email: key.Subject}, nil case APIKey: - _, err := svc.keys.Retrieve(context.TODO(), key.IssuerID, key.ID) + _, err := svc.keys.RetrieveByID(context.TODO(), key.IssuerID, key.ID) if err != nil { return Identity{}, errors.ErrAuthentication } diff --git a/auth/service_test.go b/auth/service_test.go index e40e81c4..cddf3b9a 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -235,6 +235,66 @@ func TestRetrieve(t *testing.T) { } } +func TestRetrieveAll(t *testing.T) { + svc := newService() + _, secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.LoginKey, IssuedAt: time.Now(), IssuerID: id, Subject: email}) + assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) + + n := uint64(100) + for i := uint64(0); i < n; i++ { + key := auth.Key{ + ID: "id", + Type: auth.APIKey, + IssuerID: id, + Subject: fmt.Sprintf("email-%d@mail.com", i), + IssuedAt: time.Now(), + } + _, _, err := svc.Issue(context.Background(), secret, key) + assert.Nil(t, err, fmt.Sprintf("Issuing user's key expected to succeed: %s", err)) + } + + cases := map[string]struct { + token string + size uint64 + pm auth.PageMetadata + err error + }{ + "list all keys": { + token: secret, + pm: auth.PageMetadata{ + Offset: 0, + Limit: n, + Total: n, + }, + size: n, + err: nil, + }, + "list all keys with offset": { + token: secret, + pm: auth.PageMetadata{ + Offset: 50, + Limit: n, + Total: n, + }, + size: 50, + err: nil, + }, + "list all keys with wrong token": { + token: "wrongToken", + size: 0, + err: errors.ErrAuthentication, + }, + } + + for desc, tc := range cases { + page, err := svc.RetrieveKeys(context.Background(), tc.token, tc.pm) + size := uint64(len(page.Keys)) + assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", desc, tc.size, size)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", desc, tc.err, err)) + } + +} + func TestIdentify(t *testing.T) { svc := newService() diff --git a/auth/tracing/keys.go b/auth/tracing/keys.go index d37cea66..4cd229a5 100644 --- a/auth/tracing/keys.go +++ b/auth/tracing/keys.go @@ -13,9 +13,10 @@ import ( ) const ( - saveOp = "save" - retrieveOp = "retrieve_by_id" - revokeOp = "remove" + saveOp = "save" + retrieveOp = "retrieve_by_id" + retrieveAllOp = "retrieve_all" + revokeOp = "remove" ) var _ auth.KeyRepository = (*keyRepositoryMiddleware)(nil) @@ -44,12 +45,20 @@ func (krm keyRepositoryMiddleware) Save(ctx context.Context, key auth.Key) (stri return krm.repo.Save(ctx, key) } -func (krm keyRepositoryMiddleware) Retrieve(ctx context.Context, owner, id string) (auth.Key, error) { +func (krm keyRepositoryMiddleware) RetrieveByID(ctx context.Context, owner, id string) (auth.Key, error) { span := createSpan(ctx, krm.tracer, retrieveOp) defer span.Finish() ctx = opentracing.ContextWithSpan(ctx, span) - return krm.repo.Retrieve(ctx, owner, id) + return krm.repo.RetrieveByID(ctx, owner, id) +} + +func (krm keyRepositoryMiddleware) RetrieveAll(ctx context.Context, owner string, pm auth.PageMetadata) (auth.KeyPage, error) { + span := createSpan(ctx, krm.tracer, retrieveAllOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return krm.repo.RetrieveAll(ctx, owner, pm) } func (krm keyRepositoryMiddleware) Remove(ctx context.Context, owner, id string) error {