mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-05-03 22:17:14 +08:00
Adding some new ways of serializing things
This commit is contained in:
parent
b7ad1fe499
commit
7b08820152
@ -268,6 +268,205 @@ func DeserializeInstances(f io.Reader) (ret *DenseInstances, err error) {
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClassifierMetadataV1 is what gets written into METADATA
|
||||||
|
// in a classification file format.
|
||||||
|
type ClassifierMetadataV1 struct {
|
||||||
|
// FormatVersion should always be 1 for this structure
|
||||||
|
FormatVersion int `json:"format_version"`
|
||||||
|
// Uses the classifier name (provided by the classifier)
|
||||||
|
ClassifierName string `json:"classifier"`
|
||||||
|
// ClassifierVersion is also provided by the classifier
|
||||||
|
// and checks whether this version of GoLearn can read what's
|
||||||
|
// be written.
|
||||||
|
ClassifierVersion string `json"classifier_version"`
|
||||||
|
// This is a custom metadata field, provided by the classifier
|
||||||
|
ClassifierMetadata map[string]interface{} `json:"classifier_metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClassifierDeserializer struct {
|
||||||
|
gzipReader io.Reader
|
||||||
|
fileReader io.Reader
|
||||||
|
tarReader *tar.Reader
|
||||||
|
Metadata *ClassifierMetadataV1
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub
|
||||||
|
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
|
||||||
|
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Can't open file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gzr, err := gzip.NewReader(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Can't decompress file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tz := tar.NewReader(gzr)
|
||||||
|
|
||||||
|
// Check that the serialization format is right
|
||||||
|
// Retrieve the MANIFEST and verify
|
||||||
|
manifestBytes := getTarContent(tz, "MANIFEST")
|
||||||
|
if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) {
|
||||||
|
return nil, fmt.Errorf("Unsupported MANIFEST: %s", string(manifestBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Parse METADATA
|
||||||
|
//
|
||||||
|
metadataBytes := getTarContent(tz, "METADATA")
|
||||||
|
var metadata ClassifierMetadataV1
|
||||||
|
err = json.Unmarshal(metadataBytes, &metadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we can understand this archive
|
||||||
|
if metadata.FormatVersion != 1 {
|
||||||
|
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := &ClassifierDeserializer{
|
||||||
|
f,
|
||||||
|
gzr,
|
||||||
|
tz,
|
||||||
|
&metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
|
||||||
|
return getTarContent(c.tarReader, key), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClassifierSerializer struct {
|
||||||
|
gzipWriter *gzip.Writer
|
||||||
|
fileWriter io.Writer
|
||||||
|
tarWriter *tar.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close finalizes the Classifier serialization session
|
||||||
|
func (c *ClassifierSerializer) Close() error {
|
||||||
|
|
||||||
|
// Finally, close and flush the various levels
|
||||||
|
if err := c.tarWriter.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("Could not flush tar: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.tarWriter.Close(); err != nil {
|
||||||
|
return fmt.Errorf("Could not close tar: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.gzipWriter.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("Could not flush gz: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.gzipWriter.Close(); err != nil {
|
||||||
|
return fmt.Errorf("Could not close gz: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//if err := c.fileWriter.Flush(); err != nil {
|
||||||
|
// return fmt.Errorf("Could not flush file: %s", err)
|
||||||
|
//}
|
||||||
|
|
||||||
|
//if err := c.fileWriter.Flush(); err != nil {
|
||||||
|
// return fmt.Errorf("Could not close file: %s", err)
|
||||||
|
//}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBytesForKey creates a new entry in the serializer file
|
||||||
|
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
||||||
|
|
||||||
|
//
|
||||||
|
// Write header for key
|
||||||
|
//
|
||||||
|
hdr := &tar.Header{
|
||||||
|
Name: key,
|
||||||
|
Size: int64(len(b)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.tarWriter.WriteHeader(hdr); err != nil {
|
||||||
|
return fmt.Errorf("Could not write header for '%s': %s", err)
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Write data
|
||||||
|
//
|
||||||
|
if _, err := c.tarWriter.Write(b); err != nil {
|
||||||
|
return fmt.Errorf("Could not write data for '%s': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSerializedClassifierStub generates a file to serialize into
|
||||||
|
// and writes the METADATA header.
|
||||||
|
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
||||||
|
|
||||||
|
// Open the filePath
|
||||||
|
f, err := os.OpenFile(filePath, os.O_RDWR | os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var hdr *tar.Header
|
||||||
|
gzWriter := gzip.NewWriter(f)
|
||||||
|
tw := tar.NewWriter(gzWriter)
|
||||||
|
|
||||||
|
ret := &ClassifierSerializer{
|
||||||
|
gzipWriter: gzWriter,
|
||||||
|
fileWriter: f,
|
||||||
|
tarWriter: tw,
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Write the MANIFEST entry
|
||||||
|
//
|
||||||
|
hdr = &tar.Header{
|
||||||
|
Name: "MANIFEST",
|
||||||
|
Size: int64(len(SerializationFormatVersion)),
|
||||||
|
}
|
||||||
|
if err := tw.WriteHeader(hdr); err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not write MANIFEST header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not write MANIFEST contents: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Write the METADATA entry
|
||||||
|
//
|
||||||
|
|
||||||
|
// Marshal the classifier information
|
||||||
|
cB, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(cB) == 0 {
|
||||||
|
return nil, fmt.Errorf("JSON marshal error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the information into the file
|
||||||
|
hdr = &tar.Header{
|
||||||
|
Name: "METADATA",
|
||||||
|
Size: int64(len(cB)),
|
||||||
|
}
|
||||||
|
if err := tw.WriteHeader(hdr); err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not write METADATA object %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tw.Write(cB); err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not write METDATA contents: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
|
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
|
||||||
var hdr *tar.Header
|
var hdr *tar.Header
|
||||||
|
|
||||||
|
@ -1,107 +1,142 @@
|
|||||||
package base
|
package base
|
||||||
|
|
||||||
// import (
|
import (
|
||||||
// "archive/tar"
|
"archive/tar"
|
||||||
// "compress/gzip"
|
"compress/gzip"
|
||||||
// "fmt"
|
"fmt"
|
||||||
// . "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
// "io"
|
"io"
|
||||||
// "io/ioutil"
|
"io/ioutil"
|
||||||
// "testing"
|
"testing"
|
||||||
// )
|
)
|
||||||
|
|
||||||
// func TestSerializeToCSV(t *testing.T) {
|
func TestSerializeToCSV(t *testing.T) {
|
||||||
// Convey("Reading some instances...", t, func() {
|
Convey("Reading some instances...", t, func() {
|
||||||
// inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
// So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
// Convey("Saving the instances to CSV...", func() {
|
Convey("Saving the instances to CSV...", func() {
|
||||||
// f, err := ioutil.TempFile("", "instTmpCSV")
|
f, err := ioutil.TempFile("", "instTmpCSV")
|
||||||
// So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
// err = SerializeInstancesToCSV(inst, f.Name())
|
err = SerializeInstancesToCSV(inst, f.Name())
|
||||||
// So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
// Convey("What's written out should match what's read in", func() {
|
Convey("What's written out should match what's read in", func() {
|
||||||
// dinst, err := ParseCSVToInstances(f.Name(), true)
|
dinst, err := ParseCSVToInstances(f.Name(), true)
|
||||||
// So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
// So(inst.String(), ShouldEqual, dinst.String())
|
So(inst.String(), ShouldEqual, dinst.String())
|
||||||
// })
|
})
|
||||||
// })
|
})
|
||||||
// })
|
})
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func TestSerializeToFile(t *testing.T) {
|
|
||||||
// Convey("Reading some instances...", t, func() {
|
|
||||||
// inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
||||||
// So(err, ShouldBeNil)
|
|
||||||
|
|
||||||
// Convey("Dumping to file...", func() {
|
func TestCreateAndReadClassifierStub(t *testing.T) {
|
||||||
// f, err := ioutil.TempFile("", "instTmpF")
|
Convey("Creating a classifier stub...", t, func() {
|
||||||
// So(err, ShouldBeNil)
|
|
||||||
// err = SerializeInstances(inst, f)
|
exampleClassifierMetadata := make(map[string]interface{})
|
||||||
// So(err, ShouldBeNil)
|
exampleClassifierMetadata["num_trees"] = 4
|
||||||
// f.Seek(0, 0)
|
|
||||||
// Convey("Contents of the archive should be right...", func() {
|
metadata := ClassifierMetadataV1 {
|
||||||
// gzr, err := gzip.NewReader(f)
|
FormatVersion: 1,
|
||||||
// So(err, ShouldBeNil)
|
ClassifierName: "test",
|
||||||
// tr := tar.NewReader(gzr)
|
ClassifierVersion: "1",
|
||||||
// classAttrsPresent := false
|
ClassifierMetadata: exampleClassifierMetadata,
|
||||||
// // manifestPresent := false
|
}
|
||||||
// regularAttrsPresent := false
|
|
||||||
// dataPresent := false
|
Convey("Saving the classifier...", func() {
|
||||||
// dimsPresent := false
|
f, err := ioutil.TempFile("", "classTmpF")
|
||||||
// readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
|
So(err, ShouldBeNil)
|
||||||
// for {
|
serializer, err := CreateSerializedClassifierStub(f.Name(), metadata)
|
||||||
// hdr, err := tr.Next()
|
So(err, ShouldBeNil)
|
||||||
// if err == io.EOF {
|
err = serializer.Close()
|
||||||
// break
|
So(err, ShouldBeNil)
|
||||||
// }
|
Convey("Should be able to read the information back...", func() {
|
||||||
// So(err, ShouldBeNil)
|
reader, err := ReadSerializedClassifierStub(f.Name())
|
||||||
// switch hdr.Name {
|
So(err, ShouldBeNil)
|
||||||
// case "MANIFEST":
|
So(reader, ShouldNotBeNil)
|
||||||
// tr.Read(readBytes)
|
So(reader.Metadata.FormatVersion, ShouldEqual, 1)
|
||||||
// manifestPresent = true
|
So(reader.Metadata.ClassifierName, ShouldEqual, "test")
|
||||||
// break
|
So(reader.Metadata.ClassifierVersion, ShouldEqual, "1")
|
||||||
// case "CATTRS":
|
So(reader.Metadata.ClassifierMetadata["num_trees"], ShouldEqual, 4)
|
||||||
// classAttrsPresent = true
|
})
|
||||||
// break
|
})
|
||||||
// case "ATTRS":
|
})
|
||||||
// regularAttrsPresent = true
|
}
|
||||||
// break
|
|
||||||
// case "DATA":
|
func TestSerializeToFile(t *testing.T) {
|
||||||
// dataPresent = true
|
Convey("Reading some instances...", t, func() {
|
||||||
// break
|
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
// case "DIMS":
|
So(err, ShouldBeNil)
|
||||||
// dimsPresent = true
|
|
||||||
// break
|
Convey("Dumping to file...", func() {
|
||||||
// default:
|
f, err := ioutil.TempFile("", "instTmpF")
|
||||||
// fmt.Printf("Unknown file: %s\n", hdr.Name)
|
So(err, ShouldBeNil)
|
||||||
// }
|
err = SerializeInstances(inst, f)
|
||||||
// }
|
So(err, ShouldBeNil)
|
||||||
// Convey("MANIFEST should be present", func() {
|
f.Seek(0, 0)
|
||||||
// // So(manifestPresent, ShouldBeTrue)
|
Convey("Contents of the archive should be right...", func() {
|
||||||
// // Convey("MANIFEST should be right...", func() {
|
gzr, err := gzip.NewReader(f)
|
||||||
// // So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
|
So(err, ShouldBeNil)
|
||||||
// // })
|
tr := tar.NewReader(gzr)
|
||||||
// })
|
classAttrsPresent := false
|
||||||
// Convey("DATA should be present", func() {
|
manifestPresent := false
|
||||||
// So(dataPresent, ShouldBeTrue)
|
regularAttrsPresent := false
|
||||||
// })
|
dataPresent := false
|
||||||
// Convey("ATTRS should be present", func() {
|
dimsPresent := false
|
||||||
// So(regularAttrsPresent, ShouldBeTrue)
|
readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
|
||||||
// })
|
for {
|
||||||
// Convey("CATTRS should be present", func() {
|
hdr, err := tr.Next()
|
||||||
// So(classAttrsPresent, ShouldBeTrue)
|
if err == io.EOF {
|
||||||
// })
|
break
|
||||||
// Convey("DIMS should be present", func() {
|
}
|
||||||
// So(dimsPresent, ShouldBeTrue)
|
So(err, ShouldBeNil)
|
||||||
// })
|
switch hdr.Name {
|
||||||
// })
|
case "MANIFEST":
|
||||||
// Convey("Should be able to reconstruct...", func() {
|
tr.Read(readBytes)
|
||||||
// f.Seek(0, 0)
|
manifestPresent = true
|
||||||
// dinst, err := DeserializeInstances(f)
|
break
|
||||||
// So(err, ShouldBeNil)
|
case "CATTRS":
|
||||||
// So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
|
classAttrsPresent = true
|
||||||
// })
|
break
|
||||||
// })
|
case "ATTRS":
|
||||||
// })
|
regularAttrsPresent = true
|
||||||
// }
|
break
|
||||||
|
case "DATA":
|
||||||
|
dataPresent = true
|
||||||
|
break
|
||||||
|
case "DIMS":
|
||||||
|
dimsPresent = true
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown file: %s\n", hdr.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Convey("MANIFEST should be present", func() {
|
||||||
|
So(manifestPresent, ShouldBeTrue)
|
||||||
|
Convey("MANIFEST should be right...", func() {
|
||||||
|
So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("DATA should be present", func() {
|
||||||
|
So(dataPresent, ShouldBeTrue)
|
||||||
|
})
|
||||||
|
Convey("ATTRS should be present", func() {
|
||||||
|
So(regularAttrsPresent, ShouldBeTrue)
|
||||||
|
})
|
||||||
|
Convey("CATTRS should be present", func() {
|
||||||
|
So(classAttrsPresent, ShouldBeTrue)
|
||||||
|
})
|
||||||
|
Convey("DIMS should be present", func() {
|
||||||
|
So(dimsPresent, ShouldBeTrue)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
Convey("Should be able to reconstruct...", func() {
|
||||||
|
f.Seek(0, 0)
|
||||||
|
dinst, err := DeserializeInstances(f)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user