mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
trees: implement serialization
This commit is contained in:
parent
dede6dc750
commit
f722f2e59d
@ -3,6 +3,7 @@ package base
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -390,8 +391,8 @@ func (c *ClassifierSerializer) WriteMetadataAtPrefix(prefix string, metadata Cla
|
||||
// and writes the METADATA header.
|
||||
func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) {
|
||||
|
||||
// Write to a temporary path so we don't corrupt the output file
|
||||
f, err := ioutil.TempFile(os.TempDir(), "clsTmp")
|
||||
// Open the filePath
|
||||
f, err := os.OpenFile(filePath, os.O_RDWR|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -404,8 +405,6 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
|
||||
gzipWriter: gzWriter,
|
||||
fileWriter: f,
|
||||
tarWriter: tw,
|
||||
f: f,
|
||||
filePath: filePath,
|
||||
}
|
||||
|
||||
//
|
||||
@ -434,3 +433,115 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata
|
||||
return &ret, nil
|
||||
|
||||
}
|
||||
|
||||
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
|
||||
var hdr *tar.Header
|
||||
|
||||
gzWriter := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gzWriter)
|
||||
|
||||
// Write the MANIFEST entry
|
||||
hdr = &tar.Header{
|
||||
Name: "MANIFEST",
|
||||
Size: int64(len(SerializationFormatVersion)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write MANIFEST header: %s", err)
|
||||
}
|
||||
|
||||
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
|
||||
return fmt.Errorf("Could not write MANIFEST contents: %s", err)
|
||||
}
|
||||
|
||||
// Now write the dimensions of the dataset
|
||||
attrCount, rowCount := inst.Size()
|
||||
hdr = &tar.Header{
|
||||
Name: "DIMS",
|
||||
Size: 16,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write DIMS header: %s", err)
|
||||
}
|
||||
|
||||
if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil {
|
||||
return fmt.Errorf("Could not write DIMS (attrCount): %s", err)
|
||||
}
|
||||
if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil {
|
||||
return fmt.Errorf("Could not write DIMS (rowCount): %s", err)
|
||||
}
|
||||
|
||||
// Write the ATTRIBUTES files
|
||||
classAttrs := inst.AllClassAttributes()
|
||||
normalAttrs := NonClassAttributes(inst)
|
||||
if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil {
|
||||
return fmt.Errorf("Could not write CATTRS: %s", err)
|
||||
}
|
||||
if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil {
|
||||
return fmt.Errorf("Could not write ATTRS: %s", err)
|
||||
}
|
||||
|
||||
// Data must be written out in the same order as the Attributes
|
||||
allAttrs := make([]Attribute, attrCount)
|
||||
normCount := copy(allAttrs, normalAttrs)
|
||||
for i, v := range classAttrs {
|
||||
allAttrs[normCount+i] = v
|
||||
}
|
||||
|
||||
allSpecs := ResolveAttributes(inst, allAttrs)
|
||||
|
||||
// First, estimate the amount of data we'll need...
|
||||
dataLength := int64(0)
|
||||
inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
||||
for _, v := range val {
|
||||
dataLength += int64(len(v))
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
// Then write the header
|
||||
hdr = &tar.Header{
|
||||
Name: "DATA",
|
||||
Size: dataLength,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write DATA: %s", err)
|
||||
}
|
||||
|
||||
// Then write the actual data
|
||||
writtenLength := int64(0)
|
||||
if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
|
||||
for _, v := range val {
|
||||
wl, err := tw.Write(v)
|
||||
writtenLength += int64(wl)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if writtenLength != dataLength {
|
||||
return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength)
|
||||
}
|
||||
|
||||
// Finally, close and flush the various levels
|
||||
if err := tw.Flush(); err != nil {
|
||||
return fmt.Errorf("Could not flush tar: %s", err)
|
||||
}
|
||||
|
||||
if err := tw.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close tar: %s", err)
|
||||
}
|
||||
|
||||
if err := gzWriter.Flush(); err != nil {
|
||||
return fmt.Errorf("Could not flush gz: %s", err)
|
||||
}
|
||||
|
||||
if err := gzWriter.Close(); err != nil {
|
||||
return fmt.Errorf("Could not close gz: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
153
trees/id3.go
153
trees/id3.go
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
)
|
||||
|
||||
@ -106,155 +107,27 @@ func getClassAttr(from base.FixedDataGrid) base.Attribute {
|
||||
return allClassAttrs[0]
|
||||
}
|
||||
|
||||
// MarshalJSON returns a JSON representation of this Attribute
|
||||
// for serialisation.
|
||||
func (d *DecisionTreeNode) MarshalJSON() ([]byte, error) {
|
||||
ret := map[string]interface{}{
|
||||
"type": d.Type,
|
||||
"class_dist": d.ClassDist,
|
||||
"class": d.Class,
|
||||
}
|
||||
|
||||
if d.SplitRule != nil && d.SplitRule.SplitAttr != nil {
|
||||
rawDRule, err := d.SplitRule.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dRule map[string]interface{}
|
||||
err = json.Unmarshal(rawDRule, &dRule)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret["split_rule"] = dRule
|
||||
}
|
||||
|
||||
rawClassAttr, err := d.ClassAttr.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var classAttr map[string]interface{}
|
||||
err = json.Unmarshal(rawClassAttr, &classAttr)
|
||||
ret["class_attr"] = classAttr
|
||||
|
||||
if len(d.Children) > 0 {
|
||||
|
||||
children := make(map[string]interface{})
|
||||
for k := range d.Children {
|
||||
cur, err := d.Children[k].MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var child map[string]interface{}
|
||||
err = json.Unmarshal(cur, &child)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
children[k] = child
|
||||
}
|
||||
ret["children"] = children
|
||||
}
|
||||
return json.Marshal(ret)
|
||||
}
|
||||
|
||||
// UnmarshalJSON reads a JSON representation of this Attribute.
|
||||
func (d *DecisionTreeNode) UnmarshalJSON(data []byte) error {
|
||||
jsonMap := make(map[string]interface{})
|
||||
err := json.Unmarshal(data, &jsonMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rawType := int(jsonMap["type"].(float64))
|
||||
if rawType == 1 {
|
||||
d.Type = LeafNode
|
||||
} else if rawType == 2 {
|
||||
d.Type = RuleNode
|
||||
} else {
|
||||
return fmt.Errorf("Unknown nodeType: %d", rawType)
|
||||
}
|
||||
//d.Type = NodeType(int(jsonMap["type"].(float64)))
|
||||
// Convert the class distribution back
|
||||
classDist := jsonMap["class_dist"].(map[string]interface{})
|
||||
d.ClassDist = make(map[string]int)
|
||||
for k := range classDist {
|
||||
d.ClassDist[k] = int(classDist[k].(float64))
|
||||
}
|
||||
|
||||
d.Class = jsonMap["class"].(string)
|
||||
|
||||
//
|
||||
// Decode the class attribute
|
||||
//
|
||||
// Temporarily re-marshal this field back to bytes
|
||||
rawClassAttr := jsonMap["class_attr"]
|
||||
rawClassAttrBytes, err := json.Marshal(rawClassAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
classAttr, err := base.DeserializeAttribute(rawClassAttrBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ClassAttr = classAttr
|
||||
d.SplitRule = nil
|
||||
|
||||
if splitRule, ok := jsonMap["split_rule"]; ok {
|
||||
d.SplitRule = &DecisionTreeRule{}
|
||||
splitRuleBytes, err := json.Marshal(splitRule)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = d.SplitRule.UnmarshalJSON(splitRuleBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Children = make(map[string]*DecisionTreeNode)
|
||||
childMap := jsonMap["children"].(map[string]interface{})
|
||||
for i := range childMap {
|
||||
cur := &DecisionTreeNode{}
|
||||
childBytes, err := json.Marshal(childMap[i])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = cur.UnmarshalJSON(childBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.Children[i] = cur
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save sends the classification tree to an output file
|
||||
func (d *DecisionTreeNode) Save(filePath string) error {
|
||||
metadata := base.ClassifierMetadataV1{
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "DecisionTreeNode",
|
||||
ClassifierVersion: "1",
|
||||
ClassifierMetadata: nil,
|
||||
metadata := base.ClassifierMetadataV1 {
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "test",
|
||||
ClassifierVersion: "1",
|
||||
ClassifierMetadata: exampleClassifierMetadata,
|
||||
}
|
||||
serializer, err := base.CreateSerializedClassifierStub(filePath, metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.SaveWithPrefix(serializer, "")
|
||||
serializer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DecisionTreeNode) SaveWithPrefix(serializer *base.ClassifierSerializer, prefix string) error {
|
||||
b, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = serializer.WriteBytesForKey(fmt.Sprintf("%s%s", prefix, "tree"), b)
|
||||
err = serializer.WriteBytesForKey("tree", b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serializer.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -266,13 +139,11 @@ func (d *DecisionTreeNode) Load(filePath string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = d.LoadWithPrefix(reader, "")
|
||||
reader.Close()
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
reader.Close()
|
||||
}()
|
||||
|
||||
func (d *DecisionTreeNode) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
b, err := reader.GetBytesForKey(fmt.Sprintf("%s%s", prefix, "tree"))
|
||||
b, err := reader.GetBytesForKey("tree")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
@ -10,16 +9,12 @@ import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
func TestCanSaveLoadPredictions(t *testing.T) {
|
||||
func testCanSaveLoadPredictions(trainData, testData base.FixedDataGrid) {
|
||||
rand.Seed(44414515)
|
||||
Convey("Using InferID3Tree to create the tree and do the fitting", t, func() {
|
||||
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
|
||||
|
||||
Convey("Using InferID3Tree to create the tree and do the fitting", func() {
|
||||
Convey("Using a RandomTreeRule", func() {
|
||||
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
|
||||
randomTreeRuleGenerator.Attributes = 2
|
||||
@ -30,20 +25,18 @@ func TestCanSaveLoadPredictions(t *testing.T) {
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("Saving the tree...", func() {
|
||||
f, err := ioutil.TempFile("", "tree")
|
||||
f, err := ioutil.TempFile("","tree")
|
||||
So(err, ShouldBeNil)
|
||||
err = root.Save(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("Loading the tree...", func() {
|
||||
Convey("Loading the tree...", func(){
|
||||
d := &DecisionTreeNode{}
|
||||
err := d.Load(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
So(d.String(), ShouldEqual, root.String())
|
||||
Convey("Generating predictions from the loaded tree...", func() {
|
||||
predictions2, err := d.Predict(testData)
|
||||
So(err, ShouldBeNil)
|
||||
So(fmt.Sprintf("%v", predictions2), ShouldEqual, fmt.Sprintf("%v", predictions))
|
||||
So(predictions, ShouldEqual, predictions2)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user