1
0
mirror of https://github.com/mainflux/mainflux.git synced 2025-04-27 13:48:49 +08:00
Manuel Imperiale 9972d1d1a4
MF-1240 - Return to service transport layer only service errors (#1559)
* MF-1240 - Return to service transport layer only service errors

Signed-off-by: Manuel Imperiale <manuel.imperiale@gmail.com>

* Remove unecessary errors

Signed-off-by: Manuel Imperiale <manuel.imperiale@gmail.com>

* Rm duplicated errors and fix transport

Signed-off-by: Manuel Imperiale <manuel.imperiale@gmail.com>

* Revert http endpoint_test

Signed-off-by: Manuel Imperiale <manuel.imperiale@gmail.com>

* Fix conflict

Signed-off-by: Manuel Imperiale <manuel.imperiale@gmail.com>

Co-authored-by: Dušan Borovčanin <dusan.borovcanin@mainflux.com>
2022-02-14 22:49:23 +01:00

544 lines
12 KiB
Go

// Copyright (c) Mainflux
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"strings"
"github.com/gofrs/uuid"
"github.com/lib/pq"
"github.com/mainflux/mainflux/pkg/errors"
"github.com/mainflux/mainflux/things"
)
var _ things.ChannelRepository = (*channelRepository)(nil)
type channelRepository struct {
db Database
}
type dbConnection struct {
Channel string `db:"channel"`
Thing string `db:"thing"`
Owner string `db:"owner"`
}
// NewChannelRepository instantiates a PostgreSQL implementation of channel
// repository.
func NewChannelRepository(db Database) things.ChannelRepository {
return &channelRepository{
db: db,
}
}
func (cr channelRepository) Save(ctx context.Context, channels ...things.Channel) ([]things.Channel, error) {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return nil, errors.Wrap(errors.ErrCreateEntity, err)
}
q := `INSERT INTO channels (id, owner, name, metadata)
VALUES (:id, :owner, :name, :metadata);`
for _, channel := range channels {
dbch := toDBChannel(channel)
_, err = tx.NamedExecContext(ctx, q, dbch)
if err != nil {
tx.Rollback()
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errInvalid, errTruncation:
return []things.Channel{}, errors.ErrMalformedEntity
case errDuplicate:
return []things.Channel{}, errors.ErrConflict
}
}
return []things.Channel{}, errors.Wrap(errors.ErrCreateEntity, err)
}
}
if err = tx.Commit(); err != nil {
return []things.Channel{}, errors.Wrap(errors.ErrCreateEntity, err)
}
return channels, nil
}
func (cr channelRepository) Update(ctx context.Context, channel things.Channel) error {
q := `UPDATE channels SET name = :name, metadata = :metadata WHERE owner = :owner AND id = :id;`
dbch := toDBChannel(channel)
res, err := cr.db.NamedExecContext(ctx, q, dbch)
if err != nil {
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errInvalid, errTruncation:
return errors.ErrMalformedEntity
}
}
return errors.Wrap(errors.ErrUpdateEntity, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(errors.ErrUpdateEntity, err)
}
if cnt == 0 {
return errors.ErrNotFound
}
return nil
}
func (cr channelRepository) RetrieveByID(ctx context.Context, owner, id string) (things.Channel, error) {
q := `SELECT name, metadata, owner FROM channels WHERE id = $1;`
dbch := dbChannel{
ID: id,
}
if err := cr.db.QueryRowxContext(ctx, q, id).StructScan(&dbch); err != nil {
pqErr, ok := err.(*pq.Error)
if err == sql.ErrNoRows || ok && errInvalid == pqErr.Code.Name() {
return things.Channel{}, errors.ErrNotFound
}
return things.Channel{}, errors.Wrap(errors.ErrViewEntity, err)
}
return toChannel(dbch), nil
}
func (cr channelRepository) RetrieveAll(ctx context.Context, owner string, pm things.PageMetadata) (things.ChannelsPage, error) {
nq, name := getNameQuery(pm.Name)
oq := getOrderQuery(pm.Order)
dq := getDirQuery(pm.Dir)
ownerQuery := getOwnerQuery(pm.FetchSharedThings)
meta, mq, err := getMetadataQuery(pm.Metadata)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrViewEntity, err)
}
var whereClause string
var query []string
if mq != "" {
query = append(query, mq)
}
if nq != "" {
query = append(query, nq)
}
if ownerQuery != "" {
query = append(query, ownerQuery)
}
if len(query) > 0 {
whereClause = fmt.Sprintf(" WHERE %s", strings.Join(query, " AND "))
}
q := fmt.Sprintf(`SELECT id, name, metadata FROM channels
%s ORDER BY %s %s LIMIT :limit OFFSET :offset;`, whereClause, oq, dq)
params := map[string]interface{}{
"owner": owner,
"limit": pm.Limit,
"offset": pm.Offset,
"name": name,
"metadata": meta,
}
rows, err := cr.db.NamedQueryContext(ctx, q, params)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrViewEntity, err)
}
defer rows.Close()
items := []things.Channel{}
for rows.Next() {
dbch := dbChannel{Owner: owner}
if err := rows.StructScan(&dbch); err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrViewEntity, err)
}
ch := toChannel(dbch)
items = append(items, ch)
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM channels %s;`, whereClause)
total, err := total(ctx, cr.db, cq, params)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrViewEntity, err)
}
page := things.ChannelsPage{
Channels: items,
PageMetadata: things.PageMetadata{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
Order: pm.Order,
Dir: pm.Dir,
},
}
return page, nil
}
func (cr channelRepository) RetrieveByThing(ctx context.Context, owner, thID string, pm things.PageMetadata) (things.ChannelsPage, error) {
oq := getConnOrderQuery(pm.Order, "ch")
dq := getDirQuery(pm.Dir)
// Verify if UUID format is valid to avoid internal Postgres error
if _, err := uuid.FromString(thID); err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrNotFound, err)
}
var q, qc string
switch pm.Disconnected {
case true:
q = fmt.Sprintf(`SELECT id, name, metadata
FROM channels ch
WHERE ch.owner = :owner AND ch.id NOT IN
(SELECT id FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = :owner AND conn.thing_id = :thing)
ORDER BY %s %s
LIMIT :limit
OFFSET :offset;`, oq, dq)
qc = `SELECT COUNT(*)
FROM channels ch
WHERE ch.owner = $1 AND ch.id NOT IN
(SELECT id FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = $1 AND conn.thing_id = $2);`
default:
q = fmt.Sprintf(`SELECT id, name, metadata FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = :owner AND conn.thing_id = :thing
ORDER BY %s %s
LIMIT :limit
OFFSET :offset;`, oq, dq)
qc = `SELECT COUNT(*)
FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = $1 AND conn.thing_id = $2`
}
params := map[string]interface{}{
"owner": owner,
"thing": thID,
"limit": pm.Limit,
"offset": pm.Offset,
}
rows, err := cr.db.NamedQueryContext(ctx, q, params)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrViewEntity, err)
}
defer rows.Close()
items := []things.Channel{}
for rows.Next() {
dbch := dbChannel{Owner: owner}
if err := rows.StructScan(&dbch); err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrViewEntity, err)
}
ch := toChannel(dbch)
items = append(items, ch)
}
var total uint64
if err := cr.db.GetContext(ctx, &total, qc, owner, thID); err != nil {
return things.ChannelsPage{}, errors.Wrap(errors.ErrViewEntity, err)
}
return things.ChannelsPage{
Channels: items,
PageMetadata: things.PageMetadata{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}, nil
}
func (cr channelRepository) Remove(ctx context.Context, owner, id string) error {
dbch := dbChannel{
ID: id,
Owner: owner,
}
q := `DELETE FROM channels WHERE id = :id AND owner = :owner`
cr.db.NamedExecContext(ctx, q, dbch)
return nil
}
func (cr channelRepository) Connect(ctx context.Context, owner string, chIDs, thIDs []string) error {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(things.ErrConnect, err)
}
q := `INSERT INTO connections (channel_id, channel_owner, thing_id, thing_owner)
VALUES (:channel, :owner, :thing, :owner);`
for _, chID := range chIDs {
for _, thID := range thIDs {
dbco := dbConnection{
Channel: chID,
Thing: thID,
Owner: owner,
}
_, err := tx.NamedExecContext(ctx, q, dbco)
if err != nil {
tx.Rollback()
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errFK:
return errors.ErrNotFound
case errDuplicate:
return errors.ErrConflict
}
}
return errors.Wrap(things.ErrConnect, err)
}
}
}
if err = tx.Commit(); err != nil {
return errors.Wrap(things.ErrConnect, err)
}
return nil
}
func (cr channelRepository) Disconnect(ctx context.Context, owner string, chIDs, thIDs []string) error {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(things.ErrConnect, err)
}
q := `DELETE FROM connections
WHERE channel_id = :channel AND channel_owner = :owner
AND thing_id = :thing AND thing_owner = :owner`
for _, chID := range chIDs {
for _, thID := range thIDs {
dbco := dbConnection{
Channel: chID,
Thing: thID,
Owner: owner,
}
res, err := tx.NamedExecContext(ctx, q, dbco)
if err != nil {
tx.Rollback()
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errFK:
return errors.ErrNotFound
case errDuplicate:
return errors.ErrConflict
}
}
return errors.Wrap(things.ErrDisconnect, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(things.ErrDisconnect, err)
}
if cnt == 0 {
return errors.ErrNotFound
}
}
}
if err = tx.Commit(); err != nil {
return errors.Wrap(things.ErrConnect, err)
}
return nil
}
func (cr channelRepository) HasThing(ctx context.Context, chanID, thingKey string) (string, error) {
var thingID string
q := `SELECT id FROM things WHERE key = $1`
if err := cr.db.QueryRowxContext(ctx, q, thingKey).Scan(&thingID); err != nil {
return "", errors.Wrap(errors.ErrViewEntity, err)
}
if err := cr.hasThing(ctx, chanID, thingID); err != nil {
return "", err
}
return thingID, nil
}
func (cr channelRepository) HasThingByID(ctx context.Context, chanID, thingID string) error {
return cr.hasThing(ctx, chanID, thingID)
}
func (cr channelRepository) hasThing(ctx context.Context, chanID, thingID string) error {
q := `SELECT EXISTS (SELECT 1 FROM connections WHERE channel_id = $1 AND thing_id = $2);`
exists := false
if err := cr.db.QueryRowxContext(ctx, q, chanID, thingID).Scan(&exists); err != nil {
return errors.Wrap(errors.ErrViewEntity, err)
}
if !exists {
return errors.ErrNotFound
}
return nil
}
// dbMetadata type for handling metadata properly in database/sql.
type dbMetadata map[string]interface{}
// Scan implements the database/sql scanner interface.
// When interface is nil `m` is set to nil.
// If error occurs on casting data then m points to empty metadata.
func (m *dbMetadata) Scan(value interface{}) error {
if value == nil {
m = nil
return nil
}
b, ok := value.([]byte)
if !ok {
m = &dbMetadata{}
return errors.ErrScanMetadata
}
if err := json.Unmarshal(b, m); err != nil {
return err
}
return nil
}
// Value implements database/sql valuer interface.
func (m dbMetadata) Value() (driver.Value, error) {
if len(m) == 0 {
return nil, nil
}
b, err := json.Marshal(m)
if err != nil {
return nil, err
}
return b, err
}
type dbChannel struct {
ID string `db:"id"`
Owner string `db:"owner"`
Name string `db:"name"`
Metadata dbMetadata `db:"metadata"`
}
func toDBChannel(ch things.Channel) dbChannel {
return dbChannel{
ID: ch.ID,
Owner: ch.Owner,
Name: ch.Name,
Metadata: ch.Metadata,
}
}
func toChannel(ch dbChannel) things.Channel {
return things.Channel{
ID: ch.ID,
Owner: ch.Owner,
Name: ch.Name,
Metadata: ch.Metadata,
}
}
func getNameQuery(name string) (string, string) {
if name == "" {
return "", ""
}
name = fmt.Sprintf(`%%%s%%`, strings.ToLower(name))
nq := `LOWER(name) LIKE :name`
return nq, name
}
func getOrderQuery(order string) string {
switch order {
case "name":
return "name"
default:
return "id"
}
}
func getConnOrderQuery(order string, level string) string {
switch order {
case "name":
return level + ".name"
default:
return level + ".id"
}
}
func getDirQuery(dir string) string {
switch dir {
case "asc":
return "ASC"
default:
return "DESC"
}
}
func getMetadataQuery(m things.Metadata) ([]byte, string, error) {
mq := ""
mb := []byte("{}")
if len(m) > 0 {
mq = `metadata @> :metadata`
b, err := json.Marshal(m)
if err != nil {
return nil, "", err
}
mb = b
}
return mb, mq, nil
}
func total(ctx context.Context, db Database, query string, params interface{}) (uint64, error) {
rows, err := db.NamedQueryContext(ctx, query, params)
if err != nil {
return 0, err
}
defer rows.Close()
total := uint64(0)
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
return total, nil
}