1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-05-03 22:17:14 +08:00

Adding some new ways of serializing things

This commit is contained in:
Richard Townsend 2017-08-07 14:43:21 +01:00
parent b7ad1fe499
commit 7b08820152
2 changed files with 335 additions and 101 deletions

View File

@ -268,6 +268,205 @@ func DeserializeInstances(f io.Reader) (ret *DenseInstances, err error) {
return ret, nil
}
// ClassifierMetadataV1 is what gets written into METADATA
// in a classification file format.
type ClassifierMetadataV1 struct {
// FormatVersion should always be 1 for this structure
FormatVersion int `json:"format_version"`
// Uses the classifier name (provided by the classifier)
ClassifierName string `json:"classifier"`
// ClassifierVersion is also provided by the classifier
// and checks whether this version of GoLearn can read what's
// be written.
ClassifierVersion string `json"classifier_version"`
// This is a custom metadata field, provided by the classifier
ClassifierMetadata map[string]interface{} `json:"classifier_metadata"`
}
type ClassifierDeserializer struct {
gzipReader io.Reader
fileReader io.Reader
tarReader *tar.Reader
Metadata *ClassifierMetadataV1
}
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("Can't open file: %s", err)
}
gzr, err := gzip.NewReader(f)
if err != nil {
return nil, fmt.Errorf("Can't decompress file: %s", err)
}
tz := tar.NewReader(gzr)
// Check that the serialization format is right
// Retrieve the MANIFEST and verify
manifestBytes := getTarContent(tz, "MANIFEST")
if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) {
return nil, fmt.Errorf("Unsupported MANIFEST: %s", string(manifestBytes))
}
//
// Parse METADATA
//
metadataBytes := getTarContent(tz, "METADATA")
var metadata ClassifierMetadataV1
err = json.Unmarshal(metadataBytes, &metadata)
if err != nil {
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
}
// Check that we can understand this archive
if metadata.FormatVersion != 1 {
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
}
ret := &ClassifierDeserializer{
f,
gzr,
tz,
&metadata,
}
return ret, nil
}
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
return getTarContent(c.tarReader, key), nil
}
type ClassifierSerializer struct {
gzipWriter *gzip.Writer
fileWriter io.Writer
tarWriter *tar.Writer
}
// Close finalizes the Classifier serialization session
func (c *ClassifierSerializer) Close() error {
// Finally, close and flush the various levels
if err := c.tarWriter.Flush(); err != nil {
return fmt.Errorf("Could not flush tar: %s", err)
}
if err := c.tarWriter.Close(); err != nil {
return fmt.Errorf("Could not close tar: %s", err)
}
if err := c.gzipWriter.Flush(); err != nil {
return fmt.Errorf("Could not flush gz: %s", err)
}
if err := c.gzipWriter.Close(); err != nil {
return fmt.Errorf("Could not close gz: %s", err)
}
//if err := c.fileWriter.Flush(); err != nil {
// return fmt.Errorf("Could not flush file: %s", err)
//}
//if err := c.fileWriter.Flush(); err != nil {
// return fmt.Errorf("Could not close file: %s", err)
//}
return nil
}
// WriteBytesForKey creates a new entry in the serializer file
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
//
// Write header for key
//
hdr := &tar.Header{
Name: key,
Size: int64(len(b)),
}
if err := c.tarWriter.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write header for '%s': %s", err)
}
//
// Write data
//
if _, err := c.tarWriter.Write(b); err != nil {
return fmt.Errorf("Could not write data for '%s': %s", err)
}
return nil
}
// CreateSerializedClassifierStub generates a file to serialize into
// and writes the METADATA header.
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
// Open the filePath
f, err := os.OpenFile(filePath, os.O_RDWR | os.O_TRUNC, 0600)
if err != nil {
return nil, nil
}
var hdr *tar.Header
gzWriter := gzip.NewWriter(f)
tw := tar.NewWriter(gzWriter)
ret := &ClassifierSerializer{
gzipWriter: gzWriter,
fileWriter: f,
tarWriter: tw,
}
//
// Write the MANIFEST entry
//
hdr = &tar.Header{
Name: "MANIFEST",
Size: int64(len(SerializationFormatVersion)),
}
if err := tw.WriteHeader(hdr); err != nil {
return nil, fmt.Errorf("Could not write MANIFEST header: %s", err)
}
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
return nil, fmt.Errorf("Could not write MANIFEST contents: %s", err)
}
//
// Write the METADATA entry
//
// Marshal the classifier information
cB, err := json.Marshal(metadata)
if err != nil {
return nil, err
}
if len(cB) == 0 {
return nil, fmt.Errorf("JSON marshal error: %s", err)
}
// Write the information into the file
hdr = &tar.Header{
Name: "METADATA",
Size: int64(len(cB)),
}
if err := tw.WriteHeader(hdr); err != nil {
return nil, fmt.Errorf("Could not write METADATA object %s", err)
}
if _, err := tw.Write(cB); err != nil {
return nil, fmt.Errorf("Could not write METDATA contents: %s", err)
}
return ret, nil
}
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
var hdr *tar.Header

View File

@ -1,107 +1,142 @@
package base
// import (
// "archive/tar"
// "compress/gzip"
// "fmt"
// . "github.com/smartystreets/goconvey/convey"
// "io"
// "io/ioutil"
// "testing"
// )
import (
"archive/tar"
"compress/gzip"
"fmt"
. "github.com/smartystreets/goconvey/convey"
"io"
"io/ioutil"
"testing"
)
// func TestSerializeToCSV(t *testing.T) {
// Convey("Reading some instances...", t, func() {
// inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
// So(err, ShouldBeNil)
func TestSerializeToCSV(t *testing.T) {
Convey("Reading some instances...", t, func() {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
// Convey("Saving the instances to CSV...", func() {
// f, err := ioutil.TempFile("", "instTmpCSV")
// So(err, ShouldBeNil)
// err = SerializeInstancesToCSV(inst, f.Name())
// So(err, ShouldBeNil)
// Convey("What's written out should match what's read in", func() {
// dinst, err := ParseCSVToInstances(f.Name(), true)
// So(err, ShouldBeNil)
// So(inst.String(), ShouldEqual, dinst.String())
// })
// })
// })
// }
Convey("Saving the instances to CSV...", func() {
f, err := ioutil.TempFile("", "instTmpCSV")
So(err, ShouldBeNil)
err = SerializeInstancesToCSV(inst, f.Name())
So(err, ShouldBeNil)
Convey("What's written out should match what's read in", func() {
dinst, err := ParseCSVToInstances(f.Name(), true)
So(err, ShouldBeNil)
So(inst.String(), ShouldEqual, dinst.String())
})
})
})
}
// func TestSerializeToFile(t *testing.T) {
// Convey("Reading some instances...", t, func() {
// inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
// So(err, ShouldBeNil)
// Convey("Dumping to file...", func() {
// f, err := ioutil.TempFile("", "instTmpF")
// So(err, ShouldBeNil)
// err = SerializeInstances(inst, f)
// So(err, ShouldBeNil)
// f.Seek(0, 0)
// Convey("Contents of the archive should be right...", func() {
// gzr, err := gzip.NewReader(f)
// So(err, ShouldBeNil)
// tr := tar.NewReader(gzr)
// classAttrsPresent := false
// // manifestPresent := false
// regularAttrsPresent := false
// dataPresent := false
// dimsPresent := false
// readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
// for {
// hdr, err := tr.Next()
// if err == io.EOF {
// break
// }
// So(err, ShouldBeNil)
// switch hdr.Name {
// case "MANIFEST":
// tr.Read(readBytes)
// manifestPresent = true
// break
// case "CATTRS":
// classAttrsPresent = true
// break
// case "ATTRS":
// regularAttrsPresent = true
// break
// case "DATA":
// dataPresent = true
// break
// case "DIMS":
// dimsPresent = true
// break
// default:
// fmt.Printf("Unknown file: %s\n", hdr.Name)
// }
// }
// Convey("MANIFEST should be present", func() {
// // So(manifestPresent, ShouldBeTrue)
// // Convey("MANIFEST should be right...", func() {
// // So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
// // })
// })
// Convey("DATA should be present", func() {
// So(dataPresent, ShouldBeTrue)
// })
// Convey("ATTRS should be present", func() {
// So(regularAttrsPresent, ShouldBeTrue)
// })
// Convey("CATTRS should be present", func() {
// So(classAttrsPresent, ShouldBeTrue)
// })
// Convey("DIMS should be present", func() {
// So(dimsPresent, ShouldBeTrue)
// })
// })
// Convey("Should be able to reconstruct...", func() {
// f.Seek(0, 0)
// dinst, err := DeserializeInstances(f)
// So(err, ShouldBeNil)
// So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
// })
// })
// })
// }
func TestCreateAndReadClassifierStub(t *testing.T) {
Convey("Creating a classifier stub...", t, func() {
exampleClassifierMetadata := make(map[string]interface{})
exampleClassifierMetadata["num_trees"] = 4
metadata := ClassifierMetadataV1 {
FormatVersion: 1,
ClassifierName: "test",
ClassifierVersion: "1",
ClassifierMetadata: exampleClassifierMetadata,
}
Convey("Saving the classifier...", func() {
f, err := ioutil.TempFile("", "classTmpF")
So(err, ShouldBeNil)
serializer, err := CreateSerializedClassifierStub(f.Name(), metadata)
So(err, ShouldBeNil)
err = serializer.Close()
So(err, ShouldBeNil)
Convey("Should be able to read the information back...", func() {
reader, err := ReadSerializedClassifierStub(f.Name())
So(err, ShouldBeNil)
So(reader, ShouldNotBeNil)
So(reader.Metadata.FormatVersion, ShouldEqual, 1)
So(reader.Metadata.ClassifierName, ShouldEqual, "test")
So(reader.Metadata.ClassifierVersion, ShouldEqual, "1")
So(reader.Metadata.ClassifierMetadata["num_trees"], ShouldEqual, 4)
})
})
})
}
func TestSerializeToFile(t *testing.T) {
Convey("Reading some instances...", t, func() {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Dumping to file...", func() {
f, err := ioutil.TempFile("", "instTmpF")
So(err, ShouldBeNil)
err = SerializeInstances(inst, f)
So(err, ShouldBeNil)
f.Seek(0, 0)
Convey("Contents of the archive should be right...", func() {
gzr, err := gzip.NewReader(f)
So(err, ShouldBeNil)
tr := tar.NewReader(gzr)
classAttrsPresent := false
manifestPresent := false
regularAttrsPresent := false
dataPresent := false
dimsPresent := false
readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
So(err, ShouldBeNil)
switch hdr.Name {
case "MANIFEST":
tr.Read(readBytes)
manifestPresent = true
break
case "CATTRS":
classAttrsPresent = true
break
case "ATTRS":
regularAttrsPresent = true
break
case "DATA":
dataPresent = true
break
case "DIMS":
dimsPresent = true
break
default:
fmt.Printf("Unknown file: %s\n", hdr.Name)
}
}
Convey("MANIFEST should be present", func() {
So(manifestPresent, ShouldBeTrue)
Convey("MANIFEST should be right...", func() {
So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
})
Convey("DATA should be present", func() {
So(dataPresent, ShouldBeTrue)
})
Convey("ATTRS should be present", func() {
So(regularAttrsPresent, ShouldBeTrue)
})
Convey("CATTRS should be present", func() {
So(classAttrsPresent, ShouldBeTrue)
})
Convey("DIMS should be present", func() {
So(dimsPresent, ShouldBeTrue)
})
})
Convey("Should be able to reconstruct...", func() {
f.Seek(0, 0)
dinst, err := DeserializeInstances(f)
So(err, ShouldBeNil)
So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
})
})
})
})
}