mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
trees: implement serialization
This commit is contained in:
parent
dede6dc750
commit
f722f2e59d
@ -3,6 +3,7 @@ package base
|
|||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"encoding/csv"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -390,8 +391,8 @@ func (c *ClassifierSerializer) WriteMetadataAtPrefix(prefix string, metadata Cla
|
|||||||
// and writes the METADATA header.
|
// and writes the METADATA header.
|
||||||
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
||||||
|
|
||||||
// Write to a temporary path so we don't corrupt the output file
|
// Open the filePath
|
||||||
f, err := ioutil.TempFile(os.TempDir(), "clsTmp")
|
f, err := os.OpenFile(filePath, os.O_RDWR|os.O_TRUNC, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -404,8 +405,6 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
|
|||||||
gzipWriter: gzWriter,
|
gzipWriter: gzWriter,
|
||||||
fileWriter: f,
|
fileWriter: f,
|
||||||
tarWriter: tw,
|
tarWriter: tw,
|
||||||
f: f,
|
|
||||||
filePath: filePath,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -434,3 +433,115 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
|
|||||||
return &ret, nil
|
return &ret, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
|
||||||
|
var hdr *tar.Header
|
||||||
|
|
||||||
|
gzWriter := gzip.NewWriter(f)
|
||||||
|
tw := tar.NewWriter(gzWriter)
|
||||||
|
|
||||||
|
// Write the MANIFEST entry
|
||||||
|
hdr = &tar.Header{
|
||||||
|
Name: "MANIFEST",
|
||||||
|
Size: int64(len(SerializationFormatVersion)),
|
||||||
|
}
|
||||||
|
if err := tw.WriteHeader(hdr); err != nil {
|
||||||
|
return fmt.Errorf("Could not write MANIFEST header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
|
||||||
|
return fmt.Errorf("Could not write MANIFEST contents: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now write the dimensions of the dataset
|
||||||
|
attrCount, rowCount := inst.Size()
|
||||||
|
hdr = &tar.Header{
|
||||||
|
Name: "DIMS",
|
||||||
|
Size: 16,
|
||||||
|
}
|
||||||
|
if err := tw.WriteHeader(hdr); err != nil {
|
||||||
|
return fmt.Errorf("Could not write DIMS header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil {
|
||||||
|
return fmt.Errorf("Could not write DIMS (attrCount): %s", err)
|
||||||
|
}
|
||||||
|
if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil {
|
||||||
|
return fmt.Errorf("Could not write DIMS (rowCount): %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the ATTRIBUTES files
|
||||||
|
classAttrs := inst.AllClassAttributes()
|
||||||
|
normalAttrs := NonClassAttributes(inst)
|
||||||
|
if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil {
|
||||||
|
return fmt.Errorf("Could not write CATTRS: %s", err)
|
||||||
|
}
|
||||||
|
if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil {
|
||||||
|
return fmt.Errorf("Could not write ATTRS: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data must be written out in the same order as the Attributes
|
||||||
|
allAttrs := make([]Attribute, attrCount)
|
||||||
|
normCount := copy(allAttrs, normalAttrs)
|
||||||
|
for i, v := range classAttrs {
|
||||||
|
allAttrs[normCount+i] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
allSpecs := ResolveAttributes(inst, allAttrs)
|
||||||
|
|
||||||
|
// First, estimate the amount of data we'll need...
|
||||||
|
dataLength := int64(0)
|
||||||
|
inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
||||||
|
for _, v := range val {
|
||||||
|
dataLength += int64(len(v))
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Then write the header
|
||||||
|
hdr = &tar.Header{
|
||||||
|
Name: "DATA",
|
||||||
|
Size: dataLength,
|
||||||
|
}
|
||||||
|
if err := tw.WriteHeader(hdr); err != nil {
|
||||||
|
return fmt.Errorf("Could not write DATA: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then write the actual data
|
||||||
|
writtenLength := int64(0)
|
||||||
|
if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
||||||
|
for _, v := range val {
|
||||||
|
wl, err := tw.Write(v)
|
||||||
|
writtenLength += int64(wl)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if writtenLength != dataLength {
|
||||||
|
return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, close and flush the various levels
|
||||||
|
if err := tw.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("Could not flush tar: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tw.Close(); err != nil {
|
||||||
|
return fmt.Errorf("Could not close tar: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gzWriter.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("Could not flush gz: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gzWriter.Close(); err != nil {
|
||||||
|
return fmt.Errorf("Could not close gz: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
153
trees/id3.go
153
trees/id3.go
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"github.com/sjwhitworth/golearn/evaluation"
|
"github.com/sjwhitworth/golearn/evaluation"
|
||||||
|
"encoding/json"
|
||||||
"sort"
|
"sort"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,155 +107,27 @@ func getClassAttr(from base.FixedDataGrid) base.Attribute {
|
|||||||
return allClassAttrs[0]
|
return allClassAttrs[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON returns a JSON representation of this Attribute
|
|
||||||
// for serialisation.
|
|
||||||
func (d *DecisionTreeNode) MarshalJSON() ([]byte, error) {
|
|
||||||
ret := map[string]interface{}{
|
|
||||||
"type": d.Type,
|
|
||||||
"class_dist": d.ClassDist,
|
|
||||||
"class": d.Class,
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.SplitRule != nil && d.SplitRule.SplitAttr != nil {
|
|
||||||
rawDRule, err := d.SplitRule.MarshalJSON()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var dRule map[string]interface{}
|
|
||||||
err = json.Unmarshal(rawDRule, &dRule)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ret["split_rule"] = dRule
|
|
||||||
}
|
|
||||||
|
|
||||||
rawClassAttr, err := d.ClassAttr.MarshalJSON()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var classAttr map[string]interface{}
|
|
||||||
err = json.Unmarshal(rawClassAttr, &classAttr)
|
|
||||||
ret["class_attr"] = classAttr
|
|
||||||
|
|
||||||
if len(d.Children) > 0 {
|
|
||||||
|
|
||||||
children := make(map[string]interface{})
|
|
||||||
for k := range d.Children {
|
|
||||||
cur, err := d.Children[k].MarshalJSON()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var child map[string]interface{}
|
|
||||||
err = json.Unmarshal(cur, &child)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
children[k] = child
|
|
||||||
}
|
|
||||||
ret["children"] = children
|
|
||||||
}
|
|
||||||
return json.Marshal(ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalJSON reads a JSON representation of this Attribute.
|
|
||||||
func (d *DecisionTreeNode) UnmarshalJSON(data []byte) error {
|
|
||||||
jsonMap := make(map[string]interface{})
|
|
||||||
err := json.Unmarshal(data, &jsonMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
rawType := int(jsonMap["type"].(float64))
|
|
||||||
if rawType == 1 {
|
|
||||||
d.Type = LeafNode
|
|
||||||
} else if rawType == 2 {
|
|
||||||
d.Type = RuleNode
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("Unknown nodeType: %d", rawType)
|
|
||||||
}
|
|
||||||
//d.Type = NodeType(int(jsonMap["type"].(float64)))
|
|
||||||
// Convert the class distribution back
|
|
||||||
classDist := jsonMap["class_dist"].(map[string]interface{})
|
|
||||||
d.ClassDist = make(map[string]int)
|
|
||||||
for k := range classDist {
|
|
||||||
d.ClassDist[k] = int(classDist[k].(float64))
|
|
||||||
}
|
|
||||||
|
|
||||||
d.Class = jsonMap["class"].(string)
|
|
||||||
|
|
||||||
//
|
|
||||||
// Decode the class attribute
|
|
||||||
//
|
|
||||||
// Temporarily re-marshal this field back to bytes
|
|
||||||
rawClassAttr := jsonMap["class_attr"]
|
|
||||||
rawClassAttrBytes, err := json.Marshal(rawClassAttr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
classAttr, err := base.DeserializeAttribute(rawClassAttrBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.ClassAttr = classAttr
|
|
||||||
d.SplitRule = nil
|
|
||||||
|
|
||||||
if splitRule, ok := jsonMap["split_rule"]; ok {
|
|
||||||
d.SplitRule = &DecisionTreeRule{}
|
|
||||||
splitRuleBytes, err := json.Marshal(splitRule)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
err = d.SplitRule.UnmarshalJSON(splitRuleBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.Children = make(map[string]*DecisionTreeNode)
|
|
||||||
childMap := jsonMap["children"].(map[string]interface{})
|
|
||||||
for i := range childMap {
|
|
||||||
cur := &DecisionTreeNode{}
|
|
||||||
childBytes, err := json.Marshal(childMap[i])
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
err = cur.UnmarshalJSON(childBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.Children[i] = cur
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save sends the classification tree to an output file
|
// Save sends the classification tree to an output file
|
||||||
func (d *DecisionTreeNode) Save(filePath string) error {
|
func (d *DecisionTreeNode) Save(filePath string) error {
|
||||||
metadata := base.ClassifierMetadataV1{
|
metadata := base.ClassifierMetadataV1 {
|
||||||
FormatVersion: 1,
|
FormatVersion: 1,
|
||||||
ClassifierName: "DecisionTreeNode",
|
ClassifierName: "test",
|
||||||
ClassifierVersion: "1",
|
ClassifierVersion: "1",
|
||||||
ClassifierMetadata: nil,
|
ClassifierMetadata: exampleClassifierMetadata,
|
||||||
}
|
}
|
||||||
serializer, err := base.CreateSerializedClassifierStub(filePath, metadata)
|
serializer, err := base.CreateSerializedClassifierStub(filePath, metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = d.SaveWithPrefix(serializer, "")
|
|
||||||
serializer.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DecisionTreeNode) SaveWithPrefix(serializer *base.ClassifierSerializer, prefix string) error {
|
|
||||||
b, err := json.Marshal(d)
|
b, err := json.Marshal(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = serializer.WriteBytesForKey(fmt.Sprintf("%s%s", prefix, "tree"), b)
|
err = serializer.WriteBytesForKey("tree", b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
serializer.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,13 +139,11 @@ func (d *DecisionTreeNode) Load(filePath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = d.LoadWithPrefix(reader, "")
|
defer func() {
|
||||||
reader.Close()
|
reader.Close()
|
||||||
return err
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DecisionTreeNode) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
b, err := reader.GetBytesForKey("tree")
|
||||||
b, err := reader.GetBytesForKey(fmt.Sprintf("%s%s", prefix, "tree"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package trees
|
package trees
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"github.com/sjwhitworth/golearn/evaluation"
|
"github.com/sjwhitworth/golearn/evaluation"
|
||||||
"github.com/sjwhitworth/golearn/filters"
|
"github.com/sjwhitworth/golearn/filters"
|
||||||
@ -10,16 +9,12 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
"io/ioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCanSaveLoadPredictions(t *testing.T) {
|
func testCanSaveLoadPredictions(trainData, testData base.FixedDataGrid) {
|
||||||
rand.Seed(44414515)
|
rand.Seed(44414515)
|
||||||
Convey("Using InferID3Tree to create the tree and do the fitting", t, func() {
|
Convey("Using InferID3Tree to create the tree and do the fitting", func() {
|
||||||
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
||||||
So(err, ShouldBeNil)
|
|
||||||
|
|
||||||
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
|
|
||||||
|
|
||||||
Convey("Using a RandomTreeRule", func() {
|
Convey("Using a RandomTreeRule", func() {
|
||||||
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
|
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
|
||||||
randomTreeRuleGenerator.Attributes = 2
|
randomTreeRuleGenerator.Attributes = 2
|
||||||
@ -30,20 +25,18 @@ func TestCanSaveLoadPredictions(t *testing.T) {
|
|||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
Convey("Saving the tree...", func() {
|
Convey("Saving the tree...", func() {
|
||||||
f, err := ioutil.TempFile("", "tree")
|
f, err := ioutil.TempFile("","tree")
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
err = root.Save(f.Name())
|
err = root.Save(f.Name())
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
Convey("Loading the tree...", func() {
|
Convey("Loading the tree...", func(){
|
||||||
d := &DecisionTreeNode{}
|
d := &DecisionTreeNode{}
|
||||||
err := d.Load(f.Name())
|
err := d.Load(f.Name())
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(d.String(), ShouldEqual, root.String())
|
|
||||||
Convey("Generating predictions from the loaded tree...", func() {
|
Convey("Generating predictions from the loaded tree...", func() {
|
||||||
predictions2, err := d.Predict(testData)
|
predictions2, err := d.Predict(testData)
|
||||||
So(err, ShouldBeNil)
|
So(predictions, ShouldEqual, predictions2)
|
||||||
So(fmt.Sprintf("%v", predictions2), ShouldEqual, fmt.Sprintf("%v", predictions))
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user