1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

trees: implement serialization

This commit is contained in:
Richard Townsend 2018-01-27 18:00:52 +00:00
parent dede6dc750
commit f722f2e59d
3 changed files with 133 additions and 158 deletions

View File

@ -3,6 +3,7 @@ package base
import (
"archive/tar"
"compress/gzip"
"encoding/csv"
"encoding/json"
"fmt"
"io"
@ -390,8 +391,8 @@ func (c *ClassifierSerializer) WriteMetadataAtPrefix(prefix string, metadata Cla
// and writes the METADATA header.
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
// Write to a temporary path so we don't corrupt the output file
f, err := ioutil.TempFile(os.TempDir(), "clsTmp")
// Open the filePath
f, err := os.OpenFile(filePath, os.O_RDWR|os.O_TRUNC, 0600)
if err != nil {
return nil, err
}
@ -404,8 +405,6 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
gzipWriter: gzWriter,
fileWriter: f,
tarWriter: tw,
f: f,
filePath: filePath,
}
//
@ -434,3 +433,115 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
return &ret, nil
}
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
var hdr *tar.Header
gzWriter := gzip.NewWriter(f)
tw := tar.NewWriter(gzWriter)
// Write the MANIFEST entry
hdr = &tar.Header{
Name: "MANIFEST",
Size: int64(len(SerializationFormatVersion)),
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write MANIFEST header: %s", err)
}
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
return fmt.Errorf("Could not write MANIFEST contents: %s", err)
}
// Now write the dimensions of the dataset
attrCount, rowCount := inst.Size()
hdr = &tar.Header{
Name: "DIMS",
Size: 16,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write DIMS header: %s", err)
}
if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil {
return fmt.Errorf("Could not write DIMS (attrCount): %s", err)
}
if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil {
return fmt.Errorf("Could not write DIMS (rowCount): %s", err)
}
// Write the ATTRIBUTES files
classAttrs := inst.AllClassAttributes()
normalAttrs := NonClassAttributes(inst)
if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil {
return fmt.Errorf("Could not write CATTRS: %s", err)
}
if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil {
return fmt.Errorf("Could not write ATTRS: %s", err)
}
// Data must be written out in the same order as the Attributes
allAttrs := make([]Attribute, attrCount)
normCount := copy(allAttrs, normalAttrs)
for i, v := range classAttrs {
allAttrs[normCount+i] = v
}
allSpecs := ResolveAttributes(inst, allAttrs)
// First, estimate the amount of data we'll need...
dataLength := int64(0)
inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
for _, v := range val {
dataLength += int64(len(v))
}
return true, nil
})
// Then write the header
hdr = &tar.Header{
Name: "DATA",
Size: dataLength,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write DATA: %s", err)
}
// Then write the actual data
writtenLength := int64(0)
if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
for _, v := range val {
wl, err := tw.Write(v)
writtenLength += int64(wl)
if err != nil {
return false, err
}
}
return true, nil
}); err != nil {
return err
}
if writtenLength != dataLength {
return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength)
}
// Finally, close and flush the various levels
if err := tw.Flush(); err != nil {
return fmt.Errorf("Could not flush tar: %s", err)
}
if err := tw.Close(); err != nil {
return fmt.Errorf("Could not close tar: %s", err)
}
if err := gzWriter.Flush(); err != nil {
return fmt.Errorf("Could not flush gz: %s", err)
}
if err := gzWriter.Close(); err != nil {
return fmt.Errorf("Could not close gz: %s", err)
}
return nil
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"encoding/json"
"sort"
)
@ -106,155 +107,27 @@ func getClassAttr(from base.FixedDataGrid) base.Attribute {
return allClassAttrs[0]
}
// MarshalJSON returns a JSON representation of this Attribute
// for serialisation.
func (d *DecisionTreeNode) MarshalJSON() ([]byte, error) {
ret := map[string]interface{}{
"type": d.Type,
"class_dist": d.ClassDist,
"class": d.Class,
}
if d.SplitRule != nil && d.SplitRule.SplitAttr != nil {
rawDRule, err := d.SplitRule.MarshalJSON()
if err != nil {
return nil, err
}
var dRule map[string]interface{}
err = json.Unmarshal(rawDRule, &dRule)
if err != nil {
panic(err)
}
ret["split_rule"] = dRule
}
rawClassAttr, err := d.ClassAttr.MarshalJSON()
if err != nil {
return nil, err
}
var classAttr map[string]interface{}
err = json.Unmarshal(rawClassAttr, &classAttr)
ret["class_attr"] = classAttr
if len(d.Children) > 0 {
children := make(map[string]interface{})
for k := range d.Children {
cur, err := d.Children[k].MarshalJSON()
if err != nil {
return nil, err
}
var child map[string]interface{}
err = json.Unmarshal(cur, &child)
if err != nil {
panic(err)
}
children[k] = child
}
ret["children"] = children
}
return json.Marshal(ret)
}
// UnmarshalJSON reads a JSON representation of this Attribute.
func (d *DecisionTreeNode) UnmarshalJSON(data []byte) error {
jsonMap := make(map[string]interface{})
err := json.Unmarshal(data, &jsonMap)
if err != nil {
return err
}
rawType := int(jsonMap["type"].(float64))
if rawType == 1 {
d.Type = LeafNode
} else if rawType == 2 {
d.Type = RuleNode
} else {
return fmt.Errorf("Unknown nodeType: %d", rawType)
}
//d.Type = NodeType(int(jsonMap["type"].(float64)))
// Convert the class distribution back
classDist := jsonMap["class_dist"].(map[string]interface{})
d.ClassDist = make(map[string]int)
for k := range classDist {
d.ClassDist[k] = int(classDist[k].(float64))
}
d.Class = jsonMap["class"].(string)
//
// Decode the class attribute
//
// Temporarily re-marshal this field back to bytes
rawClassAttr := jsonMap["class_attr"]
rawClassAttrBytes, err := json.Marshal(rawClassAttr)
if err != nil {
return err
}
classAttr, err := base.DeserializeAttribute(rawClassAttrBytes)
if err != nil {
return err
}
d.ClassAttr = classAttr
d.SplitRule = nil
if splitRule, ok := jsonMap["split_rule"]; ok {
d.SplitRule = &DecisionTreeRule{}
splitRuleBytes, err := json.Marshal(splitRule)
if err != nil {
panic(err)
}
err = d.SplitRule.UnmarshalJSON(splitRuleBytes)
if err != nil {
return err
}
d.Children = make(map[string]*DecisionTreeNode)
childMap := jsonMap["children"].(map[string]interface{})
for i := range childMap {
cur := &DecisionTreeNode{}
childBytes, err := json.Marshal(childMap[i])
if err != nil {
panic(err)
}
err = cur.UnmarshalJSON(childBytes)
if err != nil {
return err
}
d.Children[i] = cur
}
}
return nil
}
// Save sends the classification tree to an output file
func (d *DecisionTreeNode) Save(filePath string) error {
metadata := base.ClassifierMetadataV1 {
FormatVersion: 1,
ClassifierName: "DecisionTreeNode",
ClassifierName: "test",
ClassifierVersion: "1",
ClassifierMetadata: nil,
ClassifierMetadata: exampleClassifierMetadata,
}
serializer, err := base.CreateSerializedClassifierStub(filePath, metadata)
if err != nil {
return err
}
err = d.SaveWithPrefix(serializer, "")
serializer.Close()
return err
}
func (d *DecisionTreeNode) SaveWithPrefix(serializer *base.ClassifierSerializer, prefix string) error {
b, err := json.Marshal(d)
if err != nil {
return err
}
err = serializer.WriteBytesForKey(fmt.Sprintf("%s%s", prefix, "tree"), b)
err = serializer.WriteBytesForKey("tree", b)
if err != nil {
return err
}
serializer.Close()
return nil
}
@ -266,13 +139,11 @@ func (d *DecisionTreeNode) Load(filePath string) error {
return err
}
err = d.LoadWithPrefix(reader, "")
defer func() {
reader.Close()
return err
}
}()
func (d *DecisionTreeNode) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
b, err := reader.GetBytesForKey(fmt.Sprintf("%s%s", prefix, "tree"))
b, err := reader.GetBytesForKey("tree")
if err != nil {
return err
}

View File

@ -1,7 +1,6 @@
package trees
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
@ -10,16 +9,12 @@ import (
"math/rand"
"os"
"testing"
"io/ioutil"
)
func TestCanSaveLoadPredictions(t *testing.T) {
func testCanSaveLoadPredictions(trainData, testData base.FixedDataGrid) {
rand.Seed(44414515)
Convey("Using InferID3Tree to create the tree and do the fitting", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
Convey("Using InferID3Tree to create the tree and do the fitting", func() {
Convey("Using a RandomTreeRule", func() {
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
randomTreeRuleGenerator.Attributes = 2
@ -39,11 +34,9 @@ func TestCanSaveLoadPredictions(t *testing.T) {
d := &DecisionTreeNode{}
err := d.Load(f.Name())
So(err, ShouldBeNil)
So(d.String(), ShouldEqual, root.String())
Convey("Generating predictions from the loaded tree...", func() {
predictions2, err := d.Predict(testData)
So(err, ShouldBeNil)
So(fmt.Sprintf("%v", predictions2), ShouldEqual, fmt.Sprintf("%v", predictions))
So(predictions, ShouldEqual, predictions2)
})
})
})