mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
base: fix unmarshalling attributes, add JSON
This commit is contained in:
parent
127a8e9162
commit
499ac7a493
@ -128,6 +128,18 @@ func getTarContent(tr *tar.Reader, name string) []byte {
|
||||
panic("File not found!")
|
||||
}
|
||||
|
||||
func MarshalAttribute(a Attribute) (map[string]interface{}, error) {
|
||||
ret := make(map[string]interface{})
|
||||
marshaledAttrRaw, err := a.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(marshaledAttrRaw, &ret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func DeserializeAttribute(data []byte) (Attribute, error) {
|
||||
type JSONAttribute struct {
|
||||
@ -170,7 +182,10 @@ func DeserializeAttributes(data []byte) ([]Attribute, error) {
|
||||
|
||||
// Define a JSON shim Attribute
|
||||
var attrs []json.RawMessage
|
||||
err := json.Unmarshal(data, attrs)
|
||||
err := json.Unmarshal(data, &attrs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to deserialize attributes: %v", err)
|
||||
}
|
||||
|
||||
ret := make([]Attribute, len(attrs))
|
||||
for i, v := range attrs {
|
||||
@ -303,12 +318,14 @@ type ClassifierMetadataV1 struct {
|
||||
|
||||
type ClassifierDeserializer struct {
|
||||
gzipReader io.Reader
|
||||
fileReader io.Reader
|
||||
fileReader io.ReadCloser
|
||||
tarReader *tar.Reader
|
||||
Metadata *ClassifierMetadataV1
|
||||
}
|
||||
|
||||
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub
|
||||
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub.
|
||||
// It's used inside SaveableClassifiers to read information from a perviously saved
|
||||
// model file.
|
||||
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
|
||||
|
||||
f, err := os.Open(filePath)
|
||||
@ -333,9 +350,15 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err
|
||||
//
|
||||
// Parse METADATA
|
||||
//
|
||||
metadataBytes := getTarContent(tz, "METADATA")
|
||||
var metadata ClassifierMetadataV1
|
||||
err = json.Unmarshal(metadataBytes, &metadata)
|
||||
ret := &ClassifierDeserializer{
|
||||
f,
|
||||
gzr,
|
||||
tz,
|
||||
&metadata,
|
||||
}
|
||||
|
||||
err = ret.GetJSONForKey("METADATA", ret.Metadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
|
||||
}
|
||||
@ -345,31 +368,36 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err
|
||||
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
|
||||
}
|
||||
|
||||
ret := &ClassifierDeserializer{
|
||||
f,
|
||||
gzr,
|
||||
tz,
|
||||
&metadata,
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// GetBytesForKey returns the bytes at a given location in the output.
|
||||
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
|
||||
return getTarContent(c.tarReader, key), nil
|
||||
}
|
||||
|
||||
func (c *ClassifierDeserializer) Close() {
|
||||
|
||||
// GetJSONForKey deserializes a JSON key in the output file.
|
||||
func (c *ClassifierDeserializer) GetJSONForKey(key string, v interface{}) error {
|
||||
b, err := c.GetBytesForKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
||||
// Close cleans up everything.
|
||||
func (c *ClassifierDeserializer) Close() {
|
||||
c.fileReader.Close()
|
||||
}
|
||||
|
||||
// ClassifierSerializer is an object used by SaveableClassifiers.
|
||||
type ClassifierSerializer struct {
|
||||
gzipWriter *gzip.Writer
|
||||
fileWriter io.Writer
|
||||
fileWriter io.WriteCloser
|
||||
tarWriter *tar.Writer
|
||||
}
|
||||
|
||||
// Close finalizes the Classifier serialization session
|
||||
// Close finalizes the Classifier serialization session.
|
||||
func (c *ClassifierSerializer) Close() error {
|
||||
|
||||
// Finally, close and flush the various levels
|
||||
@ -389,18 +417,14 @@ func (c *ClassifierSerializer) Close() error {
|
||||
return fmt.Errorf("Could not close gz: %s", err)
|
||||
}
|
||||
|
||||
//if err := c.fileWriter.Flush(); err != nil {
|
||||
// return fmt.Errorf("Could not flush file: %s", err)
|
||||
//}
|
||||
|
||||
//if err := c.fileWriter.Flush(); err != nil {
|
||||
// return fmt.Errorf("Could not close file: %s", err)
|
||||
//}
|
||||
if err := c.fileWriter.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteBytesForKey creates a new entry in the serializer file
|
||||
// WriteBytesForKey creates a new entry in the serializer file with some user-defined bytes.
|
||||
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
||||
|
||||
//
|
||||
@ -412,18 +436,30 @@ func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
||||
}
|
||||
|
||||
if err := c.tarWriter.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write header for '%s': %s", err)
|
||||
return fmt.Errorf("Could not write header for '%s': %s", key, err)
|
||||
}
|
||||
//
|
||||
// Write data
|
||||
//
|
||||
if _, err := c.tarWriter.Write(b); err != nil {
|
||||
return fmt.Errorf("Could not write data for '%s': %s", err)
|
||||
return fmt.Errorf("Could not write data for '%s': %s", key, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteJSONForKey creates a new entry in the file with an interface serialized as JSON.
|
||||
func (c *ClassifierSerializer) WriteJSONForKey(key string, v interface{}) error {
|
||||
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.WriteBytesForKey(key, b)
|
||||
|
||||
}
|
||||
|
||||
// CreateSerializedClassifierStub generates a file to serialize into
|
||||
// and writes the METADATA header.
|
||||
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
||||
@ -463,28 +499,12 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
|
||||
// Write the METADATA entry
|
||||
//
|
||||
|
||||
// Marshal the classifier information
|
||||
cB, err := json.Marshal(metadata)
|
||||
// Marshal the classifier information (TODO: split this into another method)
|
||||
err = ret.WriteJSONForKey("METADATA", &metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(cB) == 0 {
|
||||
return nil, fmt.Errorf("JSON marshal error: %s", err)
|
||||
}
|
||||
|
||||
// Write the information into the file
|
||||
hdr = &tar.Header{
|
||||
Name: "METADATA",
|
||||
Size: int64(len(cB)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return nil, fmt.Errorf("Could not write METADATA object %s", err)
|
||||
}
|
||||
|
||||
if _, err := tw.Write(cB); err != nil {
|
||||
return nil, fmt.Errorf("Could not write METDATA contents: %s", err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user