mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
110 lines
3.4 KiB
Go
110 lines
3.4 KiB
Go
package trees
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
)
|
|
|
|
func TestRegressor(t *testing.T) {
|
|
|
|
Convey("Doing a CART Test", t, func() {
|
|
// For Classification Trees:
|
|
|
|
// Is Gini being calculated correctly
|
|
gini, giniMaxLabel := giniImpurity([]int64{1, 0, 0, 1}, []int64{0, 1})
|
|
So(gini, ShouldEqual, 0.5)
|
|
So(giniMaxLabel, ShouldNotBeNil)
|
|
|
|
// Is Entropy being calculated correctly
|
|
entropy, entropyMaxLabel := entropy([]int64{1, 0, 0, 1}, []int64{0, 1})
|
|
So(entropy, ShouldEqual, 1.0)
|
|
So(entropyMaxLabel, ShouldNotBeNil)
|
|
|
|
// Is Data being split into left and right properly
|
|
classifierData := [][]float64{[]float64{1, 3, 6},
|
|
[]float64{1, 2, 3},
|
|
[]float64{1, 9, 6},
|
|
[]float64{1, 11, 1}}
|
|
|
|
classifiery := []int64{0, 1, 0, 0}
|
|
|
|
leftdata, rightdata, lefty, righty := classifierCreateSplit(classifierData, 1, classifiery, 5.0)
|
|
|
|
So(len(leftdata), ShouldEqual, 2)
|
|
So(len(lefty), ShouldEqual, 2)
|
|
So(len(rightdata), ShouldEqual, 2)
|
|
So(len(righty), ShouldEqual, 2)
|
|
|
|
// Is isolating unique values working properly
|
|
So(len(classifierFindUnique([]float64{10, 1, 1})), ShouldEqual, 2)
|
|
|
|
// is data reordered correctly
|
|
orderedData, orderedY := classifierReOrderData(classifierGetFeature(classifierData, 1), classifierData, classifiery)
|
|
fmt.Println(orderedData)
|
|
fmt.Println(orderedY)
|
|
So(orderedData[1][1], ShouldEqual, 3.0)
|
|
So(orderedY[0], ShouldEqual, 1)
|
|
|
|
// Is split being updated properly based on threshold
|
|
leftdata, lefty, rightdata, righty = classifierUpdateSplit(leftdata, lefty, rightdata, righty, 1, 9.5)
|
|
So(len(leftdata), ShouldEqual, 3)
|
|
So(len(rightdata), ShouldEqual, 1)
|
|
|
|
// Is the root Node null when tree is not trained?
|
|
tree := NewDecisionTreeClassifier("gini", -1, []int64{0, 1})
|
|
So(tree.RootNode, ShouldBeNil)
|
|
So(tree.triedSplits, ShouldBeEmpty)
|
|
|
|
// ------------------------------------------
|
|
// For Regression Trees
|
|
|
|
// Is MAE being calculated correctly
|
|
mae, maeMaxLabel := maeImpurity([]float64{1, 3, 5})
|
|
So(mae, ShouldEqual, (4.0 / 3.0))
|
|
So(maeMaxLabel, ShouldNotBeNil)
|
|
|
|
// Is Entropy being calculated correctly
|
|
mse, mseMaxLabel := mseImpurity([]float64{1, 3, 5})
|
|
So(mse, ShouldEqual, (8.0 / 3.0))
|
|
So(mseMaxLabel, ShouldNotBeNil)
|
|
|
|
// Is Data being split into left and right properly
|
|
data := [][]float64{[]float64{1, 3, 6},
|
|
[]float64{1, 2, 3},
|
|
[]float64{1, 9, 6},
|
|
[]float64{1, 11, 1}}
|
|
|
|
y := []float64{1, 2, 3, 4}
|
|
|
|
leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0)
|
|
|
|
So(len(leftData), ShouldEqual, 2)
|
|
So(len(lefty), ShouldEqual, 2)
|
|
So(len(rightData), ShouldEqual, 2)
|
|
So(len(righty), ShouldEqual, 2)
|
|
|
|
// Is isolating unique values working properly
|
|
So(len(regressorFindUnique([]float64{10, 1, 1})), ShouldEqual, 2)
|
|
|
|
// is data reordered correctly
|
|
regressorOrderedData, regressorOrderedY := regressorReOrderData(regressorGetFeature(data, 1), data, y)
|
|
|
|
So(regressorOrderedData[1][1], ShouldEqual, 3.0)
|
|
So(regressorOrderedY[0], ShouldEqual, 2)
|
|
|
|
// Is split being updated properly based on threshold
|
|
leftData, leftY, rightData, rightY = regressorUpdateSplit(leftData, leftY, rightData, rightY, 1, 9.5)
|
|
So(len(leftData), ShouldEqual, 3)
|
|
So(len(rightData), ShouldEqual, 1)
|
|
|
|
// Is the root Node null when tree is not trained?
|
|
regressorTreetree := NewDecisionTreeRegressor("mae", -1)
|
|
So(regressorTreetree.RootNode, ShouldBeNil)
|
|
So(regressorTreetree.triedSplits, ShouldBeEmpty)
|
|
|
|
})
|
|
|
|
}
|