mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
61 lines
1.6 KiB
Go
61 lines
1.6 KiB
Go
// Example of how to use CART trees for both Classification and Regression
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
func main() {
|
|
|
|
// Load Titanic Data For classification
|
|
classificationData, err := base.ParseCSVToInstances("../datasets/titanic.csv", false)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
trainData, testData := base.InstancesTrainTestSplit(classificationData, 0.5)
|
|
|
|
// Create New Classification Tree
|
|
// Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels
|
|
decTree = NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
|
|
|
|
// Train Tree
|
|
decTree.Fit(trainData)
|
|
// Print out tree for visualization - shows splits and feature and predictions
|
|
fmt.Println(decTree.String())
|
|
|
|
// Access Predictions
|
|
classificationPreds := decTree.Predict(testData)
|
|
|
|
fmt.Println("Titanic Predictions")
|
|
fmt.Println(classificationPreds)
|
|
|
|
// Evaluate Accuracy on Test Data
|
|
fmt.Println(decTree.Evaluate(testData))
|
|
|
|
// Load House Price Data For Regression
|
|
regressionData, err := base.ParseCSVToInstances("../datasets/boston_house_prices.csv", false)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
trainRegData, testRegData := base.InstancesTrainTestSplit(regressionData, 0.5)
|
|
|
|
// Hyperparameters - Loss function, max Depth (-1 will split until pure)
|
|
regTree := NewDecisionTreeRegressor("mse", -1)
|
|
|
|
// Train Tree
|
|
regTree.Fit(trainRegData)
|
|
|
|
// Print out tree for visualization
|
|
fmt.Println(regTree.String())
|
|
|
|
// Access Predictions
|
|
regressionPreds := regTree.Predict(testRegData)
|
|
|
|
fmt.Println("Boston House Price Predictions")
|
|
fmt.Println(regressionPreds)
|
|
|
|
}
|