1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Consistently use (t *testing.T) instead of T or testEnv

This commit is contained in:
Amit Kumar Gupta 2014-08-22 08:13:19 +00:00
parent 695aec6eb6
commit 14aad31821
9 changed files with 195 additions and 195 deletions

View File

@ -4,92 +4,92 @@ import (
"testing"
)
func TestParseCSVGetRows(testEnv *testing.T) {
func TestParseCSVGetRows(t *testing.T) {
lineCount, err := ParseCSVGetRows("../examples/datasets/iris.csv")
if err != nil {
testEnv.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error())
t.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error())
}
if lineCount != 150 {
testEnv.Errorf("Should have %d lines, has %d", 150, lineCount)
t.Errorf("Should have %d lines, has %d", 150, lineCount)
}
lineCount, err = ParseCSVGetRows("../examples/datasets/iris_headers.csv")
if err != nil {
testEnv.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error())
t.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error())
}
if lineCount != 151 {
testEnv.Errorf("Should have %d lines, has %d", 151, lineCount)
t.Errorf("Should have %d lines, has %d", 151, lineCount)
}
}
func TestParseCSVGetRowsWithMissingFile(testEnv *testing.T) {
func TestParseCSVGetRowsWithMissingFile(t *testing.T) {
_, err := ParseCSVGetRows("../examples/datasets/non-existent.csv")
if err == nil {
testEnv.Fatal("Expected ParseCSVGetRows to return error when given path to non-existent file")
t.Fatal("Expected ParseCSVGetRows to return error when given path to non-existent file")
}
}
func TestParseCCSVGetAttributes(testEnv *testing.T) {
func TestParseCCSVGetAttributes(t *testing.T) {
attrs := ParseCSVGetAttributes("../examples/datasets/iris_headers.csv", true)
if attrs[0].GetType() != Float64Type {
testEnv.Errorf("First attribute should be a float, %s", attrs[0])
t.Errorf("First attribute should be a float, %s", attrs[0])
}
if attrs[0].GetName() != "Sepal length" {
testEnv.Errorf(attrs[0].GetName())
t.Errorf(attrs[0].GetName())
}
if attrs[4].GetType() != CategoricalType {
testEnv.Errorf("Final attribute should be categorical, %s", attrs[4])
t.Errorf("Final attribute should be categorical, %s", attrs[4])
}
if attrs[4].GetName() != "Species" {
testEnv.Error(attrs[4])
t.Error(attrs[4])
}
}
func TestParseCsvSniffAttributeTypes(testEnv *testing.T) {
func TestParseCsvSniffAttributeTypes(t *testing.T) {
attrs := ParseCSVSniffAttributeTypes("../examples/datasets/iris_headers.csv", true)
if attrs[0].GetType() != Float64Type {
testEnv.Errorf("First attribute should be a float, %s", attrs[0])
t.Errorf("First attribute should be a float, %s", attrs[0])
}
if attrs[1].GetType() != Float64Type {
testEnv.Errorf("Second attribute should be a float, %s", attrs[1])
t.Errorf("Second attribute should be a float, %s", attrs[1])
}
if attrs[2].GetType() != Float64Type {
testEnv.Errorf("Third attribute should be a float, %s", attrs[2])
t.Errorf("Third attribute should be a float, %s", attrs[2])
}
if attrs[3].GetType() != Float64Type {
testEnv.Errorf("Fourth attribute should be a float, %s", attrs[3])
t.Errorf("Fourth attribute should be a float, %s", attrs[3])
}
if attrs[4].GetType() != CategoricalType {
testEnv.Errorf("Final attribute should be categorical, %s", attrs[4])
t.Errorf("Final attribute should be categorical, %s", attrs[4])
}
}
func TestParseCSVSniffAttributeNamesWithHeaders(testEnv *testing.T) {
func TestParseCSVSniffAttributeNamesWithHeaders(t *testing.T) {
attrs := ParseCSVSniffAttributeNames("../examples/datasets/iris_headers.csv", true)
if attrs[0] != "Sepal length" {
testEnv.Error(attrs[0])
t.Error(attrs[0])
}
if attrs[1] != "Sepal width" {
testEnv.Error(attrs[1])
t.Error(attrs[1])
}
if attrs[2] != "Petal length" {
testEnv.Error(attrs[2])
t.Error(attrs[2])
}
if attrs[3] != "Petal width" {
testEnv.Error(attrs[3])
t.Error(attrs[3])
}
if attrs[4] != "Species" {
testEnv.Error(attrs[4])
t.Error(attrs[4])
}
}
func TestParseCSVToInstances(testEnv *testing.T) {
func TestParseCSVToInstances(t *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
row1 := inst.RowString(0)
@ -97,34 +97,34 @@ func TestParseCSVToInstances(testEnv *testing.T) {
row3 := inst.RowString(100)
if row1 != "5.10 3.50 1.40 0.20 Iris-setosa" {
testEnv.Error(row1)
t.Error(row1)
}
if row2 != "7.00 3.20 4.70 1.40 Iris-versicolor" {
testEnv.Error(row2)
t.Error(row2)
}
if row3 != "6.30 3.30 6.00 2.50 Iris-virginica" {
testEnv.Error(row3)
t.Error(row3)
}
}
func TestParseCSVToInstancesWithMissingFile(testEnv *testing.T) {
func TestParseCSVToInstancesWithMissingFile(t *testing.T) {
_, err := ParseCSVToInstances("../examples/datasets/non-existent.csv", true)
if err == nil {
testEnv.Fatal("Expected ParseCSVToInstances to return error when given path to non-existent file")
t.Fatal("Expected ParseCSVToInstances to return error when given path to non-existent file")
}
}
func TestReadAwkwardInsatnces(testEnv *testing.T) {
func TestReadAwkwardInsatnces(t *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/chim.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
attrs := inst.AllAttributes()
if attrs[0].GetType() != Float64Type {
testEnv.Error("Should be float!")
t.Error("Should be float!")
}
if attrs[1].GetType() != CategoricalType {
testEnv.Error("Should be discrete!")
t.Error("Should be discrete!")
}
}

View File

@ -6,13 +6,13 @@ import (
"testing"
)
func TestThreadDeserialize(T *testing.T) {
func TestThreadDeserialize(t *testing.T) {
bytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
Convey("Given a byte slice", T, func() {
var t Thread
size := t.Deserialize(bytes)
Convey("Given a byte slice", t, func() {
var thread Thread
size := thread.Deserialize(bytes)
Convey("Decoded name should be SYSTEM", func() {
So(t.name, ShouldEqual, "SYSTEM")
So(thread.name, ShouldEqual, "SYSTEM")
})
Convey("Size should be the same as the array", func() {
So(size, ShouldEqual, len(bytes))
@ -20,20 +20,20 @@ func TestThreadDeserialize(T *testing.T) {
})
}
func TestThreadSerialize(T *testing.T) {
var t Thread
func TestThreadSerialize(t *testing.T) {
var thread Thread
refBytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
t.name = "SYSTEM"
t.id = 1
thread.name = "SYSTEM"
thread.id = 1
toBytes := make([]byte, len(refBytes))
Convey("Should serialize correctly", T, func() {
t.Serialize(toBytes)
Convey("Should serialize correctly", t, func() {
thread.Serialize(toBytes)
So(toBytes, ShouldResemble, refBytes)
})
}
func TestThreadFindAndWrite(T *testing.T) {
Convey("Creating a non-existent file should succeed", T, func() {
func TestThreadFindAndWrite(t *testing.T) {
Convey("Creating a non-existent file should succeed", t, func() {
tempFile, err := os.OpenFile("hello.db", os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate")
So(err, ShouldEqual, nil)
Convey("Mapping the file should succeed", func() {

View File

@ -4,15 +4,15 @@ import (
"testing"
)
func TestLazySortDesc(testEnv *testing.T) {
func TestLazySortDesc(t *testing.T) {
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
@ -20,67 +20,67 @@ func TestLazySortDesc(testEnv *testing.T) {
as2 := ResolveAllAttributes(inst2)
if isSortedDesc(inst1, as1[0]) {
testEnv.Error("Can't test descending sort order")
t.Error("Can't test descending sort order")
}
if !isSortedDesc(inst2, as2[0]) {
testEnv.Error("Reference data not sorted in descending order!")
t.Error("Reference data not sorted in descending order!")
}
inst, err := LazySort(inst1, Descending, as1[0:len(as1)-1])
if err != nil {
testEnv.Error(err)
t.Error(err)
}
if !isSortedDesc(inst, as1[0]) {
testEnv.Error("Instances are not sorted in descending order")
testEnv.Error(inst1)
t.Error("Instances are not sorted in descending order")
t.Error(inst1)
}
if !inst2.Equal(inst) {
testEnv.Error("Instances don't match")
testEnv.Error(inst)
testEnv.Error(inst2)
t.Error("Instances don't match")
t.Error(inst)
t.Error(inst2)
}
}
func TestLazySortAsc(testEnv *testing.T) {
func TestLazySortAsc(t *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
as1 := ResolveAllAttributes(inst)
if isSortedAsc(inst, as1[0]) {
testEnv.Error("Can't test ascending sort on something ascending already")
t.Error("Can't test ascending sort on something ascending already")
}
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
insts, err := LazySort(inst, Ascending, as1)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
if !isSortedAsc(insts, as1[0]) {
testEnv.Error("Instances are not sorted in ascending order")
testEnv.Error(insts)
t.Error("Instances are not sorted in ascending order")
t.Error(insts)
}
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
as2 := ResolveAllAttributes(inst2)
if !isSortedAsc(inst2, as2[0]) {
testEnv.Error("This file should be sorted in ascending order")
t.Error("This file should be sorted in ascending order")
}
if !inst2.Equal(insts) {
testEnv.Error("Instances don't match")
testEnv.Error(inst)
testEnv.Error(inst2)
t.Error("Instances don't match")
t.Error(inst)
t.Error(inst2)
}
rowStr := insts.RowString(0)
ref := "4.30 3.00 1.10 0.10 Iris-setosa"
if rowStr != ref {
testEnv.Fatalf("'%s' != '%s'", rowStr, ref)
t.Fatalf("'%s' != '%s'", rowStr, ref)
}
}

View File

@ -32,15 +32,15 @@ func isSortedDesc(inst FixedDataGrid, attr AttributeSpec) bool {
return true
}
func TestSortDesc(testEnv *testing.T) {
func TestSortDesc(t *testing.T) {
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
@ -48,57 +48,57 @@ func TestSortDesc(testEnv *testing.T) {
as2 := ResolveAllAttributes(inst2)
if isSortedDesc(inst1, as1[0]) {
testEnv.Error("Can't test descending sort order")
t.Error("Can't test descending sort order")
}
if !isSortedDesc(inst2, as2[0]) {
testEnv.Error("Reference data not sorted in descending order!")
t.Error("Reference data not sorted in descending order!")
}
Sort(inst1, Descending, as1[0:len(as1)-1])
if err != nil {
testEnv.Error(err)
t.Error(err)
}
if !isSortedDesc(inst1, as1[0]) {
testEnv.Error("Instances are not sorted in descending order")
testEnv.Error(inst1)
t.Error("Instances are not sorted in descending order")
t.Error(inst1)
}
if !inst2.Equal(inst1) {
testEnv.Error("Instances don't match")
testEnv.Error(inst1)
testEnv.Error(inst2)
t.Error("Instances don't match")
t.Error(inst1)
t.Error(inst2)
}
}
func TestSortAsc(testEnv *testing.T) {
func TestSortAsc(t *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
as1 := ResolveAllAttributes(inst)
if isSortedAsc(inst, as1[0]) {
testEnv.Error("Can't test ascending sort on something ascending already")
t.Error("Can't test ascending sort on something ascending already")
}
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
Sort(inst, Ascending, as1[0:1])
if !isSortedAsc(inst, as1[0]) {
testEnv.Error("Instances are not sorted in ascending order")
testEnv.Error(inst)
t.Error("Instances are not sorted in ascending order")
t.Error(inst)
}
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
if err != nil {
testEnv.Error(err)
t.Error(err)
return
}
as2 := ResolveAllAttributes(inst2)
if !isSortedAsc(inst2, as2[0]) {
testEnv.Error("This file should be sorted in ascending order")
t.Error("This file should be sorted in ascending order")
}
if !inst2.Equal(inst) {
testEnv.Error("Instances don't match")
testEnv.Error(inst)
testEnv.Error(inst2)
t.Error("Instances don't match")
t.Error(inst)
t.Error(inst2)
}
}

View File

@ -7,10 +7,10 @@ import (
"testing"
)
func TestRandomForest1(testEnv *testing.T) {
func TestRandomForest1(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
filt := filters.NewChiMergeFilter(inst, 0.90)

View File

@ -5,7 +5,7 @@ import (
"testing"
)
func TestMetrics(testEnv *testing.T) {
func TestMetrics(t *testing.T) {
confusionMat := make(ConfusionMatrix)
confusionMat["a"] = make(map[string]int)
confusionMat["b"] = make(map[string]int)
@ -16,89 +16,89 @@ func TestMetrics(testEnv *testing.T) {
tp := GetTruePositives("a", confusionMat)
if math.Abs(tp-75) >= 1 {
testEnv.Error(tp)
t.Error(tp)
}
tp = GetTruePositives("b", confusionMat)
if math.Abs(tp-10) >= 1 {
testEnv.Error(tp)
t.Error(tp)
}
fn := GetFalseNegatives("a", confusionMat)
if math.Abs(fn-5) >= 1 {
testEnv.Error(fn)
t.Error(fn)
}
fn = GetFalseNegatives("b", confusionMat)
if math.Abs(fn-10) >= 1 {
testEnv.Error(fn)
t.Error(fn)
}
tn := GetTrueNegatives("a", confusionMat)
if math.Abs(tn-10) >= 1 {
testEnv.Error(tn)
t.Error(tn)
}
tn = GetTrueNegatives("b", confusionMat)
if math.Abs(tn-75) >= 1 {
testEnv.Error(tn)
t.Error(tn)
}
fp := GetFalsePositives("a", confusionMat)
if math.Abs(fp-10) >= 1 {
testEnv.Error(fp)
t.Error(fp)
}
fp = GetFalsePositives("b", confusionMat)
if math.Abs(fp-5) >= 1 {
testEnv.Error(fp)
t.Error(fp)
}
precision := GetPrecision("a", confusionMat)
recall := GetRecall("a", confusionMat)
if math.Abs(precision-0.88) >= 0.01 {
testEnv.Error(precision)
t.Error(precision)
}
if math.Abs(recall-0.94) >= 0.01 {
testEnv.Error(recall)
t.Error(recall)
}
precision = GetPrecision("b", confusionMat)
recall = GetRecall("b", confusionMat)
if math.Abs(precision-0.666) >= 0.01 {
testEnv.Error(precision)
t.Error(precision)
}
if math.Abs(recall-0.50) >= 0.01 {
testEnv.Error(recall)
t.Error(recall)
}
precision = GetMicroPrecision(confusionMat)
if math.Abs(precision-0.85) >= 0.01 {
testEnv.Error(precision)
t.Error(precision)
}
recall = GetMicroRecall(confusionMat)
if math.Abs(recall-0.85) >= 0.01 {
testEnv.Error(recall)
t.Error(recall)
}
precision = GetMacroPrecision(confusionMat)
if math.Abs(precision-0.775) >= 0.01 {
testEnv.Error(precision)
t.Error(precision)
}
recall = GetMacroRecall(confusionMat)
if math.Abs(recall-0.719) > 0.01 {
testEnv.Error(recall)
t.Error(recall)
}
fmeasure := GetF1Score("a", confusionMat)
if math.Abs(fmeasure-0.91) >= 0.1 {
testEnv.Error(fmeasure)
t.Error(fmeasure)
}
accuracy := GetAccuracy(confusionMat)
if math.Abs(accuracy-0.85) >= 0.1 {
testEnv.Error(accuracy)
t.Error(accuracy)
}
}

View File

@ -6,104 +6,104 @@ import (
"testing"
)
func TestChiMFreqTable(testEnv *testing.T) {
func TestChiMFreqTable(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
if freq[0].Frequency["c1"] != 1 {
testEnv.Error("Wrong frequency")
t.Error("Wrong frequency")
}
if freq[0].Frequency["c3"] != 4 {
testEnv.Errorf("Wrong frequency %s", freq[1])
t.Errorf("Wrong frequency %s", freq[1])
}
if freq[10].Frequency["c2"] != 1 {
testEnv.Error("Wrong frequency")
t.Error("Wrong frequency")
}
}
func TestChiClassCounter(testEnv *testing.T) {
func TestChiClassCounter(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
classes := chiCountClasses(freq)
if classes["c1"] != 27 {
testEnv.Error(classes)
t.Error(classes)
}
if classes["c2"] != 12 {
testEnv.Error(classes)
t.Error(classes)
}
if classes["c3"] != 21 {
testEnv.Error(classes)
t.Error(classes)
}
}
func TestStatisticValues(testEnv *testing.T) {
func TestStatisticValues(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
chiVal := chiComputeStatistic(freq[5], freq[6])
if math.Abs(chiVal-1.89) > 0.01 {
testEnv.Error(chiVal)
t.Error(chiVal)
}
chiVal = chiComputeStatistic(freq[1], freq[2])
if math.Abs(chiVal-1.08) > 0.01 {
testEnv.Error(chiVal)
t.Error(chiVal)
}
}
func TestChiSquareDistValues(testEnv *testing.T) {
func TestChiSquareDistValues(t *testing.T) {
chiVal1 := chiSquaredPercentile(2, 4.61)
chiVal2 := chiSquaredPercentile(3, 7.82)
chiVal3 := chiSquaredPercentile(4, 13.28)
if math.Abs(chiVal1-0.90) > 0.001 {
testEnv.Error(chiVal1)
t.Error(chiVal1)
}
if math.Abs(chiVal2-0.95) > 0.001 {
testEnv.Error(chiVal2)
t.Error(chiVal2)
}
if math.Abs(chiVal3-0.99) > 0.001 {
testEnv.Error(chiVal3)
t.Error(chiVal3)
}
}
func TestChiMerge1(testEnv *testing.T) {
func TestChiMerge1(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
_, rows := inst.Size()
freq := chiMerge(inst, inst.AllAttributes()[0], 0.90, 0, rows)
if len(freq) != 3 {
testEnv.Error("Wrong length")
t.Error("Wrong length")
}
if freq[0].Value != 1.3 {
testEnv.Error(freq[0])
t.Error(freq[0])
}
if freq[1].Value != 56.2 {
testEnv.Error(freq[1])
t.Error(freq[1])
}
if freq[2].Value != 87.1 {
testEnv.Error(freq[2])
t.Error(freq[2])
}
}
func TestChiMerge2(testEnv *testing.T) {
func TestChiMerge2(t *testing.T) {
//
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
// Sort the instances
@ -111,35 +111,35 @@ func TestChiMerge2(testEnv *testing.T) {
sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
if err != nil {
testEnv.Fatalf("Sort failed: %s", err.Error())
t.Fatalf("Sort failed: %s", err.Error())
}
// Perform Chi-Merge
_, rows := inst.Size()
freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows)
if len(freq) != 5 {
testEnv.Errorf("Wrong length (%d)", len(freq))
testEnv.Error(freq)
t.Errorf("Wrong length (%d)", len(freq))
t.Error(freq)
}
if freq[0].Value != 4.3 {
testEnv.Error(freq[0])
t.Error(freq[0])
}
if freq[1].Value != 5.5 {
testEnv.Error(freq[1])
t.Error(freq[1])
}
if freq[2].Value != 5.8 {
testEnv.Error(freq[2])
t.Error(freq[2])
}
if freq[3].Value != 6.3 {
testEnv.Error(freq[3])
t.Error(freq[3])
}
if freq[4].Value != 7.1 {
testEnv.Error(freq[4])
t.Error(freq[4])
}
}
/*
func TestChiMerge3(testEnv *testing.T) {
func TestChiMerge3(t *testing.T) {
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
@ -149,7 +149,7 @@ func TestChiMerge3(testEnv *testing.T) {
insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes()))
if err != nil {
testEnv.Error(err)
t.Error(err)
}
filt := NewChiMergeFilter(inst, 0.90)
filt.AddAttribute(inst.AllAttributes()[0])
@ -172,12 +172,12 @@ func TestChiMerge3(testEnv *testing.T) {
}
*/
func TestChiMerge4(testEnv *testing.T) {
func TestChiMerge4(t *testing.T) {
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
filt := NewChiMergeFilter(inst, 0.90)
@ -187,11 +187,11 @@ func TestChiMerge4(testEnv *testing.T) {
instf := base.NewLazilyFilteredInstances(inst, filt)
clsAttrs := instf.AllClassAttributes()
if len(clsAttrs) != 1 {
testEnv.Fatalf("%d != %d", len(clsAttrs), 1)
t.Fatalf("%d != %d", len(clsAttrs), 1)
}
firstClassAttributeName := clsAttrs[0].GetName()
expectedClassAttributeName := "Species"
if firstClassAttributeName != expectedClassAttributeName {
testEnv.Fatalf("Expected class attribute '%s'; actual class attribute '%s'", expectedClassAttributeName, firstClassAttributeName)
t.Fatalf("Expected class attribute '%s'; actual class attribute '%s'", expectedClassAttributeName, firstClassAttributeName)
}
}

View File

@ -10,10 +10,10 @@ import (
"time"
)
func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
func BenchmarkBaggingRandomForestFit(t *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
rand.Seed(time.Now().UnixNano())
@ -29,16 +29,16 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
rf.AddModel(trees.NewRandomTree(2))
}
testEnv.ResetTimer()
t.ResetTimer()
for i := 0; i < 20; i++ {
rf.Fit(instf)
}
}
func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
func BenchmarkBaggingRandomForestPredict(t *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
rand.Seed(time.Now().UnixNano())
@ -55,16 +55,16 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
}
rf.Fit(instf)
testEnv.ResetTimer()
t.ResetTimer()
for i := 0; i < 20; i++ {
rf.Predict(instf)
}
}
func TestRandomForest1(testEnv *testing.T) {
func TestRandomForest1(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)

View File

@ -8,10 +8,10 @@ import (
"testing"
)
func TestRandomTree(testEnv *testing.T) {
func TestRandomTree(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
filt := filters.NewChiMergeFilter(inst, 0.90)
@ -27,10 +27,10 @@ func TestRandomTree(testEnv *testing.T) {
_ = InferID3Tree(instf, r)
}
func TestRandomTreeClassification(testEnv *testing.T) {
func TestRandomTreeClassification(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
@ -52,10 +52,10 @@ func TestRandomTreeClassification(testEnv *testing.T) {
_ = eval.GetSummary(confusionMat)
}
func TestRandomTreeClassification2(testEnv *testing.T) {
func TestRandomTreeClassification2(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
@ -75,10 +75,10 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
_ = eval.GetSummary(confusionMat)
}
func TestPruning(testEnv *testing.T) {
func TestPruning(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
@ -100,7 +100,7 @@ func TestPruning(testEnv *testing.T) {
_ = eval.GetSummary(confusionMat)
}
func TestInformationGain(testEnv *testing.T) {
func TestInformationGain(t *testing.T) {
outlook := make(map[string]map[string]int)
outlook["sunny"] = make(map[string]int)
outlook["overcast"] = make(map[string]int)
@ -113,14 +113,14 @@ func TestInformationGain(testEnv *testing.T) {
entropy := getSplitEntropy(outlook)
if math.Abs(entropy-0.694) > 0.001 {
testEnv.Error(entropy)
t.Error(entropy)
}
}
func TestID3Inference(testEnv *testing.T) {
func TestID3Inference(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
// Build the decision tree
@ -130,47 +130,47 @@ func TestID3Inference(testEnv *testing.T) {
// Verify the tree
// First attribute should be "outlook"
if root.SplitAttr.GetName() != "outlook" {
testEnv.Error(root)
t.Error(root)
}
sunnyChild := root.Children["sunny"]
overcastChild := root.Children["overcast"]
rainyChild := root.Children["rainy"]
if sunnyChild.SplitAttr.GetName() != "humidity" {
testEnv.Error(sunnyChild)
t.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
testEnv.Error(rainyChild)
t.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
testEnv.Error(overcastChild)
t.Error(overcastChild)
}
sunnyLeafHigh := sunnyChild.Children["high"]
sunnyLeafNormal := sunnyChild.Children["normal"]
if sunnyLeafHigh.Class != "no" {
testEnv.Error(sunnyLeafHigh)
t.Error(sunnyLeafHigh)
}
if sunnyLeafNormal.Class != "yes" {
testEnv.Error(sunnyLeafNormal)
t.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
testEnv.Error(windyLeafFalse)
t.Error(windyLeafFalse)
}
if windyLeafTrue.Class != "no" {
testEnv.Error(windyLeafTrue)
t.Error(windyLeafTrue)
}
if overcastChild.Class != "yes" {
testEnv.Error(overcastChild)
t.Error(overcastChild)
}
}
func TestID3Classification(testEnv *testing.T) {
func TestID3Classification(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
filt := filters.NewBinningFilter(inst, 10)
@ -191,10 +191,10 @@ func TestID3Classification(testEnv *testing.T) {
_ = eval.GetSummary(confusionMat)
}
func TestID3(testEnv *testing.T) {
func TestID3(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
if err != nil {
testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error())
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
// Build the decision tree
@ -205,40 +205,40 @@ func TestID3(testEnv *testing.T) {
// Verify the tree
// First attribute should be "outlook"
if root.SplitAttr.GetName() != "outlook" {
testEnv.Error(root)
t.Error(root)
}
sunnyChild := root.Children["sunny"]
overcastChild := root.Children["overcast"]
rainyChild := root.Children["rainy"]
if sunnyChild.SplitAttr.GetName() != "humidity" {
testEnv.Error(sunnyChild)
t.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
testEnv.Error(rainyChild)
t.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
testEnv.Error(overcastChild)
t.Error(overcastChild)
}
sunnyLeafHigh := sunnyChild.Children["high"]
sunnyLeafNormal := sunnyChild.Children["normal"]
if sunnyLeafHigh.Class != "no" {
testEnv.Error(sunnyLeafHigh)
t.Error(sunnyLeafHigh)
}
if sunnyLeafNormal.Class != "yes" {
testEnv.Error(sunnyLeafNormal)
t.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
testEnv.Error(windyLeafFalse)
t.Error(windyLeafFalse)
}
if windyLeafTrue.Class != "no" {
testEnv.Error(windyLeafTrue)
t.Error(windyLeafTrue)
}
if overcastChild.Class != "yes" {
testEnv.Error(overcastChild)
t.Error(overcastChild)
}
}