mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
80 lines
2.1 KiB
Go
80 lines
2.1 KiB
Go
package linear_models
|
|
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"strconv"
|
|
"testing"
|
|
)
|
|
|
|
func TestLinearRegression(t *testing.T) {
|
|
Convey("Doing a linear regression", t, func() {
|
|
lr := NewLinearRegression()
|
|
|
|
Convey("With no training data", func() {
|
|
Convey("Predicting", func() {
|
|
testData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
So(err, ShouldBeNil)
|
|
|
|
_, err = lr.Predict(testData)
|
|
|
|
Convey("Should result in a NoTrainingDataError", func() {
|
|
So(err, ShouldEqual, NoTrainingDataError)
|
|
})
|
|
|
|
})
|
|
})
|
|
|
|
Convey("With not enough training data", func() {
|
|
trainingDatum, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Fitting", func() {
|
|
err = lr.Fit(trainingDatum)
|
|
|
|
Convey("Should result in a NotEnoughDataError", func() {
|
|
So(err, ShouldEqual, NotEnoughDataError)
|
|
})
|
|
})
|
|
})
|
|
|
|
Convey("With sufficient training data", func() {
|
|
instances, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
So(err, ShouldBeNil)
|
|
trainData, testData := base.InstancesTrainTestSplit(instances, 0.1)
|
|
|
|
Convey("Fitting and Predicting", func() {
|
|
err := lr.Fit(trainData)
|
|
So(err, ShouldBeNil)
|
|
|
|
predictions, err := lr.Predict(testData)
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("It makes reasonable predictions", func() {
|
|
_, rows := predictions.Size()
|
|
|
|
for i := 0; i < rows; i++ {
|
|
actualValue, _ := strconv.ParseFloat(base.GetClass(testData, i), 64)
|
|
expectedValue, _ := strconv.ParseFloat(base.GetClass(predictions, i), 64)
|
|
|
|
So(actualValue, ShouldAlmostEqual, expectedValue, actualValue*0.05)
|
|
}
|
|
})
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
func BenchmarkLinearRegressionOneRow(b *testing.B) {
|
|
// Omits error handling in favor of brevity
|
|
trainData, _ := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
testData, _ := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
|
lr := NewLinearRegression()
|
|
lr.Fit(trainData)
|
|
|
|
b.ResetTimer()
|
|
for n := 0; n < b.N; n++ {
|
|
lr.Predict(testData)
|
|
}
|
|
}
|