2014-05-03 12:01:38 -04:00
|
|
|
|
package cross_validation
|
2014-05-02 20:50:14 -04:00
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
2014-05-02 22:38:15 -04:00
|
|
|
|
mat "github.com/gonum/matrix/mat64"
|
2014-05-02 20:50:14 -04:00
|
|
|
|
"math/rand"
|
2014-05-03 12:01:38 -04:00
|
|
|
|
"sync"
|
2014-05-02 20:50:14 -04:00
|
|
|
|
"time"
|
|
|
|
|
)
|
|
|
|
|
|
2014-05-03 12:01:38 -04:00
|
|
|
|
func shuffleMatrix(returnDatasets []*mat.Dense, dataset mat.Matrix, testSize int, seed int64, wg *sync.WaitGroup) {
|
|
|
|
|
numGen := rand.New(rand.NewSource(seed))
|
|
|
|
|
|
|
|
|
|
// We don't want to alter the original dataset.
|
2014-05-02 22:38:15 -04:00
|
|
|
|
shuffledSet := mat.DenseCopyOf(dataset)
|
|
|
|
|
rowCount, colCount := shuffledSet.Dims()
|
|
|
|
|
temp := make([]float64, colCount)
|
2014-05-03 12:01:38 -04:00
|
|
|
|
|
2014-05-02 22:38:15 -04:00
|
|
|
|
// Fisher–Yates shuffle
|
|
|
|
|
for i := 0; i < rowCount; i++ {
|
2014-05-03 12:01:38 -04:00
|
|
|
|
j := numGen.Intn(i+1)
|
2014-05-02 22:38:15 -04:00
|
|
|
|
if j != i {
|
|
|
|
|
// Make a "hard" copy to avoid pointer craziness.
|
|
|
|
|
copy(temp, shuffledSet.RowView(i))
|
|
|
|
|
shuffledSet.SetRow(i, shuffledSet.RowView(j))
|
|
|
|
|
shuffledSet.SetRow(j, temp)
|
|
|
|
|
}
|
2014-05-02 20:50:14 -04:00
|
|
|
|
}
|
2014-05-03 12:01:38 -04:00
|
|
|
|
trainSize := rowCount - testSize
|
|
|
|
|
returnDatasets[0] = mat.NewDense(trainSize, colCount, shuffledSet.RawMatrix().Data[:trainSize*colCount])
|
|
|
|
|
returnDatasets[1] = mat.NewDense(testSize, colCount, shuffledSet.RawMatrix().Data[trainSize*colCount:])
|
2014-05-02 20:50:14 -04:00
|
|
|
|
|
2014-05-03 12:01:38 -04:00
|
|
|
|
wg.Done()
|
2014-05-02 20:50:14 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TrainTestSplit splits input DenseMatrix into subsets for testing.
|
|
|
|
|
// The function expects a test size number (int) or percentage (float64), and a random state or nil to get "random" shuffle.
|
|
|
|
|
// It returns a list containing the train-test split and an error status.
|
2014-05-02 22:38:15 -04:00
|
|
|
|
func TrainTestSplit(size interface{}, randomState interface{}, datasets ...*mat.Dense) ([]*mat.Dense, error) {
|
2014-05-02 20:50:14 -04:00
|
|
|
|
// Get number of instances (rows).
|
2014-05-02 22:38:15 -04:00
|
|
|
|
instanceCount, _ := datasets[0].Dims()
|
2014-05-02 20:50:14 -04:00
|
|
|
|
|
|
|
|
|
// Input should be one or two matrices.
|
|
|
|
|
dataCount := len(datasets)
|
|
|
|
|
if dataCount > 2 {
|
|
|
|
|
return nil, fmt.Errorf("Expected 1 or 2 datasets, got %d\n", dataCount)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if dataCount == 2 {
|
|
|
|
|
// Test for consistency.
|
2014-05-02 22:38:15 -04:00
|
|
|
|
labelCount, labelFeatures := datasets[1].Dims()
|
|
|
|
|
if labelCount != instanceCount {
|
2014-05-02 20:50:14 -04:00
|
|
|
|
return nil, fmt.Errorf("Data and labels must have the same number of instances")
|
2014-05-02 22:38:15 -04:00
|
|
|
|
} else if labelFeatures != 1 {
|
2014-05-02 20:50:14 -04:00
|
|
|
|
return nil, fmt.Errorf("Label matrix must have single feature")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2014-05-03 12:01:38 -04:00
|
|
|
|
var testSize int
|
2014-05-02 20:50:14 -04:00
|
|
|
|
switch size := size.(type) {
|
|
|
|
|
// If size is an integer, treat it as the test data instance count.
|
|
|
|
|
case int:
|
|
|
|
|
testSize = size
|
|
|
|
|
case float64:
|
|
|
|
|
// If size is a float, treat it as a percentage of the instances to be allocated to the test set.
|
|
|
|
|
testSize = int(float64(instanceCount)*size + 0.5)
|
|
|
|
|
default:
|
|
|
|
|
return nil, fmt.Errorf("Expected a test instance count (int) or percentage (float64)")
|
|
|
|
|
}
|
|
|
|
|
|
2014-05-03 12:01:38 -04:00
|
|
|
|
var randSeed int64
|
2014-05-02 20:50:14 -04:00
|
|
|
|
// Create a deterministic shuffle, or a "random" one based on current time.
|
|
|
|
|
if seed, ok := randomState.(int); ok {
|
2014-05-03 12:01:38 -04:00
|
|
|
|
randSeed = int64(seed)
|
2014-05-02 20:50:14 -04:00
|
|
|
|
} else {
|
2014-05-03 12:01:38 -04:00
|
|
|
|
// Use seconds since epoch as seed
|
|
|
|
|
randSeed = time.Now().Unix()
|
2014-05-02 20:50:14 -04:00
|
|
|
|
}
|
|
|
|
|
|
2014-05-03 12:01:38 -04:00
|
|
|
|
// Wait group for goroutine syncronization.
|
|
|
|
|
wg := new(sync.WaitGroup)
|
|
|
|
|
wg.Add(dataCount)
|
2014-05-02 20:50:14 -04:00
|
|
|
|
|
2014-05-03 12:01:38 -04:00
|
|
|
|
// Return slice will hold training and test data and optional labels matrix.
|
|
|
|
|
returnDatasets := make([]*mat.Dense, 2*dataCount)
|
2014-05-02 22:38:15 -04:00
|
|
|
|
|
2014-05-03 12:01:38 -04:00
|
|
|
|
for i, dataset := range datasets {
|
|
|
|
|
// Send proper returnDataset slice.
|
|
|
|
|
// This is needed so goroutine doesn't mess up the expected return order.
|
|
|
|
|
// Perhaps returning a map is a better solution...
|
|
|
|
|
go shuffleMatrix(returnDatasets[i:i+2], dataset, testSize, randSeed, wg)
|
2014-05-02 20:50:14 -04:00
|
|
|
|
}
|
2014-05-03 12:01:38 -04:00
|
|
|
|
wg.Wait()
|
2014-05-02 20:50:14 -04:00
|
|
|
|
|
|
|
|
|
return returnDatasets, nil
|
|
|
|
|
}
|