1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/clustering/em_test.go
2018-06-16 22:11:59 +08:00

101 lines
2.6 KiB
Go

package clustering
import (
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestExpectationMaximization(t *testing.T) {
Convey("Doing EM-based clustering", t, func() {
em, _ := NewExpectationMaximization(2)
// Initialization tests
// Trying to create NewExpectationMaximization with < 1 component
Convey("With less than one component", func() {
Convey("Creating a new instance", func() {
_, err := NewExpectationMaximization(0)
Convey("Should result in a InsufficientComponentsError", func() {
So(err, ShouldEqual, InsufficientComponentsError)
})
})
})
// Data tests
// Trying to Fit with fewer samples than components
Convey("With insufficient training data", func() {
Convey("Fitting", func() {
testData, err := base.ParseCSVToInstances("./gaussian_mixture_single_obs.csv", false)
So(err, ShouldBeNil)
err = em.Fit(testData)
Convey("Should result in a InsufficientDataError", func() {
So(err, ShouldEqual, InsufficientDataError)
})
})
})
// Trying to Predict before having Fit
Convey("With no training data", func() {
Convey("Predicting", func() {
testData, err := base.ParseCSVToInstances("./gaussian_mixture.csv", false)
So(err, ShouldBeNil)
_, err = em.Predict(testData)
Convey("Should result in a NoTrainingDataError", func() {
So(err, ShouldEqual, NoTrainingDataError)
})
})
})
// Computation tests
// Test the predictions are resonable
Convey("With sufficient training data", func() {
instances, err := base.ParseCSVToInstances("./gaussian_mixture.csv", true)
So(err, ShouldBeNil)
Convey("Fitting", func() {
err := em.Fit(instances)
So(err, ShouldBeNil)
first_mean := em.Params.Means.At(0, 0)
Convey("It converges to reasonable a value", func() {
So(first_mean, ShouldAlmostEqual, -5.973, .1)
})
})
})
Convey("Test more code", func() {
trainData, _ := base.ParseCSVToInstances("./gaussian_mixture.csv", false)
testData, _ := base.ParseCSVToInstances("./gaussian_mixture.csv", false)
em, err := NewExpectationMaximization(1)
if err != nil {
panic(err)
}
em.Fit(trainData)
em.Predict(testData)
})
})
}
func BenchmarkExpectationMaximizationOneRow(b *testing.B) {
// Omits error handling in favor of brevity
trainData, _ := base.ParseCSVToInstances("./gaussian_mixture.csv", false)
testData, _ := base.ParseCSVToInstances("./gaussian_mixture.csv", false)
em, err := NewExpectationMaximization(2)
if err != nil {
panic(err)
}
em.Fit(trainData)
b.ResetTimer()
for n := 0; n < b.N; n++ {
em.Predict(testData)
}
}