1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Merge pull request #101 from Sentimentron/arff-staging

ARFF import/export, CSV export, lossless serialisation
This commit is contained in:
Stephen Whitworth 2014-11-21 13:53:43 +00:00
commit 7ea42ac80b
19 changed files with 1440 additions and 188 deletions

292
base/arff.go Normal file
View File

@ -0,0 +1,292 @@
package base
import (
"bufio"
"bytes"
"encoding/csv"
"fmt"
"io"
"os"
"runtime"
"strings"
)
// SerializeInstancesToDenseARFF writes the given FixedDataGrid to a
// densely-formatted ARFF file.
func SerializeInstancesToDenseARFF(inst FixedDataGrid, path, relation string) error {
// Get all of the Attributes in a reasonable order
attrs := NonClassAttributes(inst)
cAttrs := inst.AllClassAttributes()
for _, c := range cAttrs {
attrs = append(attrs, c)
}
return SerializeInstancesToDenseARFFWithAttributes(inst, attrs, path, relation)
}
// SerializeInstancesToDenseARFFWithAttributes writes the given FixedDataGrid to a
// densely-formatted ARFF file with the header Attributes in the order given.
func SerializeInstancesToDenseARFFWithAttributes(inst FixedDataGrid, rawAttrs []Attribute, path, relation string) error {
// Open output file
f, err := os.OpenFile(path, os.O_RDWR, 0600)
if err != nil {
return err
}
defer f.Close()
// Write @relation header
f.WriteString(fmt.Sprintf("@relation %s\n\n", relation))
// Get all Attribute specifications
attrs := ResolveAttributes(inst, rawAttrs)
// Write Attribute information
for _, s := range attrs {
attr := s.attr
t := "real"
if a, ok := attr.(*CategoricalAttribute); ok {
vals := a.GetValues()
t = fmt.Sprintf("{%s}", strings.Join(vals, ", "))
}
f.WriteString(fmt.Sprintf("@attribute %s %s\n", attr.GetName(), t))
}
f.WriteString("\n@data\n")
buf := make([]string, len(attrs))
inst.MapOverRows(attrs, func(val [][]byte, row int) (bool, error) {
for i, v := range val {
buf[i] = attrs[i].attr.GetStringFromSysVal(v)
}
f.WriteString(strings.Join(buf, ","))
f.WriteString("\n")
return true, nil
})
return nil
}
// ParseARFFGetRows returns the number of data rows in an ARFF file.
func ParseARFFGetRows(filepath string) (int, error) {
f, err := os.Open(filepath)
if err != nil {
return 0, err
}
defer f.Close()
counting := false
count := 0
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if len(line) == 0 {
continue
}
if counting {
if line[0] == '@' {
continue
}
if line[0] == '%' {
continue
}
count++
continue
}
if line[0] == '@' {
line = strings.ToLower(line)
if line == "@data" {
counting = true
}
}
}
return count, nil
}
// ParseARFFGetAttributes returns the set of Attributes represented in this ARFF
func ParseARFFGetAttributes(filepath string) []Attribute {
var ret []Attribute
f, err := os.Open(filepath)
if err != nil {
panic(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var attr Attribute
line := scanner.Text()
if len(line) == 0 {
continue
}
if line[0] != '@' {
continue
}
fields := strings.Fields(line)
if len(fields) < 3 {
continue
}
fields[0] = strings.ToLower(fields[0])
attrType := strings.ToLower(fields[2])
if fields[0] != "@attribute" {
continue
}
switch attrType {
case "real":
attr = new(FloatAttribute)
break
default:
if fields[2][0] == '{' {
if strings.HasSuffix(fields[len(fields)-1], "}") {
var cats []string
if len(fields) > 3 {
cats = fields[2:len(fields)]
} else {
cats = strings.Split(fields[2], ",")
}
if len(cats) == 0 {
panic(fmt.Errorf("Empty categorical field on line '%s'", line))
}
cats[0] = cats[0][1:] // Remove leading '{'
cats[len(cats)-1] = cats[len(cats)-1][:len(cats[len(cats)-1])-1] // Remove trailing '}'
for i, v := range cats { // Miaow
cats[i] = strings.TrimSpace(v)
if strings.HasSuffix(cats[i], ",") {
// Strip end comma
cats[i] = cats[i][0 : len(cats[i])-1]
}
}
attr = NewCategoricalAttribute()
for _, v := range cats {
attr.GetSysValFromString(v)
}
} else {
panic(fmt.Errorf("Missing categorical bracket on line '%s'", line))
}
} else {
panic(fmt.Errorf("Unsupported Attribute type %s on line '%s'", fields[2], line))
}
}
if attr == nil {
panic(fmt.Errorf(line))
}
attr.SetName(fields[1])
ret = append(ret, attr)
}
maxPrecision, err := ParseCSVEstimateFilePrecision(filepath)
if err != nil {
panic(err)
}
for _, a := range ret {
if f, ok := a.(*FloatAttribute); ok {
f.Precision = maxPrecision
}
}
return ret
}
// ParseDenseARFFBuildInstancesFromReader updates an [[#UpdatableDataGrid]] from a io.Reader
func ParseDenseARFFBuildInstancesFromReader(r io.Reader, attrs []Attribute, u UpdatableDataGrid) (err error) {
var rowCounter int
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(err)
}
err = fmt.Errorf("Error at line %d (error %s)", rowCounter, r.(error))
}
}()
scanner := bufio.NewScanner(r)
reading := false
specs := ResolveAttributes(u, attrs)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "%") {
continue
}
if reading {
buf := bytes.NewBuffer([]byte(line))
reader := csv.NewReader(buf)
for {
r, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
return err
}
for i, v := range r {
v = strings.TrimSpace(v)
if a, ok := specs[i].attr.(*CategoricalAttribute); ok {
if val := a.GetSysVal(v); val == nil {
panic(fmt.Errorf("Unexpected class on line '%s'", line))
}
}
u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(v))
}
rowCounter++
}
} else {
line = strings.ToLower(line)
line = strings.TrimSpace(line)
if line == "@data" {
reading = true
}
}
}
return nil
}
// ParseDenseARFFToInstances parses the dense ARFF File into a FixedDataGrid
func ParseDenseARFFToInstances(filepath string) (ret *DenseInstances, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
err = r.(error)
}
}()
// Find the number of rows in the file
rows, err := ParseARFFGetRows(filepath)
if err != nil {
return nil, err
}
// Get the Attributes we want
attrs := ParseARFFGetAttributes(filepath)
// Allocate return value
ret = NewDenseInstances()
// Add all the Attributes
for _, a := range attrs {
ret.AddAttribute(a)
}
// Set the last Attribute as the class
ret.AddClassAttribute(attrs[len(attrs)-1])
ret.Extend(rows)
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
// Read the data
// Seek past the header
err = ParseDenseARFFBuildInstancesFromReader(f, attrs, ret)
if err != nil {
ret = nil
}
return ret, err
}

109
base/arff_test.go Normal file
View File

@ -0,0 +1,109 @@
package base
import (
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"testing"
)
func TestParseARFFGetRows(t *testing.T) {
Convey("Getting the number of rows for a ARFF file", t, func() {
Convey("With a valid file path", func() {
numNonHeaderRows := 150
lineCount, err := ParseARFFGetRows("../examples/datasets/iris.arff")
So(err, ShouldBeNil)
So(lineCount, ShouldEqual, numNonHeaderRows)
})
})
}
func TestParseARFFGetAttributes(t *testing.T) {
Convey("Getting the attributes in the headers of a CSV file", t, func() {
attributes := ParseARFFGetAttributes("../examples/datasets/iris.arff")
sepalLengthAttribute := attributes[0]
sepalWidthAttribute := attributes[1]
petalLengthAttribute := attributes[2]
petalWidthAttribute := attributes[3]
speciesAttribute := attributes[4]
Convey("It gets the correct types for the headers based on the column values", func() {
_, ok1 := sepalLengthAttribute.(*FloatAttribute)
_, ok2 := sepalWidthAttribute.(*FloatAttribute)
_, ok3 := petalLengthAttribute.(*FloatAttribute)
_, ok4 := petalWidthAttribute.(*FloatAttribute)
sA, ok5 := speciesAttribute.(*CategoricalAttribute)
So(ok1, ShouldBeTrue)
So(ok2, ShouldBeTrue)
So(ok3, ShouldBeTrue)
So(ok4, ShouldBeTrue)
So(ok5, ShouldBeTrue)
So(sA.GetValues(), ShouldResemble, []string{"Iris-setosa", "Iris-versicolor", "Iris-virginica"})
})
})
}
func TestParseARFF1(t *testing.T) {
Convey("Should just be able to load in an ARFF...", t, func() {
inst, err := ParseDenseARFFToInstances("../examples/datasets/iris.arff")
So(err, ShouldBeNil)
So(inst, ShouldNotBeNil)
So(inst.RowString(0), ShouldEqual, "5.1 3.5 1.4 0.2 Iris-setosa")
So(inst.RowString(50), ShouldEqual, "7.0 3.2 4.7 1.4 Iris-versicolor")
So(inst.RowString(100), ShouldEqual, "6.3 3.3 6.0 2.5 Iris-virginica")
})
}
func TestParseARFF2(t *testing.T) {
Convey("Loading the weather dataset...", t, func() {
inst, err := ParseDenseARFFToInstances("../examples/datasets/weather.arff")
So(err, ShouldBeNil)
Convey("Attributes should be right...", func() {
So(GetAttributeByName(inst, "outlook"), ShouldNotBeNil)
So(GetAttributeByName(inst, "temperature"), ShouldNotBeNil)
So(GetAttributeByName(inst, "humidity"), ShouldNotBeNil)
So(GetAttributeByName(inst, "windy"), ShouldNotBeNil)
So(GetAttributeByName(inst, "play"), ShouldNotBeNil)
Convey("outlook attribute values should match reference...", func() {
outlookAttr := GetAttributeByName(inst, "outlook").(*CategoricalAttribute)
So(outlookAttr.GetValues(), ShouldResemble, []string{"sunny", "overcast", "rainy"})
})
Convey("windy values should match reference...", func() {
windyAttr := GetAttributeByName(inst, "windy").(*CategoricalAttribute)
So(windyAttr.GetValues(), ShouldResemble, []string{"TRUE", "FALSE"})
})
Convey("play values should match reference...", func() {
playAttr := GetAttributeByName(inst, "play").(*CategoricalAttribute)
So(playAttr.GetValues(), ShouldResemble, []string{"yes", "no"})
})
})
})
}
func TestSerializeToARFF(t *testing.T) {
Convey("Loading the weather dataset...", t, func() {
inst, err := ParseDenseARFFToInstances("../examples/datasets/weather.arff")
So(err, ShouldBeNil)
Convey("Saving back should suceed...", func() {
attrs := ParseARFFGetAttributes("../examples/datasets/weather.arff")
f, err := ioutil.TempFile("", "inst")
So(err, ShouldBeNil)
err = SerializeInstancesToDenseARFFWithAttributes(inst, attrs, f.Name(), "weather")
So(err, ShouldBeNil)
Convey("Reading the file back should be lossless...", func() {
inst2, err := ParseDenseARFFToInstances(f.Name())
So(err, ShouldBeNil)
So(InstancesAreEqual(inst, inst2), ShouldBeTrue)
})
Convey("The file should be exactly the same as the original...", func() {
ref, err := ioutil.ReadFile("../examples/datasets/weather.arff")
So(err, ShouldBeNil)
gen, err := ioutil.ReadFile(f.Name())
So(err, ShouldBeNil)
So(string(gen), ShouldEqual, string(ref))
})
})
})
}

View File

@ -1,5 +1,9 @@
package base
import (
"encoding/json"
)
const (
// CategoricalType is for Attributes which represent values distinctly.
CategoricalType = iota
@ -10,6 +14,8 @@ const (
// Attributes disambiguate columns of the feature matrix and declare their types.
type Attribute interface {
json.Unmarshaler
json.Marshaler
// Returns the general characterstics of this Attribute .
// to avoid the overhead of casting
GetType() int

View File

@ -1,6 +1,7 @@
package base
import (
"encoding/json"
"fmt"
"strconv"
)
@ -10,6 +11,20 @@ type BinaryAttribute struct {
Name string
}
// MarshalJSON returns a JSON version of this BinaryAttribute for serialisation.
func (b *BinaryAttribute) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]interface{}{
"type": "binary",
"name": b.Name,
})
}
// UnmarshalJSON unpacks a BinaryAttribute from serialisation.
// Usually, there's nothing to deserialize.
func (b *BinaryAttribute) UnmarshalJSON(data []byte) error {
return nil
}
// NewBinaryAttribute creates a BinaryAttribute with the given name
func NewBinaryAttribute(name string) *BinaryAttribute {
return &BinaryAttribute{

View File

@ -1,6 +1,7 @@
package base
import (
"encoding/json"
"fmt"
)
@ -9,7 +10,31 @@ import (
// - useful for representing classes.
type CategoricalAttribute struct {
Name string
values []string
values []string `json:"values"`
}
// MarshalJSON returns a JSON version of this Attribute.
func (Attr *CategoricalAttribute) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]interface{}{
"type": "categorical",
"name": Attr.Name,
"attr": map[string]interface{}{
"values": Attr.values,
},
})
}
// UnmarshalJSON returns a JSON version of this Attribute.
func (Attr *CategoricalAttribute) UnmarshalJSON(data []byte) error {
var d map[string]interface{}
err := json.Unmarshal(data, &d)
if err != nil {
return err
}
for _, v := range d["values"].([]interface{}) {
Attr.values = append(Attr.values, v.(string))
}
return nil
}
// NewCategoricalAttribute creates a blank CategoricalAttribute.

View File

@ -1,11 +1,13 @@
package base
import (
"bufio"
"encoding/csv"
"fmt"
"io"
"os"
"regexp"
"runtime"
"strings"
)
@ -31,6 +33,52 @@ func ParseCSVGetRows(filepath string) (int, error) {
return counter, nil
}
// ParseCSVEstimateFilePrecision determines what the maximum number of
// digits occuring anywhere after the decimal point within the file.
func ParseCSVEstimateFilePrecision(filepath string) (int, error) {
// Creat a basic regexp
rexp := regexp.MustCompile("[0-9]+(.[0-9]+)?")
// Open the source file
f, err := os.Open(filepath)
if err != nil {
return 0, err
}
defer f.Close()
// Scan through the file line-by-line
maxL := 0
scanner := bufio.NewScanner(f)
lineCount := 0
for scanner.Scan() {
if lineCount > 5 {
break
}
line := scanner.Text()
if len(line) == 0 {
continue
}
if line[0] == '@' {
continue
}
if line[0] == '%' {
continue
}
matches := rexp.FindAllString(line, -1)
for _, m := range matches {
p := strings.Split(m, ".")
if len(p) == 2 {
l := len(p[len(p)-1])
if l > maxL {
maxL = l
}
}
}
lineCount++
}
return maxL, nil
}
// ParseCSVGetAttributes returns an ordered slice of appropriate-ly typed
// and named Attributes.
func ParseCSVGetAttributes(filepath string, hasHeaders bool) []Attribute {
@ -112,49 +160,69 @@ func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
}
}
return attrs
}
// ParseCSVBuildInstances updates an [[#UpdatableDataGrid]] from a filepath in place
func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGrid) {
// Read the input
file, err := os.Open(filepath)
// Estimate file precision
maxP, err := ParseCSVEstimateFilePrecision(filepath)
if err != nil {
panic(err)
}
defer file.Close()
reader := csv.NewReader(file)
for _, a := range attrs {
if f, ok := a.(*FloatAttribute); ok {
f.Precision = maxP
}
}
rowCounter := 0
return attrs
}
specs := ResolveAttributes(u, u.AllAttributes())
// ParseCSVBuildInstancesFromReader updates an [[#UpdatableDataGrid]] from a io.Reader
func ParseCSVBuildInstancesFromReader(r io.Reader, attrs []Attribute, hasHeader bool, u UpdatableDataGrid) (err error) {
var rowCounter int
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(err)
}
err = fmt.Errorf("Error at line %d (error %s)", rowCounter, r.(error))
}
}()
specs := ResolveAttributes(u, attrs)
reader := csv.NewReader(r)
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
panic(err)
return err
}
if rowCounter == 0 {
if hasHeaders {
hasHeaders = false
if hasHeader {
hasHeader = false
continue
}
}
for i, v := range record {
u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(v))
u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(strings.TrimSpace(v)))
}
rowCounter++
}
return nil
}
// ParseCSVToInstances reads the CSV file given by filepath and returns
// the read Instances.
func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInstances, err error) {
// Open the file
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
// Read the number of rows in the file
rowCount, err := ParseCSVGetRows(filepath)
if err != nil {
@ -176,45 +244,41 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInst
}
instances.Extend(rowCount)
// Read the input
file, err := os.Open(filepath)
err = ParseCSVBuildInstancesFromReader(f, attrs, hasHeaders, instances)
if err != nil {
return nil, err
}
defer file.Close()
reader := csv.NewReader(file)
rowCounter := 0
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
if rowCounter == 0 {
if hasHeaders {
hasHeaders = false
continue
}
}
for i, v := range record {
v = strings.Trim(v, " ")
instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v))
}
rowCounter++
}
instances.AddClassAttribute(attrs[len(attrs)-1])
return instances, nil
}
// ParseUtilsMatchAttrs tries to match the set of Attributes read from one file with
// those read from another, and writes the matching Attributes back to the original set.
func ParseMatchAttributes(attrs, templateAttrs []Attribute) {
for i, a := range attrs {
for _, b := range templateAttrs {
if a.Equals(b) {
attrs[i] = b
} else if a.GetName() == b.GetName() {
attrs[i] = b
}
}
}
}
// ParseCSVToInstancesTemplated reads the CSV file given by filepath and returns
// the read Instances, using another already read DenseInstances as a template.
func ParseCSVToTemplatedInstances(filepath string, hasHeaders bool, template *DenseInstances) (instances *DenseInstances, err error) {
// Open the file
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
// Read the number of rows in the file
rowCount, err := ParseCSVGetRows(filepath)
if err != nil {
@ -228,81 +292,16 @@ func ParseCSVToTemplatedInstances(filepath string, hasHeaders bool, template *De
// Read the row headers
attrs := ParseCSVGetAttributes(filepath, hasHeaders)
templateAttrs := template.AllAttributes()
for i, a := range attrs {
for _, b := range templateAttrs {
if a.Equals(b) {
attrs[i] = b
} else if a.GetName() == b.GetName() {
attrs[i] = b
}
}
}
ParseMatchAttributes(attrs, templateAttrs)
specs := make([]AttributeSpec, len(attrs))
// Allocate the Instances to return
instances = NewDenseInstances()
templateAgs := template.AllAttributeGroups()
for ag := range templateAgs {
agTemplate := templateAgs[ag]
if _, ok := agTemplate.(*BinaryAttributeGroup); ok {
instances.CreateAttributeGroup(ag, 0)
} else {
instances.CreateAttributeGroup(ag, 8)
}
}
for i, a := range templateAttrs {
s, err := template.GetAttribute(a)
if err != nil {
panic(err)
}
if ag, ok := template.agRevMap[s.pond]; !ok {
panic(ag)
} else {
spec, err := instances.AddAttributeToAttributeGroup(a, ag)
if err != nil {
panic(err)
}
specs[i] = spec
}
}
instances = CopyDenseInstances(template, templateAttrs)
instances.Extend(rowCount)
// Read the input
file, err := os.Open(filepath)
err = ParseCSVBuildInstancesFromReader(f, attrs, hasHeaders, instances)
if err != nil {
return nil, err
}
defer file.Close()
reader := csv.NewReader(file)
rowCounter := 0
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
if rowCounter == 0 {
if hasHeaders {
hasHeaders = false
continue
}
}
for i, v := range record {
v = strings.Trim(v, " ")
instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v))
}
rowCounter++
}
for _, a := range template.AllClassAttributes() {
instances.AddClassAttribute(a)
}
return instances, nil
}
@ -312,6 +311,13 @@ func ParseCSVToTemplatedInstances(filepath string, hasHeaders bool, template *De
// specified in the first argument and also any class Attributes specified in the second
func ParseCSVToInstancesWithAttributeGroups(filepath string, attrGroups, classAttrGroups map[string]string, attrOverrides map[int]Attribute, hasHeaders bool) (instances *DenseInstances, err error) {
// Open file
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
// Read row count
rowCount, err := ParseCSVGetRows(filepath)
if err != nil {
@ -386,45 +392,11 @@ func ParseCSVToInstancesWithAttributeGroups(filepath string, attrGroups, classAt
// Allocate
instances.Extend(rowCount)
// Read the input
file, err := os.Open(filepath)
err = ParseCSVBuildInstancesFromReader(f, attrs, hasHeaders, instances)
if err != nil {
return nil, err
}
defer file.Close()
reader := csv.NewReader(file)
rowCounter := 0
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
if rowCounter == 0 {
// Skip header row
rowCounter++
continue
}
for i, v := range record {
v = strings.Trim(v, " ")
instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v))
}
rowCounter++
}
// Add class Attributes
for _, a := range instances.AllAttributes() {
name := a.GetName() // classAttrGroups
if _, ok := classAttrGroups[name]; ok {
err = instances.AddClassAttribute(a)
if err != nil {
panic(err)
}
}
}
return instances, nil
}

View File

@ -93,12 +93,17 @@ func TestParseCSVToInstances(t *testing.T) {
So(err, ShouldBeNil)
Convey("Should parse the rows correctly", func() {
So(instances.RowString(0), ShouldEqual, "5.10 3.50 1.40 0.20 Iris-setosa")
So(instances.RowString(50), ShouldEqual, "7.00 3.20 4.70 1.40 Iris-versicolor")
So(instances.RowString(100), ShouldEqual, "6.30 3.30 6.00 2.50 Iris-virginica")
So(instances.RowString(0), ShouldEqual, "5.1 3.5 1.4 0.2 Iris-setosa")
So(instances.RowString(50), ShouldEqual, "7.0 3.2 4.7 1.4 Iris-versicolor")
So(instances.RowString(100), ShouldEqual, "6.3 3.3 6.0 2.5 Iris-virginica")
})
})
Convey("Given a path to another reasonable CSV file", func() {
_, err := ParseCSVToInstances("../examples/datasets/c45-numeric.csv", true)
So(err, ShouldBeNil)
})
Convey("Given a path to a non-existent file", func() {
_, err := ParseCSVToInstances("../examples/datasets/non-existent.csv", true)

View File

@ -443,35 +443,6 @@ func (inst *DenseInstances) swapRows(i, j int) {
}
}
// Equal checks whether a given Instance set is exactly the same
// as another: same size and same values (as determined by the Attributes)
//
// IMPORTANT: does not explicitly check if the Attributes are considered equal.
func (inst *DenseInstances) Equal(other DataGrid) bool {
_, rows := inst.Size()
for _, a := range inst.AllAttributes() {
as1, err := inst.GetAttribute(a)
if err != nil {
panic(err) // That indicates some kind of error
}
as2, err := inst.GetAttribute(a)
if err != nil {
return false // Obviously has different Attributes
}
for i := 0; i < rows; i++ {
b1 := inst.Get(as1, i)
b2 := inst.Get(as2, i)
if !byteSeqEqual(b1, b2) {
return false
}
}
}
return true
}
// String returns a human-readable summary of this dataset.
func (inst *DenseInstances) String() string {
var buffer bytes.Buffer

View File

@ -1,6 +1,7 @@
package base
import (
"encoding/json"
"fmt"
"strconv"
)
@ -12,6 +13,32 @@ type FloatAttribute struct {
Precision int
}
// MarshalJSON returns a JSON representation of this Attribute
// for serialisation.
func (f *FloatAttribute) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]interface{}{
"type": "float",
"name": f.Name,
"attr": map[string]interface{}{
"precision": f.Precision,
},
})
}
// UnmarshalJSON reads a JSON representation of this Attribute.
func (f *FloatAttribute) UnmarshalJSON(data []byte) error {
var d map[string]interface{}
err := json.Unmarshal(data, &d)
if err != nil {
return err
}
if precision, ok := d["precision"]; ok {
f.Precision = int(precision.(float64))
return nil
}
return fmt.Errorf("Precision must be specified")
}
// NewFloatAttribute returns a new FloatAttribute with a default
// precision of 2 decimal places
func NewFloatAttribute(name string) *FloatAttribute {

View File

@ -29,7 +29,7 @@ func TestLazySortDesc(t *testing.T) {
})
Convey("Result should match the reference", func() {
So(sortedDescending.Equal(result), ShouldBeTrue)
So(InstancesAreEqual(sortedDescending, result), ShouldBeTrue)
})
})
})
@ -60,11 +60,11 @@ func TestLazySortAsc(t *testing.T) {
})
Convey("Result should match the reference", func() {
So(sortedAscending.Equal(result), ShouldBeTrue)
So(InstancesAreEqual(sortedAscending, result), ShouldBeTrue)
})
Convey("First element of Result should equal known value", func() {
So(result.RowString(0), ShouldEqual, "4.30 3.00 1.10 0.10 Iris-setosa")
So(result.RowString(0), ShouldEqual, "4.3 3.0 1.1 0.1 Iris-setosa")
})
})
})

381
base/serialize.go Normal file
View File

@ -0,0 +1,381 @@
package base
import (
"archive/tar"
"compress/gzip"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"os"
"reflect"
"runtime"
)
const (
SerializationFormatVersion = "golearn 0.5"
)
func SerializeInstancesToFile(inst FixedDataGrid, path string) error {
f, err := os.OpenFile(path, os.O_RDWR, 0600)
if err != nil {
return err
}
err = SerializeInstances(inst, f)
if err != nil {
return err
}
err = f.Sync()
if err != nil {
return fmt.Errorf("Couldn't flush file: %s", err)
}
f.Close()
return nil
}
func SerializeInstancesToCSV(inst FixedDataGrid, path string) error {
f, err := os.OpenFile(path, os.O_RDWR, 0600)
if err != nil {
return err
}
defer func() {
f.Sync()
f.Close()
}()
return SerializeInstancesToCSVStream(inst, f)
}
func SerializeInstancesToCSVStream(inst FixedDataGrid, f io.Writer) error {
// Create the CSV writer
w := csv.NewWriter(f)
colCount, _ := inst.Size()
// Write out Attribute headers
// Start with the regular Attributes
normalAttrs := NonClassAttributes(inst)
classAttrs := inst.AllClassAttributes()
allAttrs := make([]Attribute, colCount)
n := copy(allAttrs, normalAttrs)
copy(allAttrs[n:], classAttrs)
headerRow := make([]string, colCount)
for i, v := range allAttrs {
headerRow[i] = v.GetName()
}
w.Write(headerRow)
specs := ResolveAttributes(inst, allAttrs)
curRow := make([]string, colCount)
inst.MapOverRows(specs, func(row [][]byte, rowNo int) (bool, error) {
for i, v := range row {
attr := allAttrs[i]
curRow[i] = attr.GetStringFromSysVal(v)
}
w.Write(curRow)
return true, nil
})
w.Flush()
return nil
}
func writeAttributesToFilePart(attrs []Attribute, f *tar.Writer, name string) error {
// Get the marshaled Attribute array
body, err := json.Marshal(attrs)
if err != nil {
return err
}
// Write a header
hdr := &tar.Header{
Name: name,
Size: int64(len(body)),
}
if err := f.WriteHeader(hdr); err != nil {
return err
}
// Write the marshaled data
if _, err := f.Write([]byte(body)); err != nil {
return err
}
return nil
}
func getTarContent(tr *tar.Reader, name string) []byte {
for {
hdr, err := tr.Next()
if err == io.EOF {
break
} else if err != nil {
panic(err)
}
if hdr.Name == name {
ret := make([]byte, hdr.Size)
n, err := tr.Read(ret)
if int64(n) != hdr.Size {
panic("Size mismatch")
}
if err != nil {
panic(err)
}
return ret
}
}
panic("File not found!")
}
func deserializeAttributes(data []byte) []Attribute {
// Define a JSON shim Attribute
type JSONAttribute struct {
Type string `json:type`
Name string `json:name`
Attr json.RawMessage `json:attr`
}
var ret []Attribute
var attrs []JSONAttribute
err := json.Unmarshal(data, &attrs)
if err != nil {
panic(fmt.Errorf("Attribute decode error: %s", err))
}
for _, a := range attrs {
var attr Attribute
var err error
switch a.Type {
case "binary":
attr = new(BinaryAttribute)
break
case "float":
attr = new(FloatAttribute)
break
case "categorical":
attr = new(CategoricalAttribute)
break
default:
panic(fmt.Errorf("Unrecognised Attribute format: %s", a.Type))
}
err = attr.UnmarshalJSON(a.Attr)
if err != nil {
panic(fmt.Errorf("Can't deserialize: %s (error: %s)", a, err))
}
attr.SetName(a.Name)
ret = append(ret, attr)
}
return ret
}
func DeserializeInstances(f io.Reader) (ret *DenseInstances, err error) {
// Recovery function
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
err = r.(error)
}
}()
// Open the .gz layer
gzReader, err := gzip.NewReader(f)
if err != nil {
return nil, fmt.Errorf("Can't open: %s", err)
}
// Open the .tar layer
tr := tar.NewReader(gzReader)
// Retrieve the MANIFEST and verify
manifestBytes := getTarContent(tr, "MANIFEST")
if !reflect.DeepEqual(manifestBytes, []byte(SerializationFormatVersion)) {
return nil, fmt.Errorf("Unsupported MANIFEST: %s", string(manifestBytes))
}
// Get the size
sizeBytes := getTarContent(tr, "DIMS")
attrCount := int(UnpackBytesToU64(sizeBytes[0:8]))
rowCount := int(UnpackBytesToU64(sizeBytes[8:]))
// Unmarshal the Attributes
attrBytes := getTarContent(tr, "CATTRS")
cAttrs := deserializeAttributes(attrBytes)
attrBytes = getTarContent(tr, "ATTRS")
normalAttrs := deserializeAttributes(attrBytes)
// Create the return instances
ret = NewDenseInstances()
// Normal Attributes first, class Attributes on the end
allAttributes := make([]Attribute, attrCount)
for i, v := range normalAttrs {
ret.AddAttribute(v)
allAttributes[i] = v
}
for i, v := range cAttrs {
ret.AddAttribute(v)
err = ret.AddClassAttribute(v)
if err != nil {
return nil, fmt.Errorf("Could not set Attribute as class Attribute: %s", err)
}
allAttributes[i+len(normalAttrs)] = v
}
// Allocate memory
err = ret.Extend(int(rowCount))
if err != nil {
return nil, fmt.Errorf("Could not allocate memory")
}
// Seek through the TAR file until we get to the DATA section
for {
hdr, err := tr.Next()
if err == io.EOF {
return nil, fmt.Errorf("DATA section missing!")
} else if err != nil {
return nil, fmt.Errorf("Error seeking to DATA section: %s", err)
}
if hdr.Name == "DATA" {
break
}
}
// Resolve AttributeSpecs
specs := ResolveAttributes(ret, allAttributes)
// Finally, read the values out of the data section
for i := 0; i < rowCount; i++ {
for _, s := range specs {
r := ret.Get(s, i)
n, err := tr.Read(r)
if n != len(r) {
return nil, fmt.Errorf("Expected %d bytes (read %d) on row %d", len(r), n, i)
}
if err != nil {
return nil, fmt.Errorf("Read error: %s", err)
}
ret.Set(s, i, r)
}
}
if err = gzReader.Close(); err != nil {
return ret, fmt.Errorf("Error closing gzip stream: %s", err)
}
return ret, nil
}
func SerializeInstances(inst FixedDataGrid, f io.Writer) error {
var hdr *tar.Header
gzWriter := gzip.NewWriter(f)
tw := tar.NewWriter(gzWriter)
// Write the MANIFEST entry
hdr = &tar.Header{
Name: "MANIFEST",
Size: int64(len(SerializationFormatVersion)),
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write MANIFEST header: %s", err)
}
if _, err := tw.Write([]byte(SerializationFormatVersion)); err != nil {
return fmt.Errorf("Could not write MANIFEST contents: %s", err)
}
// Now write the dimensions of the dataset
attrCount, rowCount := inst.Size()
hdr = &tar.Header{
Name: "DIMS",
Size: 16,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write DIMS header: %s", err)
}
if _, err := tw.Write(PackU64ToBytes(uint64(attrCount))); err != nil {
return fmt.Errorf("Could not write DIMS (attrCount): %s", err)
}
if _, err := tw.Write(PackU64ToBytes(uint64(rowCount))); err != nil {
return fmt.Errorf("Could not write DIMS (rowCount): %s", err)
}
// Write the ATTRIBUTES files
classAttrs := inst.AllClassAttributes()
normalAttrs := NonClassAttributes(inst)
if err := writeAttributesToFilePart(classAttrs, tw, "CATTRS"); err != nil {
return fmt.Errorf("Could not write CATTRS: %s", err)
}
if err := writeAttributesToFilePart(normalAttrs, tw, "ATTRS"); err != nil {
return fmt.Errorf("Could not write ATTRS: %s", err)
}
// Data must be written out in the same order as the Attributes
allAttrs := make([]Attribute, attrCount)
normCount := copy(allAttrs, normalAttrs)
for i, v := range classAttrs {
allAttrs[normCount+i] = v
}
allSpecs := ResolveAttributes(inst, allAttrs)
// First, estimate the amount of data we'll need...
dataLength := int64(0)
inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
for _, v := range val {
dataLength += int64(len(v))
}
return true, nil
})
// Then write the header
hdr = &tar.Header{
Name: "DATA",
Size: dataLength,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write DATA: %s", err)
}
// Then write the actual data
writtenLength := int64(0)
if err := inst.MapOverRows(allSpecs, func(val [][]byte, row int) (bool, error) {
for _, v := range val {
wl, err := tw.Write(v)
writtenLength += int64(wl)
if err != nil {
return false, err
}
}
return true, nil
}); err != nil {
return err
}
if writtenLength != dataLength {
return fmt.Errorf("Could not write DATA: changed size from %v to %v", dataLength, writtenLength)
}
// Finally, close and flush the various levels
if err := tw.Flush(); err != nil {
return fmt.Errorf("Could not flush tar: %s", err)
}
if err := tw.Close(); err != nil {
return fmt.Errorf("Could not close tar: %s", err)
}
if err := gzWriter.Flush(); err != nil {
return fmt.Errorf("Could not flush gz: %s", err)
}
if err := gzWriter.Close(); err != nil {
return fmt.Errorf("Could not close gz: %s", err)
}
return nil
}

107
base/serialize_test.go Normal file
View File

@ -0,0 +1,107 @@
package base
import (
"archive/tar"
"compress/gzip"
"fmt"
. "github.com/smartystreets/goconvey/convey"
"io"
"io/ioutil"
"testing"
)
func TestSerializeToCSV(t *testing.T) {
Convey("Reading some instances...", t, func() {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Saving the instances to CSV...", func() {
f, err := ioutil.TempFile("", "instTmp")
So(err, ShouldBeNil)
err = SerializeInstancesToCSV(inst, f.Name())
So(err, ShouldBeNil)
Convey("What's written out should match what's read in", func() {
dinst, err := ParseCSVToInstances(f.Name(), true)
So(err, ShouldBeNil)
So(inst.String(), ShouldEqual, dinst.String())
})
})
})
}
func TestSerializeToFile(t *testing.T) {
Convey("Reading some instances...", t, func() {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Dumping to file...", func() {
f, err := ioutil.TempFile("", "instTmp")
So(err, ShouldBeNil)
err = SerializeInstances(inst, f)
So(err, ShouldBeNil)
f.Seek(0, 0)
Convey("Contents of the archive should be right...", func() {
gzr, err := gzip.NewReader(f)
So(err, ShouldBeNil)
tr := tar.NewReader(gzr)
classAttrsPresent := false
manifestPresent := false
regularAttrsPresent := false
dataPresent := false
dimsPresent := false
readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
So(err, ShouldBeNil)
switch hdr.Name {
case "MANIFEST":
tr.Read(readBytes)
manifestPresent = true
break
case "CATTRS":
classAttrsPresent = true
break
case "ATTRS":
regularAttrsPresent = true
break
case "DATA":
dataPresent = true
break
case "DIMS":
dimsPresent = true
break
default:
fmt.Printf("Unknown file: %s\n", hdr.Name)
}
}
Convey("MANIFEST should be present", func() {
So(manifestPresent, ShouldBeTrue)
Convey("MANIFEST should be right...", func() {
So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
})
})
Convey("DATA should be present", func() {
So(dataPresent, ShouldBeTrue)
})
Convey("ATTRS should be present", func() {
So(regularAttrsPresent, ShouldBeTrue)
})
Convey("CATTRS should be present", func() {
So(classAttrsPresent, ShouldBeTrue)
})
Convey("DIMS should be present", func() {
So(dimsPresent, ShouldBeTrue)
})
})
Convey("Should be able to reconstruct...", func() {
f.Seek(0, 0)
dinst, err := DeserializeInstances(f)
So(err, ShouldBeNil)
So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
})
})
})
}

View File

@ -59,7 +59,7 @@ func TestSortDesc(t *testing.T) {
})
Convey("Result should match the reference", func() {
So(sortedDescending.Equal(result), ShouldBeTrue)
So(InstancesAreEqual(sortedDescending, result), ShouldBeTrue)
})
})
})
@ -90,11 +90,11 @@ func TestSortAsc(t *testing.T) {
})
Convey("Result should match the reference", func() {
So(sortedAscending.Equal(result), ShouldBeTrue)
So(InstancesAreEqual(sortedAscending, result), ShouldBeTrue)
})
Convey("First element of Result should equal known value", func() {
So(result.RowString(0), ShouldEqual, "4.30 3.00 1.10 0.10 Iris-setosa")
So(result.RowString(0), ShouldEqual, "4.3 3.0 1.1 0.1 Iris-setosa")
})
})
})

View File

@ -22,6 +22,37 @@ func GeneratePredictionVector(from FixedDataGrid) UpdatableDataGrid {
return ret
}
// CopyDenseInstancesStructure returns a new DenseInstances
// with identical structure (layout, Attributes) to the original
func CopyDenseInstances(template *DenseInstances, templateAttrs []Attribute) *DenseInstances {
instances := NewDenseInstances()
templateAgs := template.AllAttributeGroups()
for ag := range templateAgs {
agTemplate := templateAgs[ag]
if _, ok := agTemplate.(*BinaryAttributeGroup); ok {
instances.CreateAttributeGroup(ag, 0)
} else {
instances.CreateAttributeGroup(ag, 8)
}
}
for _, a := range templateAttrs {
s, err := template.GetAttribute(a)
if err != nil {
panic(err)
}
if ag, ok := template.agRevMap[s.pond]; !ok {
panic(ag)
} else {
_, err := instances.AddAttributeToAttributeGroup(a, ag)
if err != nil {
panic(err)
}
}
}
return instances
}
// GetClass is a shortcut for returning the string value of the current
// class on a given row.
//
@ -470,3 +501,34 @@ func CheckStrictlyCompatible(s1 FixedDataGrid, s2 FixedDataGrid) bool {
return true
}
// InstancesAreEqual checks whether a given Instance set is exactly
// the same as another (i.e. has the same size and values).
func InstancesAreEqual(inst, other FixedDataGrid) bool {
_, rows := inst.Size()
for _, a := range inst.AllAttributes() {
as1, err := inst.GetAttribute(a)
if err != nil {
panic(err) // That indicates some kind of error
}
as2, err := inst.GetAttribute(a)
if err != nil {
return false // Obviously has different Attributes
}
if !as1.GetAttribute().Equals(as2.GetAttribute()) {
return false
}
for i := 0; i < rows; i++ {
b1 := inst.Get(as1, i)
b2 := inst.Get(as2, i)
if !byteSeqEqual(b1, b2) {
return false
}
}
}
return true
}

View File

@ -17,7 +17,7 @@ func TestInstancesViewRows(t *testing.T) {
So(instView.rows[0], ShouldEqual, 5)
})
Convey("The reconstructed values should be correct...", func() {
str := "5.40 3.90 1.70 0.40 Iris-setosa"
str := "5.4 3.9 1.7 0.4 Iris-setosa"
row := instView.RowString(0)
So(row, ShouldEqual, str)
})
@ -99,7 +99,7 @@ func TestInstancesViewAttrs(t *testing.T) {
So(err, ShouldNotEqual, nil)
})
Convey("The filtered Attribute should not appear in the RowString", func() {
str := "3.90 1.70 0.40 Iris-setosa"
str := "3.9 1.7 0.4 Iris-setosa"
row := instView.RowString(5)
So(row, ShouldEqual, str)
})

225
examples/datasets/iris.arff Normal file
View File

@ -0,0 +1,225 @@
% 1. Title: Iris Plants Database
%
% 2. Sources:
% (a) Creator: R.A. Fisher
% (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
% (c) Date: July, 1988
%
% 3. Past Usage:
% - Publications: too many to mention!!! Here are a few.
% 1. Fisher,R.A. "The use of multiple measurements in taxonomic problems"
% Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions
% to Mathematical Statistics" (John Wiley, NY, 1950).
% 2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
% (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
% 3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
% Structure and Classification Rule for Recognition in Partially Exposed
% Environments". IEEE Transactions on Pattern Analysis and Machine
% Intelligence, Vol. PAMI-2, No. 1, 67-71.
% -- Results:
% -- very low misclassification rates (0% for the setosa class)
% 4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE
% Transactions on Information Theory, May 1972, 431-433.
% -- Results:
% -- very low misclassification rates again
% 5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II
% conceptual clustering system finds 3 classes in the data.
%
% 4. Relevant Information:
% --- This is perhaps the best known database to be found in the pattern
% recognition literature. Fisher's paper is a classic in the field
% and is referenced frequently to this day. (See Duda & Hart, for
% example.) The data set contains 3 classes of 50 instances each,
% where each class refers to a type of iris plant. One class is
% linearly separable from the other 2; the latter are NOT linearly
% separable from each other.
% --- Predicted attribute: class of iris plant.
% --- This is an exceedingly simple domain.
%
% 5. Number of Instances: 150 (50 in each of three classes)
%
% 6. Number of Attributes: 4 numeric, predictive attributes and the class
%
% 7. Attribute Information:
% 1. sepal length in cm
% 2. sepal width in cm
% 3. petal length in cm
% 4. petal width in cm
% 5. class:
% -- Iris Setosa
% -- Iris Versicolour
% -- Iris Virginica
%
% 8. Missing Attribute Values: None
%
% Summary Statistics:
% Min Max Mean SD Class Correlation
% sepal length: 4.3 7.9 5.84 0.83 0.7826
% sepal width: 2.0 4.4 3.05 0.43 -0.4194
% petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
% petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
%
% 9. Class Distribution: 33.3% for each of 3 classes.
@RELATION iris
@ATTRIBUTE sepallength REAL
@ATTRIBUTE sepalwidth REAL
@ATTRIBUTE petallength REAL
@ATTRIBUTE petalwidth REAL
@ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica}
@DATA
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica
%
%
%

View File

@ -0,0 +1,23 @@
@relation weather
@attribute outlook {sunny, overcast, rainy}
@attribute temperature real
@attribute humidity real
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}
@data
sunny,85,85,FALSE,no
sunny,80,90,TRUE,no
overcast,83,86,FALSE,yes
rainy,70,96,FALSE,yes
rainy,68,80,FALSE,yes
rainy,65,70,TRUE,no
overcast,64,65,TRUE,yes
sunny,72,95,FALSE,no
sunny,69,70,FALSE,yes
rainy,75,80,FALSE,yes
sunny,75,70,TRUE,yes
overcast,72,90,TRUE,yes
overcast,81,75,FALSE,yes
rainy,71,91,TRUE,no

View File

@ -0,0 +1,32 @@
// Demonstrates decision tree classification
package main
import (
"encoding/json"
"fmt"
"github.com/sjwhitworth/golearn/base"
)
func main() {
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
for _, a := range iris.AllAttributes() {
var ac base.CategoricalAttribute
var af base.FloatAttribute
s, err := json.Marshal(a)
if err != nil {
panic(err)
}
fmt.Println(string(s))
err = json.Unmarshal(s, &af)
fmt.Println(af.String())
err = json.Unmarshal(s, &ac)
fmt.Println(ac.String())
}
}

View File

@ -24,14 +24,14 @@ func TestLogisticRegression(t *testing.T) {
Z, err := lr.Predict(Y)
So(err, ShouldEqual, nil)
Convey("The result should be 1", func() {
So(Z.RowString(0), ShouldEqual, "1.00")
So(Z.RowString(0), ShouldEqual, "1.0")
})
})
Convey("When predicting the label of second vector", func() {
Z, err := lr.Predict(Y)
So(err, ShouldEqual, nil)
Convey("The result should be -1", func() {
So(Z.RowString(1), ShouldEqual, "-1.00")
So(Z.RowString(1), ShouldEqual, "-1.0")
})
})
})