mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-05-01 22:18:10 +08:00
Adding some new ways of serializing things
This commit is contained in:
parent
b7ad1fe499
commit
7b08820152
@ -268,6 +268,205 @@ func DeserializeInstances(f io.Reader) (ret *DenseInstances, err error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// ClassifierMetadataV1 is what gets written into METADATA
|
||||
// in a classification file format.
|
||||
type ClassifierMetadataV1 struct {
|
||||
// FormatVersion should always be 1 for this structure
|
||||
FormatVersion int `json:"format_version"`
|
||||
// Uses the classifier name (provided by the classifier)
|
||||
ClassifierName string `json:"classifier"`
|
||||
// ClassifierVersion is also provided by the classifier
|
||||
// and checks whether this version of GoLearn can read what's
|
||||
// be written.
|
||||
ClassifierVersion string `json"classifier_version"`
|
||||
// This is a custom metadata field, provided by the classifier
|
||||
ClassifierMetadata map[string]interface{} `json:"classifier_metadata"`
|
||||
}
|
||||
|
||||
type ClassifierDeserializer struct {
|
||||
gzipReader io.Reader
|
||||
fileReader io.Reader
|
||||
tarReader *tar.Reader
|
||||
Metadata *ClassifierMetadataV1
|
||||
}
|
||||
|
||||
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub
|
||||
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
|
||||
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Can't open file: %s", err)
|
||||
}
|
||||
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Can't decompress file: %s", err)
|
||||
}
|
||||
|
||||
tz := tar.NewReader(gzr)
|
||||
|
||||
// Check that the serialization format is right
|
||||
// Retrieve the MANIFEST and verify
|
||||
manifestBytes := getTarContent(tz, "MANIFEST")
|
||||
if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) {
|
||||
return nil, fmt.Errorf("Unsupported MANIFEST: %s", string(manifestBytes))
|
||||
}
|
||||
|
||||
//
|
||||
// Parse METADATA
|
||||
//
|
||||
metadataBytes := getTarContent(tz, "METADATA")
|
||||
var metadata ClassifierMetadataV1
|
||||
err = json.Unmarshal(metadataBytes, &metadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
|
||||
}
|
||||
|
||||
// Check that we can understand this archive
|
||||
if metadata.FormatVersion != 1 {
|
||||
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
|
||||
}
|
||||
|
||||
ret := &ClassifierDeserializer{
|
||||
f,
|
||||
gzr,
|
||||
tz,
|
||||
&metadata,
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
|
||||
return getTarContent(c.tarReader, key), nil
|
||||
}
|
||||
|
||||
type ClassifierSerializer struct {
|
||||
gzipWriter *gzip.Writer
|
||||
fileWriter io.Writer
|
||||
tarWriter *tar.Writer
|
||||
}
|
||||
|
||||
// Close finalizes the Classifier serialization session
|
||||
func (c *ClassifierSerializer) Close() error {
|
||||
|
||||
// Finally, close and flush the various levels
|
||||
if err := c.tarWriter.Flush(); err != nil {
|
||||
return fmt.Errorf("Could not flush tar: %s", err)
|
||||
}
|
||||
|
||||
if err := c.tarWriter.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close tar: %s", err)
|
||||
}
|
||||
|
||||
if err := c.gzipWriter.Flush(); err != nil {
|
||||
return fmt.Errorf("Could not flush gz: %s", err)
|
||||
}
|
||||
|
||||
if err := c.gzipWriter.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close gz: %s", err)
|
||||
}
|
||||
|
||||
//if err := c.fileWriter.Flush(); err != nil {
|
||||
// return fmt.Errorf("Could not flush file: %s", err)
|
||||
//}
|
||||
|
||||
//if err := c.fileWriter.Flush(); err != nil {
|
||||
// return fmt.Errorf("Could not close file: %s", err)
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteBytesForKey creates a new entry in the serializer file
|
||||
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
||||
|
||||
//
|
||||
// Write header for key
|
||||
//
|
||||
hdr := &tar.Header{
|
||||
Name: key,
|
||||
Size: int64(len(b)),
|
||||
}
|
||||
|
||||
if err := c.tarWriter.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write header for '%s': %s", err)
|
||||
}
|
||||
//
|
||||
// Write data
|
||||
//
|
||||
if _, err := c.tarWriter.Write(b); err != nil {
|
||||
return fmt.Errorf("Could not write data for '%s': %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSerializedClassifierStub generates a file to serialize into
|
||||
// and writes the METADATA header.
|
||||
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
||||
|
||||
// Open the filePath
|
||||
f, err := os.OpenFile(filePath, os.O_RDWR | os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var hdr *tar.Header
|
||||
gzWriter := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gzWriter)
|
||||
|
||||
ret := &ClassifierSerializer{
|
||||
gzipWriter: gzWriter,
|
||||
fileWriter: f,
|
||||
tarWriter: tw,
|
||||
}
|
||||
|
||||
//
|
||||
// Write the MANIFEST entry
|
||||
//
|
||||
hdr = &tar.Header{
|
||||
Name: "MANIFEST",
|
||||
Size: int64(len(SerializationFormatVersion)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return nil, fmt.Errorf("Could not write MANIFEST header: %s", err)
|
||||
}
|
||||
|
||||
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
|
||||
return nil, fmt.Errorf("Could not write MANIFEST contents: %s", err)
|
||||
}
|
||||
|
||||
//
|
||||
// Write the METADATA entry
|
||||
//
|
||||
|
||||
// Marshal the classifier information
|
||||
cB, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(cB) == 0 {
|
||||
return nil, fmt.Errorf("JSON marshal error: %s", err)
|
||||
}
|
||||
|
||||
// Write the information into the file
|
||||
hdr = &tar.Header{
|
||||
Name: "METADATA",
|
||||
Size: int64(len(cB)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return nil, fmt.Errorf("Could not write METADATA object %s", err)
|
||||
}
|
||||
|
||||
if _, err := tw.Write(cB); err != nil {
|
||||
return nil, fmt.Errorf("Could not write METDATA contents: %s", err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
|
||||
}
|
||||
|
||||
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
|
||||
var hdr *tar.Header
|
||||
|
||||
|
@ -1,107 +1,142 @@
|
||||
package base
|
||||
|
||||
// import (
|
||||
// "archive/tar"
|
||||
// "compress/gzip"
|
||||
// "fmt"
|
||||
// . "github.com/smartystreets/goconvey/convey"
|
||||
// "io"
|
||||
// "io/ioutil"
|
||||
// "testing"
|
||||
// )
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// func TestSerializeToCSV(t *testing.T) {
|
||||
// Convey("Reading some instances...", t, func() {
|
||||
// inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
// So(err, ShouldBeNil)
|
||||
func TestSerializeToCSV(t *testing.T) {
|
||||
Convey("Reading some instances...", t, func() {
|
||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Convey("Saving the instances to CSV...", func() {
|
||||
// f, err := ioutil.TempFile("", "instTmpCSV")
|
||||
// So(err, ShouldBeNil)
|
||||
// err = SerializeInstancesToCSV(inst, f.Name())
|
||||
// So(err, ShouldBeNil)
|
||||
// Convey("What's written out should match what's read in", func() {
|
||||
// dinst, err := ParseCSVToInstances(f.Name(), true)
|
||||
// So(err, ShouldBeNil)
|
||||
// So(inst.String(), ShouldEqual, dinst.String())
|
||||
// })
|
||||
// })
|
||||
// })
|
||||
// }
|
||||
Convey("Saving the instances to CSV...", func() {
|
||||
f, err := ioutil.TempFile("", "instTmpCSV")
|
||||
So(err, ShouldBeNil)
|
||||
err = SerializeInstancesToCSV(inst, f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
Convey("What's written out should match what's read in", func() {
|
||||
dinst, err := ParseCSVToInstances(f.Name(), true)
|
||||
So(err, ShouldBeNil)
|
||||
So(inst.String(), ShouldEqual, dinst.String())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// func TestSerializeToFile(t *testing.T) {
|
||||
// Convey("Reading some instances...", t, func() {
|
||||
// inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
// So(err, ShouldBeNil)
|
||||
|
||||
// Convey("Dumping to file...", func() {
|
||||
// f, err := ioutil.TempFile("", "instTmpF")
|
||||
// So(err, ShouldBeNil)
|
||||
// err = SerializeInstances(inst, f)
|
||||
// So(err, ShouldBeNil)
|
||||
// f.Seek(0, 0)
|
||||
// Convey("Contents of the archive should be right...", func() {
|
||||
// gzr, err := gzip.NewReader(f)
|
||||
// So(err, ShouldBeNil)
|
||||
// tr := tar.NewReader(gzr)
|
||||
// classAttrsPresent := false
|
||||
// // manifestPresent := false
|
||||
// regularAttrsPresent := false
|
||||
// dataPresent := false
|
||||
// dimsPresent := false
|
||||
// readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
|
||||
// for {
|
||||
// hdr, err := tr.Next()
|
||||
// if err == io.EOF {
|
||||
// break
|
||||
// }
|
||||
// So(err, ShouldBeNil)
|
||||
// switch hdr.Name {
|
||||
// case "MANIFEST":
|
||||
// tr.Read(readBytes)
|
||||
// manifestPresent = true
|
||||
// break
|
||||
// case "CATTRS":
|
||||
// classAttrsPresent = true
|
||||
// break
|
||||
// case "ATTRS":
|
||||
// regularAttrsPresent = true
|
||||
// break
|
||||
// case "DATA":
|
||||
// dataPresent = true
|
||||
// break
|
||||
// case "DIMS":
|
||||
// dimsPresent = true
|
||||
// break
|
||||
// default:
|
||||
// fmt.Printf("Unknown file: %s\n", hdr.Name)
|
||||
// }
|
||||
// }
|
||||
// Convey("MANIFEST should be present", func() {
|
||||
// // So(manifestPresent, ShouldBeTrue)
|
||||
// // Convey("MANIFEST should be right...", func() {
|
||||
// // So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
|
||||
// // })
|
||||
// })
|
||||
// Convey("DATA should be present", func() {
|
||||
// So(dataPresent, ShouldBeTrue)
|
||||
// })
|
||||
// Convey("ATTRS should be present", func() {
|
||||
// So(regularAttrsPresent, ShouldBeTrue)
|
||||
// })
|
||||
// Convey("CATTRS should be present", func() {
|
||||
// So(classAttrsPresent, ShouldBeTrue)
|
||||
// })
|
||||
// Convey("DIMS should be present", func() {
|
||||
// So(dimsPresent, ShouldBeTrue)
|
||||
// })
|
||||
// })
|
||||
// Convey("Should be able to reconstruct...", func() {
|
||||
// f.Seek(0, 0)
|
||||
// dinst, err := DeserializeInstances(f)
|
||||
// So(err, ShouldBeNil)
|
||||
// So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
|
||||
// })
|
||||
// })
|
||||
// })
|
||||
// }
|
||||
func TestCreateAndReadClassifierStub(t *testing.T) {
|
||||
Convey("Creating a classifier stub...", t, func() {
|
||||
|
||||
exampleClassifierMetadata := make(map[string]interface{})
|
||||
exampleClassifierMetadata["num_trees"] = 4
|
||||
|
||||
metadata := ClassifierMetadataV1 {
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "test",
|
||||
ClassifierVersion: "1",
|
||||
ClassifierMetadata: exampleClassifierMetadata,
|
||||
}
|
||||
|
||||
Convey("Saving the classifier...", func() {
|
||||
f, err := ioutil.TempFile("", "classTmpF")
|
||||
So(err, ShouldBeNil)
|
||||
serializer, err := CreateSerializedClassifierStub(f.Name(), metadata)
|
||||
So(err, ShouldBeNil)
|
||||
err = serializer.Close()
|
||||
So(err, ShouldBeNil)
|
||||
Convey("Should be able to read the information back...", func() {
|
||||
reader, err := ReadSerializedClassifierStub(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
So(reader, ShouldNotBeNil)
|
||||
So(reader.Metadata.FormatVersion, ShouldEqual, 1)
|
||||
So(reader.Metadata.ClassifierName, ShouldEqual, "test")
|
||||
So(reader.Metadata.ClassifierVersion, ShouldEqual, "1")
|
||||
So(reader.Metadata.ClassifierMetadata["num_trees"], ShouldEqual, 4)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSerializeToFile(t *testing.T) {
|
||||
Convey("Reading some instances...", t, func() {
|
||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("Dumping to file...", func() {
|
||||
f, err := ioutil.TempFile("", "instTmpF")
|
||||
So(err, ShouldBeNil)
|
||||
err = SerializeInstances(inst, f)
|
||||
So(err, ShouldBeNil)
|
||||
f.Seek(0, 0)
|
||||
Convey("Contents of the archive should be right...", func() {
|
||||
gzr, err := gzip.NewReader(f)
|
||||
So(err, ShouldBeNil)
|
||||
tr := tar.NewReader(gzr)
|
||||
classAttrsPresent := false
|
||||
manifestPresent := false
|
||||
regularAttrsPresent := false
|
||||
dataPresent := false
|
||||
dimsPresent := false
|
||||
readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
So(err, ShouldBeNil)
|
||||
switch hdr.Name {
|
||||
case "MANIFEST":
|
||||
tr.Read(readBytes)
|
||||
manifestPresent = true
|
||||
break
|
||||
case "CATTRS":
|
||||
classAttrsPresent = true
|
||||
break
|
||||
case "ATTRS":
|
||||
regularAttrsPresent = true
|
||||
break
|
||||
case "DATA":
|
||||
dataPresent = true
|
||||
break
|
||||
case "DIMS":
|
||||
dimsPresent = true
|
||||
break
|
||||
default:
|
||||
fmt.Printf("Unknown file: %s\n", hdr.Name)
|
||||
}
|
||||
}
|
||||
Convey("MANIFEST should be present", func() {
|
||||
So(manifestPresent, ShouldBeTrue)
|
||||
Convey("MANIFEST should be right...", func() {
|
||||
So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
|
||||
})
|
||||
|
||||
Convey("DATA should be present", func() {
|
||||
So(dataPresent, ShouldBeTrue)
|
||||
})
|
||||
Convey("ATTRS should be present", func() {
|
||||
So(regularAttrsPresent, ShouldBeTrue)
|
||||
})
|
||||
Convey("CATTRS should be present", func() {
|
||||
So(classAttrsPresent, ShouldBeTrue)
|
||||
})
|
||||
Convey("DIMS should be present", func() {
|
||||
So(dimsPresent, ShouldBeTrue)
|
||||
})
|
||||
})
|
||||
Convey("Should be able to reconstruct...", func() {
|
||||
f.Seek(0, 0)
|
||||
dinst, err := DeserializeInstances(f)
|
||||
So(err, ShouldBeNil)
|
||||
So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user