mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Merge pull request #76 from amitkgupta/master
Improve ConfusionMatrix GetSummary formatting, and several tiny improvements
This commit is contained in:
commit
c8bf178662
@ -35,7 +35,7 @@ func TestBAGSimple(t *testing.T) {
|
|||||||
} else if name == "2" {
|
} else if name == "2" {
|
||||||
attrSpecs[2] = a
|
attrSpecs[2] = a
|
||||||
} else {
|
} else {
|
||||||
panic(name)
|
t.Fatalf("Unexpected attribute name '%s'", name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ func TestBAG(t *testing.T) {
|
|||||||
} else if name == "2" {
|
} else if name == "2" {
|
||||||
attrSpecs[2] = a
|
attrSpecs[2] = a
|
||||||
} else {
|
} else {
|
||||||
panic(name)
|
t.Fatalf("Unexpected attribute name '%s'", name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
18
base/csv.go
18
base/csv.go
@ -11,10 +11,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ParseCSVGetRows returns the number of rows in a given file.
|
// ParseCSVGetRows returns the number of rows in a given file.
|
||||||
func ParseCSVGetRows(filepath string) int {
|
func ParseCSVGetRows(filepath string) (int, error) {
|
||||||
file, err := os.Open(filepath)
|
file, err := os.Open(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return 0, err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
@ -25,11 +25,11 @@ func ParseCSVGetRows(filepath string) int {
|
|||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
panic(err)
|
return 0, err
|
||||||
}
|
}
|
||||||
counter++
|
counter++
|
||||||
}
|
}
|
||||||
return counter
|
return counter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseCSVGetAttributes returns an ordered slice of appropriate-ly typed
|
// ParseCSVGetAttributes returns an ordered slice of appropriate-ly typed
|
||||||
@ -157,7 +157,11 @@ func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGri
|
|||||||
func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInstances, err error) {
|
func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInstances, err error) {
|
||||||
|
|
||||||
// Read the number of rows in the file
|
// Read the number of rows in the file
|
||||||
rowCount := ParseCSVGetRows(filepath)
|
rowCount, err := ParseCSVGetRows(filepath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if hasHeaders {
|
if hasHeaders {
|
||||||
rowCount--
|
rowCount--
|
||||||
}
|
}
|
||||||
@ -176,7 +180,7 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInst
|
|||||||
// Read the input
|
// Read the input
|
||||||
file, err := os.Open(filepath)
|
file, err := os.Open(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
reader := csv.NewReader(file)
|
reader := csv.NewReader(file)
|
||||||
@ -188,7 +192,7 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInst
|
|||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
if rowCounter == 0 {
|
if rowCounter == 0 {
|
||||||
if hasHeaders {
|
if hasHeaders {
|
||||||
|
@ -4,78 +4,92 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseCSVGetRows(testEnv *testing.T) {
|
func TestParseCSVGetRows(t *testing.T) {
|
||||||
lineCount := ParseCSVGetRows("../examples/datasets/iris.csv")
|
lineCount, err := ParseCSVGetRows("../examples/datasets/iris.csv")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error())
|
||||||
|
}
|
||||||
if lineCount != 150 {
|
if lineCount != 150 {
|
||||||
testEnv.Errorf("Should have %d lines, has %d", 150, lineCount)
|
t.Errorf("Should have %d lines, has %d", 150, lineCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
lineCount, err = ParseCSVGetRows("../examples/datasets/iris_headers.csv")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
lineCount = ParseCSVGetRows("../examples/datasets/iris_headers.csv")
|
|
||||||
if lineCount != 151 {
|
if lineCount != 151 {
|
||||||
testEnv.Errorf("Should have %d lines, has %d", 151, lineCount)
|
t.Errorf("Should have %d lines, has %d", 151, lineCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseCCSVGetAttributes(testEnv *testing.T) {
|
func TestParseCSVGetRowsWithMissingFile(t *testing.T) {
|
||||||
|
_, err := ParseCSVGetRows("../examples/datasets/non-existent.csv")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected ParseCSVGetRows to return error when given path to non-existent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCCSVGetAttributes(t *testing.T) {
|
||||||
attrs := ParseCSVGetAttributes("../examples/datasets/iris_headers.csv", true)
|
attrs := ParseCSVGetAttributes("../examples/datasets/iris_headers.csv", true)
|
||||||
if attrs[0].GetType() != Float64Type {
|
if attrs[0].GetType() != Float64Type {
|
||||||
testEnv.Errorf("First attribute should be a float, %s", attrs[0])
|
t.Errorf("First attribute should be a float, %s", attrs[0])
|
||||||
}
|
}
|
||||||
if attrs[0].GetName() != "Sepal length" {
|
if attrs[0].GetName() != "Sepal length" {
|
||||||
testEnv.Errorf(attrs[0].GetName())
|
t.Errorf(attrs[0].GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
if attrs[4].GetType() != CategoricalType {
|
if attrs[4].GetType() != CategoricalType {
|
||||||
testEnv.Errorf("Final attribute should be categorical, %s", attrs[4])
|
t.Errorf("Final attribute should be categorical, %s", attrs[4])
|
||||||
}
|
}
|
||||||
if attrs[4].GetName() != "Species" {
|
if attrs[4].GetName() != "Species" {
|
||||||
testEnv.Error(attrs[4])
|
t.Error(attrs[4])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseCsvSniffAttributeTypes(testEnv *testing.T) {
|
func TestParseCsvSniffAttributeTypes(t *testing.T) {
|
||||||
attrs := ParseCSVSniffAttributeTypes("../examples/datasets/iris_headers.csv", true)
|
attrs := ParseCSVSniffAttributeTypes("../examples/datasets/iris_headers.csv", true)
|
||||||
if attrs[0].GetType() != Float64Type {
|
if attrs[0].GetType() != Float64Type {
|
||||||
testEnv.Errorf("First attribute should be a float, %s", attrs[0])
|
t.Errorf("First attribute should be a float, %s", attrs[0])
|
||||||
}
|
}
|
||||||
if attrs[1].GetType() != Float64Type {
|
if attrs[1].GetType() != Float64Type {
|
||||||
testEnv.Errorf("Second attribute should be a float, %s", attrs[1])
|
t.Errorf("Second attribute should be a float, %s", attrs[1])
|
||||||
}
|
}
|
||||||
if attrs[2].GetType() != Float64Type {
|
if attrs[2].GetType() != Float64Type {
|
||||||
testEnv.Errorf("Third attribute should be a float, %s", attrs[2])
|
t.Errorf("Third attribute should be a float, %s", attrs[2])
|
||||||
}
|
}
|
||||||
if attrs[3].GetType() != Float64Type {
|
if attrs[3].GetType() != Float64Type {
|
||||||
testEnv.Errorf("Fourth attribute should be a float, %s", attrs[3])
|
t.Errorf("Fourth attribute should be a float, %s", attrs[3])
|
||||||
}
|
}
|
||||||
if attrs[4].GetType() != CategoricalType {
|
if attrs[4].GetType() != CategoricalType {
|
||||||
testEnv.Errorf("Final attribute should be categorical, %s", attrs[4])
|
t.Errorf("Final attribute should be categorical, %s", attrs[4])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseCSVSniffAttributeNamesWithHeaders(testEnv *testing.T) {
|
func TestParseCSVSniffAttributeNamesWithHeaders(t *testing.T) {
|
||||||
attrs := ParseCSVSniffAttributeNames("../examples/datasets/iris_headers.csv", true)
|
attrs := ParseCSVSniffAttributeNames("../examples/datasets/iris_headers.csv", true)
|
||||||
if attrs[0] != "Sepal length" {
|
if attrs[0] != "Sepal length" {
|
||||||
testEnv.Error(attrs[0])
|
t.Error(attrs[0])
|
||||||
}
|
}
|
||||||
if attrs[1] != "Sepal width" {
|
if attrs[1] != "Sepal width" {
|
||||||
testEnv.Error(attrs[1])
|
t.Error(attrs[1])
|
||||||
}
|
}
|
||||||
if attrs[2] != "Petal length" {
|
if attrs[2] != "Petal length" {
|
||||||
testEnv.Error(attrs[2])
|
t.Error(attrs[2])
|
||||||
}
|
}
|
||||||
if attrs[3] != "Petal width" {
|
if attrs[3] != "Petal width" {
|
||||||
testEnv.Error(attrs[3])
|
t.Error(attrs[3])
|
||||||
}
|
}
|
||||||
if attrs[4] != "Species" {
|
if attrs[4] != "Species" {
|
||||||
testEnv.Error(attrs[4])
|
t.Error(attrs[4])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadInstances(testEnv *testing.T) {
|
func TestParseCSVToInstances(t *testing.T) {
|
||||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
row1 := inst.RowString(0)
|
row1 := inst.RowString(0)
|
||||||
@ -83,27 +97,34 @@ func TestReadInstances(testEnv *testing.T) {
|
|||||||
row3 := inst.RowString(100)
|
row3 := inst.RowString(100)
|
||||||
|
|
||||||
if row1 != "5.10 3.50 1.40 0.20 Iris-setosa" {
|
if row1 != "5.10 3.50 1.40 0.20 Iris-setosa" {
|
||||||
testEnv.Error(row1)
|
t.Error(row1)
|
||||||
}
|
}
|
||||||
if row2 != "7.00 3.20 4.70 1.40 Iris-versicolor" {
|
if row2 != "7.00 3.20 4.70 1.40 Iris-versicolor" {
|
||||||
testEnv.Error(row2)
|
t.Error(row2)
|
||||||
}
|
}
|
||||||
if row3 != "6.30 3.30 6.00 2.50 Iris-virginica" {
|
if row3 != "6.30 3.30 6.00 2.50 Iris-virginica" {
|
||||||
testEnv.Error(row3)
|
t.Error(row3)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadAwkwardInsatnces(testEnv *testing.T) {
|
func TestParseCSVToInstancesWithMissingFile(t *testing.T) {
|
||||||
|
_, err := ParseCSVToInstances("../examples/datasets/non-existent.csv", true)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected ParseCSVToInstances to return error when given path to non-existent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadAwkwardInsatnces(t *testing.T) {
|
||||||
inst, err := ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
inst, err := ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
attrs := inst.AllAttributes()
|
attrs := inst.AllAttributes()
|
||||||
if attrs[0].GetType() != Float64Type {
|
if attrs[0].GetType() != Float64Type {
|
||||||
testEnv.Error("Should be float!")
|
t.Error("Should be float!")
|
||||||
}
|
}
|
||||||
if attrs[1].GetType() != CategoricalType {
|
if attrs[1].GetType() != CategoricalType {
|
||||||
testEnv.Error("Should be discrete!")
|
t.Error("Should be discrete!")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,18 +11,18 @@ func TestAllocFixed(t *testing.T) {
|
|||||||
Convey("Creating a non-existent file should succeed", t, func() {
|
Convey("Creating a non-existent file should succeed", t, func() {
|
||||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Mapping the file should suceed", func() {
|
Convey("Mapping the file should succeed", func() {
|
||||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Allocation should suceed", func() {
|
Convey("Allocation should succeed", func() {
|
||||||
r, err := mapping.AllocPages(1, 2)
|
r, err := mapping.AllocPages(1, 2)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
So(r.Start.Byte, ShouldEqual, 4*os.Getpagesize())
|
So(r.Start.Byte, ShouldEqual, 4*os.Getpagesize())
|
||||||
So(r.Start.Segment, ShouldEqual, 0)
|
So(r.Start.Segment, ShouldEqual, 0)
|
||||||
Convey("Unmapping the file should suceed", func() {
|
Convey("Unmapping the file should succeed", func() {
|
||||||
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Remapping the file should suceed", func() {
|
Convey("Remapping the file should succeed", func() {
|
||||||
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
|
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
|
||||||
Convey("Should get the same allocations back", func() {
|
Convey("Should get the same allocations back", func() {
|
||||||
rr, err := mapping.GetThreadBlocks(2)
|
rr, err := mapping.GetThreadBlocks(2)
|
||||||
@ -41,20 +41,20 @@ func TestAllocWithExtraContentsBlock(t *testing.T) {
|
|||||||
Convey("Creating a non-existent file should succeed", t, func() {
|
Convey("Creating a non-existent file should succeed", t, func() {
|
||||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Mapping the file should suceed", func() {
|
Convey("Mapping the file should succeed", func() {
|
||||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Allocation of 10 pages should suceed", func() {
|
Convey("Allocation of 10 pages should succeed", func() {
|
||||||
allocated := make([]EdfRange, 10)
|
allocated := make([]EdfRange, 10)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
r, err := mapping.AllocPages(1, 2)
|
r, err := mapping.AllocPages(1, 2)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
allocated[i] = r
|
allocated[i] = r
|
||||||
}
|
}
|
||||||
Convey("Unmapping the file should suceed", func() {
|
Convey("Unmapping the file should succeed", func() {
|
||||||
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Remapping the file should suceed", func() {
|
Convey("Remapping the file should succeed", func() {
|
||||||
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
|
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
|
||||||
Convey("Should get the same allocations back", func() {
|
Convey("Should get the same allocations back", func() {
|
||||||
rr, err := mapping.GetThreadBlocks(2)
|
rr, err := mapping.GetThreadBlocks(2)
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAnonMap(t *testing.T) {
|
func TestAnonMap(t *testing.T) {
|
||||||
Convey("Anonymous mapping should suceed", t, func() {
|
Convey("Anonymous mapping should succeed", t, func() {
|
||||||
mapping, err := EdfAnonMap()
|
mapping, err := EdfAnonMap()
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
bytes := mapping.m[0]
|
bytes := mapping.m[0]
|
||||||
@ -39,10 +39,10 @@ func TestFileCreate(t *testing.T) {
|
|||||||
Convey("Creating a non-existent file should succeed", t, func() {
|
Convey("Creating a non-existent file should succeed", t, func() {
|
||||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Mapping the file should suceed", func() {
|
Convey("Mapping the file should succeed", func() {
|
||||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Unmapping the file should suceed", func() {
|
Convey("Unmapping the file should succeed", func() {
|
||||||
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
})
|
})
|
||||||
@ -90,7 +90,7 @@ func TestFileThreadCounter(t *testing.T) {
|
|||||||
Convey("Creating a non-existent file should succeed", t, func() {
|
Convey("Creating a non-existent file should succeed", t, func() {
|
||||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Mapping the file should suceed", func() {
|
Convey("Mapping the file should succeed", func() {
|
||||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("The file should have two threads to start with", func() {
|
Convey("The file should have two threads to start with", func() {
|
||||||
|
@ -2,17 +2,17 @@ package edf
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"testing"
|
|
||||||
"os"
|
"os"
|
||||||
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestThreadDeserialize(T *testing.T) {
|
func TestThreadDeserialize(t *testing.T) {
|
||||||
bytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
|
bytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
|
||||||
Convey("Given a byte slice", T, func() {
|
Convey("Given a byte slice", t, func() {
|
||||||
var t Thread
|
var thread Thread
|
||||||
size := t.Deserialize(bytes)
|
size := thread.Deserialize(bytes)
|
||||||
Convey("Decoded name should be SYSTEM", func() {
|
Convey("Decoded name should be SYSTEM", func() {
|
||||||
So(t.name, ShouldEqual, "SYSTEM")
|
So(thread.name, ShouldEqual, "SYSTEM")
|
||||||
})
|
})
|
||||||
Convey("Size should be the same as the array", func() {
|
Convey("Size should be the same as the array", func() {
|
||||||
So(size, ShouldEqual, len(bytes))
|
So(size, ShouldEqual, len(bytes))
|
||||||
@ -20,34 +20,34 @@ func TestThreadDeserialize(T *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestThreadSerialize(T *testing.T) {
|
func TestThreadSerialize(t *testing.T) {
|
||||||
var t Thread
|
var thread Thread
|
||||||
refBytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
|
refBytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
|
||||||
t.name = "SYSTEM"
|
thread.name = "SYSTEM"
|
||||||
t.id = 1
|
thread.id = 1
|
||||||
toBytes := make([]byte, len(refBytes))
|
toBytes := make([]byte, len(refBytes))
|
||||||
Convey("Should serialize correctly", T, func() {
|
Convey("Should serialize correctly", t, func() {
|
||||||
t.Serialize(toBytes)
|
thread.Serialize(toBytes)
|
||||||
So(toBytes, ShouldResemble, refBytes)
|
So(toBytes, ShouldResemble, refBytes)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestThreadFindAndWrite(T *testing.T) {
|
func TestThreadFindAndWrite(t *testing.T) {
|
||||||
Convey("Creating a non-existent file should succeed", T, func() {
|
Convey("Creating a non-existent file should succeed", t, func() {
|
||||||
tempFile, err := os.OpenFile("hello.db", os.O_RDWR | os.O_TRUNC | os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
tempFile, err := os.OpenFile("hello.db", os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Mapping the file should suceed", func() {
|
Convey("Mapping the file should succeed", func() {
|
||||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Writing the thread should succeed", func () {
|
Convey("Writing the thread should succeed", func() {
|
||||||
t := NewThread(mapping, "MyNameISWhat")
|
t := NewThread(mapping, "MyNameISWhat")
|
||||||
Convey("Thread number should be 3", func () {
|
Convey("Thread number should be 3", func() {
|
||||||
So(t.id, ShouldEqual, 3)
|
So(t.id, ShouldEqual, 3)
|
||||||
})
|
})
|
||||||
Convey("Writing the thread should succeed", func() {
|
Convey("Writing the thread should succeed", func() {
|
||||||
err := mapping.WriteThread(t)
|
err := mapping.WriteThread(t)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
Convey("Should be able to find the thread again later", func() {
|
Convey("Should be able to find the thread again later", func() {
|
||||||
id, err := mapping.FindThread("MyNameISWhat")
|
id, err := mapping.FindThread("MyNameISWhat")
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
So(id, ShouldEqual, 3)
|
So(id, ShouldEqual, 3)
|
||||||
|
@ -1,28 +1,28 @@
|
|||||||
package base
|
package base
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FloatAttribute is an implementation which stores floating point
|
// FloatAttribute is an implementation which stores floating point
|
||||||
// representations of numbers.
|
// representations of numbers.
|
||||||
type FloatAttribute struct {
|
type FloatAttribute struct {
|
||||||
Name string
|
Name string
|
||||||
Precision int
|
Precision int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFloatAttribute returns a new FloatAttribute with a default
|
// NewFloatAttribute returns a new FloatAttribute with a default
|
||||||
// precision of 2 decimal places
|
// precision of 2 decimal places
|
||||||
func NewFloatAttribute(name string) *FloatAttribute {
|
func NewFloatAttribute(name string) *FloatAttribute {
|
||||||
return &FloatAttribute{name, 2}
|
return &FloatAttribute{name, 2}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compatable checks whether this FloatAttribute can be ponded with another
|
// Compatable checks whether this FloatAttribute can be ponded with another
|
||||||
// Attribute (checks if they're both FloatAttributes)
|
// Attribute (checks if they're both FloatAttributes)
|
||||||
func (Attr *FloatAttribute) Compatable(other Attribute) bool {
|
func (Attr *FloatAttribute) Compatable(other Attribute) bool {
|
||||||
_, ok := other.(*FloatAttribute)
|
_, ok := other.(*FloatAttribute)
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// Equals tests a FloatAttribute for equality with another Attribute.
|
// Equals tests a FloatAttribute for equality with another Attribute.
|
||||||
@ -30,50 +30,50 @@ func (Attr *FloatAttribute) Compatable(other Attribute) bool {
|
|||||||
// Returns false if the other Attribute has a different name
|
// Returns false if the other Attribute has a different name
|
||||||
// or if the other Attribute is not a FloatAttribute.
|
// or if the other Attribute is not a FloatAttribute.
|
||||||
func (Attr *FloatAttribute) Equals(other Attribute) bool {
|
func (Attr *FloatAttribute) Equals(other Attribute) bool {
|
||||||
// Check whether this FloatAttribute is equal to another
|
// Check whether this FloatAttribute is equal to another
|
||||||
_, ok := other.(*FloatAttribute)
|
_, ok := other.(*FloatAttribute)
|
||||||
if !ok {
|
if !ok {
|
||||||
// Not the same type, so can't be equal
|
// Not the same type, so can't be equal
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if Attr.GetName() != other.GetName() {
|
if Attr.GetName() != other.GetName() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetName returns this FloatAttribute's human-readable name.
|
// GetName returns this FloatAttribute's human-readable name.
|
||||||
func (Attr *FloatAttribute) GetName() string {
|
func (Attr *FloatAttribute) GetName() string {
|
||||||
return Attr.Name
|
return Attr.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetName sets this FloatAttribute's human-readable name.
|
// SetName sets this FloatAttribute's human-readable name.
|
||||||
func (Attr *FloatAttribute) SetName(name string) {
|
func (Attr *FloatAttribute) SetName(name string) {
|
||||||
Attr.Name = name
|
Attr.Name = name
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetType returns Float64Type.
|
// GetType returns Float64Type.
|
||||||
func (Attr *FloatAttribute) GetType() int {
|
func (Attr *FloatAttribute) GetType() int {
|
||||||
return Float64Type
|
return Float64Type
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a human-readable summary of this Attribute.
|
// String returns a human-readable summary of this Attribute.
|
||||||
// e.g. "FloatAttribute(Sepal Width)"
|
// e.g. "FloatAttribute(Sepal Width)"
|
||||||
func (Attr *FloatAttribute) String() string {
|
func (Attr *FloatAttribute) String() string {
|
||||||
return fmt.Sprintf("FloatAttribute(%s)", Attr.Name)
|
return fmt.Sprintf("FloatAttribute(%s)", Attr.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckSysValFromString confirms whether a given rawVal can
|
// CheckSysValFromString confirms whether a given rawVal can
|
||||||
// be converted into a valid system representation. If it can't,
|
// be converted into a valid system representation. If it can't,
|
||||||
// the returned value is nil.
|
// the returned value is nil.
|
||||||
func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) ([]byte, error) {
|
func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) ([]byte, error) {
|
||||||
f, err := strconv.ParseFloat(rawVal, 64)
|
f, err := strconv.ParseFloat(rawVal, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := PackFloatToBytes(f)
|
ret := PackFloatToBytes(f)
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSysValFromString parses the given rawVal string to a float64 and returns it.
|
// GetSysValFromString parses the given rawVal string to a float64 and returns it.
|
||||||
@ -82,22 +82,22 @@ func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) ([]byte, error)
|
|||||||
// IMPORTANT: This function panic()s if rawVal is not a valid float.
|
// IMPORTANT: This function panic()s if rawVal is not a valid float.
|
||||||
// Use CheckSysValFromString to confirm.
|
// Use CheckSysValFromString to confirm.
|
||||||
func (Attr *FloatAttribute) GetSysValFromString(rawVal string) []byte {
|
func (Attr *FloatAttribute) GetSysValFromString(rawVal string) []byte {
|
||||||
f, err := Attr.CheckSysValFromString(rawVal)
|
f, err := Attr.CheckSysValFromString(rawVal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFloatFromSysVal converts a given system value to a float
|
// GetFloatFromSysVal converts a given system value to a float
|
||||||
func (Attr *FloatAttribute) GetFloatFromSysVal(rawVal []byte) float64 {
|
func (Attr *FloatAttribute) GetFloatFromSysVal(rawVal []byte) float64 {
|
||||||
return UnpackBytesToFloat(rawVal)
|
return UnpackBytesToFloat(rawVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStringFromSysVal converts a given system value to to a string with two decimal
|
// GetStringFromSysVal converts a given system value to to a string with two decimal
|
||||||
// places of precision.
|
// places of precision.
|
||||||
func (Attr *FloatAttribute) GetStringFromSysVal(rawVal []byte) string {
|
func (Attr *FloatAttribute) GetStringFromSysVal(rawVal []byte) string {
|
||||||
f := UnpackBytesToFloat(rawVal)
|
f := UnpackBytesToFloat(rawVal)
|
||||||
formatString := fmt.Sprintf("%%.%df", Attr.Precision)
|
formatString := fmt.Sprintf("%%.%df", Attr.Precision)
|
||||||
return fmt.Sprintf(formatString, f)
|
return fmt.Sprintf(formatString, f)
|
||||||
}
|
}
|
||||||
|
@ -1,19 +1,18 @@
|
|||||||
package base
|
package base
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLazySortDesc(testEnv *testing.T) {
|
func TestLazySortDesc(t *testing.T) {
|
||||||
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,67 +20,67 @@ func TestLazySortDesc(testEnv *testing.T) {
|
|||||||
as2 := ResolveAllAttributes(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
|
|
||||||
if isSortedDesc(inst1, as1[0]) {
|
if isSortedDesc(inst1, as1[0]) {
|
||||||
testEnv.Error("Can't test descending sort order")
|
t.Error("Can't test descending sort order")
|
||||||
}
|
}
|
||||||
if !isSortedDesc(inst2, as2[0]) {
|
if !isSortedDesc(inst2, as2[0]) {
|
||||||
testEnv.Error("Reference data not sorted in descending order!")
|
t.Error("Reference data not sorted in descending order!")
|
||||||
}
|
}
|
||||||
|
|
||||||
inst, err := LazySort(inst1, Descending, as1[0:len(as1)-1])
|
inst, err := LazySort(inst1, Descending, as1[0:len(as1)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !isSortedDesc(inst, as1[0]) {
|
if !isSortedDesc(inst, as1[0]) {
|
||||||
testEnv.Error("Instances are not sorted in descending order")
|
t.Error("Instances are not sorted in descending order")
|
||||||
testEnv.Error(inst1)
|
t.Error(inst1)
|
||||||
}
|
}
|
||||||
if !inst2.Equal(inst) {
|
if !inst2.Equal(inst) {
|
||||||
testEnv.Error("Instances don't match")
|
t.Error("Instances don't match")
|
||||||
testEnv.Error(inst)
|
t.Error(inst)
|
||||||
testEnv.Error(inst2)
|
t.Error(inst2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLazySortAsc(testEnv *testing.T) {
|
func TestLazySortAsc(t *testing.T) {
|
||||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
as1 := ResolveAllAttributes(inst)
|
as1 := ResolveAllAttributes(inst)
|
||||||
if isSortedAsc(inst, as1[0]) {
|
if isSortedAsc(inst, as1[0]) {
|
||||||
testEnv.Error("Can't test ascending sort on something ascending already")
|
t.Error("Can't test ascending sort on something ascending already")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
insts, err := LazySort(inst, Ascending, as1)
|
insts, err := LazySort(inst, Ascending, as1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !isSortedAsc(insts, as1[0]) {
|
if !isSortedAsc(insts, as1[0]) {
|
||||||
testEnv.Error("Instances are not sorted in ascending order")
|
t.Error("Instances are not sorted in ascending order")
|
||||||
testEnv.Error(insts)
|
t.Error(insts)
|
||||||
}
|
}
|
||||||
|
|
||||||
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
|
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
as2 := ResolveAllAttributes(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
if !isSortedAsc(inst2, as2[0]) {
|
if !isSortedAsc(inst2, as2[0]) {
|
||||||
testEnv.Error("This file should be sorted in ascending order")
|
t.Error("This file should be sorted in ascending order")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !inst2.Equal(insts) {
|
if !inst2.Equal(insts) {
|
||||||
testEnv.Error("Instances don't match")
|
t.Error("Instances don't match")
|
||||||
testEnv.Error(inst)
|
t.Error(inst)
|
||||||
testEnv.Error(inst2)
|
t.Error(inst2)
|
||||||
}
|
}
|
||||||
|
|
||||||
rowStr := insts.RowString(0)
|
rowStr := insts.RowString(0)
|
||||||
ref := "4.30 3.00 1.10 0.10 Iris-setosa"
|
ref := "4.30 3.00 1.10 0.10 Iris-setosa"
|
||||||
if rowStr != ref {
|
if rowStr != ref {
|
||||||
panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref))
|
t.Fatalf("'%s' != '%s'", rowStr, ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -32,15 +32,15 @@ func isSortedDesc(inst FixedDataGrid, attr AttributeSpec) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortDesc(testEnv *testing.T) {
|
func TestSortDesc(t *testing.T) {
|
||||||
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,57 +48,57 @@ func TestSortDesc(testEnv *testing.T) {
|
|||||||
as2 := ResolveAllAttributes(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
|
|
||||||
if isSortedDesc(inst1, as1[0]) {
|
if isSortedDesc(inst1, as1[0]) {
|
||||||
testEnv.Error("Can't test descending sort order")
|
t.Error("Can't test descending sort order")
|
||||||
}
|
}
|
||||||
if !isSortedDesc(inst2, as2[0]) {
|
if !isSortedDesc(inst2, as2[0]) {
|
||||||
testEnv.Error("Reference data not sorted in descending order!")
|
t.Error("Reference data not sorted in descending order!")
|
||||||
}
|
}
|
||||||
|
|
||||||
Sort(inst1, Descending, as1[0:len(as1)-1])
|
Sort(inst1, Descending, as1[0:len(as1)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !isSortedDesc(inst1, as1[0]) {
|
if !isSortedDesc(inst1, as1[0]) {
|
||||||
testEnv.Error("Instances are not sorted in descending order")
|
t.Error("Instances are not sorted in descending order")
|
||||||
testEnv.Error(inst1)
|
t.Error(inst1)
|
||||||
}
|
}
|
||||||
if !inst2.Equal(inst1) {
|
if !inst2.Equal(inst1) {
|
||||||
testEnv.Error("Instances don't match")
|
t.Error("Instances don't match")
|
||||||
testEnv.Error(inst1)
|
t.Error(inst1)
|
||||||
testEnv.Error(inst2)
|
t.Error(inst2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortAsc(testEnv *testing.T) {
|
func TestSortAsc(t *testing.T) {
|
||||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
as1 := ResolveAllAttributes(inst)
|
as1 := ResolveAllAttributes(inst)
|
||||||
if isSortedAsc(inst, as1[0]) {
|
if isSortedAsc(inst, as1[0]) {
|
||||||
testEnv.Error("Can't test ascending sort on something ascending already")
|
t.Error("Can't test ascending sort on something ascending already")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
Sort(inst, Ascending, as1[0:1])
|
Sort(inst, Ascending, as1[0:1])
|
||||||
if !isSortedAsc(inst, as1[0]) {
|
if !isSortedAsc(inst, as1[0]) {
|
||||||
testEnv.Error("Instances are not sorted in ascending order")
|
t.Error("Instances are not sorted in ascending order")
|
||||||
testEnv.Error(inst)
|
t.Error(inst)
|
||||||
}
|
}
|
||||||
|
|
||||||
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
|
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
as2 := ResolveAllAttributes(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
if !isSortedAsc(inst2, as2[0]) {
|
if !isSortedAsc(inst2, as2[0]) {
|
||||||
testEnv.Error("This file should be sorted in ascending order")
|
t.Error("This file should be sorted in ascending order")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !inst2.Equal(inst) {
|
if !inst2.Equal(inst) {
|
||||||
testEnv.Error("Instances don't match")
|
t.Error("Instances don't match")
|
||||||
testEnv.Error(inst)
|
t.Error(inst)
|
||||||
testEnv.Error(inst2)
|
t.Error(inst2)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
26
base/util.go
26
base/util.go
@ -1,9 +1,6 @@
|
|||||||
package base
|
package base
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@ -62,29 +59,6 @@ func UnpackBytesToFloat(val []byte) float64 {
|
|||||||
return *(*float64)(pb)
|
return *(*float64)(pb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func xorFloatOp(item float64) float64 {
|
|
||||||
var ret float64
|
|
||||||
var tmp int64
|
|
||||||
buf := bytes.NewBuffer(nil)
|
|
||||||
binary.Write(buf, binary.LittleEndian, item)
|
|
||||||
binary.Read(buf, binary.LittleEndian, &tmp)
|
|
||||||
tmp ^= -1 << 63
|
|
||||||
binary.Write(buf, binary.LittleEndian, tmp)
|
|
||||||
binary.Read(buf, binary.LittleEndian, &ret)
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func printFloatByteArr(arr [][]byte) {
|
|
||||||
buf := bytes.NewBuffer(nil)
|
|
||||||
var f float64
|
|
||||||
for _, b := range arr {
|
|
||||||
buf.Write(b)
|
|
||||||
binary.Read(buf, binary.LittleEndian, &f)
|
|
||||||
f = xorFloatOp(f)
|
|
||||||
fmt.Println(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func byteSeqEqual(a, b []byte) bool {
|
func byteSeqEqual(a, b []byte) bool {
|
||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return false
|
return false
|
||||||
|
@ -2,9 +2,9 @@ package ensemble
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
meta "github.com/sjwhitworth/golearn/meta"
|
"github.com/sjwhitworth/golearn/meta"
|
||||||
trees "github.com/sjwhitworth/golearn/trees"
|
"github.com/sjwhitworth/golearn/trees"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RandomForest classifies instances using an ensemble
|
// RandomForest classifies instances using an ensemble
|
||||||
@ -31,6 +31,15 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
|
|||||||
|
|
||||||
// Fit builds the RandomForest on the specified instances
|
// Fit builds the RandomForest on the specified instances
|
||||||
func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
||||||
|
numNonClassAttributes := len(base.NonClassAttributes(on))
|
||||||
|
if numNonClassAttributes < f.Features {
|
||||||
|
panic(fmt.Sprintf(
|
||||||
|
"Random forest with %d features cannot fit data grid with %d non-class attributes",
|
||||||
|
f.Features,
|
||||||
|
numNonClassAttributes,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
f.Model = new(meta.BaggedModel)
|
f.Model = new(meta.BaggedModel)
|
||||||
f.Model.RandomFeatures = f.Features
|
f.Model.RandomFeatures = f.Features
|
||||||
for i := 0; i < f.ForestSize; i++ {
|
for i := 0; i < f.ForestSize; i++ {
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
package ensemble
|
package ensemble
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
|
||||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||||
filters "github.com/sjwhitworth/golearn/filters"
|
"github.com/sjwhitworth/golearn/filters"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRandomForest1(testEnv *testing.T) {
|
func TestRandomForest1(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
@ -26,8 +25,6 @@ func TestRandomForest1(testEnv *testing.T) {
|
|||||||
rf := NewRandomForest(10, 3)
|
rf := NewRandomForest(10, 3)
|
||||||
rf.Fit(trainData)
|
rf.Fit(trainData)
|
||||||
predictions := rf.Predict(testData)
|
predictions := rf.Predict(testData)
|
||||||
fmt.Println(predictions)
|
|
||||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||||
fmt.Println(confusionMat)
|
_ = eval.GetSummary(confusionMat)
|
||||||
fmt.Println(eval.GetSummary(confusionMat))
|
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,8 @@ package evaluation
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"text/tabwriter"
|
||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -176,18 +178,22 @@ func GetMacroRecall(c ConfusionMatrix) float64 {
|
|||||||
// ConfusionMatrix
|
// ConfusionMatrix
|
||||||
func GetSummary(c ConfusionMatrix) string {
|
func GetSummary(c ConfusionMatrix) string {
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
|
w := new(tabwriter.Writer)
|
||||||
|
w.Init(&buffer, 0, 8, 0, '\t', 0)
|
||||||
|
|
||||||
|
fmt.Fprintln(w, "Reference Class\tTrue Positives\tFalse Positives\tTrue Negatives\tPrecision\tRecall\tF1 Score")
|
||||||
|
fmt.Fprintln(w, "---------------\t--------------\t---------------\t--------------\t---------\t------\t--------")
|
||||||
for k := range c {
|
for k := range c {
|
||||||
buffer.WriteString(k)
|
|
||||||
buffer.WriteString("\t")
|
|
||||||
tp := GetTruePositives(k, c)
|
tp := GetTruePositives(k, c)
|
||||||
fp := GetFalsePositives(k, c)
|
fp := GetFalsePositives(k, c)
|
||||||
tn := GetTrueNegatives(k, c)
|
tn := GetTrueNegatives(k, c)
|
||||||
prec := GetPrecision(k, c)
|
prec := GetPrecision(k, c)
|
||||||
rec := GetRecall(k, c)
|
rec := GetRecall(k, c)
|
||||||
f1 := GetF1Score(k, c)
|
f1 := GetF1Score(k, c)
|
||||||
buffer.WriteString(fmt.Sprintf("%.0f\t%.0f\t%.0f\t%.4f\t%.4f\t%.4f\n", tp, fp, tn, prec, rec, f1))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s\t%.0f\t%.0f\t%.0f\t%.4f\t%.4f\t%.4f\n", k, tp, fp, tn, prec, rec, f1)
|
||||||
|
}
|
||||||
|
w.Flush()
|
||||||
buffer.WriteString(fmt.Sprintf("Overall accuracy: %.4f\n", GetAccuracy(c)))
|
buffer.WriteString(fmt.Sprintf("Overall accuracy: %.4f\n", GetAccuracy(c)))
|
||||||
|
|
||||||
return buffer.String()
|
return buffer.String()
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMetrics(testEnv *testing.T) {
|
func TestMetrics(t *testing.T) {
|
||||||
confusionMat := make(ConfusionMatrix)
|
confusionMat := make(ConfusionMatrix)
|
||||||
confusionMat["a"] = make(map[string]int)
|
confusionMat["a"] = make(map[string]int)
|
||||||
confusionMat["b"] = make(map[string]int)
|
confusionMat["b"] = make(map[string]int)
|
||||||
@ -16,89 +16,89 @@ func TestMetrics(testEnv *testing.T) {
|
|||||||
|
|
||||||
tp := GetTruePositives("a", confusionMat)
|
tp := GetTruePositives("a", confusionMat)
|
||||||
if math.Abs(tp-75) >= 1 {
|
if math.Abs(tp-75) >= 1 {
|
||||||
testEnv.Error(tp)
|
t.Error(tp)
|
||||||
}
|
}
|
||||||
tp = GetTruePositives("b", confusionMat)
|
tp = GetTruePositives("b", confusionMat)
|
||||||
if math.Abs(tp-10) >= 1 {
|
if math.Abs(tp-10) >= 1 {
|
||||||
testEnv.Error(tp)
|
t.Error(tp)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn := GetFalseNegatives("a", confusionMat)
|
fn := GetFalseNegatives("a", confusionMat)
|
||||||
if math.Abs(fn-5) >= 1 {
|
if math.Abs(fn-5) >= 1 {
|
||||||
testEnv.Error(fn)
|
t.Error(fn)
|
||||||
}
|
}
|
||||||
fn = GetFalseNegatives("b", confusionMat)
|
fn = GetFalseNegatives("b", confusionMat)
|
||||||
if math.Abs(fn-10) >= 1 {
|
if math.Abs(fn-10) >= 1 {
|
||||||
testEnv.Error(fn)
|
t.Error(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
tn := GetTrueNegatives("a", confusionMat)
|
tn := GetTrueNegatives("a", confusionMat)
|
||||||
if math.Abs(tn-10) >= 1 {
|
if math.Abs(tn-10) >= 1 {
|
||||||
testEnv.Error(tn)
|
t.Error(tn)
|
||||||
}
|
}
|
||||||
tn = GetTrueNegatives("b", confusionMat)
|
tn = GetTrueNegatives("b", confusionMat)
|
||||||
if math.Abs(tn-75) >= 1 {
|
if math.Abs(tn-75) >= 1 {
|
||||||
testEnv.Error(tn)
|
t.Error(tn)
|
||||||
}
|
}
|
||||||
|
|
||||||
fp := GetFalsePositives("a", confusionMat)
|
fp := GetFalsePositives("a", confusionMat)
|
||||||
if math.Abs(fp-10) >= 1 {
|
if math.Abs(fp-10) >= 1 {
|
||||||
testEnv.Error(fp)
|
t.Error(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
fp = GetFalsePositives("b", confusionMat)
|
fp = GetFalsePositives("b", confusionMat)
|
||||||
if math.Abs(fp-5) >= 1 {
|
if math.Abs(fp-5) >= 1 {
|
||||||
testEnv.Error(fp)
|
t.Error(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
precision := GetPrecision("a", confusionMat)
|
precision := GetPrecision("a", confusionMat)
|
||||||
recall := GetRecall("a", confusionMat)
|
recall := GetRecall("a", confusionMat)
|
||||||
|
|
||||||
if math.Abs(precision-0.88) >= 0.01 {
|
if math.Abs(precision-0.88) >= 0.01 {
|
||||||
testEnv.Error(precision)
|
t.Error(precision)
|
||||||
}
|
}
|
||||||
|
|
||||||
if math.Abs(recall-0.94) >= 0.01 {
|
if math.Abs(recall-0.94) >= 0.01 {
|
||||||
testEnv.Error(recall)
|
t.Error(recall)
|
||||||
}
|
}
|
||||||
|
|
||||||
precision = GetPrecision("b", confusionMat)
|
precision = GetPrecision("b", confusionMat)
|
||||||
recall = GetRecall("b", confusionMat)
|
recall = GetRecall("b", confusionMat)
|
||||||
if math.Abs(precision-0.666) >= 0.01 {
|
if math.Abs(precision-0.666) >= 0.01 {
|
||||||
testEnv.Error(precision)
|
t.Error(precision)
|
||||||
}
|
}
|
||||||
|
|
||||||
if math.Abs(recall-0.50) >= 0.01 {
|
if math.Abs(recall-0.50) >= 0.01 {
|
||||||
testEnv.Error(recall)
|
t.Error(recall)
|
||||||
}
|
}
|
||||||
|
|
||||||
precision = GetMicroPrecision(confusionMat)
|
precision = GetMicroPrecision(confusionMat)
|
||||||
if math.Abs(precision-0.85) >= 0.01 {
|
if math.Abs(precision-0.85) >= 0.01 {
|
||||||
testEnv.Error(precision)
|
t.Error(precision)
|
||||||
}
|
}
|
||||||
|
|
||||||
recall = GetMicroRecall(confusionMat)
|
recall = GetMicroRecall(confusionMat)
|
||||||
if math.Abs(recall-0.85) >= 0.01 {
|
if math.Abs(recall-0.85) >= 0.01 {
|
||||||
testEnv.Error(recall)
|
t.Error(recall)
|
||||||
}
|
}
|
||||||
|
|
||||||
precision = GetMacroPrecision(confusionMat)
|
precision = GetMacroPrecision(confusionMat)
|
||||||
if math.Abs(precision-0.775) >= 0.01 {
|
if math.Abs(precision-0.775) >= 0.01 {
|
||||||
testEnv.Error(precision)
|
t.Error(precision)
|
||||||
}
|
}
|
||||||
|
|
||||||
recall = GetMacroRecall(confusionMat)
|
recall = GetMacroRecall(confusionMat)
|
||||||
if math.Abs(recall-0.719) > 0.01 {
|
if math.Abs(recall-0.719) > 0.01 {
|
||||||
testEnv.Error(recall)
|
t.Error(recall)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmeasure := GetF1Score("a", confusionMat)
|
fmeasure := GetF1Score("a", confusionMat)
|
||||||
if math.Abs(fmeasure-0.91) >= 0.1 {
|
if math.Abs(fmeasure-0.91) >= 0.1 {
|
||||||
testEnv.Error(fmeasure)
|
t.Error(fmeasure)
|
||||||
}
|
}
|
||||||
|
|
||||||
accuracy := GetAccuracy(confusionMat)
|
accuracy := GetAccuracy(confusionMat)
|
||||||
if math.Abs(accuracy-0.85) >= 0.1 {
|
if math.Abs(accuracy-0.85) >= 0.1 {
|
||||||
testEnv.Error(accuracy)
|
t.Error(accuracy)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -2,9 +2,9 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
evaluation "github.com/sjwhitworth/golearn/evaluation"
|
"github.com/sjwhitworth/golearn/evaluation"
|
||||||
knn "github.com/sjwhitworth/golearn/knn"
|
"github.com/sjwhitworth/golearn/knn"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -4,11 +4,11 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
ensemble "github.com/sjwhitworth/golearn/ensemble"
|
"github.com/sjwhitworth/golearn/ensemble"
|
||||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||||
filters "github.com/sjwhitworth/golearn/filters"
|
"github.com/sjwhitworth/golearn/filters"
|
||||||
trees "github.com/sjwhitworth/golearn/trees"
|
"github.com/sjwhitworth/golearn/trees"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -45,7 +45,7 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Mkdir("lib", os.ModeDir | 0777)
|
os.Mkdir("lib", os.ModeDir|0777)
|
||||||
|
|
||||||
log.Println("Installing libs")
|
log.Println("Installing libs")
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package filters
|
package filters
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"testing"
|
"testing"
|
||||||
@ -38,9 +37,6 @@ func TestBinaryFilterClassPreservation(t *testing.T) {
|
|||||||
So(attrMap["arbitraryClass_there"], ShouldEqual, true)
|
So(attrMap["arbitraryClass_there"], ShouldEqual, true)
|
||||||
So(attrMap["arbitraryClass_world"], ShouldEqual, true)
|
So(attrMap["arbitraryClass_world"], ShouldEqual, true)
|
||||||
})
|
})
|
||||||
|
|
||||||
fmt.Println(instF)
|
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,7 +87,7 @@ func TestBinaryFilter(t *testing.T) {
|
|||||||
name := a.GetName()
|
name := a.GetName()
|
||||||
_, ok := origMap[name]
|
_, ok := origMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error(fmt.Sprintf("Weird: %s", name))
|
t.Errorf("Weird: %s", name)
|
||||||
}
|
}
|
||||||
origMap[name] = true
|
origMap[name] = true
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ package filters
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package filters
|
package filters
|
||||||
|
|
||||||
import (
|
import (
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -11,12 +11,12 @@ func TestBinning(t *testing.T) {
|
|||||||
// Read the data
|
// Read the data
|
||||||
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
|
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
//
|
//
|
||||||
// Construct the binning filter
|
// Construct the binning filter
|
||||||
|
@ -2,7 +2,7 @@ package filters
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -11,4 +11,4 @@ type FrequencyTableEntry struct {
|
|||||||
|
|
||||||
func (t *FrequencyTableEntry) String() string {
|
func (t *FrequencyTableEntry) String() string {
|
||||||
return fmt.Sprintf("%.2f %s", t.Value, t.Frequency)
|
return fmt.Sprintf("%.2f %s", t.Value, t.Frequency)
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package filters
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -185,21 +184,3 @@ func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*Freq
|
|||||||
freq = append(lowerSlice, upperSlice...)
|
freq = append(lowerSlice, upperSlice...)
|
||||||
return freq
|
return freq
|
||||||
}
|
}
|
||||||
|
|
||||||
func chiMergePrintTable(freq []*FrequencyTableEntry) {
|
|
||||||
classes := chiCountClasses(freq)
|
|
||||||
fmt.Printf("Attribute value\t")
|
|
||||||
for k := range classes {
|
|
||||||
fmt.Printf("\t%s", k)
|
|
||||||
}
|
|
||||||
fmt.Printf("\tTotal\n")
|
|
||||||
for _, f := range freq {
|
|
||||||
fmt.Printf("%.2f\t", f.Value)
|
|
||||||
total := 0
|
|
||||||
for k := range classes {
|
|
||||||
fmt.Printf("\t%d", f.Frequency[k])
|
|
||||||
total += f.Frequency[k]
|
|
||||||
}
|
|
||||||
fmt.Printf("\t%d\n", total)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,113 +1,109 @@
|
|||||||
package filters
|
package filters
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestChiMFreqTable(testEnv *testing.T) {
|
func TestChiMFreqTable(t *testing.T) {
|
||||||
|
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||||
|
|
||||||
if freq[0].Frequency["c1"] != 1 {
|
if freq[0].Frequency["c1"] != 1 {
|
||||||
testEnv.Error("Wrong frequency")
|
t.Error("Wrong frequency")
|
||||||
}
|
}
|
||||||
if freq[0].Frequency["c3"] != 4 {
|
if freq[0].Frequency["c3"] != 4 {
|
||||||
testEnv.Errorf("Wrong frequency %s", freq[1])
|
t.Errorf("Wrong frequency %s", freq[1])
|
||||||
}
|
}
|
||||||
if freq[10].Frequency["c2"] != 1 {
|
if freq[10].Frequency["c2"] != 1 {
|
||||||
testEnv.Error("Wrong frequency")
|
t.Error("Wrong frequency")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChiClassCounter(testEnv *testing.T) {
|
func TestChiClassCounter(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||||
classes := chiCountClasses(freq)
|
classes := chiCountClasses(freq)
|
||||||
if classes["c1"] != 27 {
|
if classes["c1"] != 27 {
|
||||||
testEnv.Error(classes)
|
t.Error(classes)
|
||||||
}
|
}
|
||||||
if classes["c2"] != 12 {
|
if classes["c2"] != 12 {
|
||||||
testEnv.Error(classes)
|
t.Error(classes)
|
||||||
}
|
}
|
||||||
if classes["c3"] != 21 {
|
if classes["c3"] != 21 {
|
||||||
testEnv.Error(classes)
|
t.Error(classes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStatisticValues(testEnv *testing.T) {
|
func TestStatisticValues(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||||
chiVal := chiComputeStatistic(freq[5], freq[6])
|
chiVal := chiComputeStatistic(freq[5], freq[6])
|
||||||
if math.Abs(chiVal-1.89) > 0.01 {
|
if math.Abs(chiVal-1.89) > 0.01 {
|
||||||
testEnv.Error(chiVal)
|
t.Error(chiVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
chiVal = chiComputeStatistic(freq[1], freq[2])
|
chiVal = chiComputeStatistic(freq[1], freq[2])
|
||||||
if math.Abs(chiVal-1.08) > 0.01 {
|
if math.Abs(chiVal-1.08) > 0.01 {
|
||||||
testEnv.Error(chiVal)
|
t.Error(chiVal)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChiSquareDistValues(testEnv *testing.T) {
|
func TestChiSquareDistValues(t *testing.T) {
|
||||||
chiVal1 := chiSquaredPercentile(2, 4.61)
|
chiVal1 := chiSquaredPercentile(2, 4.61)
|
||||||
chiVal2 := chiSquaredPercentile(3, 7.82)
|
chiVal2 := chiSquaredPercentile(3, 7.82)
|
||||||
chiVal3 := chiSquaredPercentile(4, 13.28)
|
chiVal3 := chiSquaredPercentile(4, 13.28)
|
||||||
if math.Abs(chiVal1-0.90) > 0.001 {
|
if math.Abs(chiVal1-0.90) > 0.001 {
|
||||||
testEnv.Error(chiVal1)
|
t.Error(chiVal1)
|
||||||
}
|
}
|
||||||
if math.Abs(chiVal2-0.95) > 0.001 {
|
if math.Abs(chiVal2-0.95) > 0.001 {
|
||||||
testEnv.Error(chiVal2)
|
t.Error(chiVal2)
|
||||||
}
|
}
|
||||||
if math.Abs(chiVal3-0.99) > 0.001 {
|
if math.Abs(chiVal3-0.99) > 0.001 {
|
||||||
testEnv.Error(chiVal3)
|
t.Error(chiVal3)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChiMerge1(testEnv *testing.T) {
|
func TestChiMerge1(t *testing.T) {
|
||||||
|
|
||||||
// Read the data
|
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
_, rows := inst.Size()
|
_, rows := inst.Size()
|
||||||
|
|
||||||
freq := chiMerge(inst, inst.AllAttributes()[0], 0.90, 0, rows)
|
freq := chiMerge(inst, inst.AllAttributes()[0], 0.90, 0, rows)
|
||||||
if len(freq) != 3 {
|
if len(freq) != 3 {
|
||||||
testEnv.Error("Wrong length")
|
t.Error("Wrong length")
|
||||||
}
|
}
|
||||||
if freq[0].Value != 1.3 {
|
if freq[0].Value != 1.3 {
|
||||||
testEnv.Error(freq[0])
|
t.Error(freq[0])
|
||||||
}
|
}
|
||||||
if freq[1].Value != 56.2 {
|
if freq[1].Value != 56.2 {
|
||||||
testEnv.Error(freq[1])
|
t.Error(freq[1])
|
||||||
}
|
}
|
||||||
if freq[2].Value != 87.1 {
|
if freq[2].Value != 87.1 {
|
||||||
testEnv.Error(freq[2])
|
t.Error(freq[2])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChiMerge2(testEnv *testing.T) {
|
func TestChiMerge2(t *testing.T) {
|
||||||
//
|
//
|
||||||
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
||||||
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort the instances
|
// Sort the instances
|
||||||
@ -115,35 +111,35 @@ func TestChiMerge2(testEnv *testing.T) {
|
|||||||
sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
|
sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
|
||||||
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
|
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatalf("Sort failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform Chi-Merge
|
// Perform Chi-Merge
|
||||||
_, rows := inst.Size()
|
_, rows := inst.Size()
|
||||||
freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows)
|
freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows)
|
||||||
if len(freq) != 5 {
|
if len(freq) != 5 {
|
||||||
testEnv.Errorf("Wrong length (%d)", len(freq))
|
t.Errorf("Wrong length (%d)", len(freq))
|
||||||
testEnv.Error(freq)
|
t.Error(freq)
|
||||||
}
|
}
|
||||||
if freq[0].Value != 4.3 {
|
if freq[0].Value != 4.3 {
|
||||||
testEnv.Error(freq[0])
|
t.Error(freq[0])
|
||||||
}
|
}
|
||||||
if freq[1].Value != 5.5 {
|
if freq[1].Value != 5.5 {
|
||||||
testEnv.Error(freq[1])
|
t.Error(freq[1])
|
||||||
}
|
}
|
||||||
if freq[2].Value != 5.8 {
|
if freq[2].Value != 5.8 {
|
||||||
testEnv.Error(freq[2])
|
t.Error(freq[2])
|
||||||
}
|
}
|
||||||
if freq[3].Value != 6.3 {
|
if freq[3].Value != 6.3 {
|
||||||
testEnv.Error(freq[3])
|
t.Error(freq[3])
|
||||||
}
|
}
|
||||||
if freq[4].Value != 7.1 {
|
if freq[4].Value != 7.1 {
|
||||||
testEnv.Error(freq[4])
|
t.Error(freq[4])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
func TestChiMerge3(testEnv *testing.T) {
|
func TestChiMerge3(t *testing.T) {
|
||||||
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
||||||
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
@ -153,7 +149,7 @@ func TestChiMerge3(testEnv *testing.T) {
|
|||||||
|
|
||||||
insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes()))
|
insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testEnv.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
filt := NewChiMergeFilter(inst, 0.90)
|
filt := NewChiMergeFilter(inst, 0.90)
|
||||||
filt.AddAttribute(inst.AllAttributes()[0])
|
filt.AddAttribute(inst.AllAttributes()[0])
|
||||||
@ -176,12 +172,12 @@ func TestChiMerge3(testEnv *testing.T) {
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func TestChiMerge4(testEnv *testing.T) {
|
func TestChiMerge4(t *testing.T) {
|
||||||
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
||||||
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
filt := NewChiMergeFilter(inst, 0.90)
|
filt := NewChiMergeFilter(inst, 0.90)
|
||||||
@ -189,13 +185,13 @@ func TestChiMerge4(testEnv *testing.T) {
|
|||||||
filt.AddAttribute(inst.AllAttributes()[1])
|
filt.AddAttribute(inst.AllAttributes()[1])
|
||||||
filt.Train()
|
filt.Train()
|
||||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
fmt.Println(instf)
|
|
||||||
fmt.Println(instf.String())
|
|
||||||
clsAttrs := instf.AllClassAttributes()
|
clsAttrs := instf.AllClassAttributes()
|
||||||
if len(clsAttrs) != 1 {
|
if len(clsAttrs) != 1 {
|
||||||
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
|
t.Fatalf("%d != %d", len(clsAttrs), 1)
|
||||||
}
|
}
|
||||||
if clsAttrs[0].GetName() != "Species" {
|
firstClassAttributeName := clsAttrs[0].GetName()
|
||||||
panic("Class Attribute wrong!")
|
expectedClassAttributeName := "Species"
|
||||||
|
if firstClassAttributeName != expectedClassAttributeName {
|
||||||
|
t.Fatalf("Expected class attribute '%s'; actual class attribute '%s'", expectedClassAttributeName, firstClassAttributeName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ package filters
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AbstractDiscretizeFilter struct {
|
type AbstractDiscretizeFilter struct {
|
||||||
|
182
filters/float.go
182
filters/float.go
@ -1,8 +1,8 @@
|
|||||||
package filters
|
package filters
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FloatConvertFilters convert a given DataGrid into one which
|
// FloatConvertFilters convert a given DataGrid into one which
|
||||||
@ -14,84 +14,84 @@ import (
|
|||||||
// CategoricalAttributes are discretised into one or more new
|
// CategoricalAttributes are discretised into one or more new
|
||||||
// BinaryAttributes.
|
// BinaryAttributes.
|
||||||
type FloatConvertFilter struct {
|
type FloatConvertFilter struct {
|
||||||
attrs []base.Attribute
|
attrs []base.Attribute
|
||||||
converted []base.FilteredAttribute
|
converted []base.FilteredAttribute
|
||||||
twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes
|
twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes
|
||||||
nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute
|
nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFloatConvertFilter creates a blank FloatConvertFilter
|
// NewFloatConvertFilter creates a blank FloatConvertFilter
|
||||||
func NewFloatConvertFilter() *FloatConvertFilter {
|
func NewFloatConvertFilter() *FloatConvertFilter {
|
||||||
ret := &FloatConvertFilter{
|
ret := &FloatConvertFilter{
|
||||||
make([]base.Attribute, 0),
|
make([]base.Attribute, 0),
|
||||||
make([]base.FilteredAttribute, 0),
|
make([]base.FilteredAttribute, 0),
|
||||||
make(map[base.Attribute]bool),
|
make(map[base.Attribute]bool),
|
||||||
make(map[base.Attribute]map[uint64]base.Attribute),
|
make(map[base.Attribute]map[uint64]base.Attribute),
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAttribute adds a new Attribute to this Filter
|
// AddAttribute adds a new Attribute to this Filter
|
||||||
func (f *FloatConvertFilter) AddAttribute(a base.Attribute) error {
|
func (f *FloatConvertFilter) AddAttribute(a base.Attribute) error {
|
||||||
f.attrs = append(f.attrs, a)
|
f.attrs = append(f.attrs, a)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAttributesAfterFiltering returns the Attributes previously computed via Train()
|
// GetAttributesAfterFiltering returns the Attributes previously computed via Train()
|
||||||
func (f *FloatConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
func (f *FloatConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||||
return f.converted
|
return f.converted
|
||||||
}
|
}
|
||||||
|
|
||||||
// String gets a human-readable string
|
// String gets a human-readable string
|
||||||
func (f *FloatConvertFilter) String() string {
|
func (f *FloatConvertFilter) String() string {
|
||||||
return fmt.Sprintf("FloatConvertFilter(%d Attribute(s))", len(f.attrs))
|
return fmt.Sprintf("FloatConvertFilter(%d Attribute(s))", len(f.attrs))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transform converts the given byte sequence using the old Attribute into the new
|
// Transform converts the given byte sequence using the old Attribute into the new
|
||||||
// byte sequence.
|
// byte sequence.
|
||||||
|
|
||||||
func (f *FloatConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte {
|
func (f *FloatConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte {
|
||||||
ret := make([]byte, 8)
|
ret := make([]byte, 8)
|
||||||
// Check for CategoricalAttribute
|
// Check for CategoricalAttribute
|
||||||
if _, ok := a.(*base.CategoricalAttribute); ok {
|
if _, ok := a.(*base.CategoricalAttribute); ok {
|
||||||
// Unpack byte value
|
// Unpack byte value
|
||||||
val := base.UnpackBytesToU64(attrBytes)
|
val := base.UnpackBytesToU64(attrBytes)
|
||||||
// If it's a two-valued one, check for non-zero
|
// If it's a two-valued one, check for non-zero
|
||||||
if f.twoValuedCategoricalAttributes[a] {
|
if f.twoValuedCategoricalAttributes[a] {
|
||||||
if val > 0 {
|
if val > 0 {
|
||||||
ret = base.PackFloatToBytes(1.0)
|
ret = base.PackFloatToBytes(1.0)
|
||||||
} else {
|
} else {
|
||||||
ret = base.PackFloatToBytes(0.0)
|
ret = base.PackFloatToBytes(0.0)
|
||||||
}
|
}
|
||||||
} else if an, ok := f.nValuedCategoricalAttributeMap[a]; ok {
|
} else if an, ok := f.nValuedCategoricalAttributeMap[a]; ok {
|
||||||
// If it's an n-valued one, check the new Attribute maps onto
|
// If it's an n-valued one, check the new Attribute maps onto
|
||||||
// the unpacked value
|
// the unpacked value
|
||||||
if af, ok := an[val]; ok {
|
if af, ok := an[val]; ok {
|
||||||
if af.Equals(n) {
|
if af.Equals(n) {
|
||||||
ret = base.PackFloatToBytes(1.0)
|
ret = base.PackFloatToBytes(1.0)
|
||||||
} else {
|
} else {
|
||||||
ret = base.PackFloatToBytes(0.0)
|
ret = base.PackFloatToBytes(0.0)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
panic("Categorical value not defined!")
|
panic("Categorical value not defined!")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Sprintf("Not a recognised Attribute %v", a))
|
panic(fmt.Sprintf("Not a recognised Attribute %v", a))
|
||||||
}
|
}
|
||||||
} else if _, ok := a.(*base.FloatAttribute); ok {
|
} else if _, ok := a.(*base.FloatAttribute); ok {
|
||||||
// Binary: just return the original value
|
// Binary: just return the original value
|
||||||
ret = attrBytes
|
ret = attrBytes
|
||||||
} else if _, ok := a.(*base.BinaryAttribute); ok {
|
} else if _, ok := a.(*base.BinaryAttribute); ok {
|
||||||
// Float: check for non-zero
|
// Float: check for non-zero
|
||||||
if attrBytes[0] > 0 {
|
if attrBytes[0] > 0 {
|
||||||
ret = base.PackFloatToBytes(1.0)
|
ret = base.PackFloatToBytes(1.0)
|
||||||
} else {
|
} else {
|
||||||
ret = base.PackFloatToBytes(0.0)
|
ret = base.PackFloatToBytes(0.0)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Sprintf("Unrecognised Attribute: %v", a))
|
panic(fmt.Sprintf("Unrecognised Attribute: %v", a))
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Train converts the Attributes into equivalently named FloatAttributes,
|
// Train converts the Attributes into equivalently named FloatAttributes,
|
||||||
@ -105,37 +105,37 @@ func (f *FloatConvertFilter) Transform(a base.Attribute, n base.Attribute, attrB
|
|||||||
// If the CategoricalAttribute has more than two (n) values, the Filter
|
// If the CategoricalAttribute has more than two (n) values, the Filter
|
||||||
// generates n FloatAttributes and sets each of them if the value's observed.
|
// generates n FloatAttributes and sets each of them if the value's observed.
|
||||||
func (f *FloatConvertFilter) Train() error {
|
func (f *FloatConvertFilter) Train() error {
|
||||||
for _, a := range f.attrs {
|
for _, a := range f.attrs {
|
||||||
if ac, ok := a.(*base.CategoricalAttribute); ok {
|
if ac, ok := a.(*base.CategoricalAttribute); ok {
|
||||||
vals := ac.GetValues()
|
vals := ac.GetValues()
|
||||||
if len(vals) <= 2 {
|
if len(vals) <= 2 {
|
||||||
nAttr := base.NewFloatAttribute(ac.GetName())
|
nAttr := base.NewFloatAttribute(ac.GetName())
|
||||||
fAttr := base.FilteredAttribute{ac, nAttr}
|
fAttr := base.FilteredAttribute{ac, nAttr}
|
||||||
f.converted = append(f.converted, fAttr)
|
f.converted = append(f.converted, fAttr)
|
||||||
f.twoValuedCategoricalAttributes[a] = true
|
f.twoValuedCategoricalAttributes[a] = true
|
||||||
} else {
|
} else {
|
||||||
if _, ok := f.nValuedCategoricalAttributeMap[a]; !ok {
|
if _, ok := f.nValuedCategoricalAttributeMap[a]; !ok {
|
||||||
f.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute)
|
f.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute)
|
||||||
}
|
}
|
||||||
for i := uint64(0); i < uint64(len(vals)); i++ {
|
for i := uint64(0); i < uint64(len(vals)); i++ {
|
||||||
v := vals[i]
|
v := vals[i]
|
||||||
newName := fmt.Sprintf("%s_%s", ac.GetName(), v)
|
newName := fmt.Sprintf("%s_%s", ac.GetName(), v)
|
||||||
newAttr := base.NewFloatAttribute(newName)
|
newAttr := base.NewFloatAttribute(newName)
|
||||||
fAttr := base.FilteredAttribute{ac, newAttr}
|
fAttr := base.FilteredAttribute{ac, newAttr}
|
||||||
f.converted = append(f.converted, fAttr)
|
f.converted = append(f.converted, fAttr)
|
||||||
f.nValuedCategoricalAttributeMap[a][i] = newAttr
|
f.nValuedCategoricalAttributeMap[a][i] = newAttr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if ab, ok := a.(*base.FloatAttribute); ok {
|
} else if ab, ok := a.(*base.FloatAttribute); ok {
|
||||||
fAttr := base.FilteredAttribute{ab, ab}
|
fAttr := base.FilteredAttribute{ab, ab}
|
||||||
f.converted = append(f.converted, fAttr)
|
f.converted = append(f.converted, fAttr)
|
||||||
} else if af, ok := a.(*base.BinaryAttribute); ok {
|
} else if af, ok := a.(*base.BinaryAttribute); ok {
|
||||||
newAttr := base.NewFloatAttribute(af.GetName())
|
newAttr := base.NewFloatAttribute(af.GetName())
|
||||||
fAttr := base.FilteredAttribute{af, newAttr}
|
fAttr := base.FilteredAttribute{af, newAttr}
|
||||||
f.converted = append(f.converted, fAttr)
|
f.converted = append(f.converted, fAttr)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("Unsupported Attribute type: %v", a)
|
return fmt.Errorf("Unsupported Attribute type: %v", a)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package filters
|
package filters
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"testing"
|
"testing"
|
||||||
@ -54,7 +53,7 @@ func TestFloatFilter(t *testing.T) {
|
|||||||
name := a.GetName()
|
name := a.GetName()
|
||||||
_, ok := origMap[name]
|
_, ok := origMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error(fmt.Sprintf("Weird: %s", name))
|
t.Errorf("Weird: %s", name)
|
||||||
}
|
}
|
||||||
origMap[name] = true
|
origMap[name] = true
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ package knn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gonum/matrix/mat64"
|
"github.com/gonum/matrix/mat64"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
pairwiseMetrics "github.com/sjwhitworth/golearn/metrics/pairwise"
|
pairwiseMetrics "github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||||
util "github.com/sjwhitworth/golearn/utilities"
|
util "github.com/sjwhitworth/golearn/utilities"
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package linear_models
|
package linear_models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
@ -54,11 +53,7 @@ func TestLinearRegression(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, rows := predictions.Size()
|
_, _ = predictions.Size()
|
||||||
|
|
||||||
for i := 0; i < rows; i++ {
|
|
||||||
fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLinearRegressionOneRow(b *testing.B) {
|
func BenchmarkLinearRegressionOneRow(b *testing.B) {
|
||||||
|
@ -2,7 +2,7 @@ package linear_models
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogisticRegression struct {
|
type LogisticRegression struct {
|
||||||
|
@ -2,7 +2,7 @@ package meta
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -1,20 +1,19 @@
|
|||||||
package meta
|
package meta
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
|
||||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||||
filters "github.com/sjwhitworth/golearn/filters"
|
"github.com/sjwhitworth/golearn/filters"
|
||||||
trees "github.com/sjwhitworth/golearn/trees"
|
"github.com/sjwhitworth/golearn/trees"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
|
func BenchmarkBaggingRandomForestFit(t *testing.B) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
@ -24,20 +23,22 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
|
|||||||
}
|
}
|
||||||
filt.Train()
|
filt.Train()
|
||||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
|
|
||||||
rf := new(BaggedModel)
|
rf := new(BaggedModel)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
rf.AddModel(trees.NewRandomTree(2))
|
rf.AddModel(trees.NewRandomTree(2))
|
||||||
}
|
}
|
||||||
testEnv.ResetTimer()
|
|
||||||
|
t.ResetTimer()
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
rf.Fit(instf)
|
rf.Fit(instf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
|
func BenchmarkBaggingRandomForestPredict(t *testing.B) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
@ -47,25 +48,27 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
|
|||||||
}
|
}
|
||||||
filt.Train()
|
filt.Train()
|
||||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
|
|
||||||
rf := new(BaggedModel)
|
rf := new(BaggedModel)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
rf.AddModel(trees.NewRandomTree(2))
|
rf.AddModel(trees.NewRandomTree(2))
|
||||||
}
|
}
|
||||||
|
|
||||||
rf.Fit(instf)
|
rf.Fit(instf)
|
||||||
testEnv.ResetTimer()
|
t.ResetTimer()
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
rf.Predict(instf)
|
rf.Predict(instf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRandomForest1(testEnv *testing.T) {
|
func TestRandomForest1(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.AddAttribute(a)
|
filt.AddAttribute(a)
|
||||||
@ -73,17 +76,14 @@ func TestRandomForest1(testEnv *testing.T) {
|
|||||||
filt.Train()
|
filt.Train()
|
||||||
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
|
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
|
||||||
testDataf := base.NewLazilyFilteredInstances(testData, filt)
|
testDataf := base.NewLazilyFilteredInstances(testData, filt)
|
||||||
|
|
||||||
rf := new(BaggedModel)
|
rf := new(BaggedModel)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
rf.AddModel(trees.NewRandomTree(2))
|
rf.AddModel(trees.NewRandomTree(2))
|
||||||
}
|
}
|
||||||
|
|
||||||
rf.Fit(trainDataf)
|
rf.Fit(trainDataf)
|
||||||
fmt.Println(rf)
|
|
||||||
predictions := rf.Predict(testDataf)
|
predictions := rf.Predict(testDataf)
|
||||||
fmt.Println(predictions)
|
|
||||||
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
|
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
|
||||||
fmt.Println(confusionMat)
|
_ = eval.GetSummary(confusionMat)
|
||||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
||||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
||||||
fmt.Println(eval.GetSummary(confusionMat))
|
|
||||||
}
|
}
|
||||||
|
@ -10,4 +10,4 @@
|
|||||||
are generated via majority voting.
|
are generated via majority voting.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package meta
|
package meta
|
||||||
|
@ -2,7 +2,7 @@ package naive
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package neural
|
package neural
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/gonum/matrix/mat64"
|
"github.com/gonum/matrix/mat64"
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
@ -13,7 +12,6 @@ func TestLayerStructureNoHidden(t *testing.T) {
|
|||||||
Convey("Creating a network...", t, func() {
|
Convey("Creating a network...", t, func() {
|
||||||
XORData, err := base.ParseCSVToInstances("xor.csv", false)
|
XORData, err := base.ParseCSVToInstances("xor.csv", false)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
fmt.Println(XORData)
|
|
||||||
Convey("Create a MultiLayerNet with no layers...", func() {
|
Convey("Create a MultiLayerNet with no layers...", func() {
|
||||||
net := NewMultiLayerNet(make([]int, 0))
|
net := NewMultiLayerNet(make([]int, 0))
|
||||||
net.MaxIterations = 0
|
net.MaxIterations = 0
|
||||||
@ -73,8 +71,6 @@ func TestLayerStructureNoHidden(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
Convey("The right nodes should be connected in the network...", func() {
|
Convey("The right nodes should be connected in the network...", func() {
|
||||||
|
|
||||||
fmt.Println(net.network)
|
|
||||||
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
|
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
|
||||||
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
|
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
|
||||||
So(net.network.GetWeight(1, 3), ShouldNotAlmostEqual, 0.000)
|
So(net.network.GetWeight(1, 3), ShouldNotAlmostEqual, 0.000)
|
||||||
@ -118,7 +114,6 @@ func TestLayeredXOR(t *testing.T) {
|
|||||||
XORData, err := base.ParseCSVToInstances("xor.csv", false)
|
XORData, err := base.ParseCSVToInstances("xor.csv", false)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
|
|
||||||
fmt.Println(XORData)
|
|
||||||
net := NewMultiLayerNet([]int{3})
|
net := NewMultiLayerNet([]int{3})
|
||||||
net.MaxIterations = 20000
|
net.MaxIterations = 20000
|
||||||
net.Fit(XORData)
|
net.Fit(XORData)
|
||||||
@ -126,8 +121,6 @@ func TestLayeredXOR(t *testing.T) {
|
|||||||
Convey("After running for 20000 iterations, should have some predictive power...", func() {
|
Convey("After running for 20000 iterations, should have some predictive power...", func() {
|
||||||
|
|
||||||
Convey("The right nodes should be connected in the network...", func() {
|
Convey("The right nodes should be connected in the network...", func() {
|
||||||
|
|
||||||
fmt.Println(net.network)
|
|
||||||
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
|
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
|
||||||
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
|
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
|
||||||
|
|
||||||
@ -138,7 +131,6 @@ func TestLayeredXOR(t *testing.T) {
|
|||||||
})
|
})
|
||||||
out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
|
out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
|
||||||
net.network.Activate(out, 2)
|
net.network.Activate(out, 2)
|
||||||
fmt.Println(out)
|
|
||||||
So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)
|
So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)
|
||||||
|
|
||||||
Convey("And Predict() should do OK too...", func() {
|
Convey("And Predict() should do OK too...", func() {
|
||||||
@ -148,7 +140,7 @@ func TestLayeredXOR(t *testing.T) {
|
|||||||
for _, a := range pred.AllAttributes() {
|
for _, a := range pred.AllAttributes() {
|
||||||
af, ok := a.(*base.FloatAttribute)
|
af, ok := a.(*base.FloatAttribute)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("All of these should be FloatAttributes!")
|
t.Fatalf("Expected all attributes to be FloatAttributes; actually some were not")
|
||||||
}
|
}
|
||||||
af.Precision = 1
|
af.Precision = 1
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package neural
|
package neural
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/gonum/matrix/mat64"
|
"github.com/gonum/matrix/mat64"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"testing"
|
"testing"
|
||||||
@ -61,7 +60,6 @@ func TestNetworkWith1Layer(t *testing.T) {
|
|||||||
for i := 1; i <= 6; i++ {
|
for i := 1; i <= 6; i++ {
|
||||||
for j := 1; j <= 6; j++ {
|
for j := 1; j <= 6; j++ {
|
||||||
v := n.GetWeight(i, j)
|
v := n.GetWeight(i, j)
|
||||||
fmt.Println(i, j, v)
|
|
||||||
switch i {
|
switch i {
|
||||||
case 1:
|
case 1:
|
||||||
switch j {
|
switch j {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package trees
|
package trees
|
||||||
|
|
||||||
import (
|
import (
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ package trees
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||||
"sort"
|
"sort"
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@ package trees
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
package trees
|
package trees
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
|
||||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||||
filters "github.com/sjwhitworth/golearn/filters"
|
"github.com/sjwhitworth/golearn/filters"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRandomTree(testEnv *testing.T) {
|
func TestRandomTree(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.AddAttribute(a)
|
filt.AddAttribute(a)
|
||||||
@ -23,17 +23,17 @@ func TestRandomTree(testEnv *testing.T) {
|
|||||||
|
|
||||||
r := new(RandomTreeRuleGenerator)
|
r := new(RandomTreeRuleGenerator)
|
||||||
r.Attributes = 2
|
r.Attributes = 2
|
||||||
fmt.Println(instf)
|
|
||||||
root := InferID3Tree(instf, r)
|
_ = InferID3Tree(instf, r)
|
||||||
fmt.Println(root)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRandomTreeClassification(testEnv *testing.T) {
|
func TestRandomTreeClassification(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||||
|
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.AddAttribute(a)
|
filt.AddAttribute(a)
|
||||||
@ -44,23 +44,21 @@ func TestRandomTreeClassification(testEnv *testing.T) {
|
|||||||
|
|
||||||
r := new(RandomTreeRuleGenerator)
|
r := new(RandomTreeRuleGenerator)
|
||||||
r.Attributes = 2
|
r.Attributes = 2
|
||||||
|
|
||||||
root := InferID3Tree(trainDataF, r)
|
root := InferID3Tree(trainDataF, r)
|
||||||
fmt.Println(root)
|
|
||||||
predictions := root.Predict(testDataF)
|
predictions := root.Predict(testDataF)
|
||||||
fmt.Println(predictions)
|
|
||||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||||
fmt.Println(confusionMat)
|
_ = eval.GetSummary(confusionMat)
|
||||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
||||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
||||||
fmt.Println(eval.GetSummary(confusionMat))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRandomTreeClassification2(testEnv *testing.T) {
|
func TestRandomTreeClassification2(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
|
||||||
|
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.AddAttribute(a)
|
filt.AddAttribute(a)
|
||||||
@ -71,22 +69,19 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
|
|||||||
|
|
||||||
root := NewRandomTree(2)
|
root := NewRandomTree(2)
|
||||||
root.Fit(trainDataF)
|
root.Fit(trainDataF)
|
||||||
fmt.Println(root)
|
|
||||||
predictions := root.Predict(testDataF)
|
predictions := root.Predict(testDataF)
|
||||||
fmt.Println(predictions)
|
|
||||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||||
fmt.Println(confusionMat)
|
_ = eval.GetSummary(confusionMat)
|
||||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
||||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
||||||
fmt.Println(eval.GetSummary(confusionMat))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPruning(testEnv *testing.T) {
|
func TestPruning(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||||
|
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.AddAttribute(a)
|
filt.AddAttribute(a)
|
||||||
@ -99,17 +94,13 @@ func TestPruning(testEnv *testing.T) {
|
|||||||
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
|
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
|
||||||
root.Fit(fittrainData)
|
root.Fit(fittrainData)
|
||||||
root.Prune(fittestData)
|
root.Prune(fittestData)
|
||||||
fmt.Println(root)
|
|
||||||
predictions := root.Predict(testDataF)
|
predictions := root.Predict(testDataF)
|
||||||
fmt.Println(predictions)
|
|
||||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||||
fmt.Println(confusionMat)
|
_ = eval.GetSummary(confusionMat)
|
||||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
||||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
||||||
fmt.Println(eval.GetSummary(confusionMat))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInformationGain(testEnv *testing.T) {
|
func TestInformationGain(t *testing.T) {
|
||||||
outlook := make(map[string]map[string]int)
|
outlook := make(map[string]map[string]int)
|
||||||
outlook["sunny"] = make(map[string]int)
|
outlook["sunny"] = make(map[string]int)
|
||||||
outlook["overcast"] = make(map[string]int)
|
outlook["overcast"] = make(map[string]int)
|
||||||
@ -122,16 +113,14 @@ func TestInformationGain(testEnv *testing.T) {
|
|||||||
|
|
||||||
entropy := getSplitEntropy(outlook)
|
entropy := getSplitEntropy(outlook)
|
||||||
if math.Abs(entropy-0.694) > 0.001 {
|
if math.Abs(entropy-0.694) > 0.001 {
|
||||||
testEnv.Error(entropy)
|
t.Error(entropy)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestID3Inference(testEnv *testing.T) {
|
func TestID3Inference(t *testing.T) {
|
||||||
|
|
||||||
// Import the "PlayTennis" dataset
|
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the decision tree
|
// Build the decision tree
|
||||||
@ -141,81 +130,71 @@ func TestID3Inference(testEnv *testing.T) {
|
|||||||
// Verify the tree
|
// Verify the tree
|
||||||
// First attribute should be "outlook"
|
// First attribute should be "outlook"
|
||||||
if root.SplitAttr.GetName() != "outlook" {
|
if root.SplitAttr.GetName() != "outlook" {
|
||||||
testEnv.Error(root)
|
t.Error(root)
|
||||||
}
|
}
|
||||||
sunnyChild := root.Children["sunny"]
|
sunnyChild := root.Children["sunny"]
|
||||||
overcastChild := root.Children["overcast"]
|
overcastChild := root.Children["overcast"]
|
||||||
rainyChild := root.Children["rainy"]
|
rainyChild := root.Children["rainy"]
|
||||||
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
||||||
testEnv.Error(sunnyChild)
|
t.Error(sunnyChild)
|
||||||
}
|
}
|
||||||
if rainyChild.SplitAttr.GetName() != "windy" {
|
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||||
fmt.Println(rainyChild.SplitAttr)
|
t.Error(rainyChild)
|
||||||
testEnv.Error(rainyChild)
|
|
||||||
}
|
}
|
||||||
if overcastChild.SplitAttr != nil {
|
if overcastChild.SplitAttr != nil {
|
||||||
testEnv.Error(overcastChild)
|
t.Error(overcastChild)
|
||||||
}
|
}
|
||||||
|
|
||||||
sunnyLeafHigh := sunnyChild.Children["high"]
|
sunnyLeafHigh := sunnyChild.Children["high"]
|
||||||
sunnyLeafNormal := sunnyChild.Children["normal"]
|
sunnyLeafNormal := sunnyChild.Children["normal"]
|
||||||
if sunnyLeafHigh.Class != "no" {
|
if sunnyLeafHigh.Class != "no" {
|
||||||
testEnv.Error(sunnyLeafHigh)
|
t.Error(sunnyLeafHigh)
|
||||||
}
|
}
|
||||||
if sunnyLeafNormal.Class != "yes" {
|
if sunnyLeafNormal.Class != "yes" {
|
||||||
testEnv.Error(sunnyLeafNormal)
|
t.Error(sunnyLeafNormal)
|
||||||
}
|
}
|
||||||
windyLeafFalse := rainyChild.Children["false"]
|
windyLeafFalse := rainyChild.Children["false"]
|
||||||
windyLeafTrue := rainyChild.Children["true"]
|
windyLeafTrue := rainyChild.Children["true"]
|
||||||
if windyLeafFalse.Class != "yes" {
|
if windyLeafFalse.Class != "yes" {
|
||||||
testEnv.Error(windyLeafFalse)
|
t.Error(windyLeafFalse)
|
||||||
}
|
}
|
||||||
if windyLeafTrue.Class != "no" {
|
if windyLeafTrue.Class != "no" {
|
||||||
testEnv.Error(windyLeafTrue)
|
t.Error(windyLeafTrue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if overcastChild.Class != "yes" {
|
if overcastChild.Class != "yes" {
|
||||||
testEnv.Error(overcastChild)
|
t.Error(overcastChild)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestID3Classification(testEnv *testing.T) {
|
func TestID3Classification(t *testing.T) {
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
fmt.Println(inst)
|
|
||||||
filt := filters.NewBinningFilter(inst, 10)
|
filt := filters.NewBinningFilter(inst, 10)
|
||||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.AddAttribute(a)
|
filt.AddAttribute(a)
|
||||||
}
|
}
|
||||||
filt.Train()
|
filt.Train()
|
||||||
fmt.Println(filt)
|
|
||||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
fmt.Println("INSTFA", instf.AllAttributes())
|
|
||||||
fmt.Println("INSTF", instf)
|
|
||||||
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
|
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
|
||||||
|
|
||||||
// Build the decision tree
|
// Build the decision tree
|
||||||
rule := new(InformationGainRuleGenerator)
|
rule := new(InformationGainRuleGenerator)
|
||||||
root := InferID3Tree(trainData, rule)
|
root := InferID3Tree(trainData, rule)
|
||||||
fmt.Println(root)
|
|
||||||
predictions := root.Predict(testData)
|
predictions := root.Predict(testData)
|
||||||
fmt.Println(predictions)
|
|
||||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||||
fmt.Println(confusionMat)
|
_ = eval.GetSummary(confusionMat)
|
||||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
||||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
||||||
fmt.Println(eval.GetSummary(confusionMat))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestID3(testEnv *testing.T) {
|
func TestID3(t *testing.T) {
|
||||||
|
|
||||||
// Import the "PlayTennis" dataset
|
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||||
fmt.Println(inst)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the decision tree
|
// Build the decision tree
|
||||||
@ -226,40 +205,40 @@ func TestID3(testEnv *testing.T) {
|
|||||||
// Verify the tree
|
// Verify the tree
|
||||||
// First attribute should be "outlook"
|
// First attribute should be "outlook"
|
||||||
if root.SplitAttr.GetName() != "outlook" {
|
if root.SplitAttr.GetName() != "outlook" {
|
||||||
testEnv.Error(root)
|
t.Error(root)
|
||||||
}
|
}
|
||||||
sunnyChild := root.Children["sunny"]
|
sunnyChild := root.Children["sunny"]
|
||||||
overcastChild := root.Children["overcast"]
|
overcastChild := root.Children["overcast"]
|
||||||
rainyChild := root.Children["rainy"]
|
rainyChild := root.Children["rainy"]
|
||||||
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
||||||
testEnv.Error(sunnyChild)
|
t.Error(sunnyChild)
|
||||||
}
|
}
|
||||||
if rainyChild.SplitAttr.GetName() != "windy" {
|
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||||
testEnv.Error(rainyChild)
|
t.Error(rainyChild)
|
||||||
}
|
}
|
||||||
if overcastChild.SplitAttr != nil {
|
if overcastChild.SplitAttr != nil {
|
||||||
testEnv.Error(overcastChild)
|
t.Error(overcastChild)
|
||||||
}
|
}
|
||||||
|
|
||||||
sunnyLeafHigh := sunnyChild.Children["high"]
|
sunnyLeafHigh := sunnyChild.Children["high"]
|
||||||
sunnyLeafNormal := sunnyChild.Children["normal"]
|
sunnyLeafNormal := sunnyChild.Children["normal"]
|
||||||
if sunnyLeafHigh.Class != "no" {
|
if sunnyLeafHigh.Class != "no" {
|
||||||
testEnv.Error(sunnyLeafHigh)
|
t.Error(sunnyLeafHigh)
|
||||||
}
|
}
|
||||||
if sunnyLeafNormal.Class != "yes" {
|
if sunnyLeafNormal.Class != "yes" {
|
||||||
testEnv.Error(sunnyLeafNormal)
|
t.Error(sunnyLeafNormal)
|
||||||
}
|
}
|
||||||
|
|
||||||
windyLeafFalse := rainyChild.Children["false"]
|
windyLeafFalse := rainyChild.Children["false"]
|
||||||
windyLeafTrue := rainyChild.Children["true"]
|
windyLeafTrue := rainyChild.Children["true"]
|
||||||
if windyLeafFalse.Class != "yes" {
|
if windyLeafFalse.Class != "yes" {
|
||||||
testEnv.Error(windyLeafFalse)
|
t.Error(windyLeafFalse)
|
||||||
}
|
}
|
||||||
if windyLeafTrue.Class != "no" {
|
if windyLeafTrue.Class != "no" {
|
||||||
testEnv.Error(windyLeafTrue)
|
t.Error(windyLeafTrue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if overcastChild.Class != "yes" {
|
if overcastChild.Class != "yes" {
|
||||||
testEnv.Error(overcastChild)
|
t.Error(overcastChild)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user