1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Use gonum/matrix/mat64, life is easier.

This commit is contained in:
Bert Chang 2014-05-03 00:55:38 +08:00
parent 4fc0ed0477
commit 804cd82cfc
3 changed files with 13 additions and 31 deletions

View File

@ -3,7 +3,7 @@ package pairwise
import (
"math"
mat "github.com/skelterjohn/go.matrix"
"github.com/gonum/matrix/mat64"
)
type Euclidean struct{}
@ -12,23 +12,17 @@ func NewEuclidean() *Euclidean {
return &Euclidean{}
}
func (self *Euclidean) InnerProduct(vectorX *mat.DenseMatrix, vectorY *mat.DenseMatrix) float64 {
CheckDimMatch(vectorX, vectorY)
result := mat.Product(mat.Transpose(vectorX), vectorY).Get(0, 0)
func (self *Euclidean) InnerProduct(vectorX *mat64.Dense, vectorY *mat64.Dense) float64 {
result := vectorX.Dot(vectorY)
return result
}
// We may need to create Metrics / Vector interface for this
func (self *Euclidean) Distance(vectorX *mat.DenseMatrix, vectorY *mat.DenseMatrix) (float64, error) {
difference, err := vectorY.MinusDense(vectorX)
func (self *Euclidean) Distance(vectorX *mat64.Dense, vectorY *mat64.Dense) float64 {
vectorX.Sub(vectorX, vectorY)
if err != nil {
return 0, err
}
result := self.InnerProduct(difference, difference)
return math.Sqrt(result), nil
result := self.InnerProduct(vectorX, vectorX)
return math.Sqrt(result)
}

View File

@ -2,17 +2,17 @@ package pairwise
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
mat "github.com/skelterjohn/go.matrix"
"github.com/gonum/matrix/mat64"
. "github.com/smartystreets/goconvey/convey"
)
func TestEuclidean(t *testing.T) {
euclidean := NewEuclidean()
Convey("Given two vectors", t, func() {
vectorX := mat.MakeDenseMatrix([]float64{1, 2, 3}, 3, 1)
vectorY := mat.MakeDenseMatrix([]float64{2, 4, 5}, 3, 1)
vectorX := mat64.NewDense(3, 1, []float64{1, 2, 3})
vectorY := mat64.NewDense(3, 1, []float64{2, 4, 5})
Convey("When doing inner product", func() {
result := euclidean.InnerProduct(vectorX, vectorY)
@ -20,16 +20,6 @@ func TestEuclidean(t *testing.T) {
Convey("The result should be 25", func() {
So(result, ShouldEqual, 25)
})
Convey("When dimension not match", func() {
vectorZ := mat.MakeDenseMatrix([]float64{3, 4, 5}, 1, 3)
Convey("It should panic with Dimension mismatch", func() {
So(func() { euclidean.InnerProduct(vectorX, vectorZ) }, ShouldPanicWith, "Dimension mismatch")
})
})
})
Convey("When calculating distance", func() {

View File

@ -3,7 +3,7 @@ package pairwise
import (
"math"
mat "github.com/skelterjohn/go.matrix"
"github.com/gonum/matrix/mat64"
)
type RBFKernel struct {
@ -14,9 +14,7 @@ func NewRBFKernel(gamma float64) *RBFKernel {
return &RBFKernel{gamma: gamma}
}
func (self *RBFKernel) InnerProduct(vectorX *mat.DenseMatrix, vectorY *mat.DenseMatrix) (float64, error) {
CheckDimMatch(vectorX, vectorY)
func (self *RBFKernel) InnerProduct(vectorX *mat64.Dense, vectorY *mat64.Dense) (float64, error) {
euclidean := NewEuclidean()
distance, err := euclidean.Distance(vectorX, vectorY)