mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
base: fix unmarshalling attributes, add JSON
This commit is contained in:
parent
127a8e9162
commit
499ac7a493
@ -128,6 +128,18 @@ func getTarContent(tr *tar.Reader, name string) []byte {
|
|||||||
panic("File not found!")
|
panic("File not found!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MarshalAttribute(a Attribute) (map[string]interface{}, error) {
|
||||||
|
ret := make(map[string]interface{})
|
||||||
|
marshaledAttrRaw, err := a.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(marshaledAttrRaw, &ret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
func DeserializeAttribute(data []byte) (Attribute, error) {
|
func DeserializeAttribute(data []byte) (Attribute, error) {
|
||||||
type JSONAttribute struct {
|
type JSONAttribute struct {
|
||||||
@ -170,7 +182,10 @@ func DeserializeAttributes(data []byte) ([]Attribute, error) {
|
|||||||
|
|
||||||
// Define a JSON shim Attribute
|
// Define a JSON shim Attribute
|
||||||
var attrs []json.RawMessage
|
var attrs []json.RawMessage
|
||||||
err := json.Unmarshal(data, attrs)
|
err := json.Unmarshal(data, &attrs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to deserialize attributes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ret := make([]Attribute, len(attrs))
|
ret := make([]Attribute, len(attrs))
|
||||||
for i, v := range attrs {
|
for i, v := range attrs {
|
||||||
@ -303,12 +318,14 @@ type ClassifierMetadataV1 struct {
|
|||||||
|
|
||||||
type ClassifierDeserializer struct {
|
type ClassifierDeserializer struct {
|
||||||
gzipReader io.Reader
|
gzipReader io.Reader
|
||||||
fileReader io.Reader
|
fileReader io.ReadCloser
|
||||||
tarReader *tar.Reader
|
tarReader *tar.Reader
|
||||||
Metadata *ClassifierMetadataV1
|
Metadata *ClassifierMetadataV1
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub
|
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub.
|
||||||
|
// It's used inside SaveableClassifiers to read information from a perviously saved
|
||||||
|
// model file.
|
||||||
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
|
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
|
||||||
|
|
||||||
f, err := os.Open(filePath)
|
f, err := os.Open(filePath)
|
||||||
@ -333,9 +350,15 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err
|
|||||||
//
|
//
|
||||||
// Parse METADATA
|
// Parse METADATA
|
||||||
//
|
//
|
||||||
metadataBytes := getTarContent(tz, "METADATA")
|
|
||||||
var metadata ClassifierMetadataV1
|
var metadata ClassifierMetadataV1
|
||||||
err = json.Unmarshal(metadataBytes, &metadata)
|
ret := &ClassifierDeserializer{
|
||||||
|
f,
|
||||||
|
gzr,
|
||||||
|
tz,
|
||||||
|
&metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ret.GetJSONForKey("METADATA", ret.Metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
|
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
|
||||||
}
|
}
|
||||||
@ -345,31 +368,36 @@ func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, err
|
|||||||
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
|
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := &ClassifierDeserializer{
|
|
||||||
f,
|
|
||||||
gzr,
|
|
||||||
tz,
|
|
||||||
&metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBytesForKey returns the bytes at a given location in the output.
|
||||||
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
|
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
|
||||||
return getTarContent(c.tarReader, key), nil
|
return getTarContent(c.tarReader, key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClassifierDeserializer) Close() {
|
// GetJSONForKey deserializes a JSON key in the output file.
|
||||||
|
func (c *ClassifierDeserializer) GetJSONForKey(key string, v interface{}) error {
|
||||||
|
b, err := c.GetBytesForKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close cleans up everything.
|
||||||
|
func (c *ClassifierDeserializer) Close() {
|
||||||
|
c.fileReader.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClassifierSerializer is an object used by SaveableClassifiers.
|
||||||
type ClassifierSerializer struct {
|
type ClassifierSerializer struct {
|
||||||
gzipWriter *gzip.Writer
|
gzipWriter *gzip.Writer
|
||||||
fileWriter io.Writer
|
fileWriter io.WriteCloser
|
||||||
tarWriter *tar.Writer
|
tarWriter *tar.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close finalizes the Classifier serialization session
|
// Close finalizes the Classifier serialization session.
|
||||||
func (c *ClassifierSerializer) Close() error {
|
func (c *ClassifierSerializer) Close() error {
|
||||||
|
|
||||||
// Finally, close and flush the various levels
|
// Finally, close and flush the various levels
|
||||||
@ -389,18 +417,14 @@ func (c *ClassifierSerializer) Close() error {
|
|||||||
return fmt.Errorf("Could not close gz: %s", err)
|
return fmt.Errorf("Could not close gz: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
//if err := c.fileWriter.Flush(); err != nil {
|
if err := c.fileWriter.Close(); err != nil {
|
||||||
// return fmt.Errorf("Could not flush file: %s", err)
|
return fmt.Errorf("Could not close file: %s", err)
|
||||||
//}
|
}
|
||||||
|
|
||||||
//if err := c.fileWriter.Flush(); err != nil {
|
|
||||||
// return fmt.Errorf("Could not close file: %s", err)
|
|
||||||
//}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteBytesForKey creates a new entry in the serializer file
|
// WriteBytesForKey creates a new entry in the serializer file with some user-defined bytes.
|
||||||
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -412,18 +436,30 @@ func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := c.tarWriter.WriteHeader(hdr); err != nil {
|
if err := c.tarWriter.WriteHeader(hdr); err != nil {
|
||||||
return fmt.Errorf("Could not write header for '%s': %s", err)
|
return fmt.Errorf("Could not write header for '%s': %s", key, err)
|
||||||
}
|
}
|
||||||
//
|
//
|
||||||
// Write data
|
// Write data
|
||||||
//
|
//
|
||||||
if _, err := c.tarWriter.Write(b); err != nil {
|
if _, err := c.tarWriter.Write(b); err != nil {
|
||||||
return fmt.Errorf("Could not write data for '%s': %s", err)
|
return fmt.Errorf("Could not write data for '%s': %s", key, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteJSONForKey creates a new entry in the file with an interface serialized as JSON.
|
||||||
|
func (c *ClassifierSerializer) WriteJSONForKey(key string, v interface{}) error {
|
||||||
|
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.WriteBytesForKey(key, b)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// CreateSerializedClassifierStub generates a file to serialize into
|
// CreateSerializedClassifierStub generates a file to serialize into
|
||||||
// and writes the METADATA header.
|
// and writes the METADATA header.
|
||||||
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
||||||
@ -463,28 +499,12 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
|
|||||||
// Write the METADATA entry
|
// Write the METADATA entry
|
||||||
//
|
//
|
||||||
|
|
||||||
// Marshal the classifier information
|
// Marshal the classifier information (TODO: split this into another method)
|
||||||
cB, err := json.Marshal(metadata)
|
err = ret.WriteJSONForKey("METADATA", &metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(cB) == 0 {
|
|
||||||
return nil, fmt.Errorf("JSON marshal error: %s", err)
|
return nil, fmt.Errorf("JSON marshal error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the information into the file
|
|
||||||
hdr = &tar.Header{
|
|
||||||
Name: "METADATA",
|
|
||||||
Size: int64(len(cB)),
|
|
||||||
}
|
|
||||||
if err := tw.WriteHeader(hdr); err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not write METADATA object %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tw.Write(cB); err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not write METDATA contents: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
return ret, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user