mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Adding cart_test.go
This commit is contained in:
parent
91a27e3ca0
commit
ef751e62c4
109
trees/cart_test.go
Normal file
109
trees/cart_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestRegressor(t *testing.T) {
|
||||
|
||||
Convey("Doing a CART Test", t, func() {
|
||||
// For Classification Trees:
|
||||
|
||||
// Is Gini being calculated correctly
|
||||
gini, giniMaxLabel := giniImpurity([]int64{1, 0, 0, 1}, []int64{0, 1})
|
||||
So(gini, ShouldEqual, 0.5)
|
||||
So(giniMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Entropy being calculated correctly
|
||||
entropy, entropyMaxLabel := entropy([]int64{1, 0, 0, 1}, []int64{0, 1})
|
||||
So(entropy, ShouldEqual, 1.0)
|
||||
So(entropyMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Data being split into left and right properly
|
||||
classifierData := [][]float64{[]float64{1, 3, 6},
|
||||
[]float64{1, 2, 3},
|
||||
[]float64{1, 9, 6},
|
||||
[]float64{1, 11, 1}}
|
||||
|
||||
classifiery := []int64{0, 1, 0, 0}
|
||||
|
||||
leftdata, rightdata, lefty, righty := classifierCreateSplit(classifierData, 1, classifiery, 5.0)
|
||||
|
||||
So(len(leftdata), ShouldEqual, 2)
|
||||
So(len(lefty), ShouldEqual, 2)
|
||||
So(len(rightdata), ShouldEqual, 2)
|
||||
So(len(righty), ShouldEqual, 2)
|
||||
|
||||
// Is isolating unique values working properly
|
||||
So(len(classifierFindUnique([]float64{10, 1, 1})), ShouldEqual, 2)
|
||||
|
||||
// is data reordered correctly
|
||||
orderedData, orderedY := classifierReOrderData(classifierGetFeature(classifierData, 1), classifierData, classifiery)
|
||||
fmt.Println(orderedData)
|
||||
fmt.Println(orderedY)
|
||||
So(orderedData[1][1], ShouldEqual, 3.0)
|
||||
So(orderedY[0], ShouldEqual, 1)
|
||||
|
||||
// Is split being updated properly based on threshold
|
||||
leftdata, lefty, rightdata, righty = classifierUpdateSplit(leftdata, lefty, rightdata, righty, 1, 9.5)
|
||||
So(len(leftdata), ShouldEqual, 3)
|
||||
So(len(rightdata), ShouldEqual, 1)
|
||||
|
||||
// Is the root Node null when tree is not trained?
|
||||
tree := NewDecisionTreeClassifier("gini", -1, []int64{0, 1})
|
||||
So(tree.RootNode, ShouldBeNil)
|
||||
So(tree.triedSplits, ShouldBeEmpty)
|
||||
|
||||
// ------------------------------------------
|
||||
// For Regression Trees
|
||||
|
||||
// Is MAE being calculated correctly
|
||||
mae, maeMaxLabel := maeImpurity([]float64{1, 3, 5})
|
||||
So(mae, ShouldEqual, (4.0 / 3.0))
|
||||
So(maeMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Entropy being calculated correctly
|
||||
mse, mseMaxLabel := mseImpurity([]float64{1, 3, 5})
|
||||
So(mse, ShouldEqual, (8.0 / 3.0))
|
||||
So(mseMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Data being split into left and right properly
|
||||
data := [][]float64{[]float64{1, 3, 6},
|
||||
[]float64{1, 2, 3},
|
||||
[]float64{1, 9, 6},
|
||||
[]float64{1, 11, 1}}
|
||||
|
||||
y := []float64{1, 2, 3, 4}
|
||||
|
||||
leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0)
|
||||
|
||||
So(len(leftData), ShouldEqual, 2)
|
||||
So(len(lefty), ShouldEqual, 2)
|
||||
So(len(rightData), ShouldEqual, 2)
|
||||
So(len(righty), ShouldEqual, 2)
|
||||
|
||||
// Is isolating unique values working properly
|
||||
So(len(regressorFindUnique([]float64{10, 1, 1})), ShouldEqual, 2)
|
||||
|
||||
// is data reordered correctly
|
||||
regressorOrderedData, regressorOrderedY := regressorReOrderData(regressorGetFeature(data, 1), data, y)
|
||||
|
||||
So(regressorOrderedData[1][1], ShouldEqual, 3.0)
|
||||
So(regressorOrderedY[0], ShouldEqual, 2)
|
||||
|
||||
// Is split being updated properly based on threshold
|
||||
leftData, leftY, rightData, rightY = regressorUpdateSplit(leftData, leftY, rightData, rightY, 1, 9.5)
|
||||
So(len(leftData), ShouldEqual, 3)
|
||||
So(len(rightData), ShouldEqual, 1)
|
||||
|
||||
// Is the root Node null when tree is not trained?
|
||||
regressorTreetree := NewDecisionTreeRegressor("mae", -1)
|
||||
So(regressorTreetree.RootNode, ShouldBeNil)
|
||||
So(regressorTreetree.triedSplits, ShouldBeEmpty)
|
||||
|
||||
})
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user