From 2673b34c6a67949c7007040bc088da6798e9982e Mon Sep 17 00:00:00 2001 From: b1ackd0t Date: Tue, 20 Jun 2023 15:59:53 +0300 Subject: [PATCH] NOISSUE - Use Insert On Conflict For Policies (#1824) * Initial Commit: Use Insert On Conflict For Policies Signed-off-by: rodneyosodo * Invalidate Cache on Adding Policy Signed-off-by: rodneyosodo --------- Signed-off-by: rodneyosodo --- pkg/sdk/go/policies_test.go | 66 +++++------------------ pkg/sdk/go/setup_test.go | 31 ++--------- things/policies/postgres/policies.go | 2 + things/policies/postgres/policies_test.go | 2 +- things/policies/service.go | 35 +++--------- things/policies/service_test.go | 2 - users/policies/postgres/policies.go | 4 +- users/policies/postgres/policies_test.go | 2 +- users/policies/service.go | 21 ++------ users/policies/service_test.go | 12 ++--- 10 files changed, 40 insertions(+), 137 deletions(-) diff --git a/pkg/sdk/go/policies_test.go b/pkg/sdk/go/policies_test.go index 6d1efffc..2dc7cd31 100644 --- a/pkg/sdk/go/policies_test.go +++ b/pkg/sdk/go/policies_test.go @@ -28,8 +28,6 @@ import ( "github.com/stretchr/testify/mock" ) -const addExistingPolicyDesc = "add existing policy" - var utadminPolicy = umocks.SubjectSet{Subject: "things", Relation: []string{"g_add"}} func newPolicyServer(svc upolicies.Service) *httptest.Server { @@ -75,7 +73,7 @@ func TestCreatePolicy(t *testing.T) { err: nil, }, { - desc: addExistingPolicyDesc, + desc: "add existing policy", policy: sdk.Policy{ Subject: subject, Object: object, @@ -153,25 +151,15 @@ func TestCreatePolicy(t *testing.T) { for _, tc := range cases { repoCall := pRepo.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := pRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertUserPolicyPage(tc.page), nil) - repoCall2 := pRepo.On("Update", mock.Anything, mock.Anything).Return(tc.err) - repoCall3 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err) + repoCall1 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err) err := mfsdk.CreatePolicy(tc.policy, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) - if tc.desc == addExistingPolicyDesc { - ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) - } } repoCall.Unset() repoCall1.Unset() - repoCall2.Unset() - repoCall3.Unset() } } @@ -315,7 +303,7 @@ func TestAssign(t *testing.T) { err: nil, }, { - desc: addExistingPolicyDesc, + desc: "add existing policy", policy: sdk.Policy{ Subject: subject, Object: object, @@ -393,25 +381,15 @@ func TestAssign(t *testing.T) { for _, tc := range cases { repoCall := pRepo.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := pRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertUserPolicyPage(tc.page), nil) - repoCall2 := pRepo.On("Update", mock.Anything, mock.Anything).Return(tc.err) - repoCall3 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err) + repoCall1 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err) err := mfsdk.Assign(tc.policy.Actions, tc.policy.Subject, tc.policy.Object, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) - if tc.desc == addExistingPolicyDesc { - ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) - } } repoCall.Unset() repoCall1.Unset() - repoCall2.Unset() - repoCall3.Unset() } } func TestUpdatePolicy(t *testing.T) { @@ -800,7 +778,7 @@ func TestConnect(t *testing.T) { err: nil, }, { - desc: addExistingPolicyDesc, + desc: "add existing policy", policy: sdk.Policy{ Subject: subject, Object: object, @@ -877,25 +855,15 @@ func TestConnect(t *testing.T) { } for _, tc := range cases { - repoCall := pRepo.On("Retrieve", mock.Anything, mock.Anything).Return(convertThingPolicyPage(tc.page), nil) - repoCall1 := pRepo.On("Update", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err) - repoCall2 := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err) + repoCall := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err) conn := sdk.ConnectionIDs{ChannelIDs: []string{tc.policy.Object}, ThingIDs: []string{tc.policy.Subject}, Actions: tc.policy.Actions} err := mfsdk.Connect(conn, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "Retrieve", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("Retrieve was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) - if tc.desc == addExistingPolicyDesc { - ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) - } } repoCall.Unset() - repoCall1.Unset() - repoCall2.Unset() } } @@ -939,7 +907,7 @@ func TestConnectThing(t *testing.T) { err: nil, }, { - desc: addExistingPolicyDesc, + desc: "add existing policy", policy: sdk.Policy{ Subject: subject, Object: object, @@ -1016,24 +984,14 @@ func TestConnectThing(t *testing.T) { } for _, tc := range cases { - repoCall := pRepo.On("Retrieve", mock.Anything, mock.Anything).Return(convertThingPolicyPage(tc.page), nil) - repoCall1 := pRepo.On("Update", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err) - repoCall2 := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err) + repoCall := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err) err := mfsdk.ConnectThing(tc.policy.Subject, tc.policy.Object, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "Retrieve", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("Retrieve was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) - if tc.desc == addExistingPolicyDesc { - ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) - } } repoCall.Unset() - repoCall1.Unset() - repoCall2.Unset() } } diff --git a/pkg/sdk/go/setup_test.go b/pkg/sdk/go/setup_test.go index fcc46ff0..e1e70385 100644 --- a/pkg/sdk/go/setup_test.go +++ b/pkg/sdk/go/setup_test.go @@ -83,14 +83,14 @@ func generateValidToken(t *testing.T, svc clients.Service, cRepo *umocks.Reposit token, err := svc.IssueToken(context.Background(), client.Credentials.Identity, client.Credentials.Secret) assert.True(t, errors.Contains(err, nil), fmt.Sprintf("Create token expected nil got %s\n", err)) repoCall.Unset() - + return token.AccessToken } func generateUUID(t *testing.T) string { ulid, err := idProvider.ID() assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) - + return ulid } @@ -156,16 +156,6 @@ func convertUserPolicies(cs []sdk.Policy) []upolicies.Policy { return ccs } -func convertThingPolicies(cs []sdk.Policy) []tpolicies.Policy { - ccs := []tpolicies.Policy{} - - for _, c := range cs { - ccs = append(ccs, convertThingPolicy(c)) - } - - return ccs -} - func convertUserPolicy(sp sdk.Policy) upolicies.Policy { return upolicies.Policy{ OwnerID: sp.OwnerID, @@ -218,7 +208,7 @@ func convertClientPage(p sdk.PageMetadata) mfclients.Page { if err != nil { return mfclients.Page{} } - + return mfclients.Page{ Status: status, Total: p.Total, @@ -248,7 +238,7 @@ func convertGroup(g sdk.Group) mfgroups.Group { if err != nil { return mfgroups.Group{} } - + return mfgroups.Group{ ID: g.ID, Owner: g.OwnerID, @@ -288,7 +278,7 @@ func convertClient(c sdk.User) mfclients.Client { if err != nil { return mfclients.Client{} } - + return mfclients.Client{ ID: c.ID, Name: c.Name, @@ -357,17 +347,6 @@ func convertUserPolicyPage(pp sdk.PolicyPage) upolicies.PolicyPage { } } -func convertThingPolicyPage(pp sdk.PolicyPage) tpolicies.PolicyPage { - return tpolicies.PolicyPage{ - Page: tpolicies.Page{ - Limit: pp.Limit, - Total: pp.Total, - Offset: pp.Offset, - }, - Policies: convertThingPolicies(pp.Policies), - } -} - func TestMain(m *testing.M) { exitCode := m.Run() os.Exit(exitCode) diff --git a/things/policies/postgres/policies.go b/things/policies/postgres/policies.go index a4683128..82d5c0c6 100644 --- a/things/policies/postgres/policies.go +++ b/things/policies/postgres/policies.go @@ -31,6 +31,8 @@ func NewRepository(db postgres.Database) policies.Repository { func (pr prepo) Save(ctx context.Context, policy policies.Policy) (policies.Policy, error) { q := `INSERT INTO policies (owner_id, subject, object, actions, created_at, updated_at, updated_by) VALUES (:owner_id, :subject, :object, :actions, :created_at, :updated_at, :updated_by) + ON CONFLICT (subject, object) DO UPDATE SET actions = :actions, + updated_at = :updated_at, updated_by = :updated_by RETURNING owner_id, subject, object, actions, created_at, updated_at, updated_by;` dbp, err := toDBPolicy(policy) diff --git a/things/policies/postgres/policies_test.go b/things/policies/postgres/policies_test.go index f6928e9e..1f0975dd 100644 --- a/things/policies/postgres/policies_test.go +++ b/things/policies/postgres/policies_test.go @@ -67,7 +67,7 @@ func TestPoliciesSave(t *testing.T) { Object: group.ID, Actions: []string{"c_delete"}, }, - err: errors.ErrConflict, + err: nil, }, } diff --git a/things/policies/service.go b/things/policies/service.go index 0a3479c8..15d6d0d8 100644 --- a/things/policies/service.go +++ b/things/policies/service.go @@ -109,45 +109,26 @@ func (svc service) AddPolicy(ctx context.Context, token string, p Policy) (Polic if err := p.Validate(); err != nil { return Policy{}, err } - pm := Page{Subject: p.Subject, Object: p.Object, Offset: 0, Limit: 1} - page, err := svc.policies.Retrieve(ctx, pm) - if err != nil { - return Policy{}, err - } - - // If the policy already exists, replace the actions - if len(page.Policies) == 1 { - if err := svc.checkPolicy(ctx, userID, p); err != nil { - return Policy{}, err - } - - p.UpdatedAt = time.Now() - p.UpdatedBy = userID - - if err := svc.policyCache.Remove(ctx, p); err != nil { - return Policy{}, err - } - - return svc.policies.Update(ctx, p) - } p.OwnerID = userID p.CreatedAt = time.Now() + // incase the policy exists, use these for update. + p.UpdatedAt = time.Now() + p.UpdatedBy = userID + + if err := svc.policyCache.Remove(ctx, p); err != nil { + return Policy{}, err + } + // If the client is admin, add the policy if err := svc.checkAdmin(ctx, userID); err == nil { - if err := svc.policyCache.Put(ctx, p); err != nil { - return Policy{}, err - } return svc.policies.Save(ctx, p) } // If the client has `g_add` action on the object or is the owner of the object, add the policy ar := AccessRequest{Subject: userID, Object: p.Object, Action: "g_add"} if _, err := svc.policies.EvaluateGroupAccess(ctx, ar); err == nil { - if err := svc.policyCache.Put(ctx, p); err != nil { - return Policy{}, err - } return svc.policies.Save(ctx, p) } diff --git a/things/policies/service_test.go b/things/policies/service_test.go index 951fc63b..ed8b7e03 100644 --- a/things/policies/service_test.go +++ b/things/policies/service_test.go @@ -131,7 +131,6 @@ func TestAddPolicy(t *testing.T) { repoCall1 := pRepo.On("EvaluateThingAccess", mock.Anything, mock.Anything).Return(policies.Policy{}, tc.err) repoCall2 := pRepo.On("Update", context.Background(), tc.policy).Return(tc.err) repoCall3 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.policy, tc.err) - repoCall4 := pRepo.On("Retrieve", context.Background(), mock.Anything).Return(tc.page, nil) _, err := svc.AddPolicy(context.Background(), tc.token, tc.policy) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { @@ -145,7 +144,6 @@ func TestAddPolicy(t *testing.T) { repoCall2.Unset() repoCall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) repoCall3.Unset() - repoCall4.Unset() } } diff --git a/users/policies/postgres/policies.go b/users/policies/postgres/policies.go index 78a17da8..5dadf8ee 100644 --- a/users/policies/postgres/policies.go +++ b/users/policies/postgres/policies.go @@ -32,7 +32,9 @@ func NewRepository(db postgres.Database) policies.Repository { func (pr prepo) Save(ctx context.Context, policy policies.Policy) error { q := `INSERT INTO policies (owner_id, subject, object, actions, created_at) - VALUES (:owner_id, :subject, :object, :actions, :created_at)` + VALUES (:owner_id, :subject, :object, :actions, :created_at) + ON CONFLICT (subject, object) DO UPDATE SET actions = :actions, + updated_at = :updated_at, updated_by = :updated_by` dbp, err := toDBPolicy(policy) if err != nil { diff --git a/users/policies/postgres/policies_test.go b/users/policies/postgres/policies_test.go index b18cede9..a0ae9b77 100644 --- a/users/policies/postgres/policies_test.go +++ b/users/policies/postgres/policies_test.go @@ -73,7 +73,7 @@ func TestPoliciesSave(t *testing.T) { Object: uid, Actions: []string{"c_delete"}, }, - err: errors.ErrConflict, + err: nil, }, } diff --git a/users/policies/service.go b/users/policies/service.go index e0f0ef03..375bcb75 100644 --- a/users/policies/service.go +++ b/users/policies/service.go @@ -65,26 +65,13 @@ func (svc service) AddPolicy(ctx context.Context, token string, p Policy) error return err } - pm := Page{Subject: p.Subject, Object: p.Object, Offset: 0, Limit: 1} - page, err := svc.policies.RetrieveAll(ctx, pm) - if err != nil { - return err - } - - // If the policy already exists, replace the actions - if len(page.Policies) == 1 { - if err := svc.checkPolicy(ctx, id, p); err != nil { - return err - } - - p.UpdatedAt = time.Now() - p.UpdatedBy = id - return svc.policies.Update(ctx, p) - } - p.OwnerID = id p.CreatedAt = time.Now() + // incase the policy exists, use these for update. + p.UpdatedAt = time.Now() + p.UpdatedBy = id + // check if the client is admin if err = svc.policies.CheckAdmin(ctx, id); err == nil { return svc.policies.Save(ctx, p) diff --git a/users/policies/service_test.go b/users/policies/service_test.go index 0cb882e3..4009e645 100644 --- a/users/policies/service_test.go +++ b/users/policies/service_test.go @@ -140,9 +140,8 @@ func TestAddPolicy(t *testing.T) { for _, tc := range cases { repoCall := pRepo.On("CheckAdmin", context.Background(), mock.Anything).Return(nil) - repoCall1 := pRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.page, nil) - repoCall2 := pRepo.On("Update", context.Background(), mock.Anything).Return(tc.err) - repoCall3 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.err) + repoCall1 := pRepo.On("Update", context.Background(), mock.Anything).Return(tc.err) + repoCall2 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.err) err := svc.AddPolicy(context.Background(), tc.token, tc.policy) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { @@ -152,19 +151,16 @@ func TestAddPolicy(t *testing.T) { require.Nil(t, err, fmt.Sprintf("checking shared %v policy expected to be succeed: %#v", tc.policy, err)) ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) - ok = repoCall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) + ok = repoCall2.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) if tc.desc == "add existing policy" { - ok = repoCall2.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything) + ok = repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() - repoCall3.Unset() } }