From ef751e62c484badf66a142053e7c7b55eb5e38f2 Mon Sep 17 00:00:00 2001 From: Ayush Date: Mon, 27 Jul 2020 17:08:44 +0530 Subject: [PATCH] Adding cart_test.go --- trees/cart_test.go | 109 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 trees/cart_test.go diff --git a/trees/cart_test.go b/trees/cart_test.go new file mode 100644 index 0000000..047392a --- /dev/null +++ b/trees/cart_test.go @@ -0,0 +1,109 @@ +package trees + +import ( + "fmt" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestRegressor(t *testing.T) { + + Convey("Doing a CART Test", t, func() { + // For Classification Trees: + + // Is Gini being calculated correctly + gini, giniMaxLabel := giniImpurity([]int64{1, 0, 0, 1}, []int64{0, 1}) + So(gini, ShouldEqual, 0.5) + So(giniMaxLabel, ShouldNotBeNil) + + // Is Entropy being calculated correctly + entropy, entropyMaxLabel := entropy([]int64{1, 0, 0, 1}, []int64{0, 1}) + So(entropy, ShouldEqual, 1.0) + So(entropyMaxLabel, ShouldNotBeNil) + + // Is Data being split into left and right properly + classifierData := [][]float64{[]float64{1, 3, 6}, + []float64{1, 2, 3}, + []float64{1, 9, 6}, + []float64{1, 11, 1}} + + classifiery := []int64{0, 1, 0, 0} + + leftdata, rightdata, lefty, righty := classifierCreateSplit(classifierData, 1, classifiery, 5.0) + + So(len(leftdata), ShouldEqual, 2) + So(len(lefty), ShouldEqual, 2) + So(len(rightdata), ShouldEqual, 2) + So(len(righty), ShouldEqual, 2) + + // Is isolating unique values working properly + So(len(classifierFindUnique([]float64{10, 1, 1})), ShouldEqual, 2) + + // is data reordered correctly + orderedData, orderedY := classifierReOrderData(classifierGetFeature(classifierData, 1), classifierData, classifiery) + fmt.Println(orderedData) + fmt.Println(orderedY) + So(orderedData[1][1], ShouldEqual, 3.0) + So(orderedY[0], ShouldEqual, 1) + + // Is split being updated properly based on threshold + leftdata, lefty, rightdata, righty = classifierUpdateSplit(leftdata, lefty, rightdata, righty, 1, 9.5) + So(len(leftdata), ShouldEqual, 3) + So(len(rightdata), ShouldEqual, 1) + + // Is the root Node null when tree is not trained? + tree := NewDecisionTreeClassifier("gini", -1, []int64{0, 1}) + So(tree.RootNode, ShouldBeNil) + So(tree.triedSplits, ShouldBeEmpty) + + // ------------------------------------------ + // For Regression Trees + + // Is MAE being calculated correctly + mae, maeMaxLabel := maeImpurity([]float64{1, 3, 5}) + So(mae, ShouldEqual, (4.0 / 3.0)) + So(maeMaxLabel, ShouldNotBeNil) + + // Is Entropy being calculated correctly + mse, mseMaxLabel := mseImpurity([]float64{1, 3, 5}) + So(mse, ShouldEqual, (8.0 / 3.0)) + So(mseMaxLabel, ShouldNotBeNil) + + // Is Data being split into left and right properly + data := [][]float64{[]float64{1, 3, 6}, + []float64{1, 2, 3}, + []float64{1, 9, 6}, + []float64{1, 11, 1}} + + y := []float64{1, 2, 3, 4} + + leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0) + + So(len(leftData), ShouldEqual, 2) + So(len(lefty), ShouldEqual, 2) + So(len(rightData), ShouldEqual, 2) + So(len(righty), ShouldEqual, 2) + + // Is isolating unique values working properly + So(len(regressorFindUnique([]float64{10, 1, 1})), ShouldEqual, 2) + + // is data reordered correctly + regressorOrderedData, regressorOrderedY := regressorReOrderData(regressorGetFeature(data, 1), data, y) + + So(regressorOrderedData[1][1], ShouldEqual, 3.0) + So(regressorOrderedY[0], ShouldEqual, 2) + + // Is split being updated properly based on threshold + leftData, leftY, rightData, rightY = regressorUpdateSplit(leftData, leftY, rightData, rightY, 1, 9.5) + So(len(leftData), ShouldEqual, 3) + So(len(rightData), ShouldEqual, 1) + + // Is the root Node null when tree is not trained? + regressorTreetree := NewDecisionTreeRegressor("mae", -1) + So(regressorTreetree.RootNode, ShouldBeNil) + So(regressorTreetree.triedSplits, ShouldBeEmpty) + + }) + +}