1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

ensemble: tests pass

This commit is contained in:
Richard Townsend 2017-09-10 19:30:02 +01:00
parent 3e80230d3d
commit e27215052b
7 changed files with 98 additions and 37 deletions

View File

@ -60,6 +60,12 @@ func copyFixedDataGridStructure(of FixedDataGrid) (*DenseInstances, []AttributeS
// Add and store new AttributeSpec
specs2[i] = ret.AddAttribute(a)
}
// Add class attributes
cAttrs := of.AllClassAttributes()
for _, a := range cAttrs {
ret.AddClassAttribute(a)
}
return ret, specs1, specs2
}

View File

@ -49,7 +49,11 @@ func (f *FunctionalTarReader) GetNamedFile (name string) ([]byte, error) {
ret := make([]byte, hdr.Size)
n, err := tr.Read(ret)
if int64(n) != hdr.Size {
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s), got %d", n, hdr.Size))
if int64(n) < hdr.Size {
log.Printf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size)
} else {
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size))
}
}
if err != nil {
return nil, err

View File

@ -4,10 +4,8 @@ import (
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/linear_models"
"github.com/sjwhitworth/golearn/meta"
"io"
"os"
"fmt"
"encoding/json"
)
// MultiLinearSVC implements a multi-class Support Vector Classifier using a one-vs-all
@ -32,8 +30,6 @@ func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64,
panic(err)
}
// Return me...
ret := &MultiLinearSVC{
parameters: params,
@ -91,7 +87,7 @@ func (m *MultiLinearSVC) GetClassifierMetadata() base.ClassifierMetadataV1 {
FormatVersion: 1,
ClassifierName: "MultiLinearSVC",
ClassifierVersion: "1",
ClassifierMetadata: nil
ClassifierMetadata: nil,
}
}
@ -101,7 +97,7 @@ func (m *MultiLinearSVC) Save(filePath string) error {
if err != nil {
return err
}
err := m.SaveWithPrefix(serializer, "")
err = m.SaveWithPrefix(serializer, "")
if err != nil {
return fmt.Errorf("Unable to Save(): %v", err)
}
@ -168,12 +164,13 @@ func (m *MultiLinearSVC) LoadWithPrefix(reader *base.ClassifierDeserializer, pre
return fmt.Errorf("Can't load parameters: %v", err)
}
m.initializeOneVsAllModel()
// Load the model
err = m.m.LoadWithPrefix(reader, p("one-vs-all"))
if err != nil {
return err
}
m.initializeOneVsAllModel()
return nil
}

View File

@ -31,7 +31,7 @@ func TestMultiSVMUnweighted(t *testing.T) {
So(err, ShouldBeNil)
Convey("Loading should work...", func() {
mLoaded := NewMultiLinearSVC("l2", "l1", true, 1.00, 1e-8, weights)
mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, nil)
err := mLoaded.Load(f.Name())
So(err, ShouldBeNil)
@ -40,7 +40,7 @@ func TestMultiSVMUnweighted(t *testing.T) {
So(err, ShouldBeNil)
newPredictions, err := mLoaded.Predict(Y)
So(err, ShouldBeNil)
So(originalPredictions, ShouldEqual, newPredictions)
So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue)
})
})
@ -68,28 +68,29 @@ func TestMultiSVMWeighted(t *testing.T) {
predictions, err := m.Predict(Y)
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
So(err, ShouldEqual, nil)
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
})
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.60)
Convey("Saving should work...", func() {
f, err := ioutil.TempFile("","tree")
So(err, ShouldBeNil)
err = m.Save(f.Name())
So(err, ShouldBeNil)
Convey("Loading should work...", func() {
mLoaded := NewMultiLinearSVC("l2", "l1", true, 1.00, 1e-8, weights)
err := mLoaded.Load(f.Name())
Convey("Saving should work...", func() {
f, err := ioutil.TempFile("", "tree")
So(err, ShouldBeNil)
err = m.Save(f.Name())
So(err, ShouldBeNil)
Convey("Predictions should be the same...", func() {
originalPredictions, err := m.Predict(Y)
Convey("Loading should work...", func() {
mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, weights)
err := mLoaded.Load(f.Name())
So(err, ShouldBeNil)
newPredictions, err := mLoaded.Predict(Y)
So(err, ShouldBeNil)
So(originalPredictions, ShouldEqual, newPredictions)
})
Convey("Predictions should be the same...", func() {
originalPredictions, err := m.Predict(Y)
So(err, ShouldBeNil)
newPredictions, err := mLoaded.Predict(Y)
So(err, ShouldBeNil)
So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue)
})
})
})
})

View File

@ -53,7 +53,7 @@ func (f *RandomForest) Fit(on base.FixedDataGrid) error {
// Predict generates predictions from a trained RandomForest.
func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, error) {
return f.Model.Predict(with), nil
return f.Model.Predict(with)
}
// String returns a human-readable representation of this tree.

View File

@ -3,6 +3,7 @@ package meta
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"log"
)
// OneVsAllModel replaces class Attributes with numeric versions
@ -61,6 +62,10 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
}
m.maxClassVal = val
// If we're reloading, we may just be fitting to the structure
_, srcRows := using.Size()
fittingToStructure := srcRows == 0
// Create individual filtered instances for training
filters := make([]*oneVsAllFilter, val+1)
classifiers := make([]base.Classifier, val+1)
@ -72,7 +77,9 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
}
filters[i] = f
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
if !fittingToStructure {
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
}
}
m.filters = filters
@ -90,7 +97,13 @@ func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, er
ret := base.GeneratePredictionVector(what)
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
specs := make([]base.AttributeSpec, m.maxClassVal+1)
if int(m.maxClassVal) > len(m.filters) || (m.maxClassVal == 0 && len(m.filters) == 0) {
return nil, base.WrapError(fmt.Errorf("Internal error: m.Filter len = %d, maxClassVal = %d", len(m.filters), m.maxClassVal))
}
for i := uint64(0); i <= m.maxClassVal; i++ {
//log.Printf("i = %d, m.Filter len = %d, maxClassVal = %d", i, len(m.filters), m.maxClassVal)
f := m.filters[i]
c := base.NewLazilyFilteredInstances(what, f)
p, err := m.classifiers[i].Predict(c)
@ -139,7 +152,10 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
if err != nil {
return base.DescribeError("Can't load INSTANCE_STRUCTURE", err)
}
m.fitOn = fitOn
m.Fit(fitOn)
/*if err != nil {
base.DescribeError("Could not fit reloaded classifier to the structure", err)
}*/
// Reload the filters
numFiltersU64, err := reader.GetU64ForKey(reader.Prefix(prefix, "FILTER_COUNT"))
@ -151,7 +167,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
for i := 0; i < numFilters; i++ {
f := oneVsAllFilter{}
mapPrefix := pI(reader.Prefix(prefix, "FILTER"), i)
mapPrefix := pI("FILTER", i)
mapCountKey := reader.Prefix(mapPrefix, "COUNT")
numAttrsInMapU64, err := reader.GetU64ForKey(mapCountKey)
if err != nil {
@ -161,7 +177,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
attrMap := make(map[base.Attribute]base.Attribute)
for j := 0; j < int(numAttrsInMapU64); j++ {
mapTupleKey := pI(mapPrefix, j)
mapTupleKey := reader.Prefix(mapPrefix, fmt.Sprintf("%d"))
mapKeyKeyKey := reader.Prefix(mapTupleKey, "KEY")
mapKeyValKey := reader.Prefix(mapTupleKey, "VAL")
@ -213,7 +229,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
m.classifiers = make([]base.Classifier, 0)
for i, c := range classVals {
cls := m.NewClassifierFunction(c)
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
clsPrefix := pI("CLASSIFIERS", i)
err = cls.LoadWithPrefix(reader, clsPrefix)
if err != nil {
@ -248,6 +264,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
return writer.Prefix(prefix, writer.Prefix(n, fmt.Sprintf("%d", i)))
}
// Save the instances
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "INSTANCE_STRUCTURE"), m.fitOn, false)
if err != nil {
@ -266,7 +283,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
return base.DescribeError("Unable to write FILTER_COUNT", err)
}
for i, f := range m.filters {
mapPrefix := pI(writer.Prefix(prefix, "FILTER"), i)
mapPrefix := pI("FILTER", i)
mapCountKey := writer.Prefix(mapPrefix, "COUNT")
err := writer.WriteU64ForKey(mapCountKey, uint64(len(f.attrs)))
if err != nil {
@ -274,7 +291,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
}
j := 0
for key := range f.attrs {
mapTupleKey := pI(mapPrefix, j)
mapTupleKey := writer.Prefix(mapPrefix, fmt.Sprintf("%d"))
mapKeyKeyKey := writer.Prefix(mapTupleKey, "KEY")
mapKeyValKey := writer.Prefix(mapTupleKey, "VAL")
@ -302,7 +319,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
// Save the classifiers
for i, c := range m.classifiers {
clsPrefix := pI(writer.Prefix(prefix, "CLASSIFIERS"), i)
clsPrefix := pI("CLASSIFIERS", i)
err = c.SaveWithPrefix(writer, clsPrefix)
if err != nil {
return base.FormatError(err, "Can't save classifier for class %s", m.classValues[i])
@ -316,7 +333,7 @@ func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Att
attrs := from.AllAttributes()
classAttrs := from.AllClassAttributes()
if len(classAttrs) != 1 {
panic("Only 1 class Attribute is supported!")
panic(fmt.Errorf("Only 1 class Attribute is supported, had %d", len(classAttrs)))
}
ret := make(map[base.Attribute]base.Attribute)
for _, a := range attrs {

View File

@ -654,3 +654,39 @@ func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
func (t *ID3DecisionTree) String() string {
return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root)
}
func (t *ID3DecisionTree) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "KNN",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}
}
func (t *ID3DecisionTree) Save(filePath string) error {
writer, err := base.CreateSerializedClassifierStub(filePath, t.GetMetadata())
if err != nil {
return err
}
fmt.Printf("writer: %v", writer)
return t.SaveWithPrefix(writer, "")
}
func (t *ID3DecisionTree) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
return t.Root.SaveWithPrefix(writer, prefix)
}
func (t *ID3DecisionTree) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
return t.LoadWithPrefix(reader, "")
}
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
t.Root = &DecisionTreeNode{}
return t.Root.LoadWithPrefix(reader, "")
}