mirror of
https://github.com/mainflux/mainflux.git
synced 2025-04-24 13:48:49 +08:00
NOISSUE - Use Insert On Conflict For Policies (#1824)
* Initial Commit: Use Insert On Conflict For Policies Signed-off-by: rodneyosodo <blackd0t@protonmail.com> * Invalidate Cache on Adding Policy Signed-off-by: rodneyosodo <blackd0t@protonmail.com> --------- Signed-off-by: rodneyosodo <blackd0t@protonmail.com>
This commit is contained in:
parent
98aa270b14
commit
2673b34c6a
@ -28,8 +28,6 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
const addExistingPolicyDesc = "add existing policy"
|
||||
|
||||
var utadminPolicy = umocks.SubjectSet{Subject: "things", Relation: []string{"g_add"}}
|
||||
|
||||
func newPolicyServer(svc upolicies.Service) *httptest.Server {
|
||||
@ -75,7 +73,7 @@ func TestCreatePolicy(t *testing.T) {
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: addExistingPolicyDesc,
|
||||
desc: "add existing policy",
|
||||
policy: sdk.Policy{
|
||||
Subject: subject,
|
||||
Object: object,
|
||||
@ -153,25 +151,15 @@ func TestCreatePolicy(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := pRepo.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil)
|
||||
repoCall1 := pRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertUserPolicyPage(tc.page), nil)
|
||||
repoCall2 := pRepo.On("Update", mock.Anything, mock.Anything).Return(tc.err)
|
||||
repoCall3 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
|
||||
repoCall1 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
|
||||
err := mfsdk.CreatePolicy(tc.policy, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
if tc.desc == addExistingPolicyDesc {
|
||||
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
@ -315,7 +303,7 @@ func TestAssign(t *testing.T) {
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: addExistingPolicyDesc,
|
||||
desc: "add existing policy",
|
||||
policy: sdk.Policy{
|
||||
Subject: subject,
|
||||
Object: object,
|
||||
@ -393,25 +381,15 @@ func TestAssign(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := pRepo.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil)
|
||||
repoCall1 := pRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertUserPolicyPage(tc.page), nil)
|
||||
repoCall2 := pRepo.On("Update", mock.Anything, mock.Anything).Return(tc.err)
|
||||
repoCall3 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
|
||||
repoCall1 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
|
||||
err := mfsdk.Assign(tc.policy.Actions, tc.policy.Subject, tc.policy.Object, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
if tc.desc == addExistingPolicyDesc {
|
||||
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
}
|
||||
}
|
||||
func TestUpdatePolicy(t *testing.T) {
|
||||
@ -800,7 +778,7 @@ func TestConnect(t *testing.T) {
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: addExistingPolicyDesc,
|
||||
desc: "add existing policy",
|
||||
policy: sdk.Policy{
|
||||
Subject: subject,
|
||||
Object: object,
|
||||
@ -877,25 +855,15 @@ func TestConnect(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := pRepo.On("Retrieve", mock.Anything, mock.Anything).Return(convertThingPolicyPage(tc.page), nil)
|
||||
repoCall1 := pRepo.On("Update", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
|
||||
repoCall2 := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
|
||||
repoCall := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
|
||||
conn := sdk.ConnectionIDs{ChannelIDs: []string{tc.policy.Object}, ThingIDs: []string{tc.policy.Subject}, Actions: tc.policy.Actions}
|
||||
err := mfsdk.Connect(conn, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall.Parent.AssertCalled(t, "Retrieve", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Retrieve was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
if tc.desc == addExistingPolicyDesc {
|
||||
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
@ -939,7 +907,7 @@ func TestConnectThing(t *testing.T) {
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: addExistingPolicyDesc,
|
||||
desc: "add existing policy",
|
||||
policy: sdk.Policy{
|
||||
Subject: subject,
|
||||
Object: object,
|
||||
@ -1016,24 +984,14 @@ func TestConnectThing(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := pRepo.On("Retrieve", mock.Anything, mock.Anything).Return(convertThingPolicyPage(tc.page), nil)
|
||||
repoCall1 := pRepo.On("Update", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
|
||||
repoCall2 := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
|
||||
repoCall := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
|
||||
err := mfsdk.ConnectThing(tc.policy.Subject, tc.policy.Object, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall.Parent.AssertCalled(t, "Retrieve", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Retrieve was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
if tc.desc == addExistingPolicyDesc {
|
||||
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,14 +83,14 @@ func generateValidToken(t *testing.T, svc clients.Service, cRepo *umocks.Reposit
|
||||
token, err := svc.IssueToken(context.Background(), client.Credentials.Identity, client.Credentials.Secret)
|
||||
assert.True(t, errors.Contains(err, nil), fmt.Sprintf("Create token expected nil got %s\n", err))
|
||||
repoCall.Unset()
|
||||
|
||||
|
||||
return token.AccessToken
|
||||
}
|
||||
|
||||
func generateUUID(t *testing.T) string {
|
||||
ulid, err := idProvider.ID()
|
||||
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
|
||||
|
||||
|
||||
return ulid
|
||||
}
|
||||
|
||||
@ -156,16 +156,6 @@ func convertUserPolicies(cs []sdk.Policy) []upolicies.Policy {
|
||||
return ccs
|
||||
}
|
||||
|
||||
func convertThingPolicies(cs []sdk.Policy) []tpolicies.Policy {
|
||||
ccs := []tpolicies.Policy{}
|
||||
|
||||
for _, c := range cs {
|
||||
ccs = append(ccs, convertThingPolicy(c))
|
||||
}
|
||||
|
||||
return ccs
|
||||
}
|
||||
|
||||
func convertUserPolicy(sp sdk.Policy) upolicies.Policy {
|
||||
return upolicies.Policy{
|
||||
OwnerID: sp.OwnerID,
|
||||
@ -218,7 +208,7 @@ func convertClientPage(p sdk.PageMetadata) mfclients.Page {
|
||||
if err != nil {
|
||||
return mfclients.Page{}
|
||||
}
|
||||
|
||||
|
||||
return mfclients.Page{
|
||||
Status: status,
|
||||
Total: p.Total,
|
||||
@ -248,7 +238,7 @@ func convertGroup(g sdk.Group) mfgroups.Group {
|
||||
if err != nil {
|
||||
return mfgroups.Group{}
|
||||
}
|
||||
|
||||
|
||||
return mfgroups.Group{
|
||||
ID: g.ID,
|
||||
Owner: g.OwnerID,
|
||||
@ -288,7 +278,7 @@ func convertClient(c sdk.User) mfclients.Client {
|
||||
if err != nil {
|
||||
return mfclients.Client{}
|
||||
}
|
||||
|
||||
|
||||
return mfclients.Client{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
@ -357,17 +347,6 @@ func convertUserPolicyPage(pp sdk.PolicyPage) upolicies.PolicyPage {
|
||||
}
|
||||
}
|
||||
|
||||
func convertThingPolicyPage(pp sdk.PolicyPage) tpolicies.PolicyPage {
|
||||
return tpolicies.PolicyPage{
|
||||
Page: tpolicies.Page{
|
||||
Limit: pp.Limit,
|
||||
Total: pp.Total,
|
||||
Offset: pp.Offset,
|
||||
},
|
||||
Policies: convertThingPolicies(pp.Policies),
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
|
@ -31,6 +31,8 @@ func NewRepository(db postgres.Database) policies.Repository {
|
||||
func (pr prepo) Save(ctx context.Context, policy policies.Policy) (policies.Policy, error) {
|
||||
q := `INSERT INTO policies (owner_id, subject, object, actions, created_at, updated_at, updated_by)
|
||||
VALUES (:owner_id, :subject, :object, :actions, :created_at, :updated_at, :updated_by)
|
||||
ON CONFLICT (subject, object) DO UPDATE SET actions = :actions,
|
||||
updated_at = :updated_at, updated_by = :updated_by
|
||||
RETURNING owner_id, subject, object, actions, created_at, updated_at, updated_by;`
|
||||
|
||||
dbp, err := toDBPolicy(policy)
|
||||
|
@ -67,7 +67,7 @@ func TestPoliciesSave(t *testing.T) {
|
||||
Object: group.ID,
|
||||
Actions: []string{"c_delete"},
|
||||
},
|
||||
err: errors.ErrConflict,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -109,45 +109,26 @@ func (svc service) AddPolicy(ctx context.Context, token string, p Policy) (Polic
|
||||
if err := p.Validate(); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
pm := Page{Subject: p.Subject, Object: p.Object, Offset: 0, Limit: 1}
|
||||
page, err := svc.policies.Retrieve(ctx, pm)
|
||||
if err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
|
||||
// If the policy already exists, replace the actions
|
||||
if len(page.Policies) == 1 {
|
||||
if err := svc.checkPolicy(ctx, userID, p); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
|
||||
p.UpdatedAt = time.Now()
|
||||
p.UpdatedBy = userID
|
||||
|
||||
if err := svc.policyCache.Remove(ctx, p); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
|
||||
return svc.policies.Update(ctx, p)
|
||||
}
|
||||
|
||||
p.OwnerID = userID
|
||||
p.CreatedAt = time.Now()
|
||||
|
||||
// incase the policy exists, use these for update.
|
||||
p.UpdatedAt = time.Now()
|
||||
p.UpdatedBy = userID
|
||||
|
||||
if err := svc.policyCache.Remove(ctx, p); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
|
||||
// If the client is admin, add the policy
|
||||
if err := svc.checkAdmin(ctx, userID); err == nil {
|
||||
if err := svc.policyCache.Put(ctx, p); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
return svc.policies.Save(ctx, p)
|
||||
}
|
||||
|
||||
// If the client has `g_add` action on the object or is the owner of the object, add the policy
|
||||
ar := AccessRequest{Subject: userID, Object: p.Object, Action: "g_add"}
|
||||
if _, err := svc.policies.EvaluateGroupAccess(ctx, ar); err == nil {
|
||||
if err := svc.policyCache.Put(ctx, p); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
return svc.policies.Save(ctx, p)
|
||||
}
|
||||
|
||||
|
@ -131,7 +131,6 @@ func TestAddPolicy(t *testing.T) {
|
||||
repoCall1 := pRepo.On("EvaluateThingAccess", mock.Anything, mock.Anything).Return(policies.Policy{}, tc.err)
|
||||
repoCall2 := pRepo.On("Update", context.Background(), tc.policy).Return(tc.err)
|
||||
repoCall3 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.policy, tc.err)
|
||||
repoCall4 := pRepo.On("Retrieve", context.Background(), mock.Anything).Return(tc.page, nil)
|
||||
_, err := svc.AddPolicy(context.Background(), tc.token, tc.policy)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
@ -145,7 +144,6 @@ func TestAddPolicy(t *testing.T) {
|
||||
repoCall2.Unset()
|
||||
repoCall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
repoCall3.Unset()
|
||||
repoCall4.Unset()
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -32,7 +32,9 @@ func NewRepository(db postgres.Database) policies.Repository {
|
||||
|
||||
func (pr prepo) Save(ctx context.Context, policy policies.Policy) error {
|
||||
q := `INSERT INTO policies (owner_id, subject, object, actions, created_at)
|
||||
VALUES (:owner_id, :subject, :object, :actions, :created_at)`
|
||||
VALUES (:owner_id, :subject, :object, :actions, :created_at)
|
||||
ON CONFLICT (subject, object) DO UPDATE SET actions = :actions,
|
||||
updated_at = :updated_at, updated_by = :updated_by`
|
||||
|
||||
dbp, err := toDBPolicy(policy)
|
||||
if err != nil {
|
||||
|
@ -73,7 +73,7 @@ func TestPoliciesSave(t *testing.T) {
|
||||
Object: uid,
|
||||
Actions: []string{"c_delete"},
|
||||
},
|
||||
err: errors.ErrConflict,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -65,26 +65,13 @@ func (svc service) AddPolicy(ctx context.Context, token string, p Policy) error
|
||||
return err
|
||||
}
|
||||
|
||||
pm := Page{Subject: p.Subject, Object: p.Object, Offset: 0, Limit: 1}
|
||||
page, err := svc.policies.RetrieveAll(ctx, pm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the policy already exists, replace the actions
|
||||
if len(page.Policies) == 1 {
|
||||
if err := svc.checkPolicy(ctx, id, p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.UpdatedAt = time.Now()
|
||||
p.UpdatedBy = id
|
||||
return svc.policies.Update(ctx, p)
|
||||
}
|
||||
|
||||
p.OwnerID = id
|
||||
p.CreatedAt = time.Now()
|
||||
|
||||
// incase the policy exists, use these for update.
|
||||
p.UpdatedAt = time.Now()
|
||||
p.UpdatedBy = id
|
||||
|
||||
// check if the client is admin
|
||||
if err = svc.policies.CheckAdmin(ctx, id); err == nil {
|
||||
return svc.policies.Save(ctx, p)
|
||||
|
@ -140,9 +140,8 @@ func TestAddPolicy(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := pRepo.On("CheckAdmin", context.Background(), mock.Anything).Return(nil)
|
||||
repoCall1 := pRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.page, nil)
|
||||
repoCall2 := pRepo.On("Update", context.Background(), mock.Anything).Return(tc.err)
|
||||
repoCall3 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.err)
|
||||
repoCall1 := pRepo.On("Update", context.Background(), mock.Anything).Return(tc.err)
|
||||
repoCall2 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.err)
|
||||
err := svc.AddPolicy(context.Background(), tc.token, tc.policy)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
@ -152,19 +151,16 @@ func TestAddPolicy(t *testing.T) {
|
||||
require.Nil(t, err, fmt.Sprintf("checking shared %v policy expected to be succeed: %#v", tc.policy, err))
|
||||
ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc))
|
||||
ok = repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
|
||||
ok = repoCall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
ok = repoCall2.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
if tc.desc == "add existing policy" {
|
||||
ok = repoCall2.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything)
|
||||
ok = repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user