diff --git a/base/dense.go b/base/dense.go index 59a55cd..2129dea 100644 --- a/base/dense.go +++ b/base/dense.go @@ -60,6 +60,12 @@ func copyFixedDataGridStructure(of FixedDataGrid) (*DenseInstances, []AttributeS // Add and store new AttributeSpec specs2[i] = ret.AddAttribute(a) } + + // Add class attributes + cAttrs := of.AllClassAttributes() + for _, a := range cAttrs { + ret.AddClassAttribute(a) + } return ret, specs1, specs2 } diff --git a/base/serialize.go b/base/serialize.go index 4747f17..3ce7e05 100644 --- a/base/serialize.go +++ b/base/serialize.go @@ -49,7 +49,11 @@ func (f *FunctionalTarReader) GetNamedFile (name string) ([]byte, error) { ret := make([]byte, hdr.Size) n, err := tr.Read(ret) if int64(n) != hdr.Size { - return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s), got %d", n, hdr.Size)) + if int64(n) < hdr.Size { + log.Printf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size) + } else { + return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size)) + } } if err != nil { return nil, err diff --git a/ensemble/multisvc.go b/ensemble/multisvc.go index f876021..f5dca33 100644 --- a/ensemble/multisvc.go +++ b/ensemble/multisvc.go @@ -4,10 +4,8 @@ import ( "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/linear_models" "github.com/sjwhitworth/golearn/meta" - "io" - "os" + "fmt" - "encoding/json" ) // MultiLinearSVC implements a multi-class Support Vector Classifier using a one-vs-all @@ -32,8 +30,6 @@ func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64, panic(err) } - - // Return me... ret := &MultiLinearSVC{ parameters: params, @@ -91,7 +87,7 @@ func (m *MultiLinearSVC) GetClassifierMetadata() base.ClassifierMetadataV1 { FormatVersion: 1, ClassifierName: "MultiLinearSVC", ClassifierVersion: "1", - ClassifierMetadata: nil + ClassifierMetadata: nil, } } @@ -101,7 +97,7 @@ func (m *MultiLinearSVC) Save(filePath string) error { if err != nil { return err } - err := m.SaveWithPrefix(serializer, "") + err = m.SaveWithPrefix(serializer, "") if err != nil { return fmt.Errorf("Unable to Save(): %v", err) } @@ -168,12 +164,13 @@ func (m *MultiLinearSVC) LoadWithPrefix(reader *base.ClassifierDeserializer, pre return fmt.Errorf("Can't load parameters: %v", err) } + m.initializeOneVsAllModel() + // Load the model err = m.m.LoadWithPrefix(reader, p("one-vs-all")) if err != nil { return err } - m.initializeOneVsAllModel() return nil } diff --git a/ensemble/multisvc_test.go b/ensemble/multisvc_test.go index d28b117..47e95cf 100644 --- a/ensemble/multisvc_test.go +++ b/ensemble/multisvc_test.go @@ -31,7 +31,7 @@ func TestMultiSVMUnweighted(t *testing.T) { So(err, ShouldBeNil) Convey("Loading should work...", func() { - mLoaded := NewMultiLinearSVC("l2", "l1", true, 1.00, 1e-8, weights) + mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, nil) err := mLoaded.Load(f.Name()) So(err, ShouldBeNil) @@ -40,7 +40,7 @@ func TestMultiSVMUnweighted(t *testing.T) { So(err, ShouldBeNil) newPredictions, err := mLoaded.Predict(Y) So(err, ShouldBeNil) - So(originalPredictions, ShouldEqual, newPredictions) + So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue) }) }) @@ -68,28 +68,29 @@ func TestMultiSVMWeighted(t *testing.T) { predictions, err := m.Predict(Y) cf, err := evaluation.GetConfusionMatrix(Y, predictions) So(err, ShouldEqual, nil) - So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70) - }) + So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.60) - Convey("Saving should work...", func() { - f, err := ioutil.TempFile("","tree") - So(err, ShouldBeNil) - err = m.Save(f.Name()) - So(err, ShouldBeNil) - Convey("Loading should work...", func() { - mLoaded := NewMultiLinearSVC("l2", "l1", true, 1.00, 1e-8, weights) - err := mLoaded.Load(f.Name()) + Convey("Saving should work...", func() { + f, err := ioutil.TempFile("", "tree") + So(err, ShouldBeNil) + err = m.Save(f.Name()) So(err, ShouldBeNil) - Convey("Predictions should be the same...", func() { - originalPredictions, err := m.Predict(Y) + Convey("Loading should work...", func() { + mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, weights) + err := mLoaded.Load(f.Name()) So(err, ShouldBeNil) - newPredictions, err := mLoaded.Predict(Y) - So(err, ShouldBeNil) - So(originalPredictions, ShouldEqual, newPredictions) - }) + Convey("Predictions should be the same...", func() { + originalPredictions, err := m.Predict(Y) + So(err, ShouldBeNil) + newPredictions, err := mLoaded.Predict(Y) + So(err, ShouldBeNil) + So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue) + }) + + }) }) }) diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index 0ec176f..55feea0 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -53,7 +53,7 @@ func (f *RandomForest) Fit(on base.FixedDataGrid) error { // Predict generates predictions from a trained RandomForest. func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, error) { - return f.Model.Predict(with), nil + return f.Model.Predict(with) } // String returns a human-readable representation of this tree. diff --git a/meta/one_v_all.go b/meta/one_v_all.go index b35209f..4f53d3d 100644 --- a/meta/one_v_all.go +++ b/meta/one_v_all.go @@ -3,6 +3,7 @@ package meta import ( "fmt" "github.com/sjwhitworth/golearn/base" + "log" ) // OneVsAllModel replaces class Attributes with numeric versions @@ -61,6 +62,10 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) { } m.maxClassVal = val + // If we're reloading, we may just be fitting to the structure + _, srcRows := using.Size() + fittingToStructure := srcRows == 0 + // Create individual filtered instances for training filters := make([]*oneVsAllFilter, val+1) classifiers := make([]base.Classifier, val+1) @@ -72,7 +77,9 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) { } filters[i] = f classifiers[i] = m.NewClassifierFunction(classVals[int(i)]) - classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f)) + if !fittingToStructure { + classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f)) + } } m.filters = filters @@ -90,7 +97,13 @@ func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, er ret := base.GeneratePredictionVector(what) vecs := make([]base.FixedDataGrid, m.maxClassVal+1) specs := make([]base.AttributeSpec, m.maxClassVal+1) + + if int(m.maxClassVal) > len(m.filters) || (m.maxClassVal == 0 && len(m.filters) == 0) { + return nil, base.WrapError(fmt.Errorf("Internal error: m.Filter len = %d, maxClassVal = %d", len(m.filters), m.maxClassVal)) + } + for i := uint64(0); i <= m.maxClassVal; i++ { + //log.Printf("i = %d, m.Filter len = %d, maxClassVal = %d", i, len(m.filters), m.maxClassVal) f := m.filters[i] c := base.NewLazilyFilteredInstances(what, f) p, err := m.classifiers[i].Predict(c) @@ -139,7 +152,10 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref if err != nil { return base.DescribeError("Can't load INSTANCE_STRUCTURE", err) } - m.fitOn = fitOn + m.Fit(fitOn) + /*if err != nil { + base.DescribeError("Could not fit reloaded classifier to the structure", err) + }*/ // Reload the filters numFiltersU64, err := reader.GetU64ForKey(reader.Prefix(prefix, "FILTER_COUNT")) @@ -151,7 +167,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref for i := 0; i < numFilters; i++ { f := oneVsAllFilter{} - mapPrefix := pI(reader.Prefix(prefix, "FILTER"), i) + mapPrefix := pI("FILTER", i) mapCountKey := reader.Prefix(mapPrefix, "COUNT") numAttrsInMapU64, err := reader.GetU64ForKey(mapCountKey) if err != nil { @@ -161,7 +177,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref attrMap := make(map[base.Attribute]base.Attribute) for j := 0; j < int(numAttrsInMapU64); j++ { - mapTupleKey := pI(mapPrefix, j) + mapTupleKey := reader.Prefix(mapPrefix, fmt.Sprintf("%d")) mapKeyKeyKey := reader.Prefix(mapTupleKey, "KEY") mapKeyValKey := reader.Prefix(mapTupleKey, "VAL") @@ -213,7 +229,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref m.classifiers = make([]base.Classifier, 0) for i, c := range classVals { cls := m.NewClassifierFunction(c) - clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i) + clsPrefix := pI("CLASSIFIERS", i) err = cls.LoadWithPrefix(reader, clsPrefix) if err != nil { @@ -248,6 +264,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix return writer.Prefix(prefix, writer.Prefix(n, fmt.Sprintf("%d", i))) } + // Save the instances err := writer.WriteInstancesForKey(writer.Prefix(prefix, "INSTANCE_STRUCTURE"), m.fitOn, false) if err != nil { @@ -266,7 +283,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix return base.DescribeError("Unable to write FILTER_COUNT", err) } for i, f := range m.filters { - mapPrefix := pI(writer.Prefix(prefix, "FILTER"), i) + mapPrefix := pI("FILTER", i) mapCountKey := writer.Prefix(mapPrefix, "COUNT") err := writer.WriteU64ForKey(mapCountKey, uint64(len(f.attrs))) if err != nil { @@ -274,7 +291,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix } j := 0 for key := range f.attrs { - mapTupleKey := pI(mapPrefix, j) + mapTupleKey := writer.Prefix(mapPrefix, fmt.Sprintf("%d")) mapKeyKeyKey := writer.Prefix(mapTupleKey, "KEY") mapKeyValKey := writer.Prefix(mapTupleKey, "VAL") @@ -302,7 +319,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix // Save the classifiers for i, c := range m.classifiers { - clsPrefix := pI(writer.Prefix(prefix, "CLASSIFIERS"), i) + clsPrefix := pI("CLASSIFIERS", i) err = c.SaveWithPrefix(writer, clsPrefix) if err != nil { return base.FormatError(err, "Can't save classifier for class %s", m.classValues[i]) @@ -316,7 +333,7 @@ func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Att attrs := from.AllAttributes() classAttrs := from.AllClassAttributes() if len(classAttrs) != 1 { - panic("Only 1 class Attribute is supported!") + panic(fmt.Errorf("Only 1 class Attribute is supported, had %d", len(classAttrs))) } ret := make(map[base.Attribute]base.Attribute) for _, a := range attrs { diff --git a/trees/id3.go b/trees/id3.go index 9ce3de1..0b8f1b2 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -654,3 +654,39 @@ func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) (base.FixedDataGrid, func (t *ID3DecisionTree) String() string { return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root) } + +func (t *ID3DecisionTree) GetMetadata() base.ClassifierMetadataV1 { + return base.ClassifierMetadataV1{ + FormatVersion: 1, + ClassifierName: "KNN", + ClassifierVersion: "1.0", + ClassifierMetadata: nil, + } +} + +func (t *ID3DecisionTree) Save(filePath string) error { + writer, err := base.CreateSerializedClassifierStub(filePath, t.GetMetadata()) + if err != nil { + return err + } + fmt.Printf("writer: %v", writer) + return t.SaveWithPrefix(writer, "") +} + +func (t *ID3DecisionTree) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error { + return t.Root.SaveWithPrefix(writer, prefix) +} + +func (t *ID3DecisionTree) Load(filePath string) error { + reader, err := base.ReadSerializedClassifierStub(filePath) + if err != nil { + return err + } + return t.LoadWithPrefix(reader, "") +} + +func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { + t.Root = &DecisionTreeNode{} + return t.Root.LoadWithPrefix(reader, "") +} +