mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
76 lines
1.7 KiB
Go
76 lines
1.7 KiB
Go
package linear_models
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
func TestNoTrainingData(t *testing.T) {
|
|
lr := NewLinearRegression()
|
|
|
|
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = lr.Predict(rawData)
|
|
if err != NoTrainingDataError {
|
|
t.Fatal("failed to error out even if no training data exists")
|
|
}
|
|
}
|
|
|
|
func TestNotEnoughTrainingData(t *testing.T) {
|
|
lr := NewLinearRegression()
|
|
|
|
rawData, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = lr.Fit(rawData)
|
|
if err != NotEnoughDataError {
|
|
t.Fatal("failed to error out even though there was not enough data")
|
|
}
|
|
}
|
|
|
|
func TestLinearRegression(t *testing.T) {
|
|
lr := NewLinearRegression()
|
|
|
|
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
|
|
err = lr.Fit(trainData)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
predictions, err := lr.Predict(testData)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, rows := predictions.Size()
|
|
|
|
for i := 0; i < rows; i++ {
|
|
fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
|
|
}
|
|
}
|
|
|
|
func BenchmarkLinearRegressionOneRow(b *testing.B) {
|
|
// Omits error handling in favor of brevity
|
|
trainData, _ := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
|
testData, _ := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
|
lr := NewLinearRegression()
|
|
lr.Fit(trainData)
|
|
|
|
b.ResetTimer()
|
|
for n := 0; n < b.N; n++ {
|
|
lr.Predict(testData)
|
|
}
|
|
}
|