1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00

Fixing all tests

This commit is contained in:
Richard Townsend 2018-01-28 16:22:33 +00:00
parent ce78cd0406
commit e2279995c1
4 changed files with 1 additions and 127 deletions

View File

@ -3,7 +3,6 @@ package base
import (
"archive/tar"
"compress/gzip"
"encoding/csv"
"encoding/json"
"fmt"
"io"
@ -285,14 +284,6 @@ func (c *ClassifierSerializer) Close() error {
return fmt.Errorf("Could not close file writer: %s", err)
}
log.Printf("Closed ClassifierSerializerStub at %s", c.f.Name())
//if err := c.f.Close(); err != nil {
// return fmt.Errorf("Could not close file: %s", err)
//}
os.Rename(c.f.Name(), c.filePath)
return nil
}
@ -433,115 +424,3 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
return &ret, nil
}
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
var hdr *tar.Header
gzWriter := gzip.NewWriter(f)
tw := tar.NewWriter(gzWriter)
// Write the MANIFEST entry
hdr = &tar.Header{
Name: "MANIFEST",
Size: int64(len(SerializationFormatVersion)),
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write MANIFEST header: %s", err)
}
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
return fmt.Errorf("Could not write MANIFEST contents: %s", err)
}
// Now write the dimensions of the dataset
attrCount, rowCount := inst.Size()
hdr = &tar.Header{
Name: "DIMS",
Size: 16,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write DIMS header: %s", err)
}
if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil {
return fmt.Errorf("Could not write DIMS (attrCount): %s", err)
}
if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil {
return fmt.Errorf("Could not write DIMS (rowCount): %s", err)
}
// Write the ATTRIBUTES files
classAttrs := inst.AllClassAttributes()
normalAttrs := NonClassAttributes(inst)
if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil {
return fmt.Errorf("Could not write CATTRS: %s", err)
}
if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil {
return fmt.Errorf("Could not write ATTRS: %s", err)
}
// Data must be written out in the same order as the Attributes
allAttrs := make([]Attribute, attrCount)
normCount := copy(allAttrs, normalAttrs)
for i, v := range classAttrs {
allAttrs[normCount+i] = v
}
allSpecs := ResolveAttributes(inst, allAttrs)
// First, estimate the amount of data we'll need...
dataLength := int64(0)
inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
for _, v := range val {
dataLength += int64(len(v))
}
return true, nil
})
// Then write the header
hdr = &tar.Header{
Name: "DATA",
Size: dataLength,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write DATA: %s", err)
}
// Then write the actual data
writtenLength := int64(0)
if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
for _, v := range val {
wl, err := tw.Write(v)
writtenLength += int64(wl)
if err != nil {
return false, err
}
}
return true, nil
}); err != nil {
return err
}
if writtenLength != dataLength {
return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength)
}
// Finally, close and flush the various levels
if err := tw.Flush(); err != nil {
return fmt.Errorf("Could not flush tar: %s", err)
}
if err := tw.Close(); err != nil {
return fmt.Errorf("Could not close tar: %s", err)
}
if err := gzWriter.Flush(); err != nil {
return fmt.Errorf("Could not flush gz: %s", err)
}
if err := gzWriter.Close(); err != nil {
return fmt.Errorf("Could not close gz: %s", err)
}
return nil
}

View File

@ -356,7 +356,6 @@ func (KNN *KNNClassifier) Save(filePath string) error {
// SaveWithPrefix outputs KNN as part of another file.
func (KNN *KNNClassifier) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
fmt.Printf("writer: %v", writer)
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "TrainingInstances"), KNN.TrainingData, true)
if err != nil {
return err

View File

@ -67,7 +67,6 @@ func (d *DecisionTreeRule) unmarshalJSON(data []byte) error {
if err != nil {
panic(err)
}
fmt.Printf("%s\n", splitBytes)
if string(splitBytes) != "\"unknown\"" {
d.SplitAttr, err = base.DeserializeAttribute(splitBytes)
if err != nil {
@ -107,7 +106,7 @@ type DecisionTreeNode struct {
Children map[string]*DecisionTreeNode `json:"children"`
ClassDist map[string]int `json:"class_dist"`
Class string `json:"class_string"`
ClassAttr base.Attribute `json:"-"`
ClassAttr base.Attribute `json:"-"`
SplitRule *DecisionTreeRule `json:"decision_tree_rule"`
}

View File

@ -7,7 +7,6 @@ import (
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"math/rand"
"fmt"
"os"
"testing"
)
@ -145,8 +144,6 @@ func verifyTreeClassification(trainData, testData base.FixedDataGrid) {
f.Close()
}()
fmt.Printf("%s", root)
err = root.Save(f.Name())
So(err, ShouldBeNil)