1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
2020-07-23 16:45:31 +05:30

62 lines
1.6 KiB
Go

// Example of how to use CART trees for both Classification and Regression
package main
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
)
func main() {
// Load Titanic Data For classification
classificationData, err := base.ParseCSVToInstances("../datasets/titanic.csv", false)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(classificationData, 0.5)
// Create New Classification Tree
// Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels
decTree = NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
// Train Tree
decTree.Fit(trainData)
// Print out tree for visualization - shows splits and feature and predictions
fmt.Println(decTree.String())
// Access Predictions
classificationPreds := decTree.Predict(testData)
fmt.Println("Titanic Predictions")
fmt.Println(classificationPreds)
// Evaluate Accuracy on Test Data
fmt.Println(decTree.Evaluate(testData))
// Load House Price Data For Regression
regressionData, err := base.ParseCSVToInstances("../datasets/boston_house_prices.csv", false)
if err != nil {
panic(err)
}
trainRegData, testRegData := base.InstancesTrainTestSplit(regressionData, 0.5)
// Hyperparameters - Loss function, max Depth (-1 will split until pure)
regTree := NewDecisionTreeRegressor("mse", -1)
// Train Tree
regTree.Fit(trainRegData)
// Print out tree for visualization
fmt.Println(regTree.String())
// Access Predictions
regressionPreds := regTree.Predict(testRegData)
fmt.Println("Boston House Price Predictions")
fmt.Println(regressionPreds)
}