1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

base: fix unmarshalling attributes, add JSON

This commit is contained in:
Richard Townsend 2017-08-26 14:56:17 +01:00
parent 127a8e9162
commit 499ac7a493

View File

@ -128,6 +128,18 @@ func getTarContent(tr *tar.Reader, name string) []byte {
panic("File not found!")
}
func MarshalAttribute(a Attribute) (map[string]interface{}, error) {
ret := make(map[string]interface{})
marshaledAttrRaw, err := a.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(marshaledAttrRaw, &ret)
if err != nil {
return nil, err
}
return ret, nil
}
func DeserializeAttribute(data []byte) (Attribute, error) {
type JSONAttribute struct {
@ -170,7 +182,10 @@ func DeserializeAttributes(data []byte) ([]Attribute, error) {
// Define a JSON shim Attribute
var attrs []json.RawMessage
err := json.Unmarshal(data, attrs)
err := json.Unmarshal(data, &attrs)
if err != nil {
return nil, fmt.Errorf("Failed to deserialize attributes: %v", err)
}
ret := make([]Attribute, len(attrs))
for i, v := range attrs {
@ -303,12 +318,14 @@ type ClassifierMetadataV1 struct {
type ClassifierDeserializer struct {
gzipReader io.Reader
fileReader io.Reader
fileReader io.ReadCloser
tarReader *tar.Reader
Metadata *ClassifierMetadataV1
}
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub.
// It's used inside SaveableClassifiers to read information from a perviously saved
// model file.
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
f, err := os.Open(filePath)
@ -333,9 +350,15 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err
//
// Parse METADATA
//
metadataBytes := getTarContent(tz, "METADATA")
var metadata ClassifierMetadataV1
err = json.Unmarshal(metadataBytes, &metadata)
ret := &ClassifierDeserializer{
f,
gzr,
tz,
&metadata,
}
err = ret.GetJSONForKey("METADATA", ret.Metadata)
if err != nil {
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
}
@ -345,31 +368,36 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
}
ret := &ClassifierDeserializer{
f,
gzr,
tz,
&metadata,
}
return ret, nil
}
// GetBytesForKey returns the bytes at a given location in the output.
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
return getTarContent(c.tarReader, key), nil
}
func (c *ClassifierDeserializer) Close() {
// GetJSONForKey deserializes a JSON key in the output file.
func (c *ClassifierDeserializer) GetJSONForKey(key string, v interface{}) error {
b, err := c.GetBytesForKey(key)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}
// Close cleans up everything.
func (c *ClassifierDeserializer) Close() {
c.fileReader.Close()
}
// ClassifierSerializer is an object used by SaveableClassifiers.
type ClassifierSerializer struct {
gzipWriter *gzip.Writer
fileWriter io.Writer
fileWriter io.WriteCloser
tarWriter *tar.Writer
}
// Close finalizes the Classifier serialization session
// Close finalizes the Classifier serialization session.
func (c *ClassifierSerializer) Close() error {
// Finally, close and flush the various levels
@ -389,18 +417,14 @@ func (c *ClassifierSerializer) Close() error {
return fmt.Errorf("Could not close gz: %s", err)
}
//if err := c.fileWriter.Flush(); err != nil {
// return fmt.Errorf("Could not flush file: %s", err)
//}
//if err := c.fileWriter.Flush(); err != nil {
// return fmt.Errorf("Could not close file: %s", err)
//}
if err := c.fileWriter.Close(); err != nil {
return fmt.Errorf("Could not close file: %s", err)
}
return nil
}
// WriteBytesForKey creates a new entry in the serializer file
// WriteBytesForKey creates a new entry in the serializer file with some user-defined bytes.
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
//
@ -412,18 +436,30 @@ func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
}
if err := c.tarWriter.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write header for '%s': %s", err)
return fmt.Errorf("Could not write header for '%s': %s", key, err)
}
//
// Write data
//
if _, err := c.tarWriter.Write(b); err != nil {
return fmt.Errorf("Could not write data for '%s': %s", err)
return fmt.Errorf("Could not write data for '%s': %s", key, err)
}
return nil
}
// WriteJSONForKey creates a new entry in the file with an interface serialized as JSON.
func (c *ClassifierSerializer) WriteJSONForKey(key string, v interface{}) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
return c.WriteBytesForKey(key, b)
}
// CreateSerializedClassifierStub generates a file to serialize into
// and writes the METADATA header.
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
@ -463,28 +499,12 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
// Write the METADATA entry
//
// Marshal the classifier information
cB, err := json.Marshal(metadata)
// Marshal the classifier information (TODO: split this into another method)
err = ret.WriteJSONForKey("METADATA", &metadata)
if err != nil {
return nil, err
}
if len(cB) == 0 {
return nil, fmt.Errorf("JSON marshal error: %s", err)
}
// Write the information into the file
hdr = &tar.Header{
Name: "METADATA",
Size: int64(len(cB)),
}
if err := tw.WriteHeader(hdr); err != nil {
return nil, fmt.Errorf("Could not write METADATA object %s", err)
}
if _, err := tw.Write(cB); err != nil {
return nil, fmt.Errorf("Could not write METDATA contents: %s", err)
}
return ret, nil
}