diff --git a/base/serialize.go b/base/serialize.go index 1b3da23..094257f 100644 --- a/base/serialize.go +++ b/base/serialize.go @@ -128,6 +128,18 @@ func getTarContent(tr *tar.Reader, name string) []byte { panic("File not found!") } +func MarshalAttribute(a Attribute) (map[string]interface{}, error) { + ret := make(map[string]interface{}) + marshaledAttrRaw, err := a.MarshalJSON() + if err != nil { + return nil, err + } + err = json.Unmarshal(marshaledAttrRaw, &ret) + if err != nil { + return nil, err + } + return ret, nil +} func DeserializeAttribute(data []byte) (Attribute, error) { type JSONAttribute struct { @@ -170,7 +182,10 @@ func DeserializeAttributes(data []byte) ([]Attribute, error) { // Define a JSON shim Attribute var attrs []json.RawMessage - err := json.Unmarshal(data, attrs) + err := json.Unmarshal(data, &attrs) + if err != nil { + return nil, fmt.Errorf("Failed to deserialize attributes: %v", err) + } ret := make([]Attribute, len(attrs)) for i, v := range attrs { @@ -303,12 +318,14 @@ type ClassifierMetadataV1 struct { type ClassifierDeserializer struct { gzipReader io.Reader - fileReader io.Reader + fileReader io.ReadCloser tarReader *tar.Reader Metadata *ClassifierMetadataV1 } -// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub +// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub. +// It's used inside SaveableClassifiers to read information from a perviously saved +// model file. func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) { f, err := os.Open(filePath) @@ -333,9 +350,15 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err // // Parse METADATA // - metadataBytes := getTarContent(tz, "METADATA") var metadata ClassifierMetadataV1 - err = json.Unmarshal(metadataBytes, &metadata) + ret := &ClassifierDeserializer{ + f, + gzr, + tz, + &metadata, + } + + err = ret.GetJSONForKey("METADATA", ret.Metadata) if err != nil { return nil, fmt.Errorf("Error whilst reading METADATA: %s", err) } @@ -345,31 +368,36 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn") } - ret := &ClassifierDeserializer{ - f, - gzr, - tz, - &metadata, - } - return ret, nil } +// GetBytesForKey returns the bytes at a given location in the output. func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) { return getTarContent(c.tarReader, key), nil } -func (c *ClassifierDeserializer) Close() { - +// GetJSONForKey deserializes a JSON key in the output file. +func (c *ClassifierDeserializer) GetJSONForKey(key string, v interface{}) error { + b, err := c.GetBytesForKey(key) + if err != nil { + return err + } + return json.Unmarshal(b, v) } +// Close cleans up everything. +func (c *ClassifierDeserializer) Close() { + c.fileReader.Close() +} + +// ClassifierSerializer is an object used by SaveableClassifiers. type ClassifierSerializer struct { gzipWriter *gzip.Writer - fileWriter io.Writer + fileWriter io.WriteCloser tarWriter *tar.Writer } -// Close finalizes the Classifier serialization session +// Close finalizes the Classifier serialization session. func (c *ClassifierSerializer) Close() error { // Finally, close and flush the various levels @@ -389,18 +417,14 @@ func (c *ClassifierSerializer) Close() error { return fmt.Errorf("Could not close gz: %s", err) } - //if err := c.fileWriter.Flush(); err != nil { - // return fmt.Errorf("Could not flush file: %s", err) - //} - - //if err := c.fileWriter.Flush(); err != nil { - // return fmt.Errorf("Could not close file: %s", err) - //} + if err := c.fileWriter.Close(); err != nil { + return fmt.Errorf("Could not close file: %s", err) + } return nil } -// WriteBytesForKey creates a new entry in the serializer file +// WriteBytesForKey creates a new entry in the serializer file with some user-defined bytes. func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error { // @@ -412,18 +436,30 @@ func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error { } if err := c.tarWriter.WriteHeader(hdr); err != nil { - return fmt.Errorf("Could not write header for '%s': %s", err) + return fmt.Errorf("Could not write header for '%s': %s", key, err) } // // Write data // if _, err := c.tarWriter.Write(b); err != nil { - return fmt.Errorf("Could not write data for '%s': %s", err) + return fmt.Errorf("Could not write data for '%s': %s", key, err) } return nil } +// WriteJSONForKey creates a new entry in the file with an interface serialized as JSON. +func (c *ClassifierSerializer) WriteJSONForKey(key string, v interface{}) error { + + b, err := json.Marshal(v) + if err != nil { + return err + } + + return c.WriteBytesForKey(key, b) + +} + // CreateSerializedClassifierStub generates a file to serialize into // and writes the METADATA header. func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) { @@ -463,28 +499,12 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata // Write the METADATA entry // - // Marshal the classifier information - cB, err := json.Marshal(metadata) + // Marshal the classifier information (TODO: split this into another method) + err = ret.WriteJSONForKey("METADATA", &metadata) if err != nil { - return nil, err - } - if len(cB) == 0 { return nil, fmt.Errorf("JSON marshal error: %s", err) } - // Write the information into the file - hdr = &tar.Header{ - Name: "METADATA", - Size: int64(len(cB)), - } - if err := tw.WriteHeader(hdr); err != nil { - return nil, fmt.Errorf("Could not write METADATA object %s", err) - } - - if _, err := tw.Write(cB); err != nil { - return nil, fmt.Errorf("Could not write METDATA contents: %s", err) - } - return ret, nil }