mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
Reformat, fix tests
This commit is contained in:
parent
fc110aab48
commit
e7fee0a2d1
@ -45,13 +45,15 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
|
||||
}
|
||||
|
||||
if hdr.Name == name {
|
||||
ret := make([]byte, hdr.Size)
|
||||
n, err := tr.Read(ret)
|
||||
if int64(n) != hdr.Size {
|
||||
if int64(n) < hdr.Size {
|
||||
log.Printf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size)
|
||||
ret, err := ioutil.ReadAll(tr)
|
||||
if err != nil {
|
||||
return nil, WrapError(err)
|
||||
}
|
||||
if int64(len(ret)) != hdr.Size {
|
||||
if int64(len(ret)) < hdr.Size {
|
||||
log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err)
|
||||
} else {
|
||||
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size))
|
||||
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@ -248,7 +250,7 @@ func (c *ClassifierDeserializer) Close() {
|
||||
// ClassifierSerializer is an object used by SaveableClassifiers.
|
||||
type ClassifierSerializer struct {
|
||||
gzipWriter *gzip.Writer
|
||||
fileWriter io.WriteCloser
|
||||
fileWriter *os.File
|
||||
tarWriter *tar.Writer
|
||||
f *os.File
|
||||
filePath string
|
||||
@ -274,6 +276,10 @@ func (c *ClassifierSerializer) Close() error {
|
||||
return fmt.Errorf("Could not close gz: %s", err)
|
||||
}
|
||||
|
||||
if err := c.fileWriter.Sync(); err != nil {
|
||||
return fmt.Errorf("Could not close file writer: %s", err)
|
||||
}
|
||||
|
||||
if err := c.fileWriter.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close file writer: %s", err)
|
||||
}
|
||||
@ -310,6 +316,9 @@ func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
||||
return fmt.Errorf("Could not write data for '%s': %s", key, err)
|
||||
}
|
||||
|
||||
c.tarWriter.Flush()
|
||||
c.gzipWriter.Flush()
|
||||
c.fileWriter.Sync()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ func TestSerializeToCSV(t *testing.T) {
|
||||
Convey("What's written out should match what's read in", func() {
|
||||
dinst, err := ParseCSVToInstances(f.Name(), true)
|
||||
So(err, ShouldBeNil)
|
||||
So(inst.String(), ShouldEqual, dinst.String())
|
||||
So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -60,3 +60,40 @@ func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, err
|
||||
func (f *RandomForest) String() string {
|
||||
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
||||
}
|
||||
|
||||
func (f *RandomForest) GetMetadata() base.ClassifierMetadataV1 {
|
||||
return base.ClassifierMetadataV1{
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "KNN",
|
||||
ClassifierVersion: "1.0",
|
||||
ClassifierMetadata: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *RandomForest) Save(filePath string) error {
|
||||
writer, err := base.CreateSerializedClassifierStub(filePath, f.GetMetadata())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = f.SaveWithPrefix(writer, "model")
|
||||
writer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *RandomForest) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
||||
return f.Model.SaveWithPrefix(writer, prefix)
|
||||
}
|
||||
|
||||
func (f *RandomForest) Load(filePath string) error {
|
||||
reader, err := base.ReadSerializedClassifierStub(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.LoadWithPrefix(reader, "model")
|
||||
}
|
||||
|
||||
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
return f.Model.LoadWithPrefix(reader, prefix)
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
func TestRandomForest(t *testing.T) {
|
||||
@ -53,3 +55,51 @@ func TestRandomForest(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestRandomForestSerialization(t *testing.T) {
|
||||
Convey("Given a valid CSV file", t, func() {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("When Chi-Merge filtering the data", func() {
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
|
||||
Convey("Splitting the data into test and training sets", func() {
|
||||
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
|
||||
|
||||
Convey("Fitting and predicting with a Random Forest", func() {
|
||||
rf := NewRandomForest(10, 3)
|
||||
err = rf.Fit(trainData)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
oldPredictions, err := rf.Predict(testData)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("Saving the model should work...", func() {
|
||||
f, err := ioutil.TempFile(os.TempDir(), "rf")
|
||||
err = rf.Save(f.Name())
|
||||
defer func() {
|
||||
f.Close()
|
||||
}()
|
||||
So(err, ShouldBeNil)
|
||||
Convey("Loading the model should work...", func() {
|
||||
newRf := NewRandomForest(10, 3)
|
||||
err := newRf.Load(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
Convey("Predictions should be the same...", func() {
|
||||
newPredictions, err := newRf.Predict(testData)
|
||||
So(err, ShouldBeNil)
|
||||
So(base.InstancesAreEqual(newPredictions, oldPredictions), ShouldBeTrue)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -253,7 +253,7 @@ func (b *BaggedModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix s
|
||||
|
||||
// Save the classifiers
|
||||
for i, c := range b.Models {
|
||||
clsPrefix := pI(writer.Prefix(prefix, "CLASSIFIERS"), i)
|
||||
clsPrefix := fmt.Sprintf("%s/", pI( "CLASSIFIERS", i))
|
||||
err = c.SaveWithPrefix(writer, clsPrefix)
|
||||
if err != nil {
|
||||
return base.FormatError(err, "Can't save classifier %d", i)
|
||||
@ -303,16 +303,12 @@ func (b *BaggedModel) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix
|
||||
if err != nil {
|
||||
return base.DescribeError("Can't read NUM_RANDOM_FEATURES", err)
|
||||
}
|
||||
/*classifiersKey := reader.Prefix(prefix, "NUM_CLASSIFIERS")
|
||||
numClassifiers, err := reader.GetU64ForKey(classifiersKey)
|
||||
if err != nil {
|
||||
return base.DescribeError("Can't read NUM_CLASSIFIERS", err)
|
||||
}*/
|
||||
|
||||
b.RandomFeatures = int(randomFeatures)
|
||||
|
||||
// Reload the classifiers
|
||||
for i, m := range b.Models {
|
||||
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
|
||||
clsPrefix := fmt.Sprintf("%s/", pI( "CLASSIFIERS", i))
|
||||
err := m.LoadWithPrefix(reader, clsPrefix)
|
||||
if err != nil {
|
||||
return base.DescribeError("Can't read classifier", err)
|
||||
|
@ -657,7 +657,7 @@ func (t *ID3DecisionTree) String() string {
|
||||
func (t *ID3DecisionTree) GetMetadata() base.ClassifierMetadataV1 {
|
||||
return base.ClassifierMetadataV1{
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "KNN",
|
||||
ClassifierName: "ID3",
|
||||
ClassifierVersion: "1.0",
|
||||
ClassifierMetadata: nil,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user