2020-07-23 16:45:31 +05:30
|
|
|
|
// Example of how to use CART trees for both Classification and Regression
|
|
|
|
|
|
|
|
|
|
package main
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
|
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
2020-08-01 15:32:59 +05:30
|
|
|
|
"github.com/sjwhitworth/golearn/trees"
|
|
|
|
|
|
2020-07-23 16:45:31 +05:30
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func main() {
|
2020-07-30 11:48:50 +05:30
|
|
|
|
/* Performance of CART Algorithm:
|
|
|
|
|
|
2020-08-01 13:16:34 +05:30
|
|
|
|
Training Time for Titanic Dataset ≈ 611 µs
|
|
|
|
|
Prediction Time for Titanic Datset ≈ 101 µs
|
|
|
|
|
|
|
|
|
|
Complexity Analysis:
|
|
|
|
|
1x Dataset -- x ms
|
|
|
|
|
2x Dataset -- 1.7x ms
|
|
|
|
|
128x Dataset -- 74x ms
|
|
|
|
|
|
|
|
|
|
Complexity is sub linear
|
2020-07-30 11:48:50 +05:30
|
|
|
|
|
|
|
|
|
Sklearn:
|
|
|
|
|
Training Time for Titanic Dataset ≈ 8.8 µs
|
|
|
|
|
Prediction Time for Titanic Datset ≈ 7.87 µs
|
|
|
|
|
|
2020-08-01 13:16:34 +05:30
|
|
|
|
|
2020-07-30 11:48:50 +05:30
|
|
|
|
This implementation and sci-kit learn produce the exact same tree for the exact same dataset.
|
|
|
|
|
Predictions on the same test set also yield the exact same accuracy.
|
|
|
|
|
|
|
|
|
|
This implementation is optimized to prevent redundant iterations over the dataset, but it is not completely optimized. Also, sklearn makes use of numpy to access column easily, whereas here a complete iteration is required.
|
|
|
|
|
In terms of Hyperparameters, this implmentation gives you the ability to choose the impurity function and the maxDepth.
|
|
|
|
|
Many of the other hyperparameters used in sklearn are not here, but pruning and impurity is included.
|
|
|
|
|
*/
|
2020-07-23 16:45:31 +05:30
|
|
|
|
|
|
|
|
|
// Load Titanic Data For classification
|
|
|
|
|
classificationData, err := base.ParseCSVToInstances("../datasets/titanic.csv", false)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
trainData, testData := base.InstancesTrainTestSplit(classificationData, 0.5)
|
|
|
|
|
|
|
|
|
|
// Create New Classification Tree
|
|
|
|
|
// Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels
|
2020-08-01 11:25:53 +05:30
|
|
|
|
decTree := NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
|
2020-07-23 16:45:31 +05:30
|
|
|
|
|
|
|
|
|
// Train Tree
|
2020-08-01 11:25:53 +05:30
|
|
|
|
err = decTree.Fit(trainData)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
2020-07-23 16:45:31 +05:30
|
|
|
|
// Print out tree for visualization - shows splits and feature and predictions
|
|
|
|
|
fmt.Println(decTree.String())
|
|
|
|
|
|
|
|
|
|
// Access Predictions
|
|
|
|
|
classificationPreds := decTree.Predict(testData)
|
|
|
|
|
|
|
|
|
|
fmt.Println("Titanic Predictions")
|
|
|
|
|
fmt.Println(classificationPreds)
|
|
|
|
|
|
|
|
|
|
// Evaluate Accuracy on Test Data
|
|
|
|
|
fmt.Println(decTree.Evaluate(testData))
|
|
|
|
|
|
|
|
|
|
// Load House Price Data For Regression
|
|
|
|
|
regressionData, err := base.ParseCSVToInstances("../datasets/boston_house_prices.csv", false)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
trainRegData, testRegData := base.InstancesTrainTestSplit(regressionData, 0.5)
|
|
|
|
|
|
|
|
|
|
// Hyperparameters - Loss function, max Depth (-1 will split until pure)
|
|
|
|
|
regTree := NewDecisionTreeRegressor("mse", -1)
|
|
|
|
|
|
|
|
|
|
// Train Tree
|
2020-08-01 11:25:53 +05:30
|
|
|
|
err = regTree.Fit(trainRegData)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
2020-07-23 16:45:31 +05:30
|
|
|
|
|
|
|
|
|
// Print out tree for visualization
|
|
|
|
|
fmt.Println(regTree.String())
|
|
|
|
|
|
|
|
|
|
// Access Predictions
|
|
|
|
|
regressionPreds := regTree.Predict(testRegData)
|
|
|
|
|
|
|
|
|
|
fmt.Println("Boston House Price Predictions")
|
|
|
|
|
fmt.Println(regressionPreds)
|
|
|
|
|
|
|
|
|
|
}
|