1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

RandomForest returns error when fitting data with fewer features than the RandomForest plans to use

- BaseClassifier Predict and Fit methods return errors
- go fmt ./...

Conflicts:
	ensemble/randomforest.go
	ensemble/randomforest_test.go
	trees/tree_test.go
This commit is contained in:
Amit Kumar Gupta 2014-08-20 07:16:11 +00:00
parent 151df652ca
commit 1809a8b358
8 changed files with 98 additions and 30 deletions

View File

@ -10,10 +10,10 @@ type Classifier interface {
// and constructs a new set of Instances of equivalent
// length with only the class Attribute and fills it in
// with predictions.
Predict(FixedDataGrid) FixedDataGrid
Predict(FixedDataGrid) (FixedDataGrid, error)
// Takes a set of instances and updates the Classifier's
// internal structures to enable prediction
Fit(FixedDataGrid)
Fit(FixedDataGrid) error
// Why not make every classifier return a nice-looking string?
String() string
}

View File

@ -1,6 +1,7 @@
package ensemble
import (
"errors"
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/meta"
@ -30,10 +31,10 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
}
// Fit builds the RandomForest on the specified instances
func (f *RandomForest) Fit(on base.FixedDataGrid) {
func (f *RandomForest) Fit(on base.FixedDataGrid) error {
numNonClassAttributes := len(base.NonClassAttributes(on))
if numNonClassAttributes < f.Features {
panic(fmt.Sprintf(
return errors.New(fmt.Sprintf(
"Random forest with %d features cannot fit data grid with %d non-class attributes",
f.Features,
numNonClassAttributes,
@ -47,11 +48,12 @@ func (f *RandomForest) Fit(on base.FixedDataGrid) {
f.Model.AddModel(tree)
}
f.Model.Fit(on)
return nil
}
// Predict generates predictions from a trained RandomForest.
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
return f.Model.Predict(with)
func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, error) {
return f.Model.Predict(with), nil
}
// String returns a human-readable representation of this tree.

View File

@ -23,11 +23,31 @@ func TestRandomForest1(t *testing.T) {
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)
err = rf.Fit(trainData)
if err != nil {
t.Fatalf("Fitting failed: %s", err.Error())
}
predictions, err := rf.Predict(testData)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
}
_ = evaluation.GetSummary(confusionMat)
}
func TestRandomForestFitErrorWithNotEnoughFeatures(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
err = rf.Fit(inst)
if err == nil {
t.Fatalf("Fitting failed: %s", err.Error())
}
}

View File

@ -43,10 +43,16 @@ func main() {
// (Parameter controls train-prune split.)
// Train the ID3 tree
tree.Fit(trainData)
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions := tree.Predict(testData)
predictions, err := tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance")
@ -62,8 +68,14 @@ func main() {
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
tree.Fit(testData)
predictions = tree.Predict(testData)
err = tree.Fit(testData)
if err != nil {
panic(err)
}
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("RandomTree Performance")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
@ -75,8 +87,14 @@ func main() {
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree.Fit(trainData)
predictions = tree.Predict(testData)
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("RandomForest Performance")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {

View File

@ -140,7 +140,8 @@ func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
if i, ok := <-processpipe; ok {
c := b.Models[i]
l := b.generatePredictionInstances(i, from)
votes <- c.Predict(l)
v, _ := c.Predict(l)
votes <- v
} else {
processwait.Done()
break

View File

@ -181,12 +181,15 @@ func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
}
// Get a baseline accuracy
baselineAccuracy := computeAccuracy(d.Predict(using), using)
predictions, _ := d.Predict(using)
baselineAccuracy := computeAccuracy(predictions, using)
// Speculatively remove the children and re-evaluate
tmpChildren := d.Children
d.Children = nil
newAccuracy := computeAccuracy(d.Predict(using), using)
predictions, _ = d.Predict(using)
newAccuracy := computeAccuracy(predictions, using)
// Keep the children removed if better, else restore
if newAccuracy < baselineAccuracy {
@ -195,7 +198,7 @@ func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
}
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
predictions := base.GeneratePredictionVector(what)
classAttr := getClassAttr(predictions)
classAttrSpec, err := predictions.GetAttribute(classAttr)
@ -235,7 +238,7 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
}
return true, nil
})
return predictions
return predictions, nil
}
//
@ -262,7 +265,7 @@ func NewID3DecisionTree(prune float64) *ID3DecisionTree {
}
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) error {
rule := new(InformationGainRuleGenerator)
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
@ -271,10 +274,11 @@ func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
} else {
t.Root = InferID3Tree(on, rule)
}
return nil
}
// Predict outputs predictions from the ID3 decision tree
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) base.FixedDataGrid {
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
return t.Root.Predict(what)
}

View File

@ -67,12 +67,13 @@ func NewRandomTree(attrs int) *RandomTree {
}
// Fit builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from base.FixedDataGrid) {
func (rt *RandomTree) Fit(from base.FixedDataGrid) error {
rt.Root = InferID3Tree(from, rt.Rule)
return nil
}
// Predict returns a set of Instances containing predictions
func (rt *RandomTree) Predict(from base.FixedDataGrid) base.FixedDataGrid {
func (rt *RandomTree) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
return rt.Root.Predict(from)
}

View File

@ -46,8 +46,11 @@ func TestRandomTreeClassification(t *testing.T) {
r.Attributes = 2
root := InferID3Tree(trainDataF, r)
predictions, err := root.Predict(testDataF)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
predictions := root.Predict(testDataF)
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
@ -71,9 +74,16 @@ func TestRandomTreeClassification2(t *testing.T) {
testDataF := base.NewLazilyFilteredInstances(testData, filt)
root := NewRandomTree(2)
root.Fit(trainDataF)
err = root.Fit(trainDataF)
if err != nil {
t.Fatalf("Fitting failed: %s", err.Error())
}
predictions, err := root.Predict(testDataF)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
predictions := root.Predict(testDataF)
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
@ -98,10 +108,18 @@ func TestPruning(t *testing.T) {
root := NewRandomTree(2)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
root.Fit(fittrainData)
root.Prune(fittestData)
predictions := root.Predict(testDataF)
err = root.Fit(fittrainData)
if err != nil {
t.Fatalf("Fitting failed: %s", err.Error())
}
root.Prune(fittestData)
predictions, err := root.Predict(testDataF)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
@ -195,7 +213,11 @@ func TestID3Classification(t *testing.T) {
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(trainData, rule)
predictions := root.Predict(testData)
predictions, err := root.Predict(testData)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())