mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
74 lines
1.7 KiB
Go
74 lines
1.7 KiB
Go
package linear_models
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
func TestNoTrainingData(t *testing.T) {
|
|
lr := NewLinearRegression()
|
|
|
|
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = lr.Predict(rawData)
|
|
if err != NoTrainingDataError {
|
|
t.Fatal("failed to error out even if no training data exists")
|
|
}
|
|
}
|
|
|
|
func TestNotEnoughTrainingData(t *testing.T) {
|
|
lr := NewLinearRegression()
|
|
|
|
rawData, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = lr.Fit(rawData)
|
|
if err != NotEnoughDataError {
|
|
t.Fatal("failed to error out even though there was not enough data")
|
|
}
|
|
}
|
|
|
|
func TestLinearRegression(t *testing.T) {
|
|
lr := NewLinearRegression()
|
|
|
|
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
|
|
err = lr.Fit(trainData)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
predictions, err := lr.Predict(testData)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for i := 0; i < predictions.Rows; i++ {
|
|
fmt.Printf("Expected: %f || Predicted: %f\n", testData.Get(i, testData.ClassIndex), predictions.Get(i, predictions.ClassIndex))
|
|
}
|
|
}
|
|
|
|
func BenchmarkLinearRegressionOneRow(b *testing.B) {
|
|
// Omits error handling in favor of brevity
|
|
trainData, _ := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
testData, _ := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
|
lr := NewLinearRegression()
|
|
lr.Fit(trainData)
|
|
|
|
b.ResetTimer()
|
|
for n := 0; n < b.N; n++ {
|
|
lr.Predict(testData)
|
|
}
|
|
}
|