mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Moved to own package and added concurrency support
Moved the function to the cross_validation package. Also modified the shuffling function to run concurrently.
This commit is contained in:
parent
030d9844c6
commit
779ad2842e
@ -1,20 +1,24 @@
|
||||
package utilities
|
||||
package cross_validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
mat "github.com/gonum/matrix/mat64"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func shuffleMatrix(dataset *mat.Dense, numGen *rand.Rand) *mat.Dense {
|
||||
func shuffleMatrix(returnDatasets []*mat.Dense, dataset mat.Matrix, testSize int, seed int64, wg *sync.WaitGroup) {
|
||||
numGen := rand.New(rand.NewSource(seed))
|
||||
|
||||
// We don't want to alter the original dataset.
|
||||
shuffledSet := mat.DenseCopyOf(dataset)
|
||||
rowCount, colCount := shuffledSet.Dims()
|
||||
temp := make([]float64, colCount)
|
||||
|
||||
|
||||
// Fisher–Yates shuffle
|
||||
for i := 0; i < rowCount; i++ {
|
||||
j := numGen.Intn(i + 1)
|
||||
j := numGen.Intn(i+1)
|
||||
if j != i {
|
||||
// Make a "hard" copy to avoid pointer craziness.
|
||||
copy(temp, shuffledSet.RowView(i))
|
||||
@ -22,8 +26,11 @@ func shuffleMatrix(dataset *mat.Dense, numGen *rand.Rand) *mat.Dense {
|
||||
shuffledSet.SetRow(j, temp)
|
||||
}
|
||||
}
|
||||
trainSize := rowCount - testSize
|
||||
returnDatasets[0] = mat.NewDense(trainSize, colCount, shuffledSet.RawMatrix().Data[:trainSize*colCount])
|
||||
returnDatasets[1] = mat.NewDense(testSize, colCount, shuffledSet.RawMatrix().Data[trainSize*colCount:])
|
||||
|
||||
return shuffledSet
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
// TrainTestSplit splits input DenseMatrix into subsets for testing.
|
||||
@ -49,41 +56,41 @@ func TrainTestSplit(size interface{}, randomState interface{}, datasets ...*mat.
|
||||
}
|
||||
}
|
||||
|
||||
var trainSize, testSize int
|
||||
var testSize int
|
||||
switch size := size.(type) {
|
||||
// If size is an integer, treat it as the test data instance count.
|
||||
case int:
|
||||
trainSize = instanceCount - size
|
||||
testSize = size
|
||||
case float64:
|
||||
// If size is a float, treat it as a percentage of the instances to be allocated to the test set.
|
||||
trainSize = int(float64(instanceCount)*(1-size) + 0.5)
|
||||
testSize = int(float64(instanceCount)*size + 0.5)
|
||||
default:
|
||||
return nil, fmt.Errorf("Expected a test instance count (int) or percentage (float64)")
|
||||
}
|
||||
|
||||
var randSeed int64
|
||||
// Create a deterministic shuffle, or a "random" one based on current time.
|
||||
var randSource rand.Source
|
||||
if seed, ok := randomState.(int); ok {
|
||||
randSource = rand.NewSource(int64(seed))
|
||||
randSeed = int64(seed)
|
||||
} else {
|
||||
randSource = rand.NewSource(time.Now().Unix())
|
||||
// Use seconds since epoch as seed
|
||||
randSeed = time.Now().Unix()
|
||||
}
|
||||
numGen := rand.New(randSource)
|
||||
|
||||
// Wait group for goroutine syncronization.
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(dataCount)
|
||||
|
||||
// Return slice will hold training and test data and optional labels matrix.
|
||||
var returnDatasets []*mat.Dense
|
||||
returnDatasets := make([]*mat.Dense, 2*dataCount)
|
||||
|
||||
for _, dataset := range datasets {
|
||||
_, featureCount := dataset.Dims()
|
||||
|
||||
tempMatrix := shuffleMatrix(dataset, numGen)
|
||||
|
||||
// Features count is different on data and labels.
|
||||
returnDatasets = append(returnDatasets, mat.NewDense(trainSize, featureCount, tempMatrix.RawMatrix().Data[:trainSize*featureCount]))
|
||||
returnDatasets = append(returnDatasets, mat.NewDense(testSize, featureCount, tempMatrix.RawMatrix().Data[trainSize*featureCount:]))
|
||||
for i, dataset := range datasets {
|
||||
// Send proper returnDataset slice.
|
||||
// This is needed so goroutine doesn't mess up the expected return order.
|
||||
// Perhaps returning a map is a better solution...
|
||||
go shuffleMatrix(returnDatasets[i:i+2], dataset, testSize, randSeed, wg)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return returnDatasets, nil
|
||||
}
|
61
cross_validation/cross_validation_test.go
Normal file
61
cross_validation/cross_validation_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package cross_validation
|
||||
|
||||
import (
|
||||
//. "github.com/smartystreets/goconvey/convey"
|
||||
mat "github.com/gonum/matrix/mat64"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
flatValues, flatLabels []float64
|
||||
values, labels *mat.Dense
|
||||
)
|
||||
|
||||
func init() {
|
||||
flatValues = make([]float64, 80)
|
||||
flatLabels = make([]float64, 20)
|
||||
|
||||
for i := 0; i < 80; i++ {
|
||||
flatValues[i] = float64(i + 1)
|
||||
// Replaces labels four times per run but who cares.
|
||||
flatLabels[int(i/4)] = float64(rand.Intn(2))
|
||||
}
|
||||
|
||||
values = mat.NewDense(20, 4, flatValues)
|
||||
labels = mat.NewDense(20, 1, flatLabels)
|
||||
}
|
||||
|
||||
func TestTrainTrainTestSplit(t *testing.T) {
|
||||
nolab1, err := TrainTestSplit(4, nil, values)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Make sure the random generator gets a new seed (time).
|
||||
time.Sleep(time.Second)
|
||||
|
||||
nolab2, _ := TrainTestSplit(4, nil, values)
|
||||
if nolab1[0].Equals(nolab2[0]) {
|
||||
t.Errorf("Shuffle with different seed returned same matrix")
|
||||
}
|
||||
|
||||
nolab1, _ = TrainTestSplit(4, 1, values)
|
||||
nolab2, _ = TrainTestSplit(4, 1, values)
|
||||
// Comparing the determinants does not guarantee uniqueness, but it will do for now.
|
||||
if !nolab1[0].Equals(nolab2[0]) {
|
||||
t.Errorf("Shuffle with same seed returned different matrix")
|
||||
}
|
||||
|
||||
// Same thing for data with labels.
|
||||
lab1, err := TrainTestSplit(0.1, 10, values, labels)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
lab2, _ := TrainTestSplit(0.1, 10, values, labels)
|
||||
if !lab1[0].Equals(lab2[0]) {
|
||||
t.Errorf("Shuffle with same seed returned different determinants")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user