mirror of
https://github.com/mainflux/mainflux.git
synced 2025-04-24 13:48:49 +08:00
MF-969 - Add List API Keys Endpoint (#1703)
* initial commit Signed-off-by: rodneyosodo <socials@rodneyosodo.com> * Fix CI Test Errors Signed-off-by: rodneyosodo <blackd0t@protonmail.com> --------- Signed-off-by: rodneyosodo <socials@rodneyosodo.com> Signed-off-by: rodneyosodo <blackd0t@protonmail.com> Co-authored-by: rodneyosodo <socials@rodneyosodo.com> Co-authored-by: Drasko DRASKOVIC <drasko.draskovic@gmail.com>
This commit is contained in:
parent
408eabaaa6
commit
7cc1dd9f89
@ -25,6 +25,29 @@ paths:
|
||||
description: Missing or invalid content type.
|
||||
'500':
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
|
||||
get:
|
||||
summary: Lists API key
|
||||
description: |
|
||||
List the API keys issued by the logged in user.
|
||||
tags:
|
||||
- auth
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/Offset"
|
||||
- $ref: "#/components/parameters/Limit"
|
||||
- $ref: "#/components/parameters/Subject"
|
||||
- $ref: "#/components/parameters/Type"
|
||||
responses:
|
||||
'201':
|
||||
description: Issued new key.
|
||||
'400':
|
||||
description: Failed due to malformed JSON.
|
||||
'409':
|
||||
description: Failed due to using already existing ID.
|
||||
'415':
|
||||
description: Missing or invalid content type.
|
||||
'500':
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
/keys/{keyID}:
|
||||
get:
|
||||
summary: Gets API key details.
|
||||
@ -645,7 +668,23 @@ components:
|
||||
schema:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
Type:
|
||||
name: type
|
||||
description: The type of the API Key.
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 0
|
||||
minimum: 0
|
||||
required: false
|
||||
Subject:
|
||||
name: subject
|
||||
description: The subject of an API Key
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
required: false
|
||||
|
||||
requestBodies:
|
||||
KeyRequest:
|
||||
description: JSON-formatted document describing key request.
|
||||
|
@ -75,6 +75,49 @@ func retrieveEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveKeysEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(listKeysReq)
|
||||
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pm := auth.PageMetadata{
|
||||
Offset: req.offset,
|
||||
Limit: req.limit,
|
||||
Subject: req.subject,
|
||||
Type: req.keyType,
|
||||
}
|
||||
kp, err := svc.RetrieveKeys(ctx, req.token, pm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := keyPageRes{
|
||||
pageRes: pageRes{
|
||||
Limit: kp.Limit,
|
||||
Offset: kp.Offset,
|
||||
Total: kp.Total,
|
||||
},
|
||||
Keys: []retrieveKeyRes{},
|
||||
}
|
||||
|
||||
for _, key := range kp.Keys {
|
||||
view := retrieveKeyRes{
|
||||
ID: key.ID,
|
||||
IssuerID: key.IssuerID,
|
||||
Subject: key.Subject,
|
||||
Type: key.Type,
|
||||
IssuedAt: key.IssuedAt,
|
||||
ExpiresAt: &key.ExpiresAt,
|
||||
}
|
||||
res.Keys = append(res.Keys, view)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func revokeEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(keyReq)
|
||||
|
@ -247,6 +247,139 @@ func TestRetrieve(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveAll(t *testing.T) {
|
||||
svc := newService()
|
||||
_, loginSecret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.LoginKey, IssuedAt: time.Now(), IssuerID: id, Subject: email})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
|
||||
n := uint64(100)
|
||||
var data []auth.Key
|
||||
for i := uint64(0); i < n; i++ {
|
||||
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), IssuerID: id, Subject: fmt.Sprintf("user_%d@example.com", i)}
|
||||
|
||||
k, _, err := svc.Issue(context.Background(), loginSecret, key)
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
k.ExpiresAt = time.Time{}
|
||||
data = append(data, k)
|
||||
}
|
||||
|
||||
ts := newServer(svc)
|
||||
defer ts.Close()
|
||||
client := ts.Client()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
url string
|
||||
auth string
|
||||
status int
|
||||
res []auth.Key
|
||||
}{
|
||||
{
|
||||
desc: "get a list of keys",
|
||||
auth: loginSecret,
|
||||
status: http.StatusOK,
|
||||
url: fmt.Sprintf("?offset=%d&limit=%d", 0, 5),
|
||||
res: data[0:5],
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with invalid token",
|
||||
auth: "wrongValue",
|
||||
status: http.StatusUnauthorized,
|
||||
url: fmt.Sprintf("?offset=%d&limit=%d", 0, 1),
|
||||
res: nil,
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with empty token",
|
||||
auth: "",
|
||||
status: http.StatusUnauthorized,
|
||||
url: fmt.Sprintf("?offset=%d&limit=%d", 0, 1),
|
||||
res: nil,
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with negative offset",
|
||||
auth: loginSecret,
|
||||
status: http.StatusBadRequest,
|
||||
url: fmt.Sprintf("?offset=%d&limit=%d", -1, 5),
|
||||
res: nil,
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with negative limit",
|
||||
auth: loginSecret,
|
||||
status: http.StatusBadRequest,
|
||||
url: fmt.Sprintf("?offset=%d&limit=%d", 1, -5),
|
||||
res: nil,
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with zero limit and offset 1",
|
||||
auth: loginSecret,
|
||||
status: http.StatusBadRequest,
|
||||
url: fmt.Sprintf("?offset=%d&limit=%d", 1, 0),
|
||||
res: nil,
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys without offset",
|
||||
auth: loginSecret,
|
||||
status: http.StatusOK,
|
||||
url: fmt.Sprintf("?limit=%d", 5),
|
||||
res: data[0:5],
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys without limit",
|
||||
auth: loginSecret,
|
||||
status: http.StatusOK,
|
||||
url: fmt.Sprintf("?offset=%d", 1),
|
||||
res: data[1:11],
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with redundant query params",
|
||||
auth: loginSecret,
|
||||
status: http.StatusOK,
|
||||
url: fmt.Sprintf("?offset=%d&limit=%d&value=something", 0, 5),
|
||||
res: data[0:5],
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with default URL",
|
||||
auth: loginSecret,
|
||||
status: http.StatusOK,
|
||||
url: "",
|
||||
res: data[0:10],
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with invalid number of params",
|
||||
auth: loginSecret,
|
||||
status: http.StatusBadRequest,
|
||||
url: "?offset=4&limit=4&limit=5&offset=5",
|
||||
res: nil,
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with invalid offset",
|
||||
auth: loginSecret,
|
||||
status: http.StatusBadRequest,
|
||||
url: "?offset=e&limit=5",
|
||||
res: nil,
|
||||
},
|
||||
{
|
||||
desc: "get a list of keys with invalid limit",
|
||||
auth: loginSecret,
|
||||
status: http.StatusBadRequest,
|
||||
url: "?offset=5&limit=e",
|
||||
res: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := testRequest{
|
||||
client: client,
|
||||
method: http.MethodGet,
|
||||
url: fmt.Sprintf("%s/keys%s", ts.URL, tc.url),
|
||||
token: tc.auth,
|
||||
}
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevoke(t *testing.T) {
|
||||
svc := newService()
|
||||
_, loginSecret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.LoginKey, IssuedAt: time.Now(), IssuerID: id, Subject: email})
|
||||
|
@ -46,3 +46,23 @@ func (req keyReq) validate() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type listKeysReq struct {
|
||||
token string
|
||||
subject string
|
||||
keyType uint32
|
||||
offset uint64
|
||||
limit uint64
|
||||
}
|
||||
|
||||
func (req listKeysReq) validate() error {
|
||||
if req.token == "" {
|
||||
return apiutil.ErrBearerToken
|
||||
}
|
||||
|
||||
if req.limit < 1 {
|
||||
return apiutil.ErrLimitSize
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -55,6 +55,17 @@ func (res retrieveKeyRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type keyPageRes struct {
|
||||
pageRes
|
||||
Keys []retrieveKeyRes `json:"keys"`
|
||||
}
|
||||
|
||||
type pageRes struct {
|
||||
Limit uint64 `json:"limit,omitempty"`
|
||||
Offset uint64 `json:"offset,omitempty"`
|
||||
Total uint64 `json:"total"`
|
||||
}
|
||||
|
||||
type revokeKeyRes struct {
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,16 @@ import (
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
const contentType = "application/json"
|
||||
const (
|
||||
contentType = "application/json"
|
||||
offsetKey = "offset"
|
||||
limitKey = "limit"
|
||||
subjectKey = "subject"
|
||||
typeKey = "type"
|
||||
defOffset = 0
|
||||
defLimit = 10
|
||||
defType = 2
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for API endpoints.
|
||||
func MakeHandler(svc auth.Service, mux *bone.Mux, tracer opentracing.Tracer, logger logger.Logger) *bone.Mux {
|
||||
@ -33,6 +42,12 @@ func MakeHandler(svc auth.Service, mux *bone.Mux, tracer opentracing.Tracer, log
|
||||
encodeResponse,
|
||||
opts...,
|
||||
))
|
||||
mux.Get("/keys", kithttp.NewServer(
|
||||
kitot.TraceServer(tracer, "issue")(retrieveKeysEndpoint(svc)),
|
||||
decodeListKeysRequest,
|
||||
encodeResponse,
|
||||
opts...,
|
||||
))
|
||||
|
||||
mux.Get("/keys/:keyID", kithttp.NewServer(
|
||||
kitot.TraceServer(tracer, "retrieve")(retrieveEndpoint(svc)),
|
||||
@ -72,6 +87,37 @@ func decodeKeyReq(_ context.Context, r *http.Request) (interface{}, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeListKeysRequest(_ context.Context, r *http.Request) (interface{}, error) {
|
||||
s, err := apiutil.ReadStringQuery(r, subjectKey, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := apiutil.ReadUintQuery(r, typeKey, defType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o, err := apiutil.ReadUintQuery(r, offsetKey, defOffset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l, err := apiutil.ReadUintQuery(r, limitKey, defLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := listKeysReq{
|
||||
token: apiutil.ExtractBearerToken(r),
|
||||
subject: s,
|
||||
keyType: uint32(t),
|
||||
offset: o,
|
||||
limit: l,
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
@ -101,6 +147,14 @@ func encodeError(_ context.Context, err error, w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
case errors.Contains(err, errors.ErrNotFound):
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
case errors.Contains(err, errors.ErrInvalidQueryParams),
|
||||
errors.Contains(err, errors.ErrMalformedEntity),
|
||||
err == apiutil.ErrMissingID,
|
||||
err == apiutil.ErrBearerKey,
|
||||
err == apiutil.ErrLimitSize,
|
||||
err == apiutil.ErrOffsetSize,
|
||||
err == apiutil.ErrInvalidIDFormat:
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
case errors.Contains(err, errors.ErrConflict):
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
case errors.Contains(err, errors.ErrUnsupportedContentType):
|
||||
|
@ -82,6 +82,19 @@ func (lm *loggingMiddleware) RetrieveKey(ctx context.Context, token, id string)
|
||||
return lm.svc.RetrieveKey(ctx, token, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RetrieveKeys(ctx context.Context, token string, pm auth.PageMetadata) (kp auth.KeyPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method retrieve for token %s took %s to complete", token, time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
|
||||
return
|
||||
}
|
||||
lm.logger.Info(fmt.Sprintf("%s without errors.", message))
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.RetrieveKeys(ctx, token, pm)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Identify(ctx context.Context, key string) (id auth.Identity, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method identify took %s to complete", time.Since(begin))
|
||||
|
@ -66,6 +66,15 @@ func (ms *metricsMiddleware) RetrieveKey(ctx context.Context, token, id string)
|
||||
return ms.svc.RetrieveKey(ctx, token, id)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) RetrieveKeys(ctx context.Context, token string, pm auth.PageMetadata) (auth.KeyPage, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "retrieve_keys").Add(1)
|
||||
ms.latency.With("method", "retrieve_keys").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.RetrieveKeys(ctx, token, pm)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (auth.Identity, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "identify").Add(1)
|
||||
|
@ -67,7 +67,8 @@ type PageMetadata struct {
|
||||
Size uint64
|
||||
Level uint64
|
||||
Name string
|
||||
Type string
|
||||
Type uint32
|
||||
Subject string
|
||||
Metadata GroupMetadata
|
||||
}
|
||||
|
||||
|
13
auth/keys.go
13
auth/keys.go
@ -40,6 +40,12 @@ type Key struct {
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// KeyPage contains a page of keys.
|
||||
type KeyPage struct {
|
||||
PageMetadata
|
||||
Keys []Key
|
||||
}
|
||||
|
||||
// Identity contains ID and Email.
|
||||
type Identity struct {
|
||||
ID string
|
||||
@ -60,8 +66,11 @@ type KeyRepository interface {
|
||||
// operation failure
|
||||
Save(context.Context, Key) (string, error)
|
||||
|
||||
// Retrieve retrieves Key by its unique identifier.
|
||||
Retrieve(context.Context, string, string) (Key, error)
|
||||
// RetrieveByID retrieves Key by its unique identifier.
|
||||
RetrieveByID(context.Context, string, string) (Key, error)
|
||||
|
||||
// RetrieveAll retrieves all keys for given user ID.
|
||||
RetrieveAll(context.Context, string, PageMetadata) (KeyPage, error)
|
||||
|
||||
// Remove removes Key with provided ID.
|
||||
Remove(context.Context, string, string) error
|
||||
|
@ -36,7 +36,7 @@ func (krm *keyRepositoryMock) Save(ctx context.Context, key auth.Key) (string, e
|
||||
krm.keys[key.ID] = key
|
||||
return key.ID, nil
|
||||
}
|
||||
func (krm *keyRepositoryMock) Retrieve(ctx context.Context, issuerID, id string) (auth.Key, error) {
|
||||
func (krm *keyRepositoryMock) RetrieveByID(ctx context.Context, issuerID, id string) (auth.Key, error) {
|
||||
krm.mu.Lock()
|
||||
defer krm.mu.Unlock()
|
||||
|
||||
@ -46,6 +46,28 @@ func (krm *keyRepositoryMock) Retrieve(ctx context.Context, issuerID, id string)
|
||||
|
||||
return auth.Key{}, errors.ErrNotFound
|
||||
}
|
||||
|
||||
func (krm *keyRepositoryMock) RetrieveAll(ctx context.Context, issuerID string, pm auth.PageMetadata) (auth.KeyPage, error) {
|
||||
krm.mu.Lock()
|
||||
defer krm.mu.Unlock()
|
||||
|
||||
kp := auth.KeyPage{}
|
||||
i := uint64(0)
|
||||
|
||||
for _, k := range krm.keys {
|
||||
if i >= pm.Offset && i < (pm.Limit+pm.Offset) {
|
||||
kp.Keys = append(kp.Keys, k)
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
kp.Offset = pm.Offset
|
||||
kp.Limit = pm.Limit
|
||||
kp.Total = uint64(i)
|
||||
|
||||
return kp, nil
|
||||
}
|
||||
|
||||
func (krm *keyRepositoryMock) Remove(ctx context.Context, issuerID, id string) error {
|
||||
krm.mu.Lock()
|
||||
defer krm.mu.Unlock()
|
||||
|
@ -3,6 +3,8 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgerrcode"
|
||||
@ -46,7 +48,7 @@ func (kr repo) Save(ctx context.Context, key auth.Key) (string, error) {
|
||||
return dbKey.ID, nil
|
||||
}
|
||||
|
||||
func (kr repo) Retrieve(ctx context.Context, issuerID, id string) (auth.Key, error) {
|
||||
func (kr repo) RetrieveByID(ctx context.Context, issuerID, id string) (auth.Key, error) {
|
||||
q := `SELECT id, type, issuer_id, subject, issued_at, expires_at FROM keys WHERE issuer_id = $1 AND id = $2`
|
||||
key := dbKey{}
|
||||
if err := kr.db.QueryRowxContext(ctx, q, issuerID, id).StructScan(&key); err != nil {
|
||||
@ -61,6 +63,62 @@ func (kr repo) Retrieve(ctx context.Context, issuerID, id string) (auth.Key, err
|
||||
return toKey(key), nil
|
||||
}
|
||||
|
||||
func (kr repo) RetrieveAll(ctx context.Context, issuerID string, pm auth.PageMetadata) (auth.KeyPage, error) {
|
||||
var query []string
|
||||
var emq string
|
||||
query = append(query, fmt.Sprintf("issuer_id = '%s'", issuerID))
|
||||
if pm.Type != 0 {
|
||||
query = append(query, fmt.Sprintf("type = '%d'", pm.Type))
|
||||
}
|
||||
if pm.Subject != "" {
|
||||
query = append(query, fmt.Sprintf("subject = '%s'", pm.Subject))
|
||||
}
|
||||
if len(query) > 0 {
|
||||
emq = fmt.Sprintf(" WHERE %s", strings.Join(query, " AND "))
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`SELECT id, type, issuer_id, subject, issued_at, expires_at FROM keys %s ORDER BY issued_at LIMIT :limit OFFSET :offset;`, emq)
|
||||
params := map[string]interface{}{
|
||||
"limit": pm.Limit,
|
||||
"offset": pm.Offset,
|
||||
}
|
||||
|
||||
rows, err := kr.db.NamedQueryContext(ctx, q, params)
|
||||
if err != nil {
|
||||
return auth.KeyPage{}, errors.Wrap(errors.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []auth.Key
|
||||
for rows.Next() {
|
||||
dbkey := dbKey{}
|
||||
if err := rows.StructScan(&dbkey); err != nil {
|
||||
return auth.KeyPage{}, errors.Wrap(errors.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
key := toKey(dbkey)
|
||||
items = append(items, key)
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM keys %s;`, emq)
|
||||
|
||||
total, err := total(ctx, kr.db, cq, params)
|
||||
if err != nil {
|
||||
return auth.KeyPage{}, errors.Wrap(errors.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := auth.KeyPage{
|
||||
Keys: items,
|
||||
PageMetadata: auth.PageMetadata{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (kr repo) Remove(ctx context.Context, issuerID, id string) error {
|
||||
q := `DELETE FROM keys WHERE issuer_id = :issuer_id AND id = :id`
|
||||
key := dbKey{
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/mainflux/mainflux/pkg/uuid"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const email = "user-save@example.com"
|
||||
@ -68,7 +69,7 @@ func TestKeySave(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRetrieve(t *testing.T) {
|
||||
func TestKeyRetrieveByID(t *testing.T) {
|
||||
dbMiddleware := postgres.NewDatabase(db)
|
||||
repo := postgres.New(dbMiddleware)
|
||||
|
||||
@ -111,11 +112,134 @@ func TestKeyRetrieve(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
_, err := repo.Retrieve(context.Background(), tc.owner, tc.id)
|
||||
_, err := repo.RetrieveByID(context.Background(), tc.owner, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRetrieveAll(t *testing.T) {
|
||||
dbMiddleware := postgres.NewDatabase(db)
|
||||
repo := postgres.New(dbMiddleware)
|
||||
|
||||
issuerID1, err := idProvider.ID()
|
||||
require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
issuerID2, err := idProvider.ID()
|
||||
require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
n := uint64(100)
|
||||
for i := uint64(0); i < n; i++ {
|
||||
id, err := idProvider.ID()
|
||||
require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
key := auth.Key{
|
||||
Subject: fmt.Sprintf("key-%d@email.com", i),
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: expTime,
|
||||
ID: id,
|
||||
IssuerID: issuerID1,
|
||||
Type: auth.LoginKey,
|
||||
}
|
||||
if i%10 == 0 {
|
||||
key.Type = auth.APIKey
|
||||
}
|
||||
if i == n-1 {
|
||||
key.IssuerID = issuerID2
|
||||
}
|
||||
_, err = repo.Save(context.Background(), key)
|
||||
assert.Nil(t, err, fmt.Sprintf("Storing Key expected to succeed: %s", err))
|
||||
}
|
||||
|
||||
cases := map[string]struct {
|
||||
owner string
|
||||
pageMetadata auth.PageMetadata
|
||||
size uint64
|
||||
}{
|
||||
"retrieve all keys": {
|
||||
owner: issuerID1,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Total: n,
|
||||
},
|
||||
size: n - 1,
|
||||
},
|
||||
"retrieve all keys with different issuer ID": {
|
||||
owner: issuerID2,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Total: n,
|
||||
},
|
||||
size: 1,
|
||||
},
|
||||
"retrieve subset of keys with existing owner": {
|
||||
owner: issuerID1,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: n/2 - 1,
|
||||
Limit: n,
|
||||
Total: n,
|
||||
},
|
||||
size: n / 2,
|
||||
},
|
||||
"retrieve keys with existing subject": {
|
||||
owner: issuerID1,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Subject: "key-10@email.com",
|
||||
},
|
||||
size: 1,
|
||||
},
|
||||
"retrieve keys with non-existing subject": {
|
||||
owner: issuerID1,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Subject: "wrong",
|
||||
Total: 0,
|
||||
},
|
||||
size: 0,
|
||||
},
|
||||
"retrieve keys with existing type": {
|
||||
owner: issuerID1,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Type: auth.APIKey,
|
||||
},
|
||||
size: 10,
|
||||
},
|
||||
"retrieve keys with non-existing type": {
|
||||
owner: issuerID1,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Total: 0,
|
||||
Type: uint32(9),
|
||||
},
|
||||
size: 0,
|
||||
},
|
||||
"retrieve all keys with existing subject and type": {
|
||||
owner: issuerID1,
|
||||
pageMetadata: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Subject: "key-10@email.com",
|
||||
Type: auth.APIKey,
|
||||
},
|
||||
size: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for desc, tc := range cases {
|
||||
page, err := repo.RetrieveAll(context.Background(), tc.owner, tc.pageMetadata)
|
||||
size := uint64(len(page.Keys))
|
||||
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", desc, tc.size, size))
|
||||
// assert.Equal(t, tc.pageMetadata.Total, page.Total, fmt.Sprintf("%s: expected total %d got %d\n", desc, tc.pageMetadata.Total, page.Total))
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %d\n", desc, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRemove(t *testing.T) {
|
||||
dbMiddleware := postgres.NewDatabase(db)
|
||||
repo := postgres.New(dbMiddleware)
|
||||
|
@ -60,6 +60,10 @@ type Authn interface {
|
||||
// ID, that is issued by the user identified by the provided key.
|
||||
RetrieveKey(ctx context.Context, token, id string) (Key, error)
|
||||
|
||||
// RetrieveKeys retrieves data for the Keys that are
|
||||
// issued by the user identified by the provided key.
|
||||
RetrieveKeys(ctx context.Context, token string, pm PageMetadata) (KeyPage, error)
|
||||
|
||||
// Identify validates token token. If token is valid, content
|
||||
// is returned. If token is invalid, or invocation failed for some
|
||||
// other reason, non-nil error value is returned in response.
|
||||
@ -134,7 +138,16 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro
|
||||
return Key{}, errors.Wrap(errRetrieve, err)
|
||||
}
|
||||
|
||||
return svc.keys.Retrieve(ctx, issuerID, id)
|
||||
return svc.keys.RetrieveByID(ctx, issuerID, id)
|
||||
}
|
||||
|
||||
func (svc service) RetrieveKeys(ctx context.Context, token string, pm PageMetadata) (KeyPage, error) {
|
||||
issuerID, _, err := svc.login(token)
|
||||
if err != nil {
|
||||
return KeyPage{}, errors.Wrap(errRetrieve, err)
|
||||
}
|
||||
|
||||
return svc.keys.RetrieveAll(ctx, issuerID, pm)
|
||||
}
|
||||
|
||||
func (svc service) Identify(ctx context.Context, token string) (Identity, error) {
|
||||
@ -151,7 +164,7 @@ func (svc service) Identify(ctx context.Context, token string) (Identity, error)
|
||||
case RecoveryKey, LoginKey:
|
||||
return Identity{ID: key.IssuerID, Email: key.Subject}, nil
|
||||
case APIKey:
|
||||
_, err := svc.keys.Retrieve(context.TODO(), key.IssuerID, key.ID)
|
||||
_, err := svc.keys.RetrieveByID(context.TODO(), key.IssuerID, key.ID)
|
||||
if err != nil {
|
||||
return Identity{}, errors.ErrAuthentication
|
||||
}
|
||||
|
@ -235,6 +235,66 @@ func TestRetrieve(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveAll(t *testing.T) {
|
||||
svc := newService()
|
||||
_, secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.LoginKey, IssuedAt: time.Now(), IssuerID: id, Subject: email})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
|
||||
n := uint64(100)
|
||||
for i := uint64(0); i < n; i++ {
|
||||
key := auth.Key{
|
||||
ID: "id",
|
||||
Type: auth.APIKey,
|
||||
IssuerID: id,
|
||||
Subject: fmt.Sprintf("email-%d@mail.com", i),
|
||||
IssuedAt: time.Now(),
|
||||
}
|
||||
_, _, err := svc.Issue(context.Background(), secret, key)
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing user's key expected to succeed: %s", err))
|
||||
}
|
||||
|
||||
cases := map[string]struct {
|
||||
token string
|
||||
size uint64
|
||||
pm auth.PageMetadata
|
||||
err error
|
||||
}{
|
||||
"list all keys": {
|
||||
token: secret,
|
||||
pm: auth.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: n,
|
||||
Total: n,
|
||||
},
|
||||
size: n,
|
||||
err: nil,
|
||||
},
|
||||
"list all keys with offset": {
|
||||
token: secret,
|
||||
pm: auth.PageMetadata{
|
||||
Offset: 50,
|
||||
Limit: n,
|
||||
Total: n,
|
||||
},
|
||||
size: 50,
|
||||
err: nil,
|
||||
},
|
||||
"list all keys with wrong token": {
|
||||
token: "wrongToken",
|
||||
size: 0,
|
||||
err: errors.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for desc, tc := range cases {
|
||||
page, err := svc.RetrieveKeys(context.Background(), tc.token, tc.pm)
|
||||
size := uint64(len(page.Keys))
|
||||
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", desc, tc.size, size))
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", desc, tc.err, err))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestIdentify(t *testing.T) {
|
||||
svc := newService()
|
||||
|
||||
|
@ -13,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
saveOp = "save"
|
||||
retrieveOp = "retrieve_by_id"
|
||||
revokeOp = "remove"
|
||||
saveOp = "save"
|
||||
retrieveOp = "retrieve_by_id"
|
||||
retrieveAllOp = "retrieve_all"
|
||||
revokeOp = "remove"
|
||||
)
|
||||
|
||||
var _ auth.KeyRepository = (*keyRepositoryMiddleware)(nil)
|
||||
@ -44,12 +45,20 @@ func (krm keyRepositoryMiddleware) Save(ctx context.Context, key auth.Key) (stri
|
||||
return krm.repo.Save(ctx, key)
|
||||
}
|
||||
|
||||
func (krm keyRepositoryMiddleware) Retrieve(ctx context.Context, owner, id string) (auth.Key, error) {
|
||||
func (krm keyRepositoryMiddleware) RetrieveByID(ctx context.Context, owner, id string) (auth.Key, error) {
|
||||
span := createSpan(ctx, krm.tracer, retrieveOp)
|
||||
defer span.Finish()
|
||||
ctx = opentracing.ContextWithSpan(ctx, span)
|
||||
|
||||
return krm.repo.Retrieve(ctx, owner, id)
|
||||
return krm.repo.RetrieveByID(ctx, owner, id)
|
||||
}
|
||||
|
||||
func (krm keyRepositoryMiddleware) RetrieveAll(ctx context.Context, owner string, pm auth.PageMetadata) (auth.KeyPage, error) {
|
||||
span := createSpan(ctx, krm.tracer, retrieveAllOp)
|
||||
defer span.Finish()
|
||||
ctx = opentracing.ContextWithSpan(ctx, span)
|
||||
|
||||
return krm.repo.RetrieveAll(ctx, owner, pm)
|
||||
}
|
||||
|
||||
func (krm keyRepositoryMiddleware) Remove(ctx context.Context, owner, id string) error {
|
||||
|
Loading…
x
Reference in New Issue
Block a user