mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Avoid renaming packages on import
This commit is contained in:
parent
478b5055c7
commit
529b3bcaa5
@ -6,10 +6,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
import (
|
||||
mat64 "github.com/gonum/matrix/mat64"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
// An Estimator is object that can ingest some data and train on it.
|
||||
|
@ -2,17 +2,17 @@ package cross_validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
mat "github.com/gonum/matrix/mat64"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func shuffleMatrix(returnDatasets []*mat.Dense, dataset mat.Matrix, testSize int, seed int64, wg *sync.WaitGroup) {
|
||||
func shuffleMatrix(returnDatasets []*mat64.Dense, dataset mat64.Matrix, testSize int, seed int64, wg *sync.WaitGroup) {
|
||||
numGen := rand.New(rand.NewSource(seed))
|
||||
|
||||
// We don't want to alter the original dataset.
|
||||
shuffledSet := mat.DenseCopyOf(dataset)
|
||||
shuffledSet := mat64.DenseCopyOf(dataset)
|
||||
rowCount, colCount := shuffledSet.Dims()
|
||||
temp := make([]float64, colCount)
|
||||
|
||||
@ -27,8 +27,8 @@ func shuffleMatrix(returnDatasets []*mat.Dense, dataset mat.Matrix, testSize int
|
||||
}
|
||||
}
|
||||
trainSize := rowCount - testSize
|
||||
returnDatasets[0] = mat.NewDense(trainSize, colCount, shuffledSet.RawMatrix().Data[:trainSize*colCount])
|
||||
returnDatasets[1] = mat.NewDense(testSize, colCount, shuffledSet.RawMatrix().Data[trainSize*colCount:])
|
||||
returnDatasets[0] = mat64.NewDense(trainSize, colCount, shuffledSet.RawMatrix().Data[:trainSize*colCount])
|
||||
returnDatasets[1] = mat64.NewDense(testSize, colCount, shuffledSet.RawMatrix().Data[trainSize*colCount:])
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
@ -36,7 +36,7 @@ func shuffleMatrix(returnDatasets []*mat.Dense, dataset mat.Matrix, testSize int
|
||||
// TrainTestSplit splits input DenseMatrix into subsets for testing.
|
||||
// The function expects a test size number (int) or percentage (float64), and a random state or nil to get "random" shuffle.
|
||||
// It returns a list containing the train-test split and an error status.
|
||||
func TrainTestSplit(size interface{}, randomState interface{}, datasets ...*mat.Dense) ([]*mat.Dense, error) {
|
||||
func TrainTestSplit(size interface{}, randomState interface{}, datasets ...*mat64.Dense) ([]*mat64.Dense, error) {
|
||||
// Get number of instances (rows).
|
||||
instanceCount, _ := datasets[0].Dims()
|
||||
|
||||
@ -82,7 +82,7 @@ func TrainTestSplit(size interface{}, randomState interface{}, datasets ...*mat.
|
||||
wg.Add(dataCount)
|
||||
|
||||
// Return slice will hold training and test data and optional labels matrix.
|
||||
returnDatasets := make([]*mat.Dense, 2*dataCount)
|
||||
returnDatasets := make([]*mat64.Dense, 2*dataCount)
|
||||
|
||||
for i, dataset := range datasets {
|
||||
// Send proper returnDataset slice.
|
||||
|
@ -2,7 +2,7 @@ package cross_validation
|
||||
|
||||
import (
|
||||
//. "github.com/smartystreets/goconvey/convey"
|
||||
mat "github.com/gonum/matrix/mat64"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
@ -10,7 +10,7 @@ import (
|
||||
|
||||
var (
|
||||
flatValues, flatLabels []float64
|
||||
values, labels *mat.Dense
|
||||
values, labels *mat64.Dense
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -23,8 +23,8 @@ func init() {
|
||||
flatLabels[int(i/4)] = float64(rand.Intn(2))
|
||||
}
|
||||
|
||||
values = mat.NewDense(20, 4, flatValues)
|
||||
labels = mat.NewDense(20, 1, flatLabels)
|
||||
values = mat64.NewDense(20, 4, flatValues)
|
||||
labels = mat64.NewDense(20, 1, flatLabels)
|
||||
}
|
||||
|
||||
func TestTrainTrainTestSplit(t *testing.T) {
|
||||
|
@ -2,7 +2,7 @@ package ensemble
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
"testing"
|
||||
)
|
||||
@ -25,9 +25,9 @@ func TestRandomForest1(t *testing.T) {
|
||||
rf := NewRandomForest(10, 3)
|
||||
rf.Fit(trainData)
|
||||
predictions := rf.Predict(testData)
|
||||
confusionMat, err := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
}
|
||||
_ = eval.GetSummary(confusionMat)
|
||||
_ = evaluation.GetSummary(confusionMat)
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/ensemble"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
"github.com/sjwhitworth/golearn/trees"
|
||||
"math/rand"
|
||||
@ -50,11 +50,11 @@ func main() {
|
||||
|
||||
// Evaluate
|
||||
fmt.Println("ID3 Performance")
|
||||
cf, err := eval.GetConfusionMatrix(testData, predictions)
|
||||
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
||||
}
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
fmt.Println(evaluation.GetSummary(cf))
|
||||
|
||||
//
|
||||
// Next up, Random Trees
|
||||
@ -65,11 +65,11 @@ func main() {
|
||||
tree.Fit(testData)
|
||||
predictions = tree.Predict(testData)
|
||||
fmt.Println("RandomTree Performance")
|
||||
cf, err = eval.GetConfusionMatrix(testData, predictions)
|
||||
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
||||
}
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
fmt.Println(evaluation.GetSummary(cf))
|
||||
|
||||
//
|
||||
// Finally, Random Forests
|
||||
@ -78,9 +78,9 @@ func main() {
|
||||
tree.Fit(trainData)
|
||||
predictions = tree.Predict(testData)
|
||||
fmt.Println("RandomForest Performance")
|
||||
cf, err = eval.GetConfusionMatrix(testData, predictions)
|
||||
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
||||
}
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
fmt.Println(evaluation.GetSummary(cf))
|
||||
}
|
||||
|
26
knn/knn.go
26
knn/knn.go
@ -6,8 +6,8 @@ package knn
|
||||
import (
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
pairwiseMetrics "github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||
util "github.com/sjwhitworth/golearn/utilities"
|
||||
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||
"github.com/sjwhitworth/golearn/utilities"
|
||||
)
|
||||
|
||||
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, and a distance function.
|
||||
@ -36,12 +36,12 @@ func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) {
|
||||
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
|
||||
// Check what distance function we are using
|
||||
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
|
||||
var distanceFunc pairwise.PairwiseDistanceFunc
|
||||
switch KNN.DistanceFunc {
|
||||
case "euclidean":
|
||||
distanceFunc = pairwiseMetrics.NewEuclidean()
|
||||
distanceFunc = pairwise.NewEuclidean()
|
||||
case "manhattan":
|
||||
distanceFunc = pairwiseMetrics.NewManhattan()
|
||||
distanceFunc = pairwise.NewManhattan()
|
||||
default:
|
||||
panic("unsupported distance function")
|
||||
|
||||
@ -85,7 +85,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
|
||||
}
|
||||
|
||||
predMat := util.FloatsToMatrix(predRowBuf)
|
||||
predMat := utilities.FloatsToMatrix(predRowBuf)
|
||||
|
||||
// Find the closest match in the training data
|
||||
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
||||
@ -96,12 +96,12 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
}
|
||||
|
||||
// Compute the distance
|
||||
trainMat := util.FloatsToMatrix(trainRowBuf)
|
||||
trainMat := utilities.FloatsToMatrix(trainRowBuf)
|
||||
distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
|
||||
return true, nil
|
||||
})
|
||||
|
||||
sorted := util.SortIntMap(distances)
|
||||
sorted := utilities.SortIntMap(distances)
|
||||
values := sorted[:KNN.NearestNeighbours]
|
||||
|
||||
// Reset maxMap
|
||||
@ -167,24 +167,24 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
||||
labels := make([]float64, 0)
|
||||
|
||||
// Check what distance function we are using
|
||||
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
|
||||
var distanceFunc pairwise.PairwiseDistanceFunc
|
||||
switch KNN.DistanceFunc {
|
||||
case "euclidean":
|
||||
distanceFunc = pairwiseMetrics.NewEuclidean()
|
||||
distanceFunc = pairwise.NewEuclidean()
|
||||
case "manhattan":
|
||||
distanceFunc = pairwiseMetrics.NewManhattan()
|
||||
distanceFunc = pairwise.NewManhattan()
|
||||
default:
|
||||
panic("unsupported distance function")
|
||||
}
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.Data.RowView(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
rowMat := utilities.FloatsToMatrix(row)
|
||||
distance := distanceFunc.Distance(rowMat, vector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
|
||||
sorted := util.SortIntMap(rownumbers)
|
||||
sorted := utilities.SortIntMap(rownumbers)
|
||||
values := sorted[:K]
|
||||
|
||||
var sum float64
|
||||
|
@ -2,7 +2,7 @@ package meta
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
"github.com/sjwhitworth/golearn/trees"
|
||||
"math/rand"
|
||||
@ -84,9 +84,9 @@ func TestRandomForest1(t *testing.T) {
|
||||
|
||||
rf.Fit(trainDataf)
|
||||
predictions := rf.Predict(testDataf)
|
||||
confusionMat, err := eval.GetConfusionMatrix(testDataf, predictions)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testDataf, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
}
|
||||
_ = eval.GetSummary(confusionMat)
|
||||
_ = evaluation.GetSummary(confusionMat)
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"sort"
|
||||
)
|
||||
|
||||
@ -153,8 +153,8 @@ func (d *DecisionTreeNode) String() string {
|
||||
|
||||
// computeAccuracy is a helper method for Prune()
|
||||
func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 {
|
||||
cf, _ := eval.GetConfusionMatrix(from, predictions)
|
||||
return eval.GetAccuracy(cf)
|
||||
cf, _ := evaluation.GetConfusionMatrix(from, predictions)
|
||||
return evaluation.GetAccuracy(cf)
|
||||
}
|
||||
|
||||
// Prune eliminates branches which hurt accuracy
|
||||
|
@ -2,7 +2,7 @@ package trees
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
"math"
|
||||
"testing"
|
||||
@ -48,11 +48,11 @@ func TestRandomTreeClassification(t *testing.T) {
|
||||
root := InferID3Tree(trainDataF, r)
|
||||
|
||||
predictions := root.Predict(testDataF)
|
||||
confusionMat, err := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
}
|
||||
_ = eval.GetSummary(confusionMat)
|
||||
_ = evaluation.GetSummary(confusionMat)
|
||||
}
|
||||
|
||||
func TestRandomTreeClassification2(t *testing.T) {
|
||||
@ -74,11 +74,11 @@ func TestRandomTreeClassification2(t *testing.T) {
|
||||
root.Fit(trainDataF)
|
||||
|
||||
predictions := root.Predict(testDataF)
|
||||
confusionMat, err := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
}
|
||||
_ = eval.GetSummary(confusionMat)
|
||||
_ = evaluation.GetSummary(confusionMat)
|
||||
}
|
||||
|
||||
func TestPruning(t *testing.T) {
|
||||
@ -102,11 +102,11 @@ func TestPruning(t *testing.T) {
|
||||
root.Prune(fittestData)
|
||||
|
||||
predictions := root.Predict(testDataF)
|
||||
confusionMat, err := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
}
|
||||
_ = eval.GetSummary(confusionMat)
|
||||
_ = evaluation.GetSummary(confusionMat)
|
||||
}
|
||||
|
||||
func TestInformationGain(t *testing.T) {
|
||||
@ -196,11 +196,11 @@ func TestID3Classification(t *testing.T) {
|
||||
root := InferID3Tree(trainData, rule)
|
||||
|
||||
predictions := root.Predict(testData)
|
||||
confusionMat, err := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
}
|
||||
_ = eval.GetSummary(confusionMat)
|
||||
_ = evaluation.GetSummary(confusionMat)
|
||||
}
|
||||
|
||||
func TestID3(t *testing.T) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user