diff --git a/things/postgres/channels.go b/things/postgres/channels.go index 516b4ccb..ef34e288 100644 --- a/things/postgres/channels.go +++ b/things/postgres/channels.go @@ -22,6 +22,12 @@ type channelRepository struct { db Database } +type dbConnection struct { + Channel string `db:"channel"` + Thing string `db:"thing"` + Owner string `db:"owner"` +} + // NewChannelRepository instantiates a PostgreSQL implementation of channel // repository. func NewChannelRepository(db Database) things.ChannelRepository { @@ -149,23 +155,11 @@ func (cr channelRepository) RetrieveAll(ctx context.Context, owner string, offse items = append(items, ch) } - cq := "" - if name != "" { - cq = `AND LOWER(name) LIKE $2` - } + cq := fmt.Sprintf(`SELECT COUNT(*) FROM channels WHERE owner = :owner %s%s;`, nq, mq) - q = fmt.Sprintf(`SELECT COUNT(*) FROM channels WHERE owner = $1 %s;`, cq) - - total := uint64(0) - switch name { - case "": - if err := cr.db.GetContext(ctx, &total, q, owner); err != nil { - return things.ChannelsPage{}, err - } - default: - if err := cr.db.GetContext(ctx, &total, q, owner, name); err != nil { - return things.ChannelsPage{}, err - } + total, err := total(ctx, cr.db, cq, params) + if err != nil { + return things.ChannelsPage{}, err } page := things.ChannelsPage{ @@ -438,8 +432,18 @@ func getMetadataQuery(m things.Metadata) ([]byte, string, error) { return mb, mq, nil } -type dbConnection struct { - Channel string `db:"channel"` - Thing string `db:"thing"` - Owner string `db:"owner"` +func total(ctx context.Context, db Database, query string, params map[string]interface{}) (uint64, error) { + rows, err := db.NamedQueryContext(ctx, query, params) + if err != nil { + return 0, err + } + + total := uint64(0) + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + + return total, nil } diff --git a/things/postgres/channels_test.go b/things/postgres/channels_test.go index d80da5a5..75668267 100644 --- a/things/postgres/channels_test.go +++ b/things/postgres/channels_test.go @@ -218,14 +218,22 @@ func TestSingleChannelRetrieval(t *testing.T) { } func TestMultiChannelRetrieval(t *testing.T) { - email := "channel-multi-retrieval@example.com" dbMiddleware := postgres.NewDatabase(db) chanRepo := postgres.NewChannelRepository(dbMiddleware) - channelName := "channel_name" - meta := things.Metadata{} - wrongMeta := things.Metadata{} - meta["name"] = "test-channel" - wrongMeta["wrong"] = "wrong" + + email := "channel-multi-retrieval@example.com" + name := "channel_name" + metadata := things.Metadata{ + "field": "value", + } + wrongMeta := things.Metadata{ + "wrong": "wrong", + } + + offset := uint64(1) + chNameNum := uint64(3) + chMetaNum := uint64(3) + chNameMetaNum := uint64(2) n := uint64(10) for i := uint64(0); i < n; i++ { @@ -237,14 +245,18 @@ func TestMultiChannelRetrieval(t *testing.T) { Owner: email, } - // Create first two Channels with name. - if i < 2 { - ch.Name = channelName + // Create Channels with name. + if i < chNameNum { + ch.Name = name } - - // Create last two Channels with metadata. - if i >= 8 { - ch.Metadata = meta + // Create Channels with metadata. + if i >= chNameNum && i < chNameNum+chMetaNum { + ch.Metadata = metadata + } + // Create Channels with name and metadata. + if i >= n-chNameMetaNum { + ch.Metadata = metadata + ch.Name = name } chanRepo.Save(context.Background(), ch) @@ -282,11 +294,11 @@ func TestMultiChannelRetrieval(t *testing.T) { }, "retrieve channels with existing name": { owner: email, - offset: 1, + offset: offset, limit: n, - name: channelName, - size: 1, - total: 2, + name: name, + size: chNameNum + chNameMetaNum - offset, + total: chNameNum + chNameMetaNum, }, "retrieve all channels with non-existing name": { owner: email, @@ -300,25 +312,33 @@ func TestMultiChannelRetrieval(t *testing.T) { owner: email, offset: 0, limit: n, - size: 2, - total: n, - metadata: meta, + size: chMetaNum + chNameMetaNum, + total: chMetaNum + chNameMetaNum, + metadata: metadata, }, "retrieve all channels with non-existing metadata": { owner: email, offset: 0, limit: n, - size: 0, - total: n, + total: 0, metadata: wrongMeta, }, + "retrieve all channels with existing name and metadata": { + owner: email, + offset: 0, + limit: n, + size: chNameMetaNum, + total: chNameMetaNum, + name: name, + metadata: metadata, + }, } for desc, tc := range cases { page, err := chanRepo.RetrieveAll(context.Background(), tc.owner, tc.offset, tc.limit, tc.name, tc.metadata) size := uint64(len(page.Channels)) - assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", desc, tc.size, size)) - assert.Equal(t, tc.total, page.Total, fmt.Sprintf("%s: expected %d got %d\n", desc, tc.total, page.Total)) + assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", desc, tc.size, size)) + assert.Equal(t, tc.total, page.Total, fmt.Sprintf("%s: expected total %d got %d\n", desc, tc.total, page.Total)) assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %d\n", desc, err)) } } diff --git a/things/postgres/things.go b/things/postgres/things.go index 401e7502..834d37a7 100644 --- a/things/postgres/things.go +++ b/things/postgres/things.go @@ -218,23 +218,11 @@ func (tr thingRepository) RetrieveAll(ctx context.Context, owner string, offset, items = append(items, th) } - cq := "" - if name != "" { - cq = `AND LOWER(name) LIKE $2` - } + cq := fmt.Sprintf(`SELECT COUNT(*) FROM things WHERE owner = :owner %s%s;`, nq, mq) - q = fmt.Sprintf(`SELECT COUNT(*) FROM things WHERE owner = $1 %s;`, cq) - - total := uint64(0) - switch name { - case "": - if err := tr.db.GetContext(ctx, &total, q, owner); err != nil { - return things.ThingsPage{}, err - } - default: - if err := tr.db.GetContext(ctx, &total, q, owner, name); err != nil { - return things.ThingsPage{}, err - } + total, err := total(ctx, tr.db, cq, params) + if err != nil { + return things.ThingsPage{}, err } page := things.ThingsPage{ diff --git a/things/postgres/things_test.go b/things/postgres/things_test.go index 7565d1a7..a352d22f 100644 --- a/things/postgres/things_test.go +++ b/things/postgres/things_test.go @@ -380,15 +380,24 @@ func TestThingRetrieveByKey(t *testing.T) { } func TestMultiThingRetrieval(t *testing.T) { - email := "thing-multi-retrieval@example.com" - name := "mainflux" - metadata := make(map[string]interface{}) - metadata["serial"] = "123456" - metadata["type"] = "test" - idp := uuid.New() dbMiddleware := postgres.NewDatabase(db) thingRepo := postgres.NewThingRepository(dbMiddleware) + email := "thing-multi-retrieval@example.com" + name := "mainflux" + metadata := things.Metadata{ + "field": "value", + } + wrongMeta := things.Metadata{ + "wrong": "wrong", + } + + idp := uuid.New() + offset := uint64(1) + thNameNum := uint64(3) + thMetaNum := uint64(3) + thNameMetaNum := uint64(2) + n := uint64(10) for i := uint64(0); i < n; i++ { thid, err := idp.ID() @@ -397,14 +406,22 @@ func TestMultiThingRetrieval(t *testing.T) { require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) th := things.Thing{ - Owner: email, - ID: thid, - Key: thkey, - Metadata: metadata, + Owner: email, + ID: thid, + Key: thkey, } - // Create first two Things with name. - if i < 2 { + // Create Things with name. + if i < thNameNum { + th.Name = name + } + // Create Things with metadata. + if i >= thNameNum && i < thNameNum+thMetaNum { + th.Metadata = metadata + } + // Create Things with name and metadata. + if i >= n-thNameMetaNum { + th.Metadata = metadata th.Name = name } @@ -446,8 +463,8 @@ func TestMultiThingRetrieval(t *testing.T) { offset: 1, limit: n, name: name, - size: 1, - total: 2, + size: thNameNum + thNameMetaNum - offset, + total: thNameNum + thNameMetaNum, }, "retrieve things with non-existing name": { owner: email, @@ -457,12 +474,29 @@ func TestMultiThingRetrieval(t *testing.T) { size: 0, total: 0, }, - "retrieve things with metadata": { + "retrieve things with existing metadata": { owner: email, offset: 0, limit: n, - size: n, - total: n, + size: thMetaNum + thNameMetaNum, + total: thMetaNum + thNameMetaNum, + metadata: metadata, + }, + "retrieve things with non-existing metadata": { + owner: email, + offset: 0, + limit: n, + size: 0, + total: 0, + metadata: wrongMeta, + }, + "retrieve all things with existing name and metadata": { + owner: email, + offset: 0, + limit: n, + size: thNameMetaNum, + total: thNameMetaNum, + name: name, metadata: metadata, }, } @@ -470,8 +504,8 @@ func TestMultiThingRetrieval(t *testing.T) { for desc, tc := range cases { page, err := thingRepo.RetrieveAll(context.Background(), tc.owner, tc.offset, tc.limit, tc.name, tc.metadata) size := uint64(len(page.Things)) - assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", desc, tc.size, size)) - assert.Equal(t, tc.total, page.Total, fmt.Sprintf("%s: expected %d got %d\n", desc, tc.total, page.Total)) + assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", desc, tc.size, size)) + assert.Equal(t, tc.total, page.Total, fmt.Sprintf("%s: expected total %d got %d\n", desc, tc.total, page.Total)) assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %d\n", desc, err)) } }