2020-07-18 12:26:50 +05:30
|
|
|
package trees
|
|
|
|
|
|
|
|
import (
|
2020-07-31 11:01:20 +05:30
|
|
|
"errors"
|
2020-07-18 12:26:50 +05:30
|
|
|
"fmt"
|
|
|
|
"math"
|
|
|
|
"sort"
|
2020-07-22 14:34:59 +05:30
|
|
|
"strconv"
|
2020-07-18 12:26:50 +05:30
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
|
|
)
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
const (
|
|
|
|
MAE string = "mae"
|
|
|
|
MSE string = "mse"
|
|
|
|
)
|
2020-07-18 12:26:50 +05:30
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// RNode - Node struct for Decision Tree Regressor
|
2020-07-27 15:03:12 +05:30
|
|
|
// It holds the information for each split
|
|
|
|
// Which feature to use, threshold, left prediction and right prediction
|
2020-07-26 11:21:20 +05:30
|
|
|
type regressorNode struct {
|
2020-07-30 10:27:16 +05:30
|
|
|
Left *regressorNode
|
|
|
|
Right *regressorNode
|
|
|
|
Threshold float64
|
|
|
|
Feature int64
|
|
|
|
LeftPred float64
|
|
|
|
RightPred float64
|
|
|
|
isNodeNeeded bool
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-22 14:34:59 +05:30
|
|
|
// CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
|
2020-07-27 15:03:12 +05:30
|
|
|
// It contains the rootNode, as well as the hyperparameters chosen by user.
|
|
|
|
// Also keeps track of splits used at tree level.
|
2020-07-22 14:34:59 +05:30
|
|
|
type CARTDecisionTreeRegressor struct {
|
2020-07-26 11:21:20 +05:30
|
|
|
RootNode *regressorNode
|
2020-07-18 12:26:50 +05:30
|
|
|
criterion string
|
|
|
|
maxDepth int64
|
|
|
|
triedSplits [][]float64
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Find average
|
2020-07-18 12:26:50 +05:30
|
|
|
func average(y []float64) float64 {
|
|
|
|
mean := 0.0
|
|
|
|
for _, value := range y {
|
|
|
|
mean += value
|
|
|
|
}
|
|
|
|
mean /= float64(len(y))
|
|
|
|
return mean
|
|
|
|
}
|
|
|
|
|
2020-07-26 11:21:20 +05:30
|
|
|
// Calculate Mean Absolute Error for a constant prediction
|
|
|
|
func meanAbsoluteError(y []float64, yBar float64) float64 {
|
|
|
|
error := 0.0
|
|
|
|
for _, target := range y {
|
|
|
|
error += math.Abs(target - yBar)
|
|
|
|
}
|
|
|
|
error /= float64(len(y))
|
|
|
|
return error
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Turn Mean Absolute Error into impurity function for decision trees.
|
2020-07-30 11:21:06 +05:30
|
|
|
func computeMaeImpurityAndAverage(y []float64) (float64, float64) {
|
2020-07-18 12:26:50 +05:30
|
|
|
yHat := average(y)
|
|
|
|
return meanAbsoluteError(y, yHat), yHat
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Calculate Mean Squared Error for constant prediction
|
2020-07-18 12:26:50 +05:30
|
|
|
func meanSquaredError(y []float64, yBar float64) float64 {
|
|
|
|
error := 0.0
|
|
|
|
for _, target := range y {
|
2020-07-18 14:21:50 +05:30
|
|
|
itemError := target - yBar
|
|
|
|
error += math.Pow(itemError, 2)
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
error /= float64(len(y))
|
|
|
|
return error
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Convert mean squared error into impurity function for decision trees
|
2020-07-30 11:21:06 +05:30
|
|
|
func computeMseImpurityAndAverage(y []float64) (float64, float64) {
|
2020-07-18 12:26:50 +05:30
|
|
|
yHat := average(y)
|
|
|
|
return meanSquaredError(y, yHat), yHat
|
|
|
|
}
|
|
|
|
|
2020-07-31 11:01:20 +05:30
|
|
|
func calculateRegressionLoss(y []float64, criterion string) (float64, float64, error) {
|
2020-07-28 14:17:18 +05:30
|
|
|
if criterion == MAE {
|
2020-07-31 11:01:20 +05:30
|
|
|
loss, avg := computeMaeImpurityAndAverage(y)
|
|
|
|
return loss, avg, nil
|
2020-07-28 14:17:18 +05:30
|
|
|
} else if criterion == MSE {
|
2020-07-31 11:01:20 +05:30
|
|
|
loss, avg := computeMseImpurityAndAverage(y)
|
|
|
|
return loss, avg, nil
|
2020-07-28 14:17:18 +05:30
|
|
|
} else {
|
|
|
|
panic("Invalid impurity function, choose from MAE or MSE")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Split the data into left and right based on trehsold and feature.
|
2020-07-26 11:21:20 +05:30
|
|
|
func regressorCreateSplit(data [][]float64, feature int64, y []float64, threshold float64) ([][]float64, [][]float64, []float64, []float64) {
|
2020-07-18 12:26:50 +05:30
|
|
|
var left [][]float64
|
|
|
|
var lefty []float64
|
|
|
|
var right [][]float64
|
|
|
|
var righty []float64
|
|
|
|
|
|
|
|
for i := range data {
|
|
|
|
example := data[i]
|
|
|
|
if example[feature] < threshold {
|
|
|
|
left = append(left, example)
|
|
|
|
lefty = append(lefty, y[i])
|
|
|
|
} else {
|
|
|
|
right = append(right, example)
|
|
|
|
righty = append(righty, y[i])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return left, right, lefty, righty
|
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Interface for creating new Decision Tree Regressor
|
2020-07-22 14:34:59 +05:30
|
|
|
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
|
|
|
|
var tree CARTDecisionTreeRegressor
|
2020-07-18 12:26:50 +05:30
|
|
|
tree.maxDepth = maxDepth
|
|
|
|
tree.criterion = strings.ToLower(criterion)
|
|
|
|
return &tree
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Re order data based on a feature for optimizing code
|
2020-07-27 15:03:12 +05:30
|
|
|
// Helps in updating splits without reiterating entire dataset
|
2020-07-26 11:21:20 +05:30
|
|
|
func regressorReOrderData(featureVal []float64, data [][]float64, y []float64) ([][]float64, []float64) {
|
2020-07-25 13:22:15 +05:30
|
|
|
s := NewSlice(featureVal)
|
2020-07-18 12:26:50 +05:30
|
|
|
sort.Sort(s)
|
|
|
|
|
|
|
|
indexes := s.Idx
|
|
|
|
|
|
|
|
var dataSorted [][]float64
|
|
|
|
var ySorted []float64
|
|
|
|
|
|
|
|
for _, index := range indexes {
|
|
|
|
dataSorted = append(dataSorted, data[index])
|
|
|
|
ySorted = append(ySorted, y[index])
|
|
|
|
}
|
|
|
|
|
|
|
|
return dataSorted, ySorted
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Update the left and right data based on change in threshold
|
2020-08-01 11:43:14 +05:30
|
|
|
func regressorUpdateSplit(left [][]float64, leftY []float64, right [][]float64, rightY []float64, feature int64, threshold float64) ([][]float64, []float64, [][]float64, []float64) {
|
2020-07-18 12:26:50 +05:30
|
|
|
|
|
|
|
for right[0][feature] < threshold {
|
|
|
|
left = append(left, right[0])
|
|
|
|
right = right[1:]
|
2020-08-01 11:43:14 +05:30
|
|
|
leftY = append(leftY, rightY[0])
|
|
|
|
rightY = rightY[1:]
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-08-01 11:43:14 +05:30
|
|
|
return left, leftY, right, rightY
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Fit - Build the tree using the data
|
|
|
|
// Creates empty root node and builds tree by calling regressorBestSplit
|
2020-07-31 11:01:20 +05:30
|
|
|
func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) error {
|
2020-07-26 11:21:20 +05:30
|
|
|
var emptyNode regressorNode
|
2020-07-31 11:01:20 +05:30
|
|
|
var err error
|
2020-07-18 12:26:50 +05:30
|
|
|
|
2020-07-31 11:01:20 +05:30
|
|
|
data := regressorConvertInstancesToProblemVec(X)
|
|
|
|
y, err := regressorConvertInstancesToLabelVec(X)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2020-07-18 12:26:50 +05:30
|
|
|
|
2020-07-31 11:01:20 +05:30
|
|
|
emptyNode, err = regressorBestSplit(*tree, data, y, emptyNode, tree.criterion, tree.maxDepth, 0)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2020-07-18 12:26:50 +05:30
|
|
|
tree.RootNode = &emptyNode
|
2020-07-31 11:01:20 +05:30
|
|
|
return nil
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Builds the tree by iteratively finding the best split.
|
|
|
|
// Recursive function - stops if maxDepth is reached or nodes are pure
|
2020-07-31 11:01:20 +05:30
|
|
|
func regressorBestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode regressorNode, criterion string, maxDepth int64, depth int64) (regressorNode, error) {
|
2020-07-18 12:26:50 +05:30
|
|
|
|
2020-07-28 14:17:18 +05:30
|
|
|
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
2020-07-18 12:26:50 +05:30
|
|
|
depth++
|
|
|
|
|
|
|
|
if depth > maxDepth && maxDepth != -1 {
|
2020-07-31 11:01:20 +05:30
|
|
|
return upperNode, nil
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
numFeatures := len(data[0])
|
2020-07-28 14:17:18 +05:30
|
|
|
var bestLoss, origLoss float64
|
2020-07-31 11:01:20 +05:30
|
|
|
var err error
|
|
|
|
origLoss, upperNode.LeftPred, err = calculateRegressionLoss(y, criterion)
|
|
|
|
if err != nil {
|
|
|
|
return upperNode, err
|
|
|
|
}
|
2020-07-18 12:26:50 +05:30
|
|
|
|
|
|
|
bestLoss = origLoss
|
|
|
|
|
2020-07-28 14:17:18 +05:30
|
|
|
bestLeft, bestRight, bestLefty, bestRighty := data, data, y, y
|
2020-07-18 12:26:50 +05:30
|
|
|
|
|
|
|
numData := len(data)
|
|
|
|
|
2020-07-28 14:17:18 +05:30
|
|
|
bestLeftLoss, bestRightLoss := bestLoss, bestLoss
|
2020-07-18 12:26:50 +05:30
|
|
|
|
2020-07-30 10:27:16 +05:30
|
|
|
upperNode.isNodeNeeded = true
|
2020-07-18 12:26:50 +05:30
|
|
|
|
2020-07-28 14:17:18 +05:30
|
|
|
var leftN, rightN regressorNode
|
2020-07-18 12:26:50 +05:30
|
|
|
// Iterate over all features
|
|
|
|
for i := 0; i < numFeatures; i++ {
|
2020-07-28 14:17:18 +05:30
|
|
|
|
|
|
|
featureVal := getFeature(data, int64(i))
|
|
|
|
unique := findUnique(featureVal)
|
2020-07-18 12:26:50 +05:30
|
|
|
sort.Float64s(unique)
|
|
|
|
|
2020-07-26 11:21:20 +05:30
|
|
|
sortData, sortY := regressorReOrderData(featureVal, data, y)
|
2020-07-18 12:26:50 +05:30
|
|
|
|
|
|
|
firstTime := true
|
|
|
|
|
|
|
|
var left, right [][]float64
|
2020-08-01 11:43:14 +05:30
|
|
|
var leftY, rightY []float64
|
2020-07-18 12:26:50 +05:30
|
|
|
|
2020-07-28 14:17:18 +05:30
|
|
|
for j := 0; j < len(unique)-1; j++ {
|
|
|
|
threshold := (unique[j] + unique[j+1]) / 2
|
|
|
|
if validate(tree.triedSplits, int64(i), threshold) {
|
|
|
|
if firstTime {
|
2020-08-01 11:43:14 +05:30
|
|
|
left, right, leftY, rightY = regressorCreateSplit(sortData, int64(i), sortY, threshold)
|
2020-07-28 14:17:18 +05:30
|
|
|
firstTime = false
|
|
|
|
} else {
|
2020-08-01 11:43:14 +05:30
|
|
|
left, leftY, right, rightY = regressorUpdateSplit(left, leftY, right, rightY, int64(i), threshold)
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-28 14:17:18 +05:30
|
|
|
var leftLoss, rightLoss float64
|
|
|
|
var leftPred, rightPred float64
|
|
|
|
|
2020-08-01 11:43:14 +05:30
|
|
|
leftLoss, leftPred, _ = calculateRegressionLoss(leftY, criterion)
|
|
|
|
rightLoss, rightPred, _ = calculateRegressionLoss(rightY, criterion)
|
2020-07-28 14:17:18 +05:30
|
|
|
|
|
|
|
subLoss := (leftLoss * float64(len(left)) / float64(numData)) + (rightLoss * float64(len(right)) / float64(numData))
|
|
|
|
|
|
|
|
if subLoss < bestLoss {
|
|
|
|
bestLoss = subLoss
|
|
|
|
|
|
|
|
bestLeft, bestRight = left, right
|
2020-08-01 11:43:14 +05:30
|
|
|
bestLefty, bestRighty = leftY, rightY
|
2020-07-28 14:17:18 +05:30
|
|
|
|
|
|
|
upperNode.Threshold, upperNode.Feature = threshold, int64(i)
|
|
|
|
|
|
|
|
upperNode.LeftPred, upperNode.RightPred = leftPred, rightPred
|
|
|
|
|
|
|
|
bestLeftLoss, bestRightLoss = leftLoss, rightLoss
|
|
|
|
}
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if bestLoss == origLoss {
|
2020-07-30 10:27:16 +05:30
|
|
|
upperNode.isNodeNeeded = false
|
2020-07-31 11:01:20 +05:30
|
|
|
return upperNode, nil
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
if bestLoss > 0 {
|
|
|
|
|
|
|
|
if bestLeftLoss > 0 {
|
|
|
|
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
2020-07-31 11:01:20 +05:30
|
|
|
leftN, err = regressorBestSplit(tree, bestLeft, bestLefty, leftN, criterion, maxDepth, depth)
|
|
|
|
if err != nil {
|
|
|
|
return upperNode, err
|
|
|
|
}
|
2020-07-30 10:27:16 +05:30
|
|
|
if leftN.isNodeNeeded == true {
|
2020-07-18 12:26:50 +05:30
|
|
|
upperNode.Left = &leftN
|
|
|
|
}
|
|
|
|
}
|
2020-07-28 14:17:18 +05:30
|
|
|
|
2020-07-18 12:26:50 +05:30
|
|
|
if bestRightLoss > 0 {
|
|
|
|
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
2020-07-31 11:01:20 +05:30
|
|
|
rightN, err = regressorBestSplit(tree, bestRight, bestRighty, rightN, criterion, maxDepth, depth)
|
|
|
|
if err != nil {
|
|
|
|
return upperNode, err
|
|
|
|
}
|
2020-07-30 10:27:16 +05:30
|
|
|
if rightN.isNodeNeeded == true {
|
2020-07-18 12:26:50 +05:30
|
|
|
upperNode.Right = &rightN
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2020-07-31 11:01:20 +05:30
|
|
|
return upperNode, nil
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Print Tree for Visualtion - calls regressorPrintTreeFromNode()
|
2020-07-22 14:34:59 +05:30
|
|
|
func (tree *CARTDecisionTreeRegressor) String() string {
|
2020-07-18 12:26:50 +05:30
|
|
|
rootNode := *tree.RootNode
|
2020-07-26 11:21:20 +05:30
|
|
|
return regressorPrintTreeFromNode(rootNode, "")
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Recursively explore the entire tree and print out all details such as threshold, feature, prediction
|
2020-07-26 11:21:20 +05:30
|
|
|
func regressorPrintTreeFromNode(tree regressorNode, spacing string) string {
|
2020-07-22 14:34:59 +05:30
|
|
|
returnString := ""
|
|
|
|
returnString += spacing + "Feature "
|
|
|
|
returnString += strconv.FormatInt(tree.Feature, 10)
|
|
|
|
returnString += " < "
|
|
|
|
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
|
|
|
returnString += "\n"
|
2020-07-18 12:26:50 +05:30
|
|
|
|
|
|
|
if tree.Left == nil {
|
2020-07-22 14:34:59 +05:30
|
|
|
returnString += spacing + "---> True" + "\n"
|
|
|
|
returnString += " " + spacing + "PREDICT "
|
|
|
|
returnString += fmt.Sprintf("%.3f", tree.LeftPred) + "\n"
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
if tree.Right == nil {
|
2020-07-22 14:34:59 +05:30
|
|
|
returnString += spacing + "---> False" + "\n"
|
|
|
|
returnString += " " + spacing + "PREDICT "
|
|
|
|
returnString += fmt.Sprintf("%.3f", tree.RightPred) + "\n"
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
if tree.Left != nil {
|
2020-07-22 14:34:59 +05:30
|
|
|
returnString += spacing + "---> True" + "\n"
|
2020-07-26 11:21:20 +05:30
|
|
|
returnString += regressorPrintTreeFromNode(*tree.Left, spacing+" ")
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
if tree.Right != nil {
|
2020-07-22 14:34:59 +05:30
|
|
|
returnString += spacing + "---> False" + "\n"
|
2020-07-26 11:21:20 +05:30
|
|
|
returnString += regressorPrintTreeFromNode(*tree.Right, spacing+" ")
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-22 14:34:59 +05:30
|
|
|
return returnString
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Predict a single data point by navigating to rootNodes.
|
|
|
|
// Uses a recursive logic
|
2020-07-26 11:21:20 +05:30
|
|
|
func regressorPredictSingle(tree regressorNode, instance []float64) float64 {
|
2020-07-18 12:26:50 +05:30
|
|
|
if instance[tree.Feature] < tree.Threshold {
|
|
|
|
if tree.Left == nil {
|
|
|
|
return tree.LeftPred
|
|
|
|
} else {
|
2020-07-26 11:21:20 +05:30
|
|
|
return regressorPredictSingle(*tree.Left, instance)
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if tree.Right == nil {
|
|
|
|
return tree.RightPred
|
|
|
|
} else {
|
2020-07-26 11:21:20 +05:30
|
|
|
return regressorPredictSingle(*tree.Right, instance)
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Predict method for multiple data points.
|
|
|
|
// First converts input data into usable format, and then calls regressorPredictFromNode
|
2020-07-22 14:34:59 +05:30
|
|
|
func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
|
2020-07-18 12:26:50 +05:30
|
|
|
root := *tree.RootNode
|
|
|
|
test := regressorConvertInstancesToProblemVec(X_test)
|
2020-07-26 11:21:20 +05:30
|
|
|
return regressorPredictFromNode(root, test)
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
|
2020-07-27 15:03:12 +05:30
|
|
|
// Use tree's root node to print out entire tree.
|
|
|
|
// Iterates over all data points and calls regressorPredictSingle to predict individual datapoints.
|
2020-07-26 11:21:20 +05:30
|
|
|
func regressorPredictFromNode(tree regressorNode, test [][]float64) []float64 {
|
2020-07-18 12:26:50 +05:30
|
|
|
var preds []float64
|
|
|
|
for i := range test {
|
2020-07-26 11:21:20 +05:30
|
|
|
i_pred := regressorPredictSingle(tree, test[i])
|
2020-07-18 12:26:50 +05:30
|
|
|
preds = append(preds, i_pred)
|
|
|
|
}
|
|
|
|
return preds
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Helper function to convert base.FixedDataGrid into required format. Called in Fit, Predict
|
2020-07-18 12:26:50 +05:30
|
|
|
func regressorConvertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
|
|
|
// Allocate problem array
|
|
|
|
_, rows := X.Size()
|
|
|
|
problemVec := make([][]float64, rows)
|
|
|
|
|
|
|
|
// Retrieve numeric non-class Attributes
|
|
|
|
numericAttrs := base.NonClassFloatAttributes(X)
|
|
|
|
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
|
|
|
|
|
|
|
// Convert each row
|
|
|
|
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
|
|
|
// Allocate a new row
|
|
|
|
probRow := make([]float64, len(numericAttrSpecs))
|
|
|
|
// Read out the row
|
|
|
|
for i, _ := range numericAttrSpecs {
|
|
|
|
probRow[i] = base.UnpackBytesToFloat(row[i])
|
|
|
|
}
|
|
|
|
// Add the row
|
|
|
|
problemVec[rowNo] = probRow
|
|
|
|
return true, nil
|
|
|
|
})
|
|
|
|
return problemVec
|
|
|
|
}
|
|
|
|
|
2020-07-18 14:21:50 +05:30
|
|
|
// Helper function to convert base.FixedDataGrid into required format. Called in Fit, Predict
|
2020-07-31 11:01:20 +05:30
|
|
|
func regressorConvertInstancesToLabelVec(X base.FixedDataGrid) ([]float64, error) {
|
2020-07-18 12:26:50 +05:30
|
|
|
// Get the class Attributes
|
|
|
|
classAttrs := X.AllClassAttributes()
|
|
|
|
// Only support 1 class Attribute
|
|
|
|
if len(classAttrs) != 1 {
|
2020-07-31 11:01:20 +05:30
|
|
|
return []float64{0}, errors.New(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
// ClassAttribute must be numeric
|
|
|
|
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
|
2020-07-31 11:01:20 +05:30
|
|
|
return []float64{0}, errors.New(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|
|
|
|
// Allocate return structure
|
|
|
|
_, rows := X.Size()
|
2020-07-31 11:01:20 +05:30
|
|
|
|
2020-07-18 12:26:50 +05:30
|
|
|
labelVec := make([]float64, rows)
|
|
|
|
// Resolve class Attribute specification
|
|
|
|
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
|
|
|
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
|
|
|
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
|
|
|
|
return true, nil
|
|
|
|
})
|
2020-07-31 11:01:20 +05:30
|
|
|
return labelVec, nil
|
2020-07-18 12:26:50 +05:30
|
|
|
}
|