// Example of how to use CART trees for both Classification and Regression package main import ( "fmt" "github.com/sjwhitworth/golearn/base" ) func main() { /* Performance of CART Algorithm: Training Time for Titanic Dataset ≈ 713 µs Prediction Time for Titanic Datset ≈ 133 µs Sklearn: Training Time for Titanic Dataset ≈ 8.8 µs Prediction Time for Titanic Datset ≈ 7.87 µs This implementation and sci-kit learn produce the exact same tree for the exact same dataset. Predictions on the same test set also yield the exact same accuracy. This implementation is optimized to prevent redundant iterations over the dataset, but it is not completely optimized. Also, sklearn makes use of numpy to access column easily, whereas here a complete iteration is required. In terms of Hyperparameters, this implmentation gives you the ability to choose the impurity function and the maxDepth. Many of the other hyperparameters used in sklearn are not here, but pruning and impurity is included. */ // Load Titanic Data For classification classificationData, err := base.ParseCSVToInstances("../datasets/titanic.csv", false) if err != nil { panic(err) } trainData, testData := base.InstancesTrainTestSplit(classificationData, 0.5) // Create New Classification Tree // Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels decTree = NewDecisionTreeClassifier("entropy", -1, []int64{0, 1}) // Train Tree decTree.Fit(trainData) // Print out tree for visualization - shows splits and feature and predictions fmt.Println(decTree.String()) // Access Predictions classificationPreds := decTree.Predict(testData) fmt.Println("Titanic Predictions") fmt.Println(classificationPreds) // Evaluate Accuracy on Test Data fmt.Println(decTree.Evaluate(testData)) // Load House Price Data For Regression regressionData, err := base.ParseCSVToInstances("../datasets/boston_house_prices.csv", false) if err != nil { panic(err) } trainRegData, testRegData := base.InstancesTrainTestSplit(regressionData, 0.5) // Hyperparameters - Loss function, max Depth (-1 will split until pure) regTree := NewDecisionTreeRegressor("mse", -1) // Train Tree regTree.Fit(trainRegData) // Print out tree for visualization fmt.Println(regTree.String()) // Access Predictions regressionPreds := regTree.Predict(testRegData) fmt.Println("Boston House Price Predictions") fmt.Println(regressionPreds) }