mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
RandomForest returns error when fitting data with fewer features than the RandomForest plans to use
- BaseClassifier Predict and Fit methods return errors - go fmt ./... Conflicts: ensemble/randomforest.go ensemble/randomforest_test.go trees/tree_test.go
This commit is contained in:
parent
151df652ca
commit
1809a8b358
@ -10,10 +10,10 @@ type Classifier interface {
|
||||
// and constructs a new set of Instances of equivalent
|
||||
// length with only the class Attribute and fills it in
|
||||
// with predictions.
|
||||
Predict(FixedDataGrid) FixedDataGrid
|
||||
Predict(FixedDataGrid) (FixedDataGrid, error)
|
||||
// Takes a set of instances and updates the Classifier's
|
||||
// internal structures to enable prediction
|
||||
Fit(FixedDataGrid)
|
||||
Fit(FixedDataGrid) error
|
||||
// Why not make every classifier return a nice-looking string?
|
||||
String() string
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package ensemble
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/meta"
|
||||
@ -30,10 +31,10 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||
}
|
||||
|
||||
// Fit builds the RandomForest on the specified instances
|
||||
func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
||||
func (f *RandomForest) Fit(on base.FixedDataGrid) error {
|
||||
numNonClassAttributes := len(base.NonClassAttributes(on))
|
||||
if numNonClassAttributes < f.Features {
|
||||
panic(fmt.Sprintf(
|
||||
return errors.New(fmt.Sprintf(
|
||||
"Random forest with %d features cannot fit data grid with %d non-class attributes",
|
||||
f.Features,
|
||||
numNonClassAttributes,
|
||||
@ -47,11 +48,12 @@ func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
||||
f.Model.AddModel(tree)
|
||||
}
|
||||
f.Model.Fit(on)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Predict generates predictions from a trained RandomForest.
|
||||
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
|
||||
return f.Model.Predict(with)
|
||||
func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
return f.Model.Predict(with), nil
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of this tree.
|
||||
|
@ -23,11 +23,31 @@ func TestRandomForest1(t *testing.T) {
|
||||
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
|
||||
|
||||
rf := NewRandomForest(10, 3)
|
||||
rf.Fit(trainData)
|
||||
predictions := rf.Predict(testData)
|
||||
err = rf.Fit(trainData)
|
||||
if err != nil {
|
||||
t.Fatalf("Fitting failed: %s", err.Error())
|
||||
}
|
||||
predictions, err := rf.Predict(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Predicting failed: %s", err.Error())
|
||||
}
|
||||
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
}
|
||||
_ = evaluation.GetSummary(confusionMat)
|
||||
}
|
||||
|
||||
func TestRandomForestFitErrorWithNotEnoughFeatures(t *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||
}
|
||||
|
||||
rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
|
||||
err = rf.Fit(inst)
|
||||
if err == nil {
|
||||
t.Fatalf("Fitting failed: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
@ -43,10 +43,16 @@ func main() {
|
||||
// (Parameter controls train-prune split.)
|
||||
|
||||
// Train the ID3 tree
|
||||
tree.Fit(trainData)
|
||||
err = tree.Fit(trainData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Generate predictions
|
||||
predictions := tree.Predict(testData)
|
||||
predictions, err := tree.Predict(testData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Evaluate
|
||||
fmt.Println("ID3 Performance")
|
||||
@ -62,8 +68,14 @@ func main() {
|
||||
|
||||
// Consider two randomly-chosen attributes
|
||||
tree = trees.NewRandomTree(2)
|
||||
tree.Fit(testData)
|
||||
predictions = tree.Predict(testData)
|
||||
err = tree.Fit(testData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
predictions, err = tree.Predict(testData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("RandomTree Performance")
|
||||
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
@ -75,8 +87,14 @@ func main() {
|
||||
// Finally, Random Forests
|
||||
//
|
||||
tree = ensemble.NewRandomForest(100, 3)
|
||||
tree.Fit(trainData)
|
||||
predictions = tree.Predict(testData)
|
||||
err = tree.Fit(trainData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
predictions, err = tree.Predict(testData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("RandomForest Performance")
|
||||
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
|
@ -140,7 +140,8 @@ func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
||||
if i, ok := <-processpipe; ok {
|
||||
c := b.Models[i]
|
||||
l := b.generatePredictionInstances(i, from)
|
||||
votes <- c.Predict(l)
|
||||
v, _ := c.Predict(l)
|
||||
votes <- v
|
||||
} else {
|
||||
processwait.Done()
|
||||
break
|
||||
|
16
trees/id3.go
16
trees/id3.go
@ -181,12 +181,15 @@ func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
|
||||
}
|
||||
|
||||
// Get a baseline accuracy
|
||||
baselineAccuracy := computeAccuracy(d.Predict(using), using)
|
||||
predictions, _ := d.Predict(using)
|
||||
baselineAccuracy := computeAccuracy(predictions, using)
|
||||
|
||||
// Speculatively remove the children and re-evaluate
|
||||
tmpChildren := d.Children
|
||||
d.Children = nil
|
||||
newAccuracy := computeAccuracy(d.Predict(using), using)
|
||||
|
||||
predictions, _ = d.Predict(using)
|
||||
newAccuracy := computeAccuracy(predictions, using)
|
||||
|
||||
// Keep the children removed if better, else restore
|
||||
if newAccuracy < baselineAccuracy {
|
||||
@ -195,7 +198,7 @@ func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
|
||||
}
|
||||
|
||||
// Predict outputs a base.Instances containing predictions from this tree
|
||||
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
predictions := base.GeneratePredictionVector(what)
|
||||
classAttr := getClassAttr(predictions)
|
||||
classAttrSpec, err := predictions.GetAttribute(classAttr)
|
||||
@ -235,7 +238,7 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
return predictions
|
||||
return predictions, nil
|
||||
}
|
||||
|
||||
//
|
||||
@ -262,7 +265,7 @@ func NewID3DecisionTree(prune float64) *ID3DecisionTree {
|
||||
}
|
||||
|
||||
// Fit builds the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
|
||||
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) error {
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
if t.PruneSplit > 0.001 {
|
||||
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
|
||||
@ -271,10 +274,11 @@ func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
|
||||
} else {
|
||||
t.Root = InferID3Tree(on, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Predict outputs predictions from the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
return t.Root.Predict(what)
|
||||
}
|
||||
|
||||
|
@ -67,12 +67,13 @@ func NewRandomTree(attrs int) *RandomTree {
|
||||
}
|
||||
|
||||
// Fit builds a RandomTree suitable for prediction
|
||||
func (rt *RandomTree) Fit(from base.FixedDataGrid) {
|
||||
func (rt *RandomTree) Fit(from base.FixedDataGrid) error {
|
||||
rt.Root = InferID3Tree(from, rt.Rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Predict returns a set of Instances containing predictions
|
||||
func (rt *RandomTree) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
||||
func (rt *RandomTree) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
return rt.Root.Predict(from)
|
||||
}
|
||||
|
||||
|
@ -46,8 +46,11 @@ func TestRandomTreeClassification(t *testing.T) {
|
||||
r.Attributes = 2
|
||||
|
||||
root := InferID3Tree(trainDataF, r)
|
||||
predictions, err := root.Predict(testDataF)
|
||||
if err != nil {
|
||||
t.Fatalf("Predicting failed: %s", err.Error())
|
||||
}
|
||||
|
||||
predictions := root.Predict(testDataF)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
@ -71,9 +74,16 @@ func TestRandomTreeClassification2(t *testing.T) {
|
||||
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
||||
|
||||
root := NewRandomTree(2)
|
||||
root.Fit(trainDataF)
|
||||
err = root.Fit(trainDataF)
|
||||
if err != nil {
|
||||
t.Fatalf("Fitting failed: %s", err.Error())
|
||||
}
|
||||
|
||||
predictions, err := root.Predict(testDataF)
|
||||
if err != nil {
|
||||
t.Fatalf("Predicting failed: %s", err.Error())
|
||||
}
|
||||
|
||||
predictions := root.Predict(testDataF)
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
@ -98,10 +108,18 @@ func TestPruning(t *testing.T) {
|
||||
|
||||
root := NewRandomTree(2)
|
||||
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
|
||||
root.Fit(fittrainData)
|
||||
root.Prune(fittestData)
|
||||
|
||||
predictions := root.Predict(testDataF)
|
||||
err = root.Fit(fittrainData)
|
||||
if err != nil {
|
||||
t.Fatalf("Fitting failed: %s", err.Error())
|
||||
}
|
||||
|
||||
root.Prune(fittestData)
|
||||
predictions, err := root.Predict(testDataF)
|
||||
if err != nil {
|
||||
t.Fatalf("Predicting failed: %s", err.Error())
|
||||
}
|
||||
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
@ -195,7 +213,11 @@ func TestID3Classification(t *testing.T) {
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
root := InferID3Tree(trainData, rule)
|
||||
|
||||
predictions := root.Predict(testData)
|
||||
predictions, err := root.Predict(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Predicting failed: %s", err.Error())
|
||||
}
|
||||
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
||||
|
Loading…
x
Reference in New Issue
Block a user