mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
606 lines
14 KiB
Go
606 lines
14 KiB
Go
package base
|
|
|
|
import (
|
|
"archive/tar"
|
|
"compress/gzip"
|
|
"encoding/csv"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"reflect"
|
|
"runtime"
|
|
"log"
|
|
)
|
|
|
|
const (
|
|
SerializationFormatVersion = "golearn 0.5"
|
|
)
|
|
|
|
func SerializeInstancesToFile(inst FixedDataGrid, path string) error {
|
|
f, err := os.OpenFile(path, os.O_RDWR, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = SerializeInstances(inst, f)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = f.Sync()
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't flush file: %s", err)
|
|
}
|
|
f.Close()
|
|
return nil
|
|
}
|
|
|
|
func SerializeInstancesToCSV(inst FixedDataGrid, path string) error {
|
|
f, err := os.OpenFile(path, os.O_RDWR, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
f.Sync()
|
|
f.Close()
|
|
}()
|
|
|
|
return SerializeInstancesToCSVStream(inst, f)
|
|
}
|
|
|
|
func SerializeInstancesToCSVStream(inst FixedDataGrid, f io.Writer) error {
|
|
// Create the CSV writer
|
|
w := csv.NewWriter(f)
|
|
|
|
colCount, _ := inst.Size()
|
|
|
|
// Write out Attribute headers
|
|
// Start with the regular Attributes
|
|
normalAttrs := NonClassAttributes(inst)
|
|
classAttrs := inst.AllClassAttributes()
|
|
allAttrs := make([]Attribute, colCount)
|
|
n := copy(allAttrs, normalAttrs)
|
|
copy(allAttrs[n:], classAttrs)
|
|
headerRow := make([]string, colCount)
|
|
for i, v := range allAttrs {
|
|
headerRow[i] = v.GetName()
|
|
}
|
|
w.Write(headerRow)
|
|
|
|
specs := ResolveAttributes(inst, allAttrs)
|
|
curRow := make([]string, colCount)
|
|
inst.MapOverRows(specs, func(row [][]byte, rowNo int) (bool, error) {
|
|
for i, v := range row {
|
|
attr := allAttrs[i]
|
|
curRow[i] = attr.GetStringFromSysVal(v)
|
|
}
|
|
w.Write(curRow)
|
|
return true, nil
|
|
})
|
|
|
|
w.Flush()
|
|
return nil
|
|
}
|
|
|
|
func writeAttributesToFilePart(attrs []Attribute, f *tar.Writer, name string) error {
|
|
// Get the marshaled Attribute array
|
|
body, err := json.Marshal(attrs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Write a header
|
|
hdr := &tar.Header{
|
|
Name: name,
|
|
Size: int64(len(body)),
|
|
}
|
|
if err := f.WriteHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Write the marshaled data
|
|
if _, err := f.Write([]byte(body)); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getTarContent(tr *tar.Reader, name string) []byte {
|
|
for {
|
|
hdr, err := tr.Next()
|
|
if err == io.EOF {
|
|
break
|
|
} else if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if hdr.Name == name {
|
|
ret := make([]byte, hdr.Size)
|
|
n, err := tr.Read(ret)
|
|
if int64(n) != hdr.Size {
|
|
panic("Size mismatch")
|
|
}
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return ret
|
|
}
|
|
}
|
|
panic("File not found!")
|
|
}
|
|
|
|
|
|
func DeserializeAttribute(data []byte) (Attribute, error) {
|
|
type JSONAttribute struct {
|
|
Type string `json:"type"`
|
|
Name string `json:"name"`
|
|
Attr json.RawMessage `json:"attr"`
|
|
}
|
|
|
|
log.Printf("DeserializeAttribute = %s", data)
|
|
var rawAttr JSONAttribute
|
|
err := json.Unmarshal(data, &rawAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var attr Attribute
|
|
|
|
switch rawAttr.Type {
|
|
case "binary":
|
|
attr = new(BinaryAttribute)
|
|
break
|
|
case "float":
|
|
attr = new(FloatAttribute)
|
|
break
|
|
case "categorical":
|
|
attr = new(CategoricalAttribute)
|
|
break
|
|
default:
|
|
return nil, fmt.Errorf("Unrecognised Attribute format: %s", rawAttr.Type)
|
|
}
|
|
|
|
err = attr.UnmarshalJSON(rawAttr.Attr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Can't deserialize: %s (error: %s)", rawAttr, err)
|
|
}
|
|
attr.SetName(rawAttr.Name)
|
|
log.Printf("DeserializeAttribute = %s", attr)
|
|
return attr, nil
|
|
}
|
|
|
|
// DeserializeAttributes constructs a ve
|
|
func DeserializeAttributes(data []byte) ([]Attribute, error) {
|
|
|
|
// Define a JSON shim Attribute
|
|
var attrs []json.RawMessage
|
|
err := json.Unmarshal(data, attrs)
|
|
|
|
ret := make([]Attribute, len(attrs))
|
|
for i, v := range attrs {
|
|
ret[i], err = DeserializeAttribute(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func DeserializeInstances(f io.Reader) (ret *DenseInstances, err error) {
|
|
|
|
// Recovery function
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
if _, ok := r.(runtime.Error); ok {
|
|
panic(r)
|
|
}
|
|
err = r.(error)
|
|
}
|
|
}()
|
|
|
|
// Open the .gz layer
|
|
gzReader, err := gzip.NewReader(f)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Can't open: %s", err)
|
|
}
|
|
// Open the .tar layer
|
|
tr := tar.NewReader(gzReader)
|
|
// Retrieve the MANIFEST and verify
|
|
manifestBytes := getTarContent(tr, "MANIFEST")
|
|
if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) {
|
|
return nil, fmt.Errorf("Unsupported MANIFEST: %s", string(manifestBytes))
|
|
}
|
|
|
|
// Get the size
|
|
sizeBytes := getTarContent(tr, "DIMS")
|
|
attrCount := int(UnpackBytesToU64(sizeBytes[0:8]))
|
|
rowCount := int(UnpackBytesToU64(sizeBytes[8:]))
|
|
|
|
// Unmarshal the Attributes
|
|
attrBytes := getTarContent(tr, "CATTRS")
|
|
cAttrs, err := DeserializeAttributes(attrBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
attrBytes = getTarContent(tr, "ATTRS")
|
|
normalAttrs, err := DeserializeAttributes(attrBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create the return instances
|
|
ret = NewDenseInstances()
|
|
|
|
// Normal Attributes first, class Attributes on the end
|
|
allAttributes := make([]Attribute, attrCount)
|
|
for i, v := range normalAttrs {
|
|
ret.AddAttribute(v)
|
|
allAttributes[i] = v
|
|
}
|
|
for i, v := range cAttrs {
|
|
ret.AddAttribute(v)
|
|
err = ret.AddClassAttribute(v)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not set Attribute as class Attribute: %s", err)
|
|
}
|
|
allAttributes[i+len(normalAttrs)] = v
|
|
}
|
|
// Allocate memory
|
|
err = ret.Extend(int(rowCount))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not allocate memory")
|
|
}
|
|
|
|
// Seek through the TAR file until we get to the DATA section
|
|
for {
|
|
hdr, err := tr.Next()
|
|
if err == io.EOF {
|
|
return nil, fmt.Errorf("DATA section missing!")
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("Error seeking to DATA section: %s", err)
|
|
}
|
|
if hdr.Name == "DATA" {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Resolve AttributeSpecs
|
|
specs := ResolveAttributes(ret, allAttributes)
|
|
|
|
// Finally, read the values out of the data section
|
|
for i := 0; i < rowCount; i++ {
|
|
for _, s := range specs {
|
|
r := ret.Get(s, i)
|
|
n, err := tr.Read(r)
|
|
if n != len(r) {
|
|
return nil, fmt.Errorf("Expected %d bytes (read %d) on row %d", len(r), n, i)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Read error: %s", err)
|
|
}
|
|
ret.Set(s, i, r)
|
|
}
|
|
}
|
|
|
|
if err = gzReader.Close(); err != nil {
|
|
return ret, fmt.Errorf("Error closing gzip stream: %s", err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
// ClassifierMetadataV1 is what gets written into METADATA
|
|
// in a classification file format.
|
|
type ClassifierMetadataV1 struct {
|
|
// FormatVersion should always be 1 for this structure
|
|
FormatVersion int `json:"format_version"`
|
|
// Uses the classifier name (provided by the classifier)
|
|
ClassifierName string `json:"classifier"`
|
|
// ClassifierVersion is also provided by the classifier
|
|
// and checks whether this version of GoLearn can read what's
|
|
// be written.
|
|
ClassifierVersion string `json"classifier_version"`
|
|
// This is a custom metadata field, provided by the classifier
|
|
ClassifierMetadata map[string]interface{} `json:"classifier_metadata"`
|
|
}
|
|
|
|
type ClassifierDeserializer struct {
|
|
gzipReader io.Reader
|
|
fileReader io.Reader
|
|
tarReader *tar.Reader
|
|
Metadata *ClassifierMetadataV1
|
|
}
|
|
|
|
// ReadSerializedClassifierStub is the counterpart of CreateSerializedClassifierStub
|
|
func ReadSerializedClassifierStub(filePath string) (*ClassifierDeserializer, error) {
|
|
|
|
f, err := os.Open(filePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Can't open file: %s", err)
|
|
}
|
|
|
|
gzr, err := gzip.NewReader(f)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Can't decompress file: %s", err)
|
|
}
|
|
|
|
tz := tar.NewReader(gzr)
|
|
|
|
// Check that the serialization format is right
|
|
// Retrieve the MANIFEST and verify
|
|
manifestBytes := getTarContent(tz, "MANIFEST")
|
|
if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) {
|
|
return nil, fmt.Errorf("Unsupported MANIFEST: %s", string(manifestBytes))
|
|
}
|
|
|
|
//
|
|
// Parse METADATA
|
|
//
|
|
metadataBytes := getTarContent(tz, "METADATA")
|
|
var metadata ClassifierMetadataV1
|
|
err = json.Unmarshal(metadataBytes, &metadata)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Error whilst reading METADATA: %s", err)
|
|
}
|
|
|
|
// Check that we can understand this archive
|
|
if metadata.FormatVersion != 1 {
|
|
return nil, fmt.Errorf("METADATA: wrong format_version for this version of golearn")
|
|
}
|
|
|
|
ret := &ClassifierDeserializer{
|
|
f,
|
|
gzr,
|
|
tz,
|
|
&metadata,
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (c *ClassifierDeserializer) GetBytesForKey(key string) ([]byte, error) {
|
|
return getTarContent(c.tarReader, key), nil
|
|
}
|
|
|
|
func (c *ClassifierDeserializer) Close() {
|
|
|
|
}
|
|
|
|
type ClassifierSerializer struct {
|
|
gzipWriter *gzip.Writer
|
|
fileWriter io.Writer
|
|
tarWriter *tar.Writer
|
|
}
|
|
|
|
// Close finalizes the Classifier serialization session
|
|
func (c *ClassifierSerializer) Close() error {
|
|
|
|
// Finally, close and flush the various levels
|
|
if err := c.tarWriter.Flush(); err != nil {
|
|
return fmt.Errorf("Could not flush tar: %s", err)
|
|
}
|
|
|
|
if err := c.tarWriter.Close(); err != nil {
|
|
return fmt.Errorf("Could not close tar: %s", err)
|
|
}
|
|
|
|
if err := c.gzipWriter.Flush(); err != nil {
|
|
return fmt.Errorf("Could not flush gz: %s", err)
|
|
}
|
|
|
|
if err := c.gzipWriter.Close(); err != nil {
|
|
return fmt.Errorf("Could not close gz: %s", err)
|
|
}
|
|
|
|
//if err := c.fileWriter.Flush(); err != nil {
|
|
// return fmt.Errorf("Could not flush file: %s", err)
|
|
//}
|
|
|
|
//if err := c.fileWriter.Flush(); err != nil {
|
|
// return fmt.Errorf("Could not close file: %s", err)
|
|
//}
|
|
|
|
return nil
|
|
}
|
|
|
|
// WriteBytesForKey creates a new entry in the serializer file
|
|
func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
|
|
|
|
//
|
|
// Write header for key
|
|
//
|
|
hdr := &tar.Header{
|
|
Name: key,
|
|
Size: int64(len(b)),
|
|
}
|
|
|
|
if err := c.tarWriter.WriteHeader(hdr); err != nil {
|
|
return fmt.Errorf("Could not write header for '%s': %s", err)
|
|
}
|
|
//
|
|
// Write data
|
|
//
|
|
if _, err := c.tarWriter.Write(b); err != nil {
|
|
return fmt.Errorf("Could not write data for '%s': %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CreateSerializedClassifierStub generates a file to serialize into
|
|
// and writes the METADATA header.
|
|
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
|
|
|
// Open the filePath
|
|
f, err := os.OpenFile(filePath, os.O_RDWR | os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var hdr *tar.Header
|
|
gzWriter := gzip.NewWriter(f)
|
|
tw := tar.NewWriter(gzWriter)
|
|
|
|
ret := &ClassifierSerializer{
|
|
gzipWriter: gzWriter,
|
|
fileWriter: f,
|
|
tarWriter: tw,
|
|
}
|
|
|
|
//
|
|
// Write the MANIFEST entry
|
|
//
|
|
hdr = &tar.Header{
|
|
Name: "MANIFEST",
|
|
Size: int64(len(SerializationFormatVersion)),
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return nil, fmt.Errorf("Could not write MANIFEST header: %s", err)
|
|
}
|
|
|
|
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
|
|
return nil, fmt.Errorf("Could not write MANIFEST contents: %s", err)
|
|
}
|
|
|
|
//
|
|
// Write the METADATA entry
|
|
//
|
|
|
|
// Marshal the classifier information
|
|
cB, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(cB) == 0 {
|
|
return nil, fmt.Errorf("JSON marshal error: %s", err)
|
|
}
|
|
|
|
// Write the information into the file
|
|
hdr = &tar.Header{
|
|
Name: "METADATA",
|
|
Size: int64(len(cB)),
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return nil, fmt.Errorf("Could not write METADATA object %s", err)
|
|
}
|
|
|
|
if _, err := tw.Write(cB); err != nil {
|
|
return nil, fmt.Errorf("Could not write METDATA contents: %s", err)
|
|
}
|
|
|
|
return ret, nil
|
|
|
|
}
|
|
|
|
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
|
|
var hdr *tar.Header
|
|
|
|
gzWriter := gzip.NewWriter(f)
|
|
tw := tar.NewWriter(gzWriter)
|
|
|
|
// Write the MANIFEST entry
|
|
hdr = &tar.Header{
|
|
Name: "MANIFEST",
|
|
Size: int64(len(SerializationFormatVersion)),
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return fmt.Errorf("Could not write MANIFEST header: %s", err)
|
|
}
|
|
|
|
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
|
|
return fmt.Errorf("Could not write MANIFEST contents: %s", err)
|
|
}
|
|
|
|
// Now write the dimensions of the dataset
|
|
attrCount, rowCount := inst.Size()
|
|
hdr = &tar.Header{
|
|
Name: "DIMS",
|
|
Size: 16,
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return fmt.Errorf("Could not write DIMS header: %s", err)
|
|
}
|
|
|
|
if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil {
|
|
return fmt.Errorf("Could not write DIMS (attrCount): %s", err)
|
|
}
|
|
if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil {
|
|
return fmt.Errorf("Could not write DIMS (rowCount): %s", err)
|
|
}
|
|
|
|
// Write the ATTRIBUTES files
|
|
classAttrs := inst.AllClassAttributes()
|
|
normalAttrs := NonClassAttributes(inst)
|
|
if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil {
|
|
return fmt.Errorf("Could not write CATTRS: %s", err)
|
|
}
|
|
if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil {
|
|
return fmt.Errorf("Could not write ATTRS: %s", err)
|
|
}
|
|
|
|
// Data must be written out in the same order as the Attributes
|
|
allAttrs := make([]Attribute, attrCount)
|
|
normCount := copy(allAttrs, normalAttrs)
|
|
for i, v := range classAttrs {
|
|
allAttrs[normCount+i] = v
|
|
}
|
|
|
|
allSpecs := ResolveAttributes(inst, allAttrs)
|
|
|
|
// First, estimate the amount of data we'll need...
|
|
dataLength := int64(0)
|
|
inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
|
for _, v := range val {
|
|
dataLength += int64(len(v))
|
|
}
|
|
return true, nil
|
|
})
|
|
|
|
// Then write the header
|
|
hdr = &tar.Header{
|
|
Name: "DATA",
|
|
Size: dataLength,
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return fmt.Errorf("Could not write DATA: %s", err)
|
|
}
|
|
|
|
// Then write the actual data
|
|
writtenLength := int64(0)
|
|
if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
|
for _, v := range val {
|
|
wl, err := tw.Write(v)
|
|
writtenLength += int64(wl)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
}
|
|
return true, nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if writtenLength != dataLength {
|
|
return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength)
|
|
}
|
|
|
|
// Finally, close and flush the various levels
|
|
if err := tw.Flush(); err != nil {
|
|
return fmt.Errorf("Could not flush tar: %s", err)
|
|
}
|
|
|
|
if err := tw.Close(); err != nil {
|
|
return fmt.Errorf("Could not close tar: %s", err)
|
|
}
|
|
|
|
if err := gzWriter.Flush(); err != nil {
|
|
return fmt.Errorf("Could not flush gz: %s", err)
|
|
}
|
|
|
|
if err := gzWriter.Close(); err != nil {
|
|
return fmt.Errorf("Could not close gz: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|