1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/cross_validation/cross_validation_test.go
Nelson Santos 779ad2842e Moved to own package and added concurrency support
Moved the function to the cross_validation package. Also modified the
shuffling function to run concurrently.
2014-05-03 12:01:38 -04:00

62 lines
1.4 KiB
Go

package cross_validation
import (
//. "github.com/smartystreets/goconvey/convey"
mat "github.com/gonum/matrix/mat64"
"math/rand"
"testing"
"time"
)
var (
flatValues, flatLabels []float64
values, labels *mat.Dense
)
func init() {
flatValues = make([]float64, 80)
flatLabels = make([]float64, 20)
for i := 0; i < 80; i++ {
flatValues[i] = float64(i + 1)
// Replaces labels four times per run but who cares.
flatLabels[int(i/4)] = float64(rand.Intn(2))
}
values = mat.NewDense(20, 4, flatValues)
labels = mat.NewDense(20, 1, flatLabels)
}
func TestTrainTrainTestSplit(t *testing.T) {
nolab1, err := TrainTestSplit(4, nil, values)
if err != nil {
t.Error(err)
}
// Make sure the random generator gets a new seed (time).
time.Sleep(time.Second)
nolab2, _ := TrainTestSplit(4, nil, values)
if nolab1[0].Equals(nolab2[0]) {
t.Errorf("Shuffle with different seed returned same matrix")
}
nolab1, _ = TrainTestSplit(4, 1, values)
nolab2, _ = TrainTestSplit(4, 1, values)
// Comparing the determinants does not guarantee uniqueness, but it will do for now.
if !nolab1[0].Equals(nolab2[0]) {
t.Errorf("Shuffle with same seed returned different matrix")
}
// Same thing for data with labels.
lab1, err := TrainTestSplit(0.1, 10, values, labels)
if err != nil {
t.Error(err)
}
lab2, _ := TrainTestSplit(0.1, 10, values, labels)
if !lab1[0].Equals(lab2[0]) {
t.Errorf("Shuffle with same seed returned different determinants")
}
}