1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/base/serialize.go
Richard Townsend ff52c013eb Update gonum to latest version
Should fix #200 and #205
2018-03-24 00:19:35 +00:00

427 lines
12 KiB
Go

package base
import (
"archive/tar"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"reflect"
)
const (
SerializationFormatVersion = "golearn 1.0"
)
// FunctionalTarReader allows you to read anything in a tar file in any order, rather than just
// sequentially.
type FunctionalTarReader struct {
Regenerate func() *tar.Reader
}
// NewFunctionalTarReader creates a new FunctionalTarReader using a function that it can call
// to get a tar.Reader at the beginning of the file.
func NewFunctionalTarReader(regenFunc func() *tar.Reader) *FunctionalTarReader {
return &FunctionalTarReader{
regenFunc,
}
}
// GetNamedFile returns a file named a given thing from the tar file. If there's more than one
// entry, the most recent is returned.
func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
tr := f.Regenerate()
var returnCandidate []byte = nil
for {
hdr, err := tr.Next()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
if hdr.Name == name {
ret, err := ioutil.ReadAll(tr)
if err != nil {
return nil, WrapError(err)
}
if int64(len(ret)) != hdr.Size {
if int64(len(ret)) < hdr.Size {
log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err)
} else {
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size))
}
}
if err != nil {
return nil, err
}
returnCandidate = ret
}
}
if returnCandidate == nil {
return nil, WrapError(fmt.Errorf("Not found (looking for %s)", name))
}
return returnCandidate, nil
}
func tarPrefix(prefix string, suffix string) string {
if prefix == "" {
return suffix
}
return fmt.Sprintf("%s/%s", prefix, suffix)
}
// ClassifierMetadataV1 is what gets written into METADATA
// in a classification file format.
type ClassifierMetadataV1 struct {
// FormatVersion should always be 1 for this structure
FormatVersion int `json:"format_version"`
// Uses the classifier name (provided by the classifier)
ClassifierName string `json:"classifier"`
// ClassifierVersion is also provided by the classifier
// and checks whether this version of GoLearn can read what's
// be written.
ClassifierVersion string `json"classifier_version"`
// This is a custom metadata field, provided by the classifier
ClassifierMetadata map[string]interface{} `json:"classifier_metadata"`
}
// ClassifierDeserializer attaches helper functions useful for reading classificatiers. (UNSTABLE).
type ClassifierDeserializer struct {
gzipReader io.Reader
fileReader io.ReadCloser
tarReader *FunctionalTarReader
Metadata *ClassifierMetadataV1
}
// Prefix outputs a string in the right format for TAR
func (c *ClassifierDeserializer) Prefix(prefix string, suffix string) string {
if prefix == "" {
return suffix
}
return fmt.Sprintf("%s/%s", prefix, suffix)
}
// ReadMetadataAtPrefix reads the METADATA file after prefix. If an error is returned, the first value is undefined.
func (c *ClassifierDeserializer) ReadMetadataAtPrefix(prefix string) (ClassifierMetadataV1, error) {
var ret ClassifierMetadataV1
err := c.GetJSONForKey(c.Prefix(prefix, "METADATA"), &ret)
return ret, err
}
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub.
// It's used inside SaveableClassifiers to read information from a perviously saved
// model file.
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, DescribeError("Can't open file", err)
}
gzr, err := gzip.NewReader(f)
if err != nil {
return nil, DescribeError("Can't decompress", err)
}
regenerateFunc := func() *tar.Reader {
f.Seek(0, os.SEEK_SET)
gzr.Reset(f)
tz := tar.NewReader(gzr)
return tz
}
tz := NewFunctionalTarReader(regenerateFunc)
// Check that the serialization format is right
// Retrieve the MANIFEST and verify
manifestBytes, err := tz.GetNamedFile("CLS_MANIFEST")
if err != nil {
return nil, DescribeError("Error reading CLS_MANIFEST", err)
}
if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) {
return nil, fmt.Errorf("Unsupported CLS_MANIFEST: %s", string(manifestBytes))
}
//
// Parse METADATA
//
var metadata ClassifierMetadataV1
ret := &ClassifierDeserializer{
f,
gzr,
tz,
&metadata,
}
metadata, err = ret.ReadMetadataAtPrefix("")
if err != nil {
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
}
ret.Metadata = &metadata
// Check that we can understand this archive
if metadata.FormatVersion != 1 {
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
}
return ret, nil
}
// GetBytesForKey returns the bytes at a given location in the output.
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
return c.tarReader.GetNamedFile(key)
}
func (c *ClassifierDeserializer) GetStringForKey(key string) (string, error) {
b, err := c.GetBytesForKey(key)
if err != nil {
return "", err
}
return string(b), err
}
// GetJSONForKey deserializes a JSON key in the output file.
func (c *ClassifierDeserializer) GetJSONForKey(key string, v interface{}) error {
b, err := c.GetBytesForKey(key)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}
// GetInstancesForKey deserializes some instances stored in a classifier output file
func (c *ClassifierDeserializer) GetInstancesForKey(key string) (FixedDataGrid, error) {
return DeserializeInstancesFromTarReader(c.tarReader, key)
}
// GetUInt64ForKey returns a int64 stored at a given key
func (c *ClassifierDeserializer) GetU64ForKey(key string) (uint64, error) {
b, err := c.GetBytesForKey(key)
if err != nil {
return 0, err
}
return UnpackBytesToU64(b), nil
}
// GetAttributeForKey returns an Attribute stored at a given key
func (c *ClassifierDeserializer) GetAttributeForKey(key string) (Attribute, error) {
b, err := c.GetBytesForKey(key)
if err != nil {
return nil, WrapError(err)
}
attr, err := DeserializeAttribute(b)
if err != nil {
return nil, WrapError(err)
}
return attr, nil
}
// GetAttributesForKey returns an Attribute list stored at a given key
func (c *ClassifierDeserializer) GetAttributesForKey(key string) ([]Attribute, error) {
attrCountKey := c.Prefix(key, "ATTR_COUNT")
attrCount, err := c.GetU64ForKey(attrCountKey)
if err != nil {
return nil, DescribeError("Unable to read ATTR_COUNT", err)
}
ret := make([]Attribute, attrCount)
for i := range ret {
attrKey := c.Prefix(key, fmt.Sprintf("%d", i))
ret[i], err = c.GetAttributeForKey(attrKey)
if err != nil {
return nil, DescribeError("Unable to read Attribute", err)
}
}
return ret, nil
}
// Close cleans up everything.
func (c *ClassifierDeserializer) Close() {
c.fileReader.Close()
}
// ClassifierSerializer is an object used by SaveableClassifiers.
type ClassifierSerializer struct {
gzipWriter *gzip.Writer
fileWriter *os.File
tarWriter *tar.Writer
f *os.File
filePath string
}
// Close finalizes the Classifier serialization session.
func (c *ClassifierSerializer) Close() error {
// Finally, close and flush the various levels
if err := c.tarWriter.Flush(); err != nil {
return fmt.Errorf("Could not flush tar: %s", err)
}
if err := c.tarWriter.Close(); err != nil {
return fmt.Errorf("Could not close tar: %s", err)
}
if err := c.gzipWriter.Flush(); err != nil {
return fmt.Errorf("Could not flush gz: %s", err)
}
if err := c.gzipWriter.Close(); err != nil {
return fmt.Errorf("Could not close gz: %s", err)
}
if err := c.fileWriter.Sync(); err != nil {
return fmt.Errorf("Could not close file writer: %s", err)
}
if err := c.fileWriter.Close(); err != nil {
return fmt.Errorf("Could not close file writer: %s", err)
}
return nil
}
// WriteBytesForKey creates a new entry in the serializer file with some user-defined bytes.
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
//
// Write header for key
//
hdr := &tar.Header{
Name: key,
Size: int64(len(b)),
}
if err := c.tarWriter.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write header for '%s': %s", key, err)
}
//
// Write data
//
if _, err := c.tarWriter.Write(b); err != nil {
return fmt.Errorf("Could not write data for '%s': %s", key, err)
}
c.tarWriter.Flush()
c.gzipWriter.Flush()
c.fileWriter.Sync()
return nil
}
// WriteU64ForKey creates a new entry in the serializer file with the bytes of a uint64
func (c *ClassifierSerializer) WriteU64ForKey(key string, v uint64) error {
b := PackU64ToBytes(v)
return c.WriteBytesForKey(key, b)
}
// WriteJSONForKey creates a new entry in the file with an interface serialized as JSON.
func (c *ClassifierSerializer) WriteJSONForKey(key string, v interface{}) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
return c.WriteBytesForKey(key, b)
}
// WriteAttributeForKey creates a new entry in the file containing a serialized representation of Attribute
func (c *ClassifierSerializer) WriteAttributeForKey(key string, a Attribute) error {
b, err := SerializeAttribute(a)
if err != nil {
return WrapError(err)
}
return c.WriteBytesForKey(key, b)
}
// WriteAttributesForKey does the same as WriteAttributeForKey, just with more than one Attribute.
func (c *ClassifierSerializer) WriteAttributesForKey(key string, attrs []Attribute) error {
attrCountKey := c.Prefix(key, "ATTR_COUNT")
err := c.WriteU64ForKey(attrCountKey, uint64(len(attrs)))
if err != nil {
return DescribeError("Unable to write ATTR_COUNT", err)
}
for i, a := range attrs {
attrKey := c.Prefix(key, fmt.Sprintf("%d", i))
err = c.WriteAttributeForKey(attrKey, a)
if err != nil {
return DescribeError("Unable to write Attribute", err)
}
}
return nil
}
// WriteInstances for key creates a new entry in the file containing some training instances
func (c *ClassifierSerializer) WriteInstancesForKey(key string, g FixedDataGrid, includeData bool) error {
fmt.Sprintf("%v", c)
return SerializeInstancesToTarWriter(g, c.tarWriter, key, includeData)
}
// Prefix outputs a string in the right format for TAR
func (c *ClassifierSerializer) Prefix(prefix string, suffix string) string {
if prefix == "" {
return suffix
}
return fmt.Sprintf("%s/%s", prefix, suffix)
}
// WriteMetadataAtPrefix outputs a METADATA entry in the right place
func (c *ClassifierSerializer) WriteMetadataAtPrefix(prefix string, metadata ClassifierMetadataV1) error {
return c.WriteJSONForKey(c.Prefix(prefix, "METADATA"), &metadata)
}
// CreateSerializedClassifierStub generates a file to serialize into
// and writes the METADATA header.
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
// Open the filePath
f, err := os.OpenFile(filePath, os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0600)
if err != nil {
return nil, err
}
var hdr *tar.Header
gzWriter := gzip.NewWriter(f)
tw := tar.NewWriter(gzWriter)
ret := ClassifierSerializer{
gzipWriter: gzWriter,
fileWriter: f,
tarWriter: tw,
}
//
// Write the MANIFEST entry
//
hdr = &tar.Header{
Name: "CLS_MANIFEST",
Size: int64(len(SerializationFormatVersion)),
}
if err := tw.WriteHeader(hdr); err != nil {
return nil, fmt.Errorf("Could not write CLS_MANIFEST header: %s", err)
}
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
return nil, fmt.Errorf("Could not write CLS_MANIFEST contents: %s", err)
}
//
// Write the METADATA entry
//
err = ret.WriteMetadataAtPrefix("", metadata)
if err != nil {
return nil, fmt.Errorf("JSON marshal error: %s", err)
}
return &ret, nil
}