diff --git a/base/serialize.go b/base/serialize.go index 8dc7aff..df1c78d 100644 --- a/base/serialize.go +++ b/base/serialize.go @@ -3,6 +3,7 @@ package base import ( "archive/tar" "compress/gzip" + "encoding/csv" "encoding/json" "fmt" "io" @@ -390,8 +391,8 @@ func (c *ClassifierSerializer) WriteMetadataAtPrefix(prefix string, metadata Cla // and writes the METADATA header. func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadataV1) (*ClassifierSerializer, error) { - // Write to a temporary path so we don't corrupt the output file - f, err := ioutil.TempFile(os.TempDir(), "clsTmp") + // Open the filePath + f, err := os.OpenFile(filePath, os.O_RDWR|os.O_TRUNC, 0600) if err != nil { return nil, err } @@ -404,8 +405,6 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata gzipWriter: gzWriter, fileWriter: f, tarWriter: tw, - f: f, - filePath: filePath, } // @@ -434,3 +433,115 @@ func CreateSerializedClassifierStub(filePath string, metadata ClassifierMetadata return &ret, nil } + +func SerializeInstances(inst FixedDataGrid, f io.Writer) error { + var hdr *tar.Header + + gzWriter := gzip.NewWriter(f) + tw := tar.NewWriter(gzWriter) + + // Write the MANIFEST entry + hdr = &tar.Header{ + Name: "MANIFEST", + Size: int64(len(SerializationFormatVersion)), + } + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("Could not write MANIFEST header: %s", err) + } + + if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil { + return fmt.Errorf("Could not write MANIFEST contents: %s", err) + } + + // Now write the dimensions of the dataset + attrCount, rowCount := inst.Size() + hdr = &tar.Header{ + Name: "DIMS", + Size: 16, + } + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("Could not write DIMS header: %s", err) + } + + if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil { + return fmt.Errorf("Could not write DIMS (attrCount): %s", err) + } + if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil { + return fmt.Errorf("Could not write DIMS (rowCount): %s", err) + } + + // Write the ATTRIBUTES files + classAttrs := inst.AllClassAttributes() + normalAttrs := NonClassAttributes(inst) + if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil { + return fmt.Errorf("Could not write CATTRS: %s", err) + } + if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil { + return fmt.Errorf("Could not write ATTRS: %s", err) + } + + // Data must be written out in the same order as the Attributes + allAttrs := make([]Attribute, attrCount) + normCount := copy(allAttrs, normalAttrs) + for i, v := range classAttrs { + allAttrs[normCount+i] = v + } + + allSpecs := ResolveAttributes(inst, allAttrs) + + // First, estimate the amount of data we'll need... + dataLength := int64(0) + inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) { + for _, v := range val { + dataLength += int64(len(v)) + } + return true, nil + }) + + // Then write the header + hdr = &tar.Header{ + Name: "DATA", + Size: dataLength, + } + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("Could not write DATA: %s", err) + } + + // Then write the actual data + writtenLength := int64(0) + if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) { + for _, v := range val { + wl, err := tw.Write(v) + writtenLength += int64(wl) + if err != nil { + return false, err + } + } + return true, nil + }); err != nil { + return err + } + + if writtenLength != dataLength { + return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength) + } + + // Finally, close and flush the various levels + if err := tw.Flush(); err != nil { + return fmt.Errorf("Could not flush tar: %s", err) + } + + if err := tw.Close(); err != nil { + return fmt.Errorf("Could not close tar: %s", err) + } + + if err := gzWriter.Flush(); err != nil { + return fmt.Errorf("Could not flush gz: %s", err) + } + + if err := gzWriter.Close(); err != nil { + return fmt.Errorf("Could not close gz: %s", err) + } + + return nil +} diff --git a/trees/id3.go b/trees/id3.go index 97f2b26..b88bea8 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" + "encoding/json" "sort" ) @@ -106,155 +107,27 @@ func getClassAttr(from base.FixedDataGrid) base.Attribute { return allClassAttrs[0] } -// MarshalJSON returns a JSON representation of this Attribute -// for serialisation. -func (d *DecisionTreeNode) MarshalJSON() ([]byte, error) { - ret := map[string]interface{}{ - "type": d.Type, - "class_dist": d.ClassDist, - "class": d.Class, - } - - if d.SplitRule != nil && d.SplitRule.SplitAttr != nil { - rawDRule, err := d.SplitRule.MarshalJSON() - if err != nil { - return nil, err - } - var dRule map[string]interface{} - err = json.Unmarshal(rawDRule, &dRule) - if err != nil { - panic(err) - } - ret["split_rule"] = dRule - } - - rawClassAttr, err := d.ClassAttr.MarshalJSON() - if err != nil { - return nil, err - } - var classAttr map[string]interface{} - err = json.Unmarshal(rawClassAttr, &classAttr) - ret["class_attr"] = classAttr - - if len(d.Children) > 0 { - - children := make(map[string]interface{}) - for k := range d.Children { - cur, err := d.Children[k].MarshalJSON() - if err != nil { - return nil, err - } - var child map[string]interface{} - err = json.Unmarshal(cur, &child) - if err != nil { - panic(err) - } - children[k] = child - } - ret["children"] = children - } - return json.Marshal(ret) -} - -// UnmarshalJSON reads a JSON representation of this Attribute. -func (d *DecisionTreeNode) UnmarshalJSON(data []byte) error { - jsonMap := make(map[string]interface{}) - err := json.Unmarshal(data, &jsonMap) - if err != nil { - return err - } - rawType := int(jsonMap["type"].(float64)) - if rawType == 1 { - d.Type = LeafNode - } else if rawType == 2 { - d.Type = RuleNode - } else { - return fmt.Errorf("Unknown nodeType: %d", rawType) - } - //d.Type = NodeType(int(jsonMap["type"].(float64))) - // Convert the class distribution back - classDist := jsonMap["class_dist"].(map[string]interface{}) - d.ClassDist = make(map[string]int) - for k := range classDist { - d.ClassDist[k] = int(classDist[k].(float64)) - } - - d.Class = jsonMap["class"].(string) - - // - // Decode the class attribute - // - // Temporarily re-marshal this field back to bytes - rawClassAttr := jsonMap["class_attr"] - rawClassAttrBytes, err := json.Marshal(rawClassAttr) - if err != nil { - return err - } - - classAttr, err := base.DeserializeAttribute(rawClassAttrBytes) - if err != nil { - return err - } - d.ClassAttr = classAttr - d.SplitRule = nil - - if splitRule, ok := jsonMap["split_rule"]; ok { - d.SplitRule = &DecisionTreeRule{} - splitRuleBytes, err := json.Marshal(splitRule) - if err != nil { - panic(err) - } - err = d.SplitRule.UnmarshalJSON(splitRuleBytes) - if err != nil { - return err - } - - d.Children = make(map[string]*DecisionTreeNode) - childMap := jsonMap["children"].(map[string]interface{}) - for i := range childMap { - cur := &DecisionTreeNode{} - childBytes, err := json.Marshal(childMap[i]) - if err != nil { - panic(err) - } - err = cur.UnmarshalJSON(childBytes) - if err != nil { - return err - } - d.Children[i] = cur - } - - } - - return nil -} - // Save sends the classification tree to an output file func (d *DecisionTreeNode) Save(filePath string) error { - metadata := base.ClassifierMetadataV1{ - FormatVersion: 1, - ClassifierName: "DecisionTreeNode", - ClassifierVersion: "1", - ClassifierMetadata: nil, + metadata := base.ClassifierMetadataV1 { + FormatVersion: 1, + ClassifierName: "test", + ClassifierVersion: "1", + ClassifierMetadata: exampleClassifierMetadata, } serializer, err := base.CreateSerializedClassifierStub(filePath, metadata) if err != nil { return err } - err = d.SaveWithPrefix(serializer, "") - serializer.Close() - return err -} - -func (d *DecisionTreeNode) SaveWithPrefix(serializer *base.ClassifierSerializer, prefix string) error { b, err := json.Marshal(d) if err != nil { return err } - err = serializer.WriteBytesForKey(fmt.Sprintf("%s%s", prefix, "tree"), b) + err = serializer.WriteBytesForKey("tree", b) if err != nil { return err } + serializer.Close() return nil } @@ -266,13 +139,11 @@ func (d *DecisionTreeNode) Load(filePath string) error { return err } - err = d.LoadWithPrefix(reader, "") - reader.Close() - return err -} + defer func() { + reader.Close() + }() -func (d *DecisionTreeNode) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { - b, err := reader.GetBytesForKey(fmt.Sprintf("%s%s", prefix, "tree")) + b, err := reader.GetBytesForKey("tree") if err != nil { return err } diff --git a/trees/tree_test.go b/trees/tree_test.go index 5820590..64adad2 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -1,7 +1,6 @@ package trees import ( - "fmt" "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" @@ -10,16 +9,12 @@ import ( "math/rand" "os" "testing" + "io/ioutil" ) -func TestCanSaveLoadPredictions(t *testing.T) { +func testCanSaveLoadPredictions(trainData, testData base.FixedDataGrid) { rand.Seed(44414515) - Convey("Using InferID3Tree to create the tree and do the fitting", t, func() { - instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) - So(err, ShouldBeNil) - - trainData, testData := base.InstancesTrainTestSplit(instances, 0.6) - + Convey("Using InferID3Tree to create the tree and do the fitting", func() { Convey("Using a RandomTreeRule", func() { randomTreeRuleGenerator := new(RandomTreeRuleGenerator) randomTreeRuleGenerator.Attributes = 2 @@ -30,20 +25,18 @@ func TestCanSaveLoadPredictions(t *testing.T) { So(err, ShouldBeNil) Convey("Saving the tree...", func() { - f, err := ioutil.TempFile("", "tree") + f, err := ioutil.TempFile("","tree") So(err, ShouldBeNil) err = root.Save(f.Name()) So(err, ShouldBeNil) - Convey("Loading the tree...", func() { + Convey("Loading the tree...", func(){ d := &DecisionTreeNode{} err := d.Load(f.Name()) So(err, ShouldBeNil) - So(d.String(), ShouldEqual, root.String()) Convey("Generating predictions from the loaded tree...", func() { predictions2, err := d.Predict(testData) - So(err, ShouldBeNil) - So(fmt.Sprintf("%v", predictions2), ShouldEqual, fmt.Sprintf("%v", predictions)) + So(predictions, ShouldEqual, predictions2) }) }) })