mirror of
https://github.com/mainflux/mainflux.git
synced 2025-04-26 13:48:53 +08:00
341 lines
8.2 KiB
Go
341 lines
8.2 KiB
Go
![]() |
// Copyright (c) Mainflux
|
||
|
// SPDX-License-Identifier: Apache-2.0
|
||
|
|
||
|
package certs_test
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"fmt"
|
||
|
"io/ioutil"
|
||
|
"net/http/httptest"
|
||
|
"os"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/mainflux/mainflux"
|
||
|
bsmocks "github.com/mainflux/mainflux/bootstrap/mocks"
|
||
|
"github.com/mainflux/mainflux/certs"
|
||
|
"github.com/mainflux/mainflux/certs/mocks"
|
||
|
"github.com/mainflux/mainflux/pkg/errors"
|
||
|
mfsdk "github.com/mainflux/mainflux/pkg/sdk/go"
|
||
|
"github.com/mainflux/mainflux/things"
|
||
|
httpapi "github.com/mainflux/mainflux/things/api/things/http"
|
||
|
thmocks "github.com/mainflux/mainflux/things/mocks"
|
||
|
"github.com/opentracing/opentracing-go/mocktracer"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
wrongValue = "wrong-value"
|
||
|
email = "user@example.com"
|
||
|
token = "token"
|
||
|
thingsNum = 1
|
||
|
thingKey = "thingKey"
|
||
|
thingID = "1"
|
||
|
daysValid = "1h"
|
||
|
keyBits = 2048
|
||
|
key = "rsa"
|
||
|
certNum = 10
|
||
|
|
||
|
cfgLogLevel = "error"
|
||
|
cfgClientTLS = false
|
||
|
cfgServerCert = ""
|
||
|
cfgServerKey = ""
|
||
|
cfgBaseURL = "http://localhost"
|
||
|
cfgThingsPrefix = ""
|
||
|
cfgJaegerURL = ""
|
||
|
cfgAuthURL = "localhost:8181"
|
||
|
cfgAuthTimeout = "1s"
|
||
|
|
||
|
caPath = "../docker/ssl/certs/ca.crt"
|
||
|
caKeyPath = "../docker/ssl/certs/ca.key"
|
||
|
cfgSignHoursValid = "24h"
|
||
|
cfgSignRSABits = 2048
|
||
|
)
|
||
|
|
||
|
func newService(tokens map[string]string) (certs.Service, error) {
|
||
|
users := bsmocks.NewUsersService(map[string]string{token: email})
|
||
|
server := newThingsServer(newThingsService(users))
|
||
|
|
||
|
auth := thmocks.NewAuthService(tokens)
|
||
|
config := mfsdk.Config{
|
||
|
BaseURL: server.URL,
|
||
|
}
|
||
|
|
||
|
sdk := mfsdk.NewSDK(config)
|
||
|
repo := mocks.NewCertsRepository()
|
||
|
|
||
|
tlsCert, caCert, err := loadCertificates(caPath, caKeyPath)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
authTimeout, err := time.ParseDuration(cfgAuthTimeout)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
c := certs.Config{
|
||
|
LogLevel: cfgLogLevel,
|
||
|
ClientTLS: cfgClientTLS,
|
||
|
ServerCert: cfgServerCert,
|
||
|
ServerKey: cfgServerKey,
|
||
|
BaseURL: cfgBaseURL,
|
||
|
ThingsPrefix: cfgThingsPrefix,
|
||
|
JaegerURL: cfgJaegerURL,
|
||
|
AuthURL: cfgAuthURL,
|
||
|
SignTLSCert: tlsCert,
|
||
|
SignX509Cert: caCert,
|
||
|
SignHoursValid: cfgSignHoursValid,
|
||
|
SignRSABits: cfgSignRSABits,
|
||
|
}
|
||
|
|
||
|
pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignRSABits, cfgSignHoursValid, authTimeout)
|
||
|
|
||
|
return certs.New(auth, repo, sdk, c, pki), nil
|
||
|
}
|
||
|
|
||
|
func newThingsService(auth mainflux.AuthServiceClient) things.Service {
|
||
|
ths := make(map[string]things.Thing, thingsNum)
|
||
|
for i := 0; i < thingsNum; i++ {
|
||
|
id := strconv.Itoa(i + 1)
|
||
|
ths[id] = things.Thing{
|
||
|
ID: id,
|
||
|
Key: thingKey,
|
||
|
Owner: email,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return bsmocks.NewThingsService(ths, map[string]things.Channel{}, auth)
|
||
|
}
|
||
|
|
||
|
func TestIssueCert(t *testing.T) {
|
||
|
svc, err := newService(map[string]string{token: email})
|
||
|
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
|
||
|
|
||
|
cases := []struct {
|
||
|
token string
|
||
|
desc string
|
||
|
thingID string
|
||
|
daysValid string
|
||
|
key string
|
||
|
keyBits int
|
||
|
err error
|
||
|
}{
|
||
|
{
|
||
|
desc: "issue new cert",
|
||
|
token: token,
|
||
|
thingID: thingID,
|
||
|
daysValid: daysValid,
|
||
|
key: key,
|
||
|
keyBits: 2048,
|
||
|
err: nil,
|
||
|
},
|
||
|
{
|
||
|
desc: "issue new cert for non existing thing id",
|
||
|
token: token,
|
||
|
thingID: "2",
|
||
|
daysValid: daysValid,
|
||
|
key: key,
|
||
|
keyBits: 2048,
|
||
|
err: certs.ErrFailedCertCreation,
|
||
|
},
|
||
|
{
|
||
|
desc: "issue new cert for non existing thing id",
|
||
|
token: wrongValue,
|
||
|
thingID: thingID,
|
||
|
daysValid: daysValid,
|
||
|
key: key,
|
||
|
keyBits: 2048,
|
||
|
err: certs.ErrUnauthorizedAccess,
|
||
|
},
|
||
|
{
|
||
|
desc: "issue new cert for bad key bits",
|
||
|
token: token,
|
||
|
thingID: thingID,
|
||
|
daysValid: daysValid,
|
||
|
key: key,
|
||
|
keyBits: -2,
|
||
|
err: certs.ErrFailedCertCreation,
|
||
|
},
|
||
|
{
|
||
|
desc: "issue new cert for bad key bits",
|
||
|
token: token,
|
||
|
thingID: thingID,
|
||
|
daysValid: daysValid,
|
||
|
key: key,
|
||
|
keyBits: -2,
|
||
|
err: certs.ErrFailedCertCreation,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tc := range cases {
|
||
|
c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.daysValid, tc.keyBits, tc.key)
|
||
|
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||
|
cert, _ := readCert([]byte(c.ClientCert))
|
||
|
if cert != nil {
|
||
|
assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func TestRevokeCert(t *testing.T) {
|
||
|
svc, err := newService(map[string]string{token: email})
|
||
|
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
|
||
|
|
||
|
_, err = svc.IssueCert(context.Background(), token, thingID, daysValid, keyBits, key)
|
||
|
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
|
||
|
|
||
|
cases := []struct {
|
||
|
token string
|
||
|
desc string
|
||
|
thingID string
|
||
|
err error
|
||
|
}{
|
||
|
{
|
||
|
desc: "revoke cert",
|
||
|
token: token,
|
||
|
thingID: thingID,
|
||
|
err: nil,
|
||
|
},
|
||
|
{
|
||
|
desc: "revoke cert for invalid token",
|
||
|
token: wrongValue,
|
||
|
thingID: thingID,
|
||
|
err: certs.ErrUnauthorizedAccess,
|
||
|
},
|
||
|
{
|
||
|
desc: "revoke cert for invalid thing id",
|
||
|
token: token,
|
||
|
thingID: "2",
|
||
|
err: certs.ErrFailedCertRevocation,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tc := range cases {
|
||
|
_, err := svc.RevokeCert(context.Background(), tc.token, tc.thingID)
|
||
|
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func TestListCerts(t *testing.T) {
|
||
|
svc, err := newService(map[string]string{token: email})
|
||
|
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
|
||
|
|
||
|
for i := 0; i < certNum; i++ {
|
||
|
_, err = svc.IssueCert(context.Background(), token, thingID, daysValid, keyBits, key)
|
||
|
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
|
||
|
}
|
||
|
|
||
|
cases := []struct {
|
||
|
token string
|
||
|
desc string
|
||
|
thingID string
|
||
|
offset uint64
|
||
|
limit uint64
|
||
|
size uint64
|
||
|
err error
|
||
|
}{
|
||
|
{
|
||
|
desc: "list all certs with valid token",
|
||
|
token: token,
|
||
|
thingID: thingID,
|
||
|
offset: 0,
|
||
|
limit: certNum,
|
||
|
size: certNum,
|
||
|
err: nil,
|
||
|
},
|
||
|
{
|
||
|
desc: "list all certs with invalid token",
|
||
|
token: wrongValue,
|
||
|
thingID: thingID,
|
||
|
offset: 0,
|
||
|
limit: certNum,
|
||
|
size: 0,
|
||
|
err: certs.ErrUnauthorizedAccess,
|
||
|
},
|
||
|
{
|
||
|
desc: "list half certs with invalid token",
|
||
|
token: token,
|
||
|
thingID: thingID,
|
||
|
offset: certNum / 2,
|
||
|
limit: certNum,
|
||
|
size: certNum / 2,
|
||
|
err: nil,
|
||
|
},
|
||
|
{
|
||
|
desc: "list last certs with invalid token",
|
||
|
token: token,
|
||
|
thingID: thingID,
|
||
|
offset: certNum - 1,
|
||
|
limit: certNum,
|
||
|
size: 1,
|
||
|
err: nil,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tc := range cases {
|
||
|
page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
|
||
|
size := uint64(len(page.Certs))
|
||
|
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
|
||
|
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func newThingsServer(svc things.Service) *httptest.Server {
|
||
|
mux := httpapi.MakeHandler(mocktracer.New(), svc)
|
||
|
return httptest.NewServer(mux)
|
||
|
}
|
||
|
|
||
|
func loadCertificates(caPath, caKeyPath string) (tls.Certificate, *x509.Certificate, error) {
|
||
|
var tlsCert tls.Certificate
|
||
|
var caCert *x509.Certificate
|
||
|
|
||
|
if caPath == "" || caKeyPath == "" {
|
||
|
return tlsCert, caCert, nil
|
||
|
}
|
||
|
|
||
|
if _, err := os.Stat(caPath); os.IsNotExist(err) {
|
||
|
return tlsCert, caCert, err
|
||
|
}
|
||
|
|
||
|
if _, err := os.Stat(caKeyPath); os.IsNotExist(err) {
|
||
|
return tlsCert, caCert, err
|
||
|
}
|
||
|
|
||
|
tlsCert, err := tls.LoadX509KeyPair(caPath, caKeyPath)
|
||
|
if err != nil {
|
||
|
return tlsCert, caCert, errors.Wrap(err, err)
|
||
|
}
|
||
|
|
||
|
b, err := ioutil.ReadFile(caPath)
|
||
|
if err != nil {
|
||
|
return tlsCert, caCert, err
|
||
|
}
|
||
|
|
||
|
caCert, err = readCert(b)
|
||
|
if err != nil {
|
||
|
return tlsCert, caCert, errors.Wrap(err, err)
|
||
|
}
|
||
|
|
||
|
return tlsCert, caCert, nil
|
||
|
}
|
||
|
|
||
|
func readCert(b []byte) (*x509.Certificate, error) {
|
||
|
block, _ := pem.Decode(b)
|
||
|
if block == nil {
|
||
|
return nil, errors.New("failed to decode PEM data")
|
||
|
}
|
||
|
|
||
|
return x509.ParseCertificate(block.Bytes)
|
||
|
}
|