1
0
mirror of https://github.com/mainflux/mainflux.git synced 2025-04-24 13:48:49 +08:00

NOISSUE - Use Insert On Conflict For Policies (#1824)

* Initial Commit: Use Insert On Conflict For Policies

Signed-off-by: rodneyosodo <blackd0t@protonmail.com>

* Invalidate Cache on Adding Policy

Signed-off-by: rodneyosodo <blackd0t@protonmail.com>

---------

Signed-off-by: rodneyosodo <blackd0t@protonmail.com>
This commit is contained in:
b1ackd0t 2023-06-20 15:59:53 +03:00 committed by GitHub
parent 98aa270b14
commit 2673b34c6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 40 additions and 137 deletions

View File

@ -28,8 +28,6 @@ import (
"github.com/stretchr/testify/mock"
)
const addExistingPolicyDesc = "add existing policy"
var utadminPolicy = umocks.SubjectSet{Subject: "things", Relation: []string{"g_add"}}
func newPolicyServer(svc upolicies.Service) *httptest.Server {
@ -75,7 +73,7 @@ func TestCreatePolicy(t *testing.T) {
err: nil,
},
{
desc: addExistingPolicyDesc,
desc: "add existing policy",
policy: sdk.Policy{
Subject: subject,
Object: object,
@ -153,25 +151,15 @@ func TestCreatePolicy(t *testing.T) {
for _, tc := range cases {
repoCall := pRepo.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil)
repoCall1 := pRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertUserPolicyPage(tc.page), nil)
repoCall2 := pRepo.On("Update", mock.Anything, mock.Anything).Return(tc.err)
repoCall3 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
repoCall1 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
err := mfsdk.CreatePolicy(tc.policy, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
if tc.desc == addExistingPolicyDesc {
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
}
}
@ -315,7 +303,7 @@ func TestAssign(t *testing.T) {
err: nil,
},
{
desc: addExistingPolicyDesc,
desc: "add existing policy",
policy: sdk.Policy{
Subject: subject,
Object: object,
@ -393,25 +381,15 @@ func TestAssign(t *testing.T) {
for _, tc := range cases {
repoCall := pRepo.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil)
repoCall1 := pRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertUserPolicyPage(tc.page), nil)
repoCall2 := pRepo.On("Update", mock.Anything, mock.Anything).Return(tc.err)
repoCall3 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
repoCall1 := pRepo.On("Save", mock.Anything, mock.Anything).Return(tc.err)
err := mfsdk.Assign(tc.policy.Actions, tc.policy.Subject, tc.policy.Object, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
if tc.desc == addExistingPolicyDesc {
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
}
}
func TestUpdatePolicy(t *testing.T) {
@ -800,7 +778,7 @@ func TestConnect(t *testing.T) {
err: nil,
},
{
desc: addExistingPolicyDesc,
desc: "add existing policy",
policy: sdk.Policy{
Subject: subject,
Object: object,
@ -877,25 +855,15 @@ func TestConnect(t *testing.T) {
}
for _, tc := range cases {
repoCall := pRepo.On("Retrieve", mock.Anything, mock.Anything).Return(convertThingPolicyPage(tc.page), nil)
repoCall1 := pRepo.On("Update", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
repoCall2 := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
repoCall := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
conn := sdk.ConnectionIDs{ChannelIDs: []string{tc.policy.Object}, ThingIDs: []string{tc.policy.Subject}, Actions: tc.policy.Actions}
err := mfsdk.Connect(conn, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall.Parent.AssertCalled(t, "Retrieve", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Retrieve was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
if tc.desc == addExistingPolicyDesc {
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}
@ -939,7 +907,7 @@ func TestConnectThing(t *testing.T) {
err: nil,
},
{
desc: addExistingPolicyDesc,
desc: "add existing policy",
policy: sdk.Policy{
Subject: subject,
Object: object,
@ -1016,24 +984,14 @@ func TestConnectThing(t *testing.T) {
}
for _, tc := range cases {
repoCall := pRepo.On("Retrieve", mock.Anything, mock.Anything).Return(convertThingPolicyPage(tc.page), nil)
repoCall1 := pRepo.On("Update", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
repoCall2 := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
repoCall := pRepo.On("Save", mock.Anything, mock.Anything).Return(convertThingPolicy(tc.policy), tc.err)
err := mfsdk.ConnectThing(tc.policy.Subject, tc.policy.Object, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall.Parent.AssertCalled(t, "Retrieve", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Retrieve was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
if tc.desc == addExistingPolicyDesc {
ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}

View File

@ -83,14 +83,14 @@ func generateValidToken(t *testing.T, svc clients.Service, cRepo *umocks.Reposit
token, err := svc.IssueToken(context.Background(), client.Credentials.Identity, client.Credentials.Secret)
assert.True(t, errors.Contains(err, nil), fmt.Sprintf("Create token expected nil got %s\n", err))
repoCall.Unset()
return token.AccessToken
}
func generateUUID(t *testing.T) string {
ulid, err := idProvider.ID()
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
return ulid
}
@ -156,16 +156,6 @@ func convertUserPolicies(cs []sdk.Policy) []upolicies.Policy {
return ccs
}
func convertThingPolicies(cs []sdk.Policy) []tpolicies.Policy {
ccs := []tpolicies.Policy{}
for _, c := range cs {
ccs = append(ccs, convertThingPolicy(c))
}
return ccs
}
func convertUserPolicy(sp sdk.Policy) upolicies.Policy {
return upolicies.Policy{
OwnerID: sp.OwnerID,
@ -218,7 +208,7 @@ func convertClientPage(p sdk.PageMetadata) mfclients.Page {
if err != nil {
return mfclients.Page{}
}
return mfclients.Page{
Status: status,
Total: p.Total,
@ -248,7 +238,7 @@ func convertGroup(g sdk.Group) mfgroups.Group {
if err != nil {
return mfgroups.Group{}
}
return mfgroups.Group{
ID: g.ID,
Owner: g.OwnerID,
@ -288,7 +278,7 @@ func convertClient(c sdk.User) mfclients.Client {
if err != nil {
return mfclients.Client{}
}
return mfclients.Client{
ID: c.ID,
Name: c.Name,
@ -357,17 +347,6 @@ func convertUserPolicyPage(pp sdk.PolicyPage) upolicies.PolicyPage {
}
}
func convertThingPolicyPage(pp sdk.PolicyPage) tpolicies.PolicyPage {
return tpolicies.PolicyPage{
Page: tpolicies.Page{
Limit: pp.Limit,
Total: pp.Total,
Offset: pp.Offset,
},
Policies: convertThingPolicies(pp.Policies),
}
}
func TestMain(m *testing.M) {
exitCode := m.Run()
os.Exit(exitCode)

View File

@ -31,6 +31,8 @@ func NewRepository(db postgres.Database) policies.Repository {
func (pr prepo) Save(ctx context.Context, policy policies.Policy) (policies.Policy, error) {
q := `INSERT INTO policies (owner_id, subject, object, actions, created_at, updated_at, updated_by)
VALUES (:owner_id, :subject, :object, :actions, :created_at, :updated_at, :updated_by)
ON CONFLICT (subject, object) DO UPDATE SET actions = :actions,
updated_at = :updated_at, updated_by = :updated_by
RETURNING owner_id, subject, object, actions, created_at, updated_at, updated_by;`
dbp, err := toDBPolicy(policy)

View File

@ -67,7 +67,7 @@ func TestPoliciesSave(t *testing.T) {
Object: group.ID,
Actions: []string{"c_delete"},
},
err: errors.ErrConflict,
err: nil,
},
}

View File

@ -109,45 +109,26 @@ func (svc service) AddPolicy(ctx context.Context, token string, p Policy) (Polic
if err := p.Validate(); err != nil {
return Policy{}, err
}
pm := Page{Subject: p.Subject, Object: p.Object, Offset: 0, Limit: 1}
page, err := svc.policies.Retrieve(ctx, pm)
if err != nil {
return Policy{}, err
}
// If the policy already exists, replace the actions
if len(page.Policies) == 1 {
if err := svc.checkPolicy(ctx, userID, p); err != nil {
return Policy{}, err
}
p.UpdatedAt = time.Now()
p.UpdatedBy = userID
if err := svc.policyCache.Remove(ctx, p); err != nil {
return Policy{}, err
}
return svc.policies.Update(ctx, p)
}
p.OwnerID = userID
p.CreatedAt = time.Now()
// incase the policy exists, use these for update.
p.UpdatedAt = time.Now()
p.UpdatedBy = userID
if err := svc.policyCache.Remove(ctx, p); err != nil {
return Policy{}, err
}
// If the client is admin, add the policy
if err := svc.checkAdmin(ctx, userID); err == nil {
if err := svc.policyCache.Put(ctx, p); err != nil {
return Policy{}, err
}
return svc.policies.Save(ctx, p)
}
// If the client has `g_add` action on the object or is the owner of the object, add the policy
ar := AccessRequest{Subject: userID, Object: p.Object, Action: "g_add"}
if _, err := svc.policies.EvaluateGroupAccess(ctx, ar); err == nil {
if err := svc.policyCache.Put(ctx, p); err != nil {
return Policy{}, err
}
return svc.policies.Save(ctx, p)
}

View File

@ -131,7 +131,6 @@ func TestAddPolicy(t *testing.T) {
repoCall1 := pRepo.On("EvaluateThingAccess", mock.Anything, mock.Anything).Return(policies.Policy{}, tc.err)
repoCall2 := pRepo.On("Update", context.Background(), tc.policy).Return(tc.err)
repoCall3 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.policy, tc.err)
repoCall4 := pRepo.On("Retrieve", context.Background(), mock.Anything).Return(tc.page, nil)
_, err := svc.AddPolicy(context.Background(), tc.token, tc.policy)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
@ -145,7 +144,6 @@ func TestAddPolicy(t *testing.T) {
repoCall2.Unset()
repoCall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
repoCall3.Unset()
repoCall4.Unset()
}
}

View File

@ -32,7 +32,9 @@ func NewRepository(db postgres.Database) policies.Repository {
func (pr prepo) Save(ctx context.Context, policy policies.Policy) error {
q := `INSERT INTO policies (owner_id, subject, object, actions, created_at)
VALUES (:owner_id, :subject, :object, :actions, :created_at)`
VALUES (:owner_id, :subject, :object, :actions, :created_at)
ON CONFLICT (subject, object) DO UPDATE SET actions = :actions,
updated_at = :updated_at, updated_by = :updated_by`
dbp, err := toDBPolicy(policy)
if err != nil {

View File

@ -73,7 +73,7 @@ func TestPoliciesSave(t *testing.T) {
Object: uid,
Actions: []string{"c_delete"},
},
err: errors.ErrConflict,
err: nil,
},
}

View File

@ -65,26 +65,13 @@ func (svc service) AddPolicy(ctx context.Context, token string, p Policy) error
return err
}
pm := Page{Subject: p.Subject, Object: p.Object, Offset: 0, Limit: 1}
page, err := svc.policies.RetrieveAll(ctx, pm)
if err != nil {
return err
}
// If the policy already exists, replace the actions
if len(page.Policies) == 1 {
if err := svc.checkPolicy(ctx, id, p); err != nil {
return err
}
p.UpdatedAt = time.Now()
p.UpdatedBy = id
return svc.policies.Update(ctx, p)
}
p.OwnerID = id
p.CreatedAt = time.Now()
// incase the policy exists, use these for update.
p.UpdatedAt = time.Now()
p.UpdatedBy = id
// check if the client is admin
if err = svc.policies.CheckAdmin(ctx, id); err == nil {
return svc.policies.Save(ctx, p)

View File

@ -140,9 +140,8 @@ func TestAddPolicy(t *testing.T) {
for _, tc := range cases {
repoCall := pRepo.On("CheckAdmin", context.Background(), mock.Anything).Return(nil)
repoCall1 := pRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.page, nil)
repoCall2 := pRepo.On("Update", context.Background(), mock.Anything).Return(tc.err)
repoCall3 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.err)
repoCall1 := pRepo.On("Update", context.Background(), mock.Anything).Return(tc.err)
repoCall2 := pRepo.On("Save", context.Background(), mock.Anything).Return(tc.err)
err := svc.AddPolicy(context.Background(), tc.token, tc.policy)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
@ -152,19 +151,16 @@ func TestAddPolicy(t *testing.T) {
require.Nil(t, err, fmt.Sprintf("checking shared %v policy expected to be succeed: %#v", tc.policy, err))
ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc))
ok = repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
ok = repoCall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
ok = repoCall2.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
if tc.desc == "add existing policy" {
ok = repoCall2.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything)
ok = repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
}
}