mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
Fixing all tests
This commit is contained in:
parent
ce78cd0406
commit
e2279995c1
@ -3,7 +3,6 @@ package base
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -285,14 +284,6 @@ func (c *ClassifierSerializer) Close() error {
|
||||
return fmt.Errorf("Could not close file writer: %s", err)
|
||||
}
|
||||
|
||||
log.Printf("Closed ClassifierSerializerStub at %s", c.f.Name())
|
||||
|
||||
//if err := c.f.Close(); err != nil {
|
||||
// return fmt.Errorf("Could not close file: %s", err)
|
||||
//}
|
||||
|
||||
os.Rename(c.f.Name(), c.filePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -433,115 +424,3 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
|
||||
return &ret, nil
|
||||
|
||||
}
|
||||
|
||||
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
|
||||
var hdr *tar.Header
|
||||
|
||||
gzWriter := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gzWriter)
|
||||
|
||||
// Write the MANIFEST entry
|
||||
hdr = &tar.Header{
|
||||
Name: "MANIFEST",
|
||||
Size: int64(len(SerializationFormatVersion)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write MANIFEST header: %s", err)
|
||||
}
|
||||
|
||||
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
|
||||
return fmt.Errorf("Could not write MANIFEST contents: %s", err)
|
||||
}
|
||||
|
||||
// Now write the dimensions of the dataset
|
||||
attrCount, rowCount := inst.Size()
|
||||
hdr = &tar.Header{
|
||||
Name: "DIMS",
|
||||
Size: 16,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write DIMS header: %s", err)
|
||||
}
|
||||
|
||||
if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil {
|
||||
return fmt.Errorf("Could not write DIMS (attrCount): %s", err)
|
||||
}
|
||||
if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil {
|
||||
return fmt.Errorf("Could not write DIMS (rowCount): %s", err)
|
||||
}
|
||||
|
||||
// Write the ATTRIBUTES files
|
||||
classAttrs := inst.AllClassAttributes()
|
||||
normalAttrs := NonClassAttributes(inst)
|
||||
if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil {
|
||||
return fmt.Errorf("Could not write CATTRS: %s", err)
|
||||
}
|
||||
if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil {
|
||||
return fmt.Errorf("Could not write ATTRS: %s", err)
|
||||
}
|
||||
|
||||
// Data must be written out in the same order as the Attributes
|
||||
allAttrs := make([]Attribute, attrCount)
|
||||
normCount := copy(allAttrs, normalAttrs)
|
||||
for i, v := range classAttrs {
|
||||
allAttrs[normCount+i] = v
|
||||
}
|
||||
|
||||
allSpecs := ResolveAttributes(inst, allAttrs)
|
||||
|
||||
// First, estimate the amount of data we'll need...
|
||||
dataLength := int64(0)
|
||||
inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
||||
for _, v := range val {
|
||||
dataLength += int64(len(v))
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
// Then write the header
|
||||
hdr = &tar.Header{
|
||||
Name: "DATA",
|
||||
Size: dataLength,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write DATA: %s", err)
|
||||
}
|
||||
|
||||
// Then write the actual data
|
||||
writtenLength := int64(0)
|
||||
if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
||||
for _, v := range val {
|
||||
wl, err := tw.Write(v)
|
||||
writtenLength += int64(wl)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if writtenLength != dataLength {
|
||||
return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength)
|
||||
}
|
||||
|
||||
// Finally, close and flush the various levels
|
||||
if err := tw.Flush(); err != nil {
|
||||
return fmt.Errorf("Could not flush tar: %s", err)
|
||||
}
|
||||
|
||||
if err := tw.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close tar: %s", err)
|
||||
}
|
||||
|
||||
if err := gzWriter.Flush(); err != nil {
|
||||
return fmt.Errorf("Could not flush gz: %s", err)
|
||||
}
|
||||
|
||||
if err := gzWriter.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close gz: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -356,7 +356,6 @@ func (KNN *KNNClassifier) Save(filePath string) error {
|
||||
|
||||
// SaveWithPrefix outputs KNN as part of another file.
|
||||
func (KNN *KNNClassifier) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
||||
fmt.Printf("writer: %v", writer)
|
||||
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "TrainingInstances"), KNN.TrainingData, true)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -67,7 +67,6 @@ func (d *DecisionTreeRule) unmarshalJSON(data []byte) error {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("%s\n", splitBytes)
|
||||
if string(splitBytes) != "\"unknown\"" {
|
||||
d.SplitAttr, err = base.DeserializeAttribute(splitBytes)
|
||||
if err != nil {
|
||||
@ -107,7 +106,7 @@ type DecisionTreeNode struct {
|
||||
Children map[string]*DecisionTreeNode `json:"children"`
|
||||
ClassDist map[string]int `json:"class_dist"`
|
||||
Class string `json:"class_string"`
|
||||
ClassAttr base.Attribute `json:"-"`
|
||||
ClassAttr base.Attribute `json:"-"`
|
||||
SplitRule *DecisionTreeRule `json:"decision_tree_rule"`
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
@ -145,8 +144,6 @@ func verifyTreeClassification(trainData, testData base.FixedDataGrid) {
|
||||
f.Close()
|
||||
}()
|
||||
|
||||
fmt.Printf("%s", root)
|
||||
|
||||
err = root.Save(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user