diff --git a/base/arff.go b/base/arff.go new file mode 100644 index 0000000..1098c6b --- /dev/null +++ b/base/arff.go @@ -0,0 +1,292 @@ +package base + +import ( + "bufio" + "bytes" + "encoding/csv" + "fmt" + "io" + "os" + "runtime" + "strings" +) + +// SerializeInstancesToDenseARFF writes the given FixedDataGrid to a +// densely-formatted ARFF file. +func SerializeInstancesToDenseARFF(inst FixedDataGrid, path, relation string) error { + + // Get all of the Attributes in a reasonable order + attrs := NonClassAttributes(inst) + cAttrs := inst.AllClassAttributes() + for _, c := range cAttrs { + attrs = append(attrs, c) + } + + return SerializeInstancesToDenseARFFWithAttributes(inst, attrs, path, relation) + +} + +// SerializeInstancesToDenseARFFWithAttributes writes the given FixedDataGrid to a +// densely-formatted ARFF file with the header Attributes in the order given. +func SerializeInstancesToDenseARFFWithAttributes(inst FixedDataGrid, rawAttrs []Attribute, path, relation string) error { + + // Open output file + f, err := os.OpenFile(path, os.O_RDWR, 0600) + if err != nil { + return err + } + defer f.Close() + + // Write @relation header + f.WriteString(fmt.Sprintf("@relation %s\n\n", relation)) + + // Get all Attribute specifications + attrs := ResolveAttributes(inst, rawAttrs) + + // Write Attribute information + for _, s := range attrs { + attr := s.attr + t := "real" + if a, ok := attr.(*CategoricalAttribute); ok { + vals := a.GetValues() + t = fmt.Sprintf("{%s}", strings.Join(vals, ", ")) + } + f.WriteString(fmt.Sprintf("@attribute %s %s\n", attr.GetName(), t)) + } + f.WriteString("\n@data\n") + + buf := make([]string, len(attrs)) + inst.MapOverRows(attrs, func(val [][]byte, row int) (bool, error) { + for i, v := range val { + buf[i] = attrs[i].attr.GetStringFromSysVal(v) + } + f.WriteString(strings.Join(buf, ",")) + f.WriteString("\n") + return true, nil + }) + + return nil +} + +// ParseARFFGetRows returns the number of data rows in an ARFF file. +func ParseARFFGetRows(filepath string) (int, error) { + + f, err := os.Open(filepath) + if err != nil { + return 0, err + } + defer f.Close() + + counting := false + count := 0 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if len(line) == 0 { + continue + } + if counting { + if line[0] == '@' { + continue + } + if line[0] == '%' { + continue + } + count++ + continue + } + if line[0] == '@' { + line = strings.ToLower(line) + if line == "@data" { + counting = true + } + } + } + return count, nil +} + +// ParseARFFGetAttributes returns the set of Attributes represented in this ARFF +func ParseARFFGetAttributes(filepath string) []Attribute { + var ret []Attribute + + f, err := os.Open(filepath) + if err != nil { + panic(err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + var attr Attribute + line := scanner.Text() + if len(line) == 0 { + continue + } + if line[0] != '@' { + continue + } + fields := strings.Fields(line) + if len(fields) < 3 { + continue + } + fields[0] = strings.ToLower(fields[0]) + attrType := strings.ToLower(fields[2]) + if fields[0] != "@attribute" { + continue + } + switch attrType { + case "real": + attr = new(FloatAttribute) + break + default: + if fields[2][0] == '{' { + if strings.HasSuffix(fields[len(fields)-1], "}") { + var cats []string + if len(fields) > 3 { + cats = fields[2:len(fields)] + } else { + cats = strings.Split(fields[2], ",") + } + if len(cats) == 0 { + panic(fmt.Errorf("Empty categorical field on line '%s'", line)) + } + cats[0] = cats[0][1:] // Remove leading '{' + cats[len(cats)-1] = cats[len(cats)-1][:len(cats[len(cats)-1])-1] // Remove trailing '}' + for i, v := range cats { // Miaow + cats[i] = strings.TrimSpace(v) + if strings.HasSuffix(cats[i], ",") { + // Strip end comma + cats[i] = cats[i][0 : len(cats[i])-1] + } + } + attr = NewCategoricalAttribute() + for _, v := range cats { + attr.GetSysValFromString(v) + } + } else { + panic(fmt.Errorf("Missing categorical bracket on line '%s'", line)) + } + } else { + panic(fmt.Errorf("Unsupported Attribute type %s on line '%s'", fields[2], line)) + } + } + + if attr == nil { + panic(fmt.Errorf(line)) + } + attr.SetName(fields[1]) + ret = append(ret, attr) + } + + maxPrecision, err := ParseCSVEstimateFilePrecision(filepath) + if err != nil { + panic(err) + } + for _, a := range ret { + if f, ok := a.(*FloatAttribute); ok { + f.Precision = maxPrecision + } + } + return ret +} + +// ParseDenseARFFBuildInstancesFromReader updates an [[#UpdatableDataGrid]] from a io.Reader +func ParseDenseARFFBuildInstancesFromReader(r io.Reader, attrs []Attribute, u UpdatableDataGrid) (err error) { + var rowCounter int + + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(err) + } + err = fmt.Errorf("Error at line %d (error %s)", rowCounter, r.(error)) + } + }() + + scanner := bufio.NewScanner(r) + reading := false + specs := ResolveAttributes(u, attrs) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "%") { + continue + } + if reading { + buf := bytes.NewBuffer([]byte(line)) + reader := csv.NewReader(buf) + for { + r, err := reader.Read() + if err == io.EOF { + break + } else if err != nil { + return err + } + for i, v := range r { + v = strings.TrimSpace(v) + if a, ok := specs[i].attr.(*CategoricalAttribute); ok { + if val := a.GetSysVal(v); val == nil { + panic(fmt.Errorf("Unexpected class on line '%s'", line)) + } + } + u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(v)) + } + rowCounter++ + } + } else { + line = strings.ToLower(line) + line = strings.TrimSpace(line) + if line == "@data" { + reading = true + } + } + } + + return nil +} + +// ParseDenseARFFToInstances parses the dense ARFF File into a FixedDataGrid +func ParseDenseARFFToInstances(filepath string) (ret *DenseInstances, err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + // Find the number of rows in the file + rows, err := ParseARFFGetRows(filepath) + if err != nil { + return nil, err + } + + // Get the Attributes we want + attrs := ParseARFFGetAttributes(filepath) + + // Allocate return value + ret = NewDenseInstances() + + // Add all the Attributes + for _, a := range attrs { + ret.AddAttribute(a) + } + + // Set the last Attribute as the class + ret.AddClassAttribute(attrs[len(attrs)-1]) + ret.Extend(rows) + + f, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer f.Close() + + // Read the data + // Seek past the header + err = ParseDenseARFFBuildInstancesFromReader(f, attrs, ret) + if err != nil { + ret = nil + } + return ret, err +} diff --git a/base/arff_test.go b/base/arff_test.go new file mode 100644 index 0000000..40b8251 --- /dev/null +++ b/base/arff_test.go @@ -0,0 +1,109 @@ +package base + +import ( + . "github.com/smartystreets/goconvey/convey" + "io/ioutil" + "testing" +) + +func TestParseARFFGetRows(t *testing.T) { + Convey("Getting the number of rows for a ARFF file", t, func() { + Convey("With a valid file path", func() { + numNonHeaderRows := 150 + lineCount, err := ParseARFFGetRows("../examples/datasets/iris.arff") + So(err, ShouldBeNil) + So(lineCount, ShouldEqual, numNonHeaderRows) + }) + }) +} + +func TestParseARFFGetAttributes(t *testing.T) { + Convey("Getting the attributes in the headers of a CSV file", t, func() { + attributes := ParseARFFGetAttributes("../examples/datasets/iris.arff") + sepalLengthAttribute := attributes[0] + sepalWidthAttribute := attributes[1] + petalLengthAttribute := attributes[2] + petalWidthAttribute := attributes[3] + speciesAttribute := attributes[4] + + Convey("It gets the correct types for the headers based on the column values", func() { + _, ok1 := sepalLengthAttribute.(*FloatAttribute) + _, ok2 := sepalWidthAttribute.(*FloatAttribute) + _, ok3 := petalLengthAttribute.(*FloatAttribute) + _, ok4 := petalWidthAttribute.(*FloatAttribute) + sA, ok5 := speciesAttribute.(*CategoricalAttribute) + So(ok1, ShouldBeTrue) + So(ok2, ShouldBeTrue) + So(ok3, ShouldBeTrue) + So(ok4, ShouldBeTrue) + So(ok5, ShouldBeTrue) + So(sA.GetValues(), ShouldResemble, []string{"Iris-setosa", "Iris-versicolor", "Iris-virginica"}) + }) + }) +} + +func TestParseARFF1(t *testing.T) { + Convey("Should just be able to load in an ARFF...", t, func() { + inst, err := ParseDenseARFFToInstances("../examples/datasets/iris.arff") + So(err, ShouldBeNil) + So(inst, ShouldNotBeNil) + So(inst.RowString(0), ShouldEqual, "5.1 3.5 1.4 0.2 Iris-setosa") + So(inst.RowString(50), ShouldEqual, "7.0 3.2 4.7 1.4 Iris-versicolor") + So(inst.RowString(100), ShouldEqual, "6.3 3.3 6.0 2.5 Iris-virginica") + }) +} + +func TestParseARFF2(t *testing.T) { + Convey("Loading the weather dataset...", t, func() { + inst, err := ParseDenseARFFToInstances("../examples/datasets/weather.arff") + So(err, ShouldBeNil) + + Convey("Attributes should be right...", func() { + So(GetAttributeByName(inst, "outlook"), ShouldNotBeNil) + So(GetAttributeByName(inst, "temperature"), ShouldNotBeNil) + So(GetAttributeByName(inst, "humidity"), ShouldNotBeNil) + So(GetAttributeByName(inst, "windy"), ShouldNotBeNil) + So(GetAttributeByName(inst, "play"), ShouldNotBeNil) + Convey("outlook attribute values should match reference...", func() { + outlookAttr := GetAttributeByName(inst, "outlook").(*CategoricalAttribute) + So(outlookAttr.GetValues(), ShouldResemble, []string{"sunny", "overcast", "rainy"}) + }) + Convey("windy values should match reference...", func() { + windyAttr := GetAttributeByName(inst, "windy").(*CategoricalAttribute) + So(windyAttr.GetValues(), ShouldResemble, []string{"TRUE", "FALSE"}) + }) + Convey("play values should match reference...", func() { + playAttr := GetAttributeByName(inst, "play").(*CategoricalAttribute) + So(playAttr.GetValues(), ShouldResemble, []string{"yes", "no"}) + }) + + }) + + }) +} + +func TestSerializeToARFF(t *testing.T) { + Convey("Loading the weather dataset...", t, func() { + inst, err := ParseDenseARFFToInstances("../examples/datasets/weather.arff") + So(err, ShouldBeNil) + Convey("Saving back should suceed...", func() { + attrs := ParseARFFGetAttributes("../examples/datasets/weather.arff") + f, err := ioutil.TempFile("", "inst") + So(err, ShouldBeNil) + err = SerializeInstancesToDenseARFFWithAttributes(inst, attrs, f.Name(), "weather") + So(err, ShouldBeNil) + Convey("Reading the file back should be lossless...", func() { + inst2, err := ParseDenseARFFToInstances(f.Name()) + So(err, ShouldBeNil) + So(InstancesAreEqual(inst, inst2), ShouldBeTrue) + }) + Convey("The file should be exactly the same as the original...", func() { + ref, err := ioutil.ReadFile("../examples/datasets/weather.arff") + So(err, ShouldBeNil) + gen, err := ioutil.ReadFile(f.Name()) + So(err, ShouldBeNil) + So(string(gen), ShouldEqual, string(ref)) + }) + }) + }) +} diff --git a/base/attributes.go b/base/attributes.go index 48df381..6729f08 100644 --- a/base/attributes.go +++ b/base/attributes.go @@ -1,5 +1,9 @@ package base +import ( + "encoding/json" +) + const ( // CategoricalType is for Attributes which represent values distinctly. CategoricalType = iota @@ -10,6 +14,8 @@ const ( // Attributes disambiguate columns of the feature matrix and declare their types. type Attribute interface { + json.Unmarshaler + json.Marshaler // Returns the general characterstics of this Attribute . // to avoid the overhead of casting GetType() int diff --git a/base/binary.go b/base/binary.go index 1842cc2..152ef01 100644 --- a/base/binary.go +++ b/base/binary.go @@ -1,6 +1,7 @@ package base import ( + "encoding/json" "fmt" "strconv" ) @@ -10,6 +11,20 @@ type BinaryAttribute struct { Name string } +// MarshalJSON returns a JSON version of this BinaryAttribute for serialisation. +func (b *BinaryAttribute) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "type": "binary", + "name": b.Name, + }) +} + +// UnmarshalJSON unpacks a BinaryAttribute from serialisation. +// Usually, there's nothing to deserialize. +func (b *BinaryAttribute) UnmarshalJSON(data []byte) error { + return nil +} + // NewBinaryAttribute creates a BinaryAttribute with the given name func NewBinaryAttribute(name string) *BinaryAttribute { return &BinaryAttribute{ diff --git a/base/categorical.go b/base/categorical.go index be6d2b7..9012014 100644 --- a/base/categorical.go +++ b/base/categorical.go @@ -1,6 +1,7 @@ package base import ( + "encoding/json" "fmt" ) @@ -9,7 +10,31 @@ import ( // - useful for representing classes. type CategoricalAttribute struct { Name string - values []string + values []string `json:"values"` +} + +// MarshalJSON returns a JSON version of this Attribute. +func (Attr *CategoricalAttribute) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "type": "categorical", + "name": Attr.Name, + "attr": map[string]interface{}{ + "values": Attr.values, + }, + }) +} + +// UnmarshalJSON returns a JSON version of this Attribute. +func (Attr *CategoricalAttribute) UnmarshalJSON(data []byte) error { + var d map[string]interface{} + err := json.Unmarshal(data, &d) + if err != nil { + return err + } + for _, v := range d["values"].([]interface{}) { + Attr.values = append(Attr.values, v.(string)) + } + return nil } // NewCategoricalAttribute creates a blank CategoricalAttribute. diff --git a/base/csv.go b/base/csv.go index 52e7004..58a2efe 100644 --- a/base/csv.go +++ b/base/csv.go @@ -1,11 +1,13 @@ package base import ( + "bufio" "encoding/csv" "fmt" "io" "os" "regexp" + "runtime" "strings" ) @@ -31,6 +33,52 @@ func ParseCSVGetRows(filepath string) (int, error) { return counter, nil } +// ParseCSVEstimateFilePrecision determines what the maximum number of +// digits occuring anywhere after the decimal point within the file. +func ParseCSVEstimateFilePrecision(filepath string) (int, error) { + // Creat a basic regexp + rexp := regexp.MustCompile("[0-9]+(.[0-9]+)?") + + // Open the source file + f, err := os.Open(filepath) + if err != nil { + return 0, err + } + defer f.Close() + + // Scan through the file line-by-line + maxL := 0 + scanner := bufio.NewScanner(f) + lineCount := 0 + for scanner.Scan() { + if lineCount > 5 { + break + } + line := scanner.Text() + if len(line) == 0 { + continue + } + if line[0] == '@' { + continue + } + if line[0] == '%' { + continue + } + matches := rexp.FindAllString(line, -1) + for _, m := range matches { + p := strings.Split(m, ".") + if len(p) == 2 { + l := len(p[len(p)-1]) + if l > maxL { + maxL = l + } + } + } + lineCount++ + } + return maxL, nil +} + // ParseCSVGetAttributes returns an ordered slice of appropriate-ly typed // and named Attributes. func ParseCSVGetAttributes(filepath string, hasHeaders bool) []Attribute { @@ -112,49 +160,69 @@ func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute { } } - return attrs -} - -// ParseCSVBuildInstances updates an [[#UpdatableDataGrid]] from a filepath in place -func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGrid) { - - // Read the input - file, err := os.Open(filepath) + // Estimate file precision + maxP, err := ParseCSVEstimateFilePrecision(filepath) if err != nil { panic(err) } - defer file.Close() - reader := csv.NewReader(file) + for _, a := range attrs { + if f, ok := a.(*FloatAttribute); ok { + f.Precision = maxP + } + } - rowCounter := 0 + return attrs +} - specs := ResolveAttributes(u, u.AllAttributes()) +// ParseCSVBuildInstancesFromReader updates an [[#UpdatableDataGrid]] from a io.Reader +func ParseCSVBuildInstancesFromReader(r io.Reader, attrs []Attribute, hasHeader bool, u UpdatableDataGrid) (err error) { + var rowCounter int + + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(err) + } + err = fmt.Errorf("Error at line %d (error %s)", rowCounter, r.(error)) + } + }() + + specs := ResolveAttributes(u, attrs) + reader := csv.NewReader(r) for { record, err := reader.Read() if err == io.EOF { break } else if err != nil { - panic(err) + return err } if rowCounter == 0 { - if hasHeaders { - hasHeaders = false + if hasHeader { + hasHeader = false continue } } for i, v := range record { - u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(v)) + u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(strings.TrimSpace(v))) } rowCounter++ } + return nil } // ParseCSVToInstances reads the CSV file given by filepath and returns // the read Instances. func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInstances, err error) { + // Open the file + f, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer f.Close() + // Read the number of rows in the file rowCount, err := ParseCSVGetRows(filepath) if err != nil { @@ -176,45 +244,41 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInst } instances.Extend(rowCount) - // Read the input - file, err := os.Open(filepath) + err = ParseCSVBuildInstancesFromReader(f, attrs, hasHeaders, instances) if err != nil { return nil, err } - defer file.Close() - reader := csv.NewReader(file) - - rowCounter := 0 - - for { - record, err := reader.Read() - if err == io.EOF { - break - } else if err != nil { - return nil, err - } - if rowCounter == 0 { - if hasHeaders { - hasHeaders = false - continue - } - } - for i, v := range record { - v = strings.Trim(v, " ") - instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v)) - } - rowCounter++ - } instances.AddClassAttribute(attrs[len(attrs)-1]) return instances, nil } +// ParseUtilsMatchAttrs tries to match the set of Attributes read from one file with +// those read from another, and writes the matching Attributes back to the original set. +func ParseMatchAttributes(attrs, templateAttrs []Attribute) { + for i, a := range attrs { + for _, b := range templateAttrs { + if a.Equals(b) { + attrs[i] = b + } else if a.GetName() == b.GetName() { + attrs[i] = b + } + } + } +} + // ParseCSVToInstancesTemplated reads the CSV file given by filepath and returns // the read Instances, using another already read DenseInstances as a template. func ParseCSVToTemplatedInstances(filepath string, hasHeaders bool, template *DenseInstances) (instances *DenseInstances, err error) { + // Open the file + f, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer f.Close() + // Read the number of rows in the file rowCount, err := ParseCSVGetRows(filepath) if err != nil { @@ -228,81 +292,16 @@ func ParseCSVToTemplatedInstances(filepath string, hasHeaders bool, template *De // Read the row headers attrs := ParseCSVGetAttributes(filepath, hasHeaders) templateAttrs := template.AllAttributes() - for i, a := range attrs { - for _, b := range templateAttrs { - if a.Equals(b) { - attrs[i] = b - } else if a.GetName() == b.GetName() { - attrs[i] = b - } - } - } + ParseMatchAttributes(attrs, templateAttrs) - specs := make([]AttributeSpec, len(attrs)) // Allocate the Instances to return - instances = NewDenseInstances() - - templateAgs := template.AllAttributeGroups() - for ag := range templateAgs { - agTemplate := templateAgs[ag] - if _, ok := agTemplate.(*BinaryAttributeGroup); ok { - instances.CreateAttributeGroup(ag, 0) - } else { - instances.CreateAttributeGroup(ag, 8) - } - } - - for i, a := range templateAttrs { - s, err := template.GetAttribute(a) - if err != nil { - panic(err) - } - if ag, ok := template.agRevMap[s.pond]; !ok { - panic(ag) - } else { - spec, err := instances.AddAttributeToAttributeGroup(a, ag) - if err != nil { - panic(err) - } - specs[i] = spec - } - } - + instances = CopyDenseInstances(template, templateAttrs) instances.Extend(rowCount) - // Read the input - file, err := os.Open(filepath) + err = ParseCSVBuildInstancesFromReader(f, attrs, hasHeaders, instances) if err != nil { return nil, err } - defer file.Close() - reader := csv.NewReader(file) - - rowCounter := 0 - - for { - record, err := reader.Read() - if err == io.EOF { - break - } else if err != nil { - return nil, err - } - if rowCounter == 0 { - if hasHeaders { - hasHeaders = false - continue - } - } - for i, v := range record { - v = strings.Trim(v, " ") - instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v)) - } - rowCounter++ - } - - for _, a := range template.AllClassAttributes() { - instances.AddClassAttribute(a) - } return instances, nil } @@ -312,6 +311,13 @@ func ParseCSVToTemplatedInstances(filepath string, hasHeaders bool, template *De // specified in the first argument and also any class Attributes specified in the second func ParseCSVToInstancesWithAttributeGroups(filepath string, attrGroups, classAttrGroups map[string]string, attrOverrides map[int]Attribute, hasHeaders bool) (instances *DenseInstances, err error) { + // Open file + f, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer f.Close() + // Read row count rowCount, err := ParseCSVGetRows(filepath) if err != nil { @@ -386,45 +392,11 @@ func ParseCSVToInstancesWithAttributeGroups(filepath string, attrGroups, classAt // Allocate instances.Extend(rowCount) - // Read the input - file, err := os.Open(filepath) + err = ParseCSVBuildInstancesFromReader(f, attrs, hasHeaders, instances) if err != nil { return nil, err } - defer file.Close() - reader := csv.NewReader(file) - rowCounter := 0 - - for { - record, err := reader.Read() - if err == io.EOF { - break - } else if err != nil { - return nil, err - } - if rowCounter == 0 { - // Skip header row - rowCounter++ - continue - } - for i, v := range record { - v = strings.Trim(v, " ") - instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v)) - } - rowCounter++ - } - - // Add class Attributes - for _, a := range instances.AllAttributes() { - name := a.GetName() // classAttrGroups - if _, ok := classAttrGroups[name]; ok { - err = instances.AddClassAttribute(a) - if err != nil { - panic(err) - } - } - } return instances, nil } diff --git a/base/csv_test.go b/base/csv_test.go index 9f74d75..6742d31 100644 --- a/base/csv_test.go +++ b/base/csv_test.go @@ -93,12 +93,17 @@ func TestParseCSVToInstances(t *testing.T) { So(err, ShouldBeNil) Convey("Should parse the rows correctly", func() { - So(instances.RowString(0), ShouldEqual, "5.10 3.50 1.40 0.20 Iris-setosa") - So(instances.RowString(50), ShouldEqual, "7.00 3.20 4.70 1.40 Iris-versicolor") - So(instances.RowString(100), ShouldEqual, "6.30 3.30 6.00 2.50 Iris-virginica") + So(instances.RowString(0), ShouldEqual, "5.1 3.5 1.4 0.2 Iris-setosa") + So(instances.RowString(50), ShouldEqual, "7.0 3.2 4.7 1.4 Iris-versicolor") + So(instances.RowString(100), ShouldEqual, "6.3 3.3 6.0 2.5 Iris-virginica") }) }) + Convey("Given a path to another reasonable CSV file", func() { + _, err := ParseCSVToInstances("../examples/datasets/c45-numeric.csv", true) + So(err, ShouldBeNil) + }) + Convey("Given a path to a non-existent file", func() { _, err := ParseCSVToInstances("../examples/datasets/non-existent.csv", true) diff --git a/base/dense.go b/base/dense.go index e99e308..a4cedd7 100644 --- a/base/dense.go +++ b/base/dense.go @@ -443,35 +443,6 @@ func (inst *DenseInstances) swapRows(i, j int) { } } -// Equal checks whether a given Instance set is exactly the same -// as another: same size and same values (as determined by the Attributes) -// -// IMPORTANT: does not explicitly check if the Attributes are considered equal. -func (inst *DenseInstances) Equal(other DataGrid) bool { - - _, rows := inst.Size() - - for _, a := range inst.AllAttributes() { - as1, err := inst.GetAttribute(a) - if err != nil { - panic(err) // That indicates some kind of error - } - as2, err := inst.GetAttribute(a) - if err != nil { - return false // Obviously has different Attributes - } - for i := 0; i < rows; i++ { - b1 := inst.Get(as1, i) - b2 := inst.Get(as2, i) - if !byteSeqEqual(b1, b2) { - return false - } - } - } - - return true -} - // String returns a human-readable summary of this dataset. func (inst *DenseInstances) String() string { var buffer bytes.Buffer diff --git a/base/float.go b/base/float.go index efb4551..db01d09 100644 --- a/base/float.go +++ b/base/float.go @@ -1,6 +1,7 @@ package base import ( + "encoding/json" "fmt" "strconv" ) @@ -12,6 +13,32 @@ type FloatAttribute struct { Precision int } +// MarshalJSON returns a JSON representation of this Attribute +// for serialisation. +func (f *FloatAttribute) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "type": "float", + "name": f.Name, + "attr": map[string]interface{}{ + "precision": f.Precision, + }, + }) +} + +// UnmarshalJSON reads a JSON representation of this Attribute. +func (f *FloatAttribute) UnmarshalJSON(data []byte) error { + var d map[string]interface{} + err := json.Unmarshal(data, &d) + if err != nil { + return err + } + if precision, ok := d["precision"]; ok { + f.Precision = int(precision.(float64)) + return nil + } + return fmt.Errorf("Precision must be specified") +} + // NewFloatAttribute returns a new FloatAttribute with a default // precision of 2 decimal places func NewFloatAttribute(name string) *FloatAttribute { diff --git a/base/lazy_sort_test.go b/base/lazy_sort_test.go index 2c88aa8..635f5f9 100644 --- a/base/lazy_sort_test.go +++ b/base/lazy_sort_test.go @@ -29,7 +29,7 @@ func TestLazySortDesc(t *testing.T) { }) Convey("Result should match the reference", func() { - So(sortedDescending.Equal(result), ShouldBeTrue) + So(InstancesAreEqual(sortedDescending, result), ShouldBeTrue) }) }) }) @@ -60,11 +60,11 @@ func TestLazySortAsc(t *testing.T) { }) Convey("Result should match the reference", func() { - So(sortedAscending.Equal(result), ShouldBeTrue) + So(InstancesAreEqual(sortedAscending, result), ShouldBeTrue) }) Convey("First element of Result should equal known value", func() { - So(result.RowString(0), ShouldEqual, "4.30 3.00 1.10 0.10 Iris-setosa") + So(result.RowString(0), ShouldEqual, "4.3 3.0 1.1 0.1 Iris-setosa") }) }) }) diff --git a/base/serialize.go b/base/serialize.go new file mode 100644 index 0000000..6c86f72 --- /dev/null +++ b/base/serialize.go @@ -0,0 +1,381 @@ +package base + +import ( + "archive/tar" + "compress/gzip" + "encoding/csv" + "encoding/json" + "fmt" + "io" + "os" + "reflect" + "runtime" +) + +const ( + SerializationFormatVersion = "golearn 0.5" +) + +func SerializeInstancesToFile(inst FixedDataGrid, path string) error { + f, err := os.OpenFile(path, os.O_RDWR, 0600) + if err != nil { + return err + } + err = SerializeInstances(inst, f) + if err != nil { + return err + } + err = f.Sync() + if err != nil { + return fmt.Errorf("Couldn't flush file: %s", err) + } + f.Close() + return nil +} + +func SerializeInstancesToCSV(inst FixedDataGrid, path string) error { + f, err := os.OpenFile(path, os.O_RDWR, 0600) + if err != nil { + return err + } + defer func() { + f.Sync() + f.Close() + }() + + return SerializeInstancesToCSVStream(inst, f) +} + +func SerializeInstancesToCSVStream(inst FixedDataGrid, f io.Writer) error { + // Create the CSV writer + w := csv.NewWriter(f) + + colCount, _ := inst.Size() + + // Write out Attribute headers + // Start with the regular Attributes + normalAttrs := NonClassAttributes(inst) + classAttrs := inst.AllClassAttributes() + allAttrs := make([]Attribute, colCount) + n := copy(allAttrs, normalAttrs) + copy(allAttrs[n:], classAttrs) + headerRow := make([]string, colCount) + for i, v := range allAttrs { + headerRow[i] = v.GetName() + } + w.Write(headerRow) + + specs := ResolveAttributes(inst, allAttrs) + curRow := make([]string, colCount) + inst.MapOverRows(specs, func(row [][]byte, rowNo int) (bool, error) { + for i, v := range row { + attr := allAttrs[i] + curRow[i] = attr.GetStringFromSysVal(v) + } + w.Write(curRow) + return true, nil + }) + + w.Flush() + return nil +} + +func writeAttributesToFilePart(attrs []Attribute, f *tar.Writer, name string) error { + // Get the marshaled Attribute array + body, err := json.Marshal(attrs) + if err != nil { + return err + } + + // Write a header + hdr := &tar.Header{ + Name: name, + Size: int64(len(body)), + } + if err := f.WriteHeader(hdr); err != nil { + return err + } + + // Write the marshaled data + if _, err := f.Write([]byte(body)); err != nil { + return err + } + + return nil +} + +func getTarContent(tr *tar.Reader, name string) []byte { + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } else if err != nil { + panic(err) + } + + if hdr.Name == name { + ret := make([]byte, hdr.Size) + n, err := tr.Read(ret) + if int64(n) != hdr.Size { + panic("Size mismatch") + } + if err != nil { + panic(err) + } + return ret + } + } + panic("File not found!") +} + +func deserializeAttributes(data []byte) []Attribute { + + // Define a JSON shim Attribute + type JSONAttribute struct { + Type string `json:type` + Name string `json:name` + Attr json.RawMessage `json:attr` + } + + var ret []Attribute + var attrs []JSONAttribute + + err := json.Unmarshal(data, &attrs) + if err != nil { + panic(fmt.Errorf("Attribute decode error: %s", err)) + } + + for _, a := range attrs { + var attr Attribute + var err error + switch a.Type { + case "binary": + attr = new(BinaryAttribute) + break + case "float": + attr = new(FloatAttribute) + break + case "categorical": + attr = new(CategoricalAttribute) + break + default: + panic(fmt.Errorf("Unrecognised Attribute format: %s", a.Type)) + } + err = attr.UnmarshalJSON(a.Attr) + if err != nil { + panic(fmt.Errorf("Can't deserialize: %s (error: %s)", a, err)) + } + attr.SetName(a.Name) + ret = append(ret, attr) + } + return ret +} + +func DeserializeInstances(f io.Reader) (ret *DenseInstances, err error) { + + // Recovery function + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + // Open the .gz layer + gzReader, err := gzip.NewReader(f) + if err != nil { + return nil, fmt.Errorf("Can't open: %s", err) + } + // Open the .tar layer + tr := tar.NewReader(gzReader) + // Retrieve the MANIFEST and verify + manifestBytes := getTarContent(tr, "MANIFEST") + if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) { + return nil, fmt.Errorf("Unsupported MANIFEST: %s", string(manifestBytes)) + } + + // Get the size + sizeBytes := getTarContent(tr, "DIMS") + attrCount := int(UnpackBytesToU64(sizeBytes[0:8])) + rowCount := int(UnpackBytesToU64(sizeBytes[8:])) + + // Unmarshal the Attributes + attrBytes := getTarContent(tr, "CATTRS") + cAttrs := deserializeAttributes(attrBytes) + attrBytes = getTarContent(tr, "ATTRS") + normalAttrs := deserializeAttributes(attrBytes) + + // Create the return instances + ret = NewDenseInstances() + + // Normal Attributes first, class Attributes on the end + allAttributes := make([]Attribute, attrCount) + for i, v := range normalAttrs { + ret.AddAttribute(v) + allAttributes[i] = v + } + for i, v := range cAttrs { + ret.AddAttribute(v) + err = ret.AddClassAttribute(v) + if err != nil { + return nil, fmt.Errorf("Could not set Attribute as class Attribute: %s", err) + } + allAttributes[i+len(normalAttrs)] = v + } + // Allocate memory + err = ret.Extend(int(rowCount)) + if err != nil { + return nil, fmt.Errorf("Could not allocate memory") + } + + // Seek through the TAR file until we get to the DATA section + for { + hdr, err := tr.Next() + if err == io.EOF { + return nil, fmt.Errorf("DATA section missing!") + } else if err != nil { + return nil, fmt.Errorf("Error seeking to DATA section: %s", err) + } + if hdr.Name == "DATA" { + break + } + } + + // Resolve AttributeSpecs + specs := ResolveAttributes(ret, allAttributes) + + // Finally, read the values out of the data section + for i := 0; i < rowCount; i++ { + for _, s := range specs { + r := ret.Get(s, i) + n, err := tr.Read(r) + if n != len(r) { + return nil, fmt.Errorf("Expected %d bytes (read %d) on row %d", len(r), n, i) + } + if err != nil { + return nil, fmt.Errorf("Read error: %s", err) + } + ret.Set(s, i, r) + } + } + + if err = gzReader.Close(); err != nil { + return ret, fmt.Errorf("Error closing gzip stream: %s", err) + } + + return ret, nil +} + +func SerializeInstances(inst FixedDataGrid, f io.Writer) error { + var hdr *tar.Header + + gzWriter := gzip.NewWriter(f) + tw := tar.NewWriter(gzWriter) + + // Write the MANIFEST entry + hdr = &tar.Header{ + Name: "MANIFEST", + Size: int64(len(SerializationFormatVersion)), + } + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("Could not write MANIFEST header: %s", err) + } + + if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil { + return fmt.Errorf("Could not write MANIFEST contents: %s", err) + } + + // Now write the dimensions of the dataset + attrCount, rowCount := inst.Size() + hdr = &tar.Header{ + Name: "DIMS", + Size: 16, + } + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("Could not write DIMS header: %s", err) + } + + if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil { + return fmt.Errorf("Could not write DIMS (attrCount): %s", err) + } + if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil { + return fmt.Errorf("Could not write DIMS (rowCount): %s", err) + } + + // Write the ATTRIBUTES files + classAttrs := inst.AllClassAttributes() + normalAttrs := NonClassAttributes(inst) + if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil { + return fmt.Errorf("Could not write CATTRS: %s", err) + } + if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil { + return fmt.Errorf("Could not write ATTRS: %s", err) + } + + // Data must be written out in the same order as the Attributes + allAttrs := make([]Attribute, attrCount) + normCount := copy(allAttrs, normalAttrs) + for i, v := range classAttrs { + allAttrs[normCount+i] = v + } + + allSpecs := ResolveAttributes(inst, allAttrs) + + // First, estimate the amount of data we'll need... + dataLength := int64(0) + inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) { + for _, v := range val { + dataLength += int64(len(v)) + } + return true, nil + }) + + // Then write the header + hdr = &tar.Header{ + Name: "DATA", + Size: dataLength, + } + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("Could not write DATA: %s", err) + } + + // Then write the actual data + writtenLength := int64(0) + if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) { + for _, v := range val { + wl, err := tw.Write(v) + writtenLength += int64(wl) + if err != nil { + return false, err + } + } + return true, nil + }); err != nil { + return err + } + + if writtenLength != dataLength { + return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength) + } + + // Finally, close and flush the various levels + if err := tw.Flush(); err != nil { + return fmt.Errorf("Could not flush tar: %s", err) + } + + if err := tw.Close(); err != nil { + return fmt.Errorf("Could not close tar: %s", err) + } + + if err := gzWriter.Flush(); err != nil { + return fmt.Errorf("Could not flush gz: %s", err) + } + + if err := gzWriter.Close(); err != nil { + return fmt.Errorf("Could not close gz: %s", err) + } + + return nil +} diff --git a/base/serialize_test.go b/base/serialize_test.go new file mode 100644 index 0000000..bb4c074 --- /dev/null +++ b/base/serialize_test.go @@ -0,0 +1,107 @@ +package base + +import ( + "archive/tar" + "compress/gzip" + "fmt" + . "github.com/smartystreets/goconvey/convey" + "io" + "io/ioutil" + "testing" +) + +func TestSerializeToCSV(t *testing.T) { + Convey("Reading some instances...", t, func() { + inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + So(err, ShouldBeNil) + + Convey("Saving the instances to CSV...", func() { + f, err := ioutil.TempFile("", "instTmp") + So(err, ShouldBeNil) + err = SerializeInstancesToCSV(inst, f.Name()) + So(err, ShouldBeNil) + Convey("What's written out should match what's read in", func() { + dinst, err := ParseCSVToInstances(f.Name(), true) + So(err, ShouldBeNil) + So(inst.String(), ShouldEqual, dinst.String()) + }) + }) + }) +} + +func TestSerializeToFile(t *testing.T) { + Convey("Reading some instances...", t, func() { + inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + So(err, ShouldBeNil) + + Convey("Dumping to file...", func() { + f, err := ioutil.TempFile("", "instTmp") + So(err, ShouldBeNil) + err = SerializeInstances(inst, f) + So(err, ShouldBeNil) + f.Seek(0, 0) + Convey("Contents of the archive should be right...", func() { + gzr, err := gzip.NewReader(f) + So(err, ShouldBeNil) + tr := tar.NewReader(gzr) + classAttrsPresent := false + manifestPresent := false + regularAttrsPresent := false + dataPresent := false + dimsPresent := false + readBytes := make([]byte, len([]byte(SerializationFormatVersion))) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + So(err, ShouldBeNil) + switch hdr.Name { + case "MANIFEST": + tr.Read(readBytes) + manifestPresent = true + break + case "CATTRS": + classAttrsPresent = true + break + case "ATTRS": + regularAttrsPresent = true + break + case "DATA": + dataPresent = true + break + case "DIMS": + dimsPresent = true + break + default: + fmt.Printf("Unknown file: %s\n", hdr.Name) + } + } + Convey("MANIFEST should be present", func() { + So(manifestPresent, ShouldBeTrue) + Convey("MANIFEST should be right...", func() { + So(readBytes, ShouldResemble, []byte(SerializationFormatVersion)) + }) + }) + Convey("DATA should be present", func() { + So(dataPresent, ShouldBeTrue) + }) + Convey("ATTRS should be present", func() { + So(regularAttrsPresent, ShouldBeTrue) + }) + Convey("CATTRS should be present", func() { + So(classAttrsPresent, ShouldBeTrue) + }) + Convey("DIMS should be present", func() { + So(dimsPresent, ShouldBeTrue) + }) + }) + Convey("Should be able to reconstruct...", func() { + f.Seek(0, 0) + dinst, err := DeserializeInstances(f) + So(err, ShouldBeNil) + So(InstancesAreEqual(inst, dinst), ShouldBeTrue) + }) + }) + }) +} diff --git a/base/sort_test.go b/base/sort_test.go index 55b92b6..40df220 100644 --- a/base/sort_test.go +++ b/base/sort_test.go @@ -59,7 +59,7 @@ func TestSortDesc(t *testing.T) { }) Convey("Result should match the reference", func() { - So(sortedDescending.Equal(result), ShouldBeTrue) + So(InstancesAreEqual(sortedDescending, result), ShouldBeTrue) }) }) }) @@ -90,11 +90,11 @@ func TestSortAsc(t *testing.T) { }) Convey("Result should match the reference", func() { - So(sortedAscending.Equal(result), ShouldBeTrue) + So(InstancesAreEqual(sortedAscending, result), ShouldBeTrue) }) Convey("First element of Result should equal known value", func() { - So(result.RowString(0), ShouldEqual, "4.30 3.00 1.10 0.10 Iris-setosa") + So(result.RowString(0), ShouldEqual, "4.3 3.0 1.1 0.1 Iris-setosa") }) }) }) diff --git a/base/util_instances.go b/base/util_instances.go index 309e5f7..2def385 100644 --- a/base/util_instances.go +++ b/base/util_instances.go @@ -22,6 +22,37 @@ func GeneratePredictionVector(from FixedDataGrid) UpdatableDataGrid { return ret } +// CopyDenseInstancesStructure returns a new DenseInstances +// with identical structure (layout, Attributes) to the original +func CopyDenseInstances(template *DenseInstances, templateAttrs []Attribute) *DenseInstances { + instances := NewDenseInstances() + templateAgs := template.AllAttributeGroups() + for ag := range templateAgs { + agTemplate := templateAgs[ag] + if _, ok := agTemplate.(*BinaryAttributeGroup); ok { + instances.CreateAttributeGroup(ag, 0) + } else { + instances.CreateAttributeGroup(ag, 8) + } + } + + for _, a := range templateAttrs { + s, err := template.GetAttribute(a) + if err != nil { + panic(err) + } + if ag, ok := template.agRevMap[s.pond]; !ok { + panic(ag) + } else { + _, err := instances.AddAttributeToAttributeGroup(a, ag) + if err != nil { + panic(err) + } + } + } + return instances +} + // GetClass is a shortcut for returning the string value of the current // class on a given row. // @@ -470,3 +501,34 @@ func CheckStrictlyCompatible(s1 FixedDataGrid, s2 FixedDataGrid) bool { return true } + +// InstancesAreEqual checks whether a given Instance set is exactly +// the same as another (i.e. has the same size and values). +func InstancesAreEqual(inst, other FixedDataGrid) bool { + _, rows := inst.Size() + + for _, a := range inst.AllAttributes() { + as1, err := inst.GetAttribute(a) + if err != nil { + panic(err) // That indicates some kind of error + } + as2, err := inst.GetAttribute(a) + if err != nil { + return false // Obviously has different Attributes + } + + if !as1.GetAttribute().Equals(as2.GetAttribute()) { + return false + } + + for i := 0; i < rows; i++ { + b1 := inst.Get(as1, i) + b2 := inst.Get(as2, i) + if !byteSeqEqual(b1, b2) { + return false + } + } + } + + return true +} diff --git a/base/view_test.go b/base/view_test.go index 3d9860b..d28be6e 100644 --- a/base/view_test.go +++ b/base/view_test.go @@ -17,7 +17,7 @@ func TestInstancesViewRows(t *testing.T) { So(instView.rows[0], ShouldEqual, 5) }) Convey("The reconstructed values should be correct...", func() { - str := "5.40 3.90 1.70 0.40 Iris-setosa" + str := "5.4 3.9 1.7 0.4 Iris-setosa" row := instView.RowString(0) So(row, ShouldEqual, str) }) @@ -99,7 +99,7 @@ func TestInstancesViewAttrs(t *testing.T) { So(err, ShouldNotEqual, nil) }) Convey("The filtered Attribute should not appear in the RowString", func() { - str := "3.90 1.70 0.40 Iris-setosa" + str := "3.9 1.7 0.4 Iris-setosa" row := instView.RowString(5) So(row, ShouldEqual, str) }) diff --git a/examples/datasets/iris.arff b/examples/datasets/iris.arff new file mode 100644 index 0000000..780480c --- /dev/null +++ b/examples/datasets/iris.arff @@ -0,0 +1,225 @@ +% 1. Title: Iris Plants Database +% +% 2. Sources: +% (a) Creator: R.A. Fisher +% (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov) +% (c) Date: July, 1988 +% +% 3. Past Usage: +% - Publications: too many to mention!!! Here are a few. +% 1. Fisher,R.A. "The use of multiple measurements in taxonomic problems" +% Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions +% to Mathematical Statistics" (John Wiley, NY, 1950). +% 2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis. +% (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218. +% 3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System +% Structure and Classification Rule for Recognition in Partially Exposed +% Environments". IEEE Transactions on Pattern Analysis and Machine +% Intelligence, Vol. PAMI-2, No. 1, 67-71. +% -- Results: +% -- very low misclassification rates (0% for the setosa class) +% 4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE +% Transactions on Information Theory, May 1972, 431-433. +% -- Results: +% -- very low misclassification rates again +% 5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II +% conceptual clustering system finds 3 classes in the data. +% +% 4. Relevant Information: +% --- This is perhaps the best known database to be found in the pattern +% recognition literature. Fisher's paper is a classic in the field +% and is referenced frequently to this day. (See Duda & Hart, for +% example.) The data set contains 3 classes of 50 instances each, +% where each class refers to a type of iris plant. One class is +% linearly separable from the other 2; the latter are NOT linearly +% separable from each other. +% --- Predicted attribute: class of iris plant. +% --- This is an exceedingly simple domain. +% +% 5. Number of Instances: 150 (50 in each of three classes) +% +% 6. Number of Attributes: 4 numeric, predictive attributes and the class +% +% 7. Attribute Information: +% 1. sepal length in cm +% 2. sepal width in cm +% 3. petal length in cm +% 4. petal width in cm +% 5. class: +% -- Iris Setosa +% -- Iris Versicolour +% -- Iris Virginica +% +% 8. Missing Attribute Values: None +% +% Summary Statistics: +% Min Max Mean SD Class Correlation +% sepal length: 4.3 7.9 5.84 0.83 0.7826 +% sepal width: 2.0 4.4 3.05 0.43 -0.4194 +% petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) +% petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) +% +% 9. Class Distribution: 33.3% for each of 3 classes. + +@RELATION iris + +@ATTRIBUTE sepallength REAL +@ATTRIBUTE sepalwidth REAL +@ATTRIBUTE petallength REAL +@ATTRIBUTE petalwidth REAL +@ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica} + +@DATA +5.1,3.5,1.4,0.2,Iris-setosa +4.9,3.0,1.4,0.2,Iris-setosa +4.7,3.2,1.3,0.2,Iris-setosa +4.6,3.1,1.5,0.2,Iris-setosa +5.0,3.6,1.4,0.2,Iris-setosa +5.4,3.9,1.7,0.4,Iris-setosa +4.6,3.4,1.4,0.3,Iris-setosa +5.0,3.4,1.5,0.2,Iris-setosa +4.4,2.9,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.4,3.7,1.5,0.2,Iris-setosa +4.8,3.4,1.6,0.2,Iris-setosa +4.8,3.0,1.4,0.1,Iris-setosa +4.3,3.0,1.1,0.1,Iris-setosa +5.8,4.0,1.2,0.2,Iris-setosa +5.7,4.4,1.5,0.4,Iris-setosa +5.4,3.9,1.3,0.4,Iris-setosa +5.1,3.5,1.4,0.3,Iris-setosa +5.7,3.8,1.7,0.3,Iris-setosa +5.1,3.8,1.5,0.3,Iris-setosa +5.4,3.4,1.7,0.2,Iris-setosa +5.1,3.7,1.5,0.4,Iris-setosa +4.6,3.6,1.0,0.2,Iris-setosa +5.1,3.3,1.7,0.5,Iris-setosa +4.8,3.4,1.9,0.2,Iris-setosa +5.0,3.0,1.6,0.2,Iris-setosa +5.0,3.4,1.6,0.4,Iris-setosa +5.2,3.5,1.5,0.2,Iris-setosa +5.2,3.4,1.4,0.2,Iris-setosa +4.7,3.2,1.6,0.2,Iris-setosa +4.8,3.1,1.6,0.2,Iris-setosa +5.4,3.4,1.5,0.4,Iris-setosa +5.2,4.1,1.5,0.1,Iris-setosa +5.5,4.2,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.0,3.2,1.2,0.2,Iris-setosa +5.5,3.5,1.3,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +4.4,3.0,1.3,0.2,Iris-setosa +5.1,3.4,1.5,0.2,Iris-setosa +5.0,3.5,1.3,0.3,Iris-setosa +4.5,2.3,1.3,0.3,Iris-setosa +4.4,3.2,1.3,0.2,Iris-setosa +5.0,3.5,1.6,0.6,Iris-setosa +5.1,3.8,1.9,0.4,Iris-setosa +4.8,3.0,1.4,0.3,Iris-setosa +5.1,3.8,1.6,0.2,Iris-setosa +4.6,3.2,1.4,0.2,Iris-setosa +5.3,3.7,1.5,0.2,Iris-setosa +5.0,3.3,1.4,0.2,Iris-setosa +7.0,3.2,4.7,1.4,Iris-versicolor +6.4,3.2,4.5,1.5,Iris-versicolor +6.9,3.1,4.9,1.5,Iris-versicolor +5.5,2.3,4.0,1.3,Iris-versicolor +6.5,2.8,4.6,1.5,Iris-versicolor +5.7,2.8,4.5,1.3,Iris-versicolor +6.3,3.3,4.7,1.6,Iris-versicolor +4.9,2.4,3.3,1.0,Iris-versicolor +6.6,2.9,4.6,1.3,Iris-versicolor +5.2,2.7,3.9,1.4,Iris-versicolor +5.0,2.0,3.5,1.0,Iris-versicolor +5.9,3.0,4.2,1.5,Iris-versicolor +6.0,2.2,4.0,1.0,Iris-versicolor +6.1,2.9,4.7,1.4,Iris-versicolor +5.6,2.9,3.6,1.3,Iris-versicolor +6.7,3.1,4.4,1.4,Iris-versicolor +5.6,3.0,4.5,1.5,Iris-versicolor +5.8,2.7,4.1,1.0,Iris-versicolor +6.2,2.2,4.5,1.5,Iris-versicolor +5.6,2.5,3.9,1.1,Iris-versicolor +5.9,3.2,4.8,1.8,Iris-versicolor +6.1,2.8,4.0,1.3,Iris-versicolor +6.3,2.5,4.9,1.5,Iris-versicolor +6.1,2.8,4.7,1.2,Iris-versicolor +6.4,2.9,4.3,1.3,Iris-versicolor +6.6,3.0,4.4,1.4,Iris-versicolor +6.8,2.8,4.8,1.4,Iris-versicolor +6.7,3.0,5.0,1.7,Iris-versicolor +6.0,2.9,4.5,1.5,Iris-versicolor +5.7,2.6,3.5,1.0,Iris-versicolor +5.5,2.4,3.8,1.1,Iris-versicolor +5.5,2.4,3.7,1.0,Iris-versicolor +5.8,2.7,3.9,1.2,Iris-versicolor +6.0,2.7,5.1,1.6,Iris-versicolor +5.4,3.0,4.5,1.5,Iris-versicolor +6.0,3.4,4.5,1.6,Iris-versicolor +6.7,3.1,4.7,1.5,Iris-versicolor +6.3,2.3,4.4,1.3,Iris-versicolor +5.6,3.0,4.1,1.3,Iris-versicolor +5.5,2.5,4.0,1.3,Iris-versicolor +5.5,2.6,4.4,1.2,Iris-versicolor +6.1,3.0,4.6,1.4,Iris-versicolor +5.8,2.6,4.0,1.2,Iris-versicolor +5.0,2.3,3.3,1.0,Iris-versicolor +5.6,2.7,4.2,1.3,Iris-versicolor +5.7,3.0,4.2,1.2,Iris-versicolor +5.7,2.9,4.2,1.3,Iris-versicolor +6.2,2.9,4.3,1.3,Iris-versicolor +5.1,2.5,3.0,1.1,Iris-versicolor +5.7,2.8,4.1,1.3,Iris-versicolor +6.3,3.3,6.0,2.5,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +7.1,3.0,5.9,2.1,Iris-virginica +6.3,2.9,5.6,1.8,Iris-virginica +6.5,3.0,5.8,2.2,Iris-virginica +7.6,3.0,6.6,2.1,Iris-virginica +4.9,2.5,4.5,1.7,Iris-virginica +7.3,2.9,6.3,1.8,Iris-virginica +6.7,2.5,5.8,1.8,Iris-virginica +7.2,3.6,6.1,2.5,Iris-virginica +6.5,3.2,5.1,2.0,Iris-virginica +6.4,2.7,5.3,1.9,Iris-virginica +6.8,3.0,5.5,2.1,Iris-virginica +5.7,2.5,5.0,2.0,Iris-virginica +5.8,2.8,5.1,2.4,Iris-virginica +6.4,3.2,5.3,2.3,Iris-virginica +6.5,3.0,5.5,1.8,Iris-virginica +7.7,3.8,6.7,2.2,Iris-virginica +7.7,2.6,6.9,2.3,Iris-virginica +6.0,2.2,5.0,1.5,Iris-virginica +6.9,3.2,5.7,2.3,Iris-virginica +5.6,2.8,4.9,2.0,Iris-virginica +7.7,2.8,6.7,2.0,Iris-virginica +6.3,2.7,4.9,1.8,Iris-virginica +6.7,3.3,5.7,2.1,Iris-virginica +7.2,3.2,6.0,1.8,Iris-virginica +6.2,2.8,4.8,1.8,Iris-virginica +6.1,3.0,4.9,1.8,Iris-virginica +6.4,2.8,5.6,2.1,Iris-virginica +7.2,3.0,5.8,1.6,Iris-virginica +7.4,2.8,6.1,1.9,Iris-virginica +7.9,3.8,6.4,2.0,Iris-virginica +6.4,2.8,5.6,2.2,Iris-virginica +6.3,2.8,5.1,1.5,Iris-virginica +6.1,2.6,5.6,1.4,Iris-virginica +7.7,3.0,6.1,2.3,Iris-virginica +6.3,3.4,5.6,2.4,Iris-virginica +6.4,3.1,5.5,1.8,Iris-virginica +6.0,3.0,4.8,1.8,Iris-virginica +6.9,3.1,5.4,2.1,Iris-virginica +6.7,3.1,5.6,2.4,Iris-virginica +6.9,3.1,5.1,2.3,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +6.8,3.2,5.9,2.3,Iris-virginica +6.7,3.3,5.7,2.5,Iris-virginica +6.7,3.0,5.2,2.3,Iris-virginica +6.3,2.5,5.0,1.9,Iris-virginica +6.5,3.0,5.2,2.0,Iris-virginica +6.2,3.4,5.4,2.3,Iris-virginica +5.9,3.0,5.1,1.8,Iris-virginica +% +% +% diff --git a/examples/datasets/weather.arff b/examples/datasets/weather.arff new file mode 100644 index 0000000..9426e3a --- /dev/null +++ b/examples/datasets/weather.arff @@ -0,0 +1,23 @@ +@relation weather + +@attribute outlook {sunny, overcast, rainy} +@attribute temperature real +@attribute humidity real +@attribute windy {TRUE, FALSE} +@attribute play {yes, no} + +@data +sunny,85,85,FALSE,no +sunny,80,90,TRUE,no +overcast,83,86,FALSE,yes +rainy,70,96,FALSE,yes +rainy,68,80,FALSE,yes +rainy,65,70,TRUE,no +overcast,64,65,TRUE,yes +sunny,72,95,FALSE,no +sunny,69,70,FALSE,yes +rainy,75,80,FALSE,yes +sunny,75,70,TRUE,yes +overcast,72,90,TRUE,yes +overcast,81,75,FALSE,yes +rainy,71,91,TRUE,no diff --git a/examples/serialization/attributes.go b/examples/serialization/attributes.go new file mode 100644 index 0000000..a827fbc --- /dev/null +++ b/examples/serialization/attributes.go @@ -0,0 +1,32 @@ +// Demonstrates decision tree classification + +package main + +import ( + "encoding/json" + "fmt" + "github.com/sjwhitworth/golearn/base" +) + +func main() { + + // Load in the iris dataset + iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + + for _, a := range iris.AllAttributes() { + var ac base.CategoricalAttribute + var af base.FloatAttribute + s, err := json.Marshal(a) + if err != nil { + panic(err) + } + fmt.Println(string(s)) + err = json.Unmarshal(s, &af) + fmt.Println(af.String()) + err = json.Unmarshal(s, &ac) + fmt.Println(ac.String()) + } +} diff --git a/linear_models/linear_models_test.go b/linear_models/linear_models_test.go index 1992ae8..2a6cc41 100644 --- a/linear_models/linear_models_test.go +++ b/linear_models/linear_models_test.go @@ -24,14 +24,14 @@ func TestLogisticRegression(t *testing.T) { Z, err := lr.Predict(Y) So(err, ShouldEqual, nil) Convey("The result should be 1", func() { - So(Z.RowString(0), ShouldEqual, "1.00") + So(Z.RowString(0), ShouldEqual, "1.0") }) }) Convey("When predicting the label of second vector", func() { Z, err := lr.Predict(Y) So(err, ShouldEqual, nil) Convey("The result should be -1", func() { - So(Z.RowString(1), ShouldEqual, "-1.00") + So(Z.RowString(1), ShouldEqual, "-1.0") }) }) })