mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
ensemble: tests pass
This commit is contained in:
parent
3e80230d3d
commit
e27215052b
@ -60,6 +60,12 @@ func copyFixedDataGridStructure(of FixedDataGrid) (*DenseInstances, []AttributeS
|
||||
// Add and store new AttributeSpec
|
||||
specs2[i] = ret.AddAttribute(a)
|
||||
}
|
||||
|
||||
// Add class attributes
|
||||
cAttrs := of.AllClassAttributes()
|
||||
for _, a := range cAttrs {
|
||||
ret.AddClassAttribute(a)
|
||||
}
|
||||
return ret, specs1, specs2
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,11 @@ func (f *FunctionalTarReader) GetNamedFile (name string) ([]byte, error) {
|
||||
ret := make([]byte, hdr.Size)
|
||||
n, err := tr.Read(ret)
|
||||
if int64(n) != hdr.Size {
|
||||
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s), got %d", n, hdr.Size))
|
||||
if int64(n) < hdr.Size {
|
||||
log.Printf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size)
|
||||
} else {
|
||||
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -4,10 +4,8 @@ import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/linear_models"
|
||||
"github.com/sjwhitworth/golearn/meta"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// MultiLinearSVC implements a multi-class Support Vector Classifier using a one-vs-all
|
||||
@ -32,8 +30,6 @@ func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64,
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Return me...
|
||||
ret := &MultiLinearSVC{
|
||||
parameters: params,
|
||||
@ -91,7 +87,7 @@ func (m *MultiLinearSVC) GetClassifierMetadata() base.ClassifierMetadataV1 {
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "MultiLinearSVC",
|
||||
ClassifierVersion: "1",
|
||||
ClassifierMetadata: nil
|
||||
ClassifierMetadata: nil,
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,7 +97,7 @@ func (m *MultiLinearSVC) Save(filePath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := m.SaveWithPrefix(serializer, "")
|
||||
err = m.SaveWithPrefix(serializer, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to Save(): %v", err)
|
||||
}
|
||||
@ -168,12 +164,13 @@ func (m *MultiLinearSVC) LoadWithPrefix(reader *base.ClassifierDeserializer, pre
|
||||
return fmt.Errorf("Can't load parameters: %v", err)
|
||||
}
|
||||
|
||||
m.initializeOneVsAllModel()
|
||||
|
||||
// Load the model
|
||||
err = m.m.LoadWithPrefix(reader, p("one-vs-all"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.initializeOneVsAllModel()
|
||||
return nil
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func TestMultiSVMUnweighted(t *testing.T) {
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("Loading should work...", func() {
|
||||
mLoaded := NewMultiLinearSVC("l2", "l1", true, 1.00, 1e-8, weights)
|
||||
mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, nil)
|
||||
err := mLoaded.Load(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
@ -40,7 +40,7 @@ func TestMultiSVMUnweighted(t *testing.T) {
|
||||
So(err, ShouldBeNil)
|
||||
newPredictions, err := mLoaded.Predict(Y)
|
||||
So(err, ShouldBeNil)
|
||||
So(originalPredictions, ShouldEqual, newPredictions)
|
||||
So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue)
|
||||
})
|
||||
|
||||
})
|
||||
@ -68,28 +68,29 @@ func TestMultiSVMWeighted(t *testing.T) {
|
||||
predictions, err := m.Predict(Y)
|
||||
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
|
||||
})
|
||||
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.60)
|
||||
|
||||
Convey("Saving should work...", func() {
|
||||
f, err := ioutil.TempFile("","tree")
|
||||
So(err, ShouldBeNil)
|
||||
err = m.Save(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("Loading should work...", func() {
|
||||
mLoaded := NewMultiLinearSVC("l2", "l1", true, 1.00, 1e-8, weights)
|
||||
err := mLoaded.Load(f.Name())
|
||||
Convey("Saving should work...", func() {
|
||||
f, err := ioutil.TempFile("", "tree")
|
||||
So(err, ShouldBeNil)
|
||||
err = m.Save(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("Predictions should be the same...", func() {
|
||||
originalPredictions, err := m.Predict(Y)
|
||||
Convey("Loading should work...", func() {
|
||||
mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, weights)
|
||||
err := mLoaded.Load(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
newPredictions, err := mLoaded.Predict(Y)
|
||||
So(err, ShouldBeNil)
|
||||
So(originalPredictions, ShouldEqual, newPredictions)
|
||||
})
|
||||
|
||||
Convey("Predictions should be the same...", func() {
|
||||
originalPredictions, err := m.Predict(Y)
|
||||
So(err, ShouldBeNil)
|
||||
newPredictions, err := mLoaded.Predict(Y)
|
||||
So(err, ShouldBeNil)
|
||||
So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue)
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
|
@ -53,7 +53,7 @@ func (f *RandomForest) Fit(on base.FixedDataGrid) error {
|
||||
|
||||
// Predict generates predictions from a trained RandomForest.
|
||||
func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
return f.Model.Predict(with), nil
|
||||
return f.Model.Predict(with)
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of this tree.
|
||||
|
@ -3,6 +3,7 @@ package meta
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"log"
|
||||
)
|
||||
|
||||
// OneVsAllModel replaces class Attributes with numeric versions
|
||||
@ -61,6 +62,10 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
||||
}
|
||||
m.maxClassVal = val
|
||||
|
||||
// If we're reloading, we may just be fitting to the structure
|
||||
_, srcRows := using.Size()
|
||||
fittingToStructure := srcRows == 0
|
||||
|
||||
// Create individual filtered instances for training
|
||||
filters := make([]*oneVsAllFilter, val+1)
|
||||
classifiers := make([]base.Classifier, val+1)
|
||||
@ -72,7 +77,9 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
||||
}
|
||||
filters[i] = f
|
||||
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
|
||||
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
|
||||
if !fittingToStructure {
|
||||
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
|
||||
}
|
||||
}
|
||||
|
||||
m.filters = filters
|
||||
@ -90,7 +97,13 @@ func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, er
|
||||
ret := base.GeneratePredictionVector(what)
|
||||
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
|
||||
specs := make([]base.AttributeSpec, m.maxClassVal+1)
|
||||
|
||||
if int(m.maxClassVal) > len(m.filters) || (m.maxClassVal == 0 && len(m.filters) == 0) {
|
||||
return nil, base.WrapError(fmt.Errorf("Internal error: m.Filter len = %d, maxClassVal = %d", len(m.filters), m.maxClassVal))
|
||||
}
|
||||
|
||||
for i := uint64(0); i <= m.maxClassVal; i++ {
|
||||
//log.Printf("i = %d, m.Filter len = %d, maxClassVal = %d", i, len(m.filters), m.maxClassVal)
|
||||
f := m.filters[i]
|
||||
c := base.NewLazilyFilteredInstances(what, f)
|
||||
p, err := m.classifiers[i].Predict(c)
|
||||
@ -139,7 +152,10 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
|
||||
if err != nil {
|
||||
return base.DescribeError("Can't load INSTANCE_STRUCTURE", err)
|
||||
}
|
||||
m.fitOn = fitOn
|
||||
m.Fit(fitOn)
|
||||
/*if err != nil {
|
||||
base.DescribeError("Could not fit reloaded classifier to the structure", err)
|
||||
}*/
|
||||
|
||||
// Reload the filters
|
||||
numFiltersU64, err := reader.GetU64ForKey(reader.Prefix(prefix, "FILTER_COUNT"))
|
||||
@ -151,7 +167,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
|
||||
for i := 0; i < numFilters; i++ {
|
||||
f := oneVsAllFilter{}
|
||||
|
||||
mapPrefix := pI(reader.Prefix(prefix, "FILTER"), i)
|
||||
mapPrefix := pI("FILTER", i)
|
||||
mapCountKey := reader.Prefix(mapPrefix, "COUNT")
|
||||
numAttrsInMapU64, err := reader.GetU64ForKey(mapCountKey)
|
||||
if err != nil {
|
||||
@ -161,7 +177,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
|
||||
attrMap := make(map[base.Attribute]base.Attribute)
|
||||
|
||||
for j := 0; j < int(numAttrsInMapU64); j++ {
|
||||
mapTupleKey := pI(mapPrefix, j)
|
||||
mapTupleKey := reader.Prefix(mapPrefix, fmt.Sprintf("%d"))
|
||||
mapKeyKeyKey := reader.Prefix(mapTupleKey, "KEY")
|
||||
mapKeyValKey := reader.Prefix(mapTupleKey, "VAL")
|
||||
|
||||
@ -213,7 +229,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
|
||||
m.classifiers = make([]base.Classifier, 0)
|
||||
for i, c := range classVals {
|
||||
cls := m.NewClassifierFunction(c)
|
||||
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
|
||||
clsPrefix := pI("CLASSIFIERS", i)
|
||||
|
||||
err = cls.LoadWithPrefix(reader, clsPrefix)
|
||||
if err != nil {
|
||||
@ -248,6 +264,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
|
||||
return writer.Prefix(prefix, writer.Prefix(n, fmt.Sprintf("%d", i)))
|
||||
}
|
||||
|
||||
|
||||
// Save the instances
|
||||
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "INSTANCE_STRUCTURE"), m.fitOn, false)
|
||||
if err != nil {
|
||||
@ -266,7 +283,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
|
||||
return base.DescribeError("Unable to write FILTER_COUNT", err)
|
||||
}
|
||||
for i, f := range m.filters {
|
||||
mapPrefix := pI(writer.Prefix(prefix, "FILTER"), i)
|
||||
mapPrefix := pI("FILTER", i)
|
||||
mapCountKey := writer.Prefix(mapPrefix, "COUNT")
|
||||
err := writer.WriteU64ForKey(mapCountKey, uint64(len(f.attrs)))
|
||||
if err != nil {
|
||||
@ -274,7 +291,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
|
||||
}
|
||||
j := 0
|
||||
for key := range f.attrs {
|
||||
mapTupleKey := pI(mapPrefix, j)
|
||||
mapTupleKey := writer.Prefix(mapPrefix, fmt.Sprintf("%d"))
|
||||
mapKeyKeyKey := writer.Prefix(mapTupleKey, "KEY")
|
||||
mapKeyValKey := writer.Prefix(mapTupleKey, "VAL")
|
||||
|
||||
@ -302,7 +319,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
|
||||
|
||||
// Save the classifiers
|
||||
for i, c := range m.classifiers {
|
||||
clsPrefix := pI(writer.Prefix(prefix, "CLASSIFIERS"), i)
|
||||
clsPrefix := pI("CLASSIFIERS", i)
|
||||
err = c.SaveWithPrefix(writer, clsPrefix)
|
||||
if err != nil {
|
||||
return base.FormatError(err, "Can't save classifier for class %s", m.classValues[i])
|
||||
@ -316,7 +333,7 @@ func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Att
|
||||
attrs := from.AllAttributes()
|
||||
classAttrs := from.AllClassAttributes()
|
||||
if len(classAttrs) != 1 {
|
||||
panic("Only 1 class Attribute is supported!")
|
||||
panic(fmt.Errorf("Only 1 class Attribute is supported, had %d", len(classAttrs)))
|
||||
}
|
||||
ret := make(map[base.Attribute]base.Attribute)
|
||||
for _, a := range attrs {
|
||||
|
36
trees/id3.go
36
trees/id3.go
@ -654,3 +654,39 @@ func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
func (t *ID3DecisionTree) String() string {
|
||||
return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root)
|
||||
}
|
||||
|
||||
func (t *ID3DecisionTree) GetMetadata() base.ClassifierMetadataV1 {
|
||||
return base.ClassifierMetadataV1{
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "KNN",
|
||||
ClassifierVersion: "1.0",
|
||||
ClassifierMetadata: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ID3DecisionTree) Save(filePath string) error {
|
||||
writer, err := base.CreateSerializedClassifierStub(filePath, t.GetMetadata())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("writer: %v", writer)
|
||||
return t.SaveWithPrefix(writer, "")
|
||||
}
|
||||
|
||||
func (t *ID3DecisionTree) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
||||
return t.Root.SaveWithPrefix(writer, prefix)
|
||||
}
|
||||
|
||||
func (t *ID3DecisionTree) Load(filePath string) error {
|
||||
reader, err := base.ReadSerializedClassifierStub(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return t.LoadWithPrefix(reader, "")
|
||||
}
|
||||
|
||||
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
t.Root = &DecisionTreeNode{}
|
||||
return t.Root.LoadWithPrefix(reader, "")
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user