mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
161 lines
5.1 KiB
Go
161 lines
5.1 KiB
Go
package filters
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"testing"
|
|
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
)
|
|
|
|
func TestChiMergeFrequencyTable(t *testing.T) {
|
|
Convey("Chi-Merge Frequency Table", t, func() {
|
|
instances, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
|
So(err, ShouldBeNil)
|
|
|
|
frequencyTable := ChiMBuildFrequencyTable(instances.AllAttributes()[0], instances)
|
|
|
|
Convey("Computes frequencies correctly", func() {
|
|
So(frequencyTable[0].Frequency["c1"], ShouldEqual, 1)
|
|
So(frequencyTable[0].Frequency["c3"], ShouldEqual, 4)
|
|
So(frequencyTable[10].Frequency["c2"], ShouldEqual, 1)
|
|
})
|
|
|
|
Convey("Counts classes correctly", func() {
|
|
classes := chiCountClasses(frequencyTable)
|
|
|
|
So(classes["c1"], ShouldEqual, 27)
|
|
So(classes["c2"], ShouldEqual, 12)
|
|
So(classes["c3"], ShouldEqual, 21)
|
|
})
|
|
|
|
Convey("Computes statistics correctly", func() {
|
|
So(chiComputeStatistic(frequencyTable[5], frequencyTable[6]), ShouldAlmostEqual, 1.89, 0.01)
|
|
So(chiComputeStatistic(frequencyTable[1], frequencyTable[2]), ShouldAlmostEqual, 1.08, 0.01)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestChiSquaredDistribution(t *testing.T) {
|
|
Convey("Chi-Squared Distribution percentiles are computed correctly", t, func() {
|
|
So(chiSquaredPercentile(2, 4.61), ShouldAlmostEqual, 0.9, 0.001)
|
|
So(chiSquaredPercentile(3, 7.82), ShouldAlmostEqual, 0.95, 0.001)
|
|
So(chiSquaredPercentile(4, 13.28), ShouldAlmostEqual, 0.99, 0.001)
|
|
})
|
|
}
|
|
|
|
func TestChiMergeDiscretization(t *testing.T) {
|
|
Convey("Chi-Merge Discretization", t, func() {
|
|
chimDatasetPath := "../examples/datasets/chim.csv"
|
|
|
|
Convey(fmt.Sprintf("With the '%s' dataset", chimDatasetPath), func() {
|
|
instances, err := base.ParseCSVToInstances(chimDatasetPath, true)
|
|
So(err, ShouldBeNil)
|
|
|
|
_, rows := instances.Size()
|
|
|
|
frequencies := chiMerge(instances, instances.AllAttributes()[0], 0.9, 0, rows)
|
|
values := []float64{}
|
|
for _, entry := range frequencies {
|
|
values = append(values, entry.Value)
|
|
}
|
|
|
|
Convey("Computes frequencies correctly", func() {
|
|
So(values, ShouldResemble, []float64{1.3, 56.2, 87.1})
|
|
})
|
|
})
|
|
|
|
irisHeadersDatasetpath := "../examples/datasets/iris_headers.csv"
|
|
|
|
Convey(fmt.Sprintf("With the '%s' dataset", irisHeadersDatasetpath), func() {
|
|
instances, err := base.ParseCSVToInstances(irisHeadersDatasetpath, true)
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Sorting the instances first", func() {
|
|
allAttributes := instances.AllAttributes()
|
|
sortedAttributesSpecs := base.ResolveAttributes(instances, allAttributes)[0:1]
|
|
sortedInstances, err := base.Sort(instances, base.Ascending, sortedAttributesSpecs)
|
|
So(err, ShouldBeNil)
|
|
|
|
_, rows := sortedInstances.Size()
|
|
|
|
frequencies := chiMerge(sortedInstances, sortedInstances.AllAttributes()[0], 0.9, 0, rows)
|
|
values := []float64{}
|
|
for _, entry := range frequencies {
|
|
values = append(values, entry.Value)
|
|
}
|
|
|
|
Convey("Computes frequencies correctly", func() {
|
|
So(values, ShouldResemble, []float64{4.3, 5.5, 5.8, 6.3, 7.1})
|
|
})
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestChiMergeFilter(t *testing.T) {
|
|
Convey("Chi-Merge Filter", t, func() {
|
|
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
|
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
|
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Create and train the filter", func() {
|
|
filter := NewChiMergeFilter(instances, 0.90)
|
|
filter.AddAttribute(instances.AllAttributes()[0])
|
|
filter.AddAttribute(instances.AllAttributes()[1])
|
|
filter.Train()
|
|
|
|
Convey("Filter the dataset", func() {
|
|
filteredInstances := base.NewLazilyFilteredInstances(instances, filter)
|
|
|
|
classAttributes := filteredInstances.AllClassAttributes()
|
|
|
|
Convey("There should only be one class attribute", func() {
|
|
So(len(classAttributes), ShouldEqual, 1)
|
|
})
|
|
|
|
expectedClassAttribute := "Species"
|
|
|
|
Convey(fmt.Sprintf("The class attribute should be %s", expectedClassAttribute), func() {
|
|
So(classAttributes[0].GetName(), ShouldEqual, expectedClassAttribute)
|
|
})
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
/*
|
|
func TestChiMerge3(t *testing.T) {
|
|
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
|
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes()))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
filt := NewChiMergeFilter(inst, 0.90)
|
|
filt.AddAttribute(inst.AllAttributes()[0])
|
|
filt.Train()
|
|
instf := base.NewLazilyFilteredInstances(insts, filt)
|
|
fmt.Println(instf)
|
|
fmt.Println(instf.String())
|
|
rowStr := instf.RowString(0)
|
|
ref := "4.300000 3.00 1.10 0.10 Iris-setosa"
|
|
if rowStr != ref {
|
|
panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref))
|
|
}
|
|
clsAttrs := instf.AllClassAttributes()
|
|
if len(clsAttrs) != 1 {
|
|
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
|
|
}
|
|
if clsAttrs[0].GetName() != "Species" {
|
|
panic("Class Attribute wrong!")
|
|
}
|
|
}
|
|
*/
|