mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
filters: merge from v2-instances
This commit is contained in:
parent
a9028b8174
commit
3821477b0f
@ -9,113 +9,110 @@ import (
|
|||||||
// BinningFilter does equal-width binning for numeric
|
// BinningFilter does equal-width binning for numeric
|
||||||
// Attributes (aka "histogram binning")
|
// Attributes (aka "histogram binning")
|
||||||
type BinningFilter struct {
|
type BinningFilter struct {
|
||||||
Attributes []int
|
AbstractDiscretizeFilter
|
||||||
Instances *base.Instances
|
bins int
|
||||||
BinCount int
|
minVals map[base.Attribute]float64
|
||||||
MinVals map[int]float64
|
maxVals map[base.Attribute]float64
|
||||||
MaxVals map[int]float64
|
|
||||||
trained bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBinningFilter creates a BinningFilter structure
|
// NewBinningFilter creates a BinningFilter structure
|
||||||
// with some helpful default initialisations.
|
// with some helpful default initialisations.
|
||||||
func NewBinningFilter(inst *base.Instances, bins int) BinningFilter {
|
func NewBinningFilter(d base.FixedDataGrid, bins int) *BinningFilter {
|
||||||
return BinningFilter{
|
return &BinningFilter{
|
||||||
make([]int, 0),
|
AbstractDiscretizeFilter{
|
||||||
inst,
|
make(map[base.Attribute]bool),
|
||||||
bins,
|
|
||||||
make(map[int]float64),
|
|
||||||
make(map[int]float64),
|
|
||||||
false,
|
false,
|
||||||
|
d,
|
||||||
|
},
|
||||||
|
bins,
|
||||||
|
make(map[base.Attribute]float64),
|
||||||
|
make(map[base.Attribute]float64),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAttribute adds the index of the given attribute `a'
|
// Train computes and stores the bin values
|
||||||
// to the BinningFilter for discretisation.
|
|
||||||
func (b *BinningFilter) AddAttribute(a base.Attribute) {
|
|
||||||
attrIndex := b.Instances.GetAttrIndex(a)
|
|
||||||
if attrIndex == -1 {
|
|
||||||
panic("invalid attribute")
|
|
||||||
}
|
|
||||||
b.Attributes = append(b.Attributes, attrIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAllNumericAttributes adds every suitable attribute
|
|
||||||
// to the BinningFilter for discretiation
|
|
||||||
func (b *BinningFilter) AddAllNumericAttributes() {
|
|
||||||
for i := 0; i < b.Instances.Cols; i++ {
|
|
||||||
if i == b.Instances.ClassIndex {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
attr := b.Instances.GetAttr(i)
|
|
||||||
if attr.GetType() != base.Float64Type {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b.Attributes = append(b.Attributes, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build computes and stores the bin values
|
|
||||||
// for the training instances.
|
// for the training instances.
|
||||||
func (b *BinningFilter) Build() {
|
func (b *BinningFilter) Train() error {
|
||||||
for _, attr := range b.Attributes {
|
|
||||||
maxVal := math.Inf(-1)
|
as := b.getAttributeSpecs()
|
||||||
minVal := math.Inf(1)
|
// Set up the AttributeSpecs, and values
|
||||||
for i := 0; i < b.Instances.Rows; i++ {
|
for attr := range b.attrs {
|
||||||
val := b.Instances.Get(i, attr)
|
if !b.attrs[attr] {
|
||||||
if val > maxVal {
|
continue
|
||||||
maxVal = val
|
|
||||||
}
|
|
||||||
if val < minVal {
|
|
||||||
minVal = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.MaxVals[attr] = maxVal
|
|
||||||
b.MinVals[attr] = minVal
|
|
||||||
b.trained = true
|
|
||||||
}
|
}
|
||||||
|
b.minVals[attr] = float64(math.Inf(1))
|
||||||
|
b.maxVals[attr] = float64(math.Inf(-1))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run applies a trained BinningFilter to a set of Instances,
|
err := b.train.MapOverRows(as, func(row [][]byte, rowNo int) (bool, error) {
|
||||||
// discretising any numeric attributes added.
|
for i, a := range row {
|
||||||
//
|
attr := as[i].GetAttribute()
|
||||||
// IMPORTANT: Run discretises in-place, so make sure to take
|
attrf := attr.(*base.FloatAttribute)
|
||||||
// a copy if the original instances are still needed
|
val := float64(attrf.GetFloatFromSysVal(a))
|
||||||
//
|
if val > b.maxVals[attr] {
|
||||||
// IMPORTANT: This function panic()s if the filter has not been
|
b.maxVals[attr] = val
|
||||||
// trained. Call Build() before running this function
|
|
||||||
//
|
|
||||||
// IMPORTANT: Call Build() after adding any additional attributes.
|
|
||||||
// Otherwise, the training structure will be out of date from
|
|
||||||
// the values expected and could cause a panic.
|
|
||||||
func (b *BinningFilter) Run(on *base.Instances) {
|
|
||||||
if !b.trained {
|
|
||||||
panic("Call Build() beforehand")
|
|
||||||
}
|
}
|
||||||
for attr := range b.Attributes {
|
if val < b.minVals[attr] {
|
||||||
minVal := b.MinVals[attr]
|
b.minVals[attr] = val
|
||||||
maxVal := b.MaxVals[attr]
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Training error: %s", err)
|
||||||
|
}
|
||||||
|
b.trained = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform takes an Attribute and byte sequence and returns
|
||||||
|
// the transformed byte sequence.
|
||||||
|
func (b *BinningFilter) Transform(a base.Attribute, field []byte) []byte {
|
||||||
|
|
||||||
|
if !b.attrs[a] {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
af, ok := a.(*base.FloatAttribute)
|
||||||
|
if !ok {
|
||||||
|
panic("Attribute is the wrong type")
|
||||||
|
}
|
||||||
|
minVal := b.minVals[a]
|
||||||
|
maxVal := b.maxVals[a]
|
||||||
disc := 0
|
disc := 0
|
||||||
// Casts to float32 to replicate a floating point precision error
|
// Casts to float64 to replicate a floating point precision error
|
||||||
delta := float32(maxVal - minVal)
|
delta := float64(maxVal-minVal) / float64(b.bins)
|
||||||
delta /= float32(b.BinCount)
|
val := float64(af.GetFloatFromSysVal(field))
|
||||||
for i := 0; i < on.Rows; i++ {
|
|
||||||
val := on.Get(i, attr)
|
|
||||||
if val <= minVal {
|
if val <= minVal {
|
||||||
disc = 0
|
disc = 0
|
||||||
} else {
|
} else {
|
||||||
disc = int(math.Floor(float64(float32(val-minVal) / delta)))
|
disc = int(math.Floor(float64(float64(val-minVal)/delta + 0.0001)))
|
||||||
if disc >= b.BinCount {
|
}
|
||||||
disc = b.BinCount - 1
|
return base.PackU64ToBytes(uint64(disc))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAttributesAfterFiltering gets a list of before/after
|
||||||
|
// Attributes as base.FilteredAttributes
|
||||||
|
func (b *BinningFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||||
|
oldAttrs := b.train.AllAttributes()
|
||||||
|
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
||||||
|
for i, a := range oldAttrs {
|
||||||
|
if b.attrs[a] {
|
||||||
|
retAttr := new(base.CategoricalAttribute)
|
||||||
|
minVal := b.minVals[a]
|
||||||
|
maxVal := b.maxVals[a]
|
||||||
|
delta := float64(maxVal-minVal) / float64(b.bins)
|
||||||
|
retAttr.SetName(a.GetName())
|
||||||
|
for i := 0; i <= b.bins; i++ {
|
||||||
|
floatVal := float64(i)*delta + minVal
|
||||||
|
fmtStr := fmt.Sprintf("%%.%df", a.(*base.FloatAttribute).Precision)
|
||||||
|
binVal := fmt.Sprintf(fmtStr, floatVal)
|
||||||
|
retAttr.GetSysValFromString(binVal)
|
||||||
|
}
|
||||||
|
ret[i] = base.FilteredAttribute{a, retAttr}
|
||||||
|
} else {
|
||||||
|
ret[i] = base.FilteredAttribute{a, a}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
on.Set(i, attr, float64(disc))
|
return ret
|
||||||
}
|
|
||||||
newAttribute := new(base.CategoricalAttribute)
|
|
||||||
newAttribute.SetName(on.GetAttr(attr).GetName())
|
|
||||||
for i := 0; i < b.BinCount; i++ {
|
|
||||||
newAttribute.GetSysValFromString(fmt.Sprintf("%d", i))
|
|
||||||
}
|
|
||||||
on.ReplaceAttr(attr, newAttribute)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -2,27 +2,39 @@ package filters
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
base "github.com/sjwhitworth/golearn/base"
|
||||||
"math"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBinning(testEnv *testing.T) {
|
func TestBinning(testEnv *testing.T) {
|
||||||
|
//
|
||||||
|
// Read the data
|
||||||
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
|
|
||||||
inst3, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// Construct the binning filter
|
||||||
|
binAttr := inst1.AllAttributes()[0]
|
||||||
filt := NewBinningFilter(inst1, 10)
|
filt := NewBinningFilter(inst1, 10)
|
||||||
filt.AddAttribute(inst1.GetAttr(0))
|
filt.AddAttribute(binAttr)
|
||||||
filt.Build()
|
filt.Train()
|
||||||
filt.Run(inst1)
|
inst1f := base.NewLazilyFilteredInstances(inst1, filt)
|
||||||
for i := 0; i < inst1.Rows; i++ {
|
|
||||||
val1 := inst1.Get(i, 0)
|
// Retrieve the categorical version of the original Attribute
|
||||||
val2 := inst2.Get(i, 0)
|
|
||||||
val3 := inst3.Get(i, 0)
|
//
|
||||||
if math.Abs(val1-val2) >= 1 {
|
// Create the LazilyFilteredInstances
|
||||||
testEnv.Error(val1, val2, val3, i)
|
// and check the values
|
||||||
}
|
Convey("Discretized version should match reference", testEnv, func() {
|
||||||
|
_, rows := inst1.Size()
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
So(inst1f.RowString(i), ShouldEqual, inst2.RowString(i))
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -12,301 +12,30 @@ import (
|
|||||||
// See Bramer, "Principles of Data Mining", 2nd Edition
|
// See Bramer, "Principles of Data Mining", 2nd Edition
|
||||||
// pp 105--115
|
// pp 105--115
|
||||||
type ChiMergeFilter struct {
|
type ChiMergeFilter struct {
|
||||||
Attributes []int
|
AbstractDiscretizeFilter
|
||||||
Instances *base.Instances
|
tables map[base.Attribute][]*FrequencyTableEntry
|
||||||
Tables map[int][]*FrequencyTableEntry
|
|
||||||
Significance float64
|
Significance float64
|
||||||
MinRows int
|
MinRows int
|
||||||
MaxRows int
|
MaxRows int
|
||||||
_Trained bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChiMergeFilter creates a ChiMergeFilter with some helpful initialisations.
|
// NewChiMergeFilter creates a ChiMergeFilter with some helpful intialisations.
|
||||||
func NewChiMergeFilter(inst *base.Instances, significance float64) ChiMergeFilter {
|
func NewChiMergeFilter(d base.FixedDataGrid, significance float64) *ChiMergeFilter {
|
||||||
return ChiMergeFilter{
|
_, rows := d.Size()
|
||||||
make([]int, 0),
|
return &ChiMergeFilter{
|
||||||
inst,
|
AbstractDiscretizeFilter{
|
||||||
make(map[int][]*FrequencyTableEntry),
|
make(map[base.Attribute]bool),
|
||||||
significance,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
false,
|
false,
|
||||||
|
d,
|
||||||
|
},
|
||||||
|
make(map[base.Attribute][]*FrequencyTableEntry),
|
||||||
|
significance,
|
||||||
|
2,
|
||||||
|
rows,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build trains a ChiMergeFilter on the ChiMergeFilter.Instances given
|
// Train computes and stores the
|
||||||
func (c *ChiMergeFilter) Build() {
|
|
||||||
for _, attr := range c.Attributes {
|
|
||||||
tab := chiMerge(c.Instances, attr, c.Significance, c.MinRows, c.MaxRows)
|
|
||||||
c.Tables[attr] = tab
|
|
||||||
c._Trained = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAllNumericAttributes adds every suitable attribute
|
|
||||||
// to the ChiMergeFilter for discretisation
|
|
||||||
func (c *ChiMergeFilter) AddAllNumericAttributes() {
|
|
||||||
for i := 0; i < c.Instances.Cols; i++ {
|
|
||||||
if i == c.Instances.ClassIndex {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
attr := c.Instances.GetAttr(i)
|
|
||||||
if attr.GetType() != base.Float64Type {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
c.Attributes = append(c.Attributes, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run discretises the set of Instances `on'
|
|
||||||
//
|
|
||||||
// IMPORTANT: ChiMergeFilter discretises in place.
|
|
||||||
func (c *ChiMergeFilter) Run(on *base.Instances) {
|
|
||||||
if !c._Trained {
|
|
||||||
panic("Call Build() beforehand")
|
|
||||||
}
|
|
||||||
for attr := range c.Tables {
|
|
||||||
table := c.Tables[attr]
|
|
||||||
for i := 0; i < on.Rows; i++ {
|
|
||||||
val := on.Get(i, attr)
|
|
||||||
dis := 0
|
|
||||||
for j, k := range table {
|
|
||||||
if k.Value < val {
|
|
||||||
dis = j
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
on.Set(i, attr, float64(dis))
|
|
||||||
}
|
|
||||||
newAttribute := new(base.CategoricalAttribute)
|
|
||||||
newAttribute.SetName(on.GetAttr(attr).GetName())
|
|
||||||
for _, k := range table {
|
|
||||||
newAttribute.GetSysValFromString(fmt.Sprintf("%f", k.Value))
|
|
||||||
}
|
|
||||||
on.ReplaceAttr(attr, newAttribute)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAttribute add a given numeric Attribute `attr' to the
|
|
||||||
// filter.
|
|
||||||
//
|
|
||||||
// IMPORTANT: This function panic()s if it can't locate the
|
|
||||||
// attribute in the Instances set.
|
|
||||||
func (c *ChiMergeFilter) AddAttribute(attr base.Attribute) {
|
|
||||||
if attr.GetType() != base.Float64Type {
|
|
||||||
panic("ChiMerge only works on Float64Attributes")
|
|
||||||
}
|
|
||||||
attrIndex := c.Instances.GetAttrIndex(attr)
|
|
||||||
if attrIndex == -1 {
|
|
||||||
panic("Invalid attribute!")
|
|
||||||
}
|
|
||||||
c.Attributes = append(c.Attributes, attrIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
type FrequencyTableEntry struct {
|
|
||||||
Value float64
|
|
||||||
Frequency map[string]int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *FrequencyTableEntry) String() string {
|
|
||||||
return fmt.Sprintf("%.2f %v", t.Value, t.Frequency)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ChiMBuildFrequencyTable(attr int, inst *base.Instances) []*FrequencyTableEntry {
|
|
||||||
ret := make([]*FrequencyTableEntry, 0)
|
|
||||||
var attribute *base.FloatAttribute
|
|
||||||
attribute, ok := inst.GetAttr(attr).(*base.FloatAttribute)
|
|
||||||
if !ok {
|
|
||||||
panic("only use Chi-M on numeric stuff")
|
|
||||||
}
|
|
||||||
for i := 0; i < inst.Rows; i++ {
|
|
||||||
value := inst.Get(i, attr)
|
|
||||||
valueConv := attribute.GetUsrVal(value)
|
|
||||||
class := inst.GetClass(i)
|
|
||||||
// Search the frequency table for the value
|
|
||||||
found := false
|
|
||||||
for _, entry := range ret {
|
|
||||||
if entry.Value == valueConv {
|
|
||||||
found = true
|
|
||||||
entry.Frequency[class]++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
newEntry := &FrequencyTableEntry{
|
|
||||||
valueConv,
|
|
||||||
make(map[string]int),
|
|
||||||
}
|
|
||||||
newEntry.Frequency[class] = 1
|
|
||||||
ret = append(ret, newEntry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func chiSquaredPdf(k float64, x float64) float64 {
|
|
||||||
if x < 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
top := math.Pow(x, (k/2)-1) * math.Exp(-x/2)
|
|
||||||
bottom := math.Pow(2, k/2) * math.Gamma(k/2)
|
|
||||||
return top / bottom
|
|
||||||
}
|
|
||||||
|
|
||||||
func chiSquaredPercentile(k int, x float64) float64 {
|
|
||||||
// Implements Yahya et al.'s "A Numerical Procedure
|
|
||||||
// for Computing Chi-Square Percentage Points"
|
|
||||||
// InterStat Journal 01/2007; April 25:page:1-8.
|
|
||||||
steps := 32
|
|
||||||
intervals := 4 * steps
|
|
||||||
w := x / (4.0 * float64(steps))
|
|
||||||
values := make([]float64, intervals+1)
|
|
||||||
for i := 0; i < intervals+1; i++ {
|
|
||||||
c := w * float64(i)
|
|
||||||
v := chiSquaredPdf(float64(k), c)
|
|
||||||
values[i] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
ret1 := values[0] + values[len(values)-1]
|
|
||||||
ret2 := 0.0
|
|
||||||
ret3 := 0.0
|
|
||||||
ret4 := 0.0
|
|
||||||
|
|
||||||
for i := 2; i < intervals-1; i += 4 {
|
|
||||||
ret2 += values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 4; i < intervals-3; i += 4 {
|
|
||||||
ret3 += values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 1; i < intervals; i += 2 {
|
|
||||||
ret4 += values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
return (2.0 * w / 45) * (7*ret1 + 12*ret2 + 14*ret3 + 32*ret4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func chiCountClasses(entries []*FrequencyTableEntry) map[string]int {
|
|
||||||
classCounter := make(map[string]int)
|
|
||||||
for _, e := range entries {
|
|
||||||
for k := range e.Frequency {
|
|
||||||
classCounter[k] += e.Frequency[k]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return classCounter
|
|
||||||
}
|
|
||||||
|
|
||||||
func chiComputeStatistic(entry1 *FrequencyTableEntry, entry2 *FrequencyTableEntry) float64 {
|
|
||||||
|
|
||||||
// Sum the number of things observed per class
|
|
||||||
classCounter := make(map[string]int)
|
|
||||||
for k := range entry1.Frequency {
|
|
||||||
classCounter[k] += entry1.Frequency[k]
|
|
||||||
}
|
|
||||||
for k := range entry2.Frequency {
|
|
||||||
classCounter[k] += entry2.Frequency[k]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sum the number of things observed per value
|
|
||||||
entryObservations1 := 0
|
|
||||||
entryObservations2 := 0
|
|
||||||
for k := range entry1.Frequency {
|
|
||||||
entryObservations1 += entry1.Frequency[k]
|
|
||||||
}
|
|
||||||
for k := range entry2.Frequency {
|
|
||||||
entryObservations2 += entry2.Frequency[k]
|
|
||||||
}
|
|
||||||
|
|
||||||
totalObservations := entryObservations1 + entryObservations2
|
|
||||||
// Compute the expected values per class
|
|
||||||
expectedClassValues1 := make(map[string]float64)
|
|
||||||
expectedClassValues2 := make(map[string]float64)
|
|
||||||
for k := range classCounter {
|
|
||||||
expectedClassValues1[k] = float64(classCounter[k])
|
|
||||||
expectedClassValues1[k] *= float64(entryObservations1)
|
|
||||||
expectedClassValues1[k] /= float64(totalObservations)
|
|
||||||
}
|
|
||||||
for k := range classCounter {
|
|
||||||
expectedClassValues2[k] = float64(classCounter[k])
|
|
||||||
expectedClassValues2[k] *= float64(entryObservations2)
|
|
||||||
expectedClassValues2[k] /= float64(totalObservations)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute chi-squared value
|
|
||||||
chiSum := 0.0
|
|
||||||
for k := range expectedClassValues1 {
|
|
||||||
numerator := float64(entry1.Frequency[k])
|
|
||||||
numerator -= expectedClassValues1[k]
|
|
||||||
numerator = math.Pow(numerator, 2)
|
|
||||||
denominator := float64(expectedClassValues1[k])
|
|
||||||
if denominator < 0.5 {
|
|
||||||
denominator = 0.5
|
|
||||||
}
|
|
||||||
chiSum += numerator / denominator
|
|
||||||
}
|
|
||||||
for k := range expectedClassValues2 {
|
|
||||||
numerator := float64(entry2.Frequency[k])
|
|
||||||
numerator -= expectedClassValues2[k]
|
|
||||||
numerator = math.Pow(numerator, 2)
|
|
||||||
denominator := float64(expectedClassValues2[k])
|
|
||||||
if denominator < 0.5 {
|
|
||||||
denominator = 0.5
|
|
||||||
}
|
|
||||||
chiSum += numerator / denominator
|
|
||||||
}
|
|
||||||
|
|
||||||
return chiSum
|
|
||||||
}
|
|
||||||
|
|
||||||
func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*FrequencyTableEntry {
|
|
||||||
mergeEntry1 := freq[minIndex]
|
|
||||||
mergeEntry2 := freq[minIndex+1]
|
|
||||||
classCounter := make(map[string]int)
|
|
||||||
for k := range mergeEntry1.Frequency {
|
|
||||||
classCounter[k] += mergeEntry1.Frequency[k]
|
|
||||||
}
|
|
||||||
for k := range mergeEntry2.Frequency {
|
|
||||||
classCounter[k] += mergeEntry2.Frequency[k]
|
|
||||||
}
|
|
||||||
newVal := freq[minIndex].Value
|
|
||||||
newEntry := &FrequencyTableEntry{
|
|
||||||
newVal,
|
|
||||||
classCounter,
|
|
||||||
}
|
|
||||||
lowerSlice := freq
|
|
||||||
upperSlice := freq
|
|
||||||
if minIndex > 0 {
|
|
||||||
lowerSlice = freq[0:minIndex]
|
|
||||||
upperSlice = freq[minIndex+1:]
|
|
||||||
} else {
|
|
||||||
lowerSlice = make([]*FrequencyTableEntry, 0)
|
|
||||||
upperSlice = freq[1:]
|
|
||||||
}
|
|
||||||
upperSlice[0] = newEntry
|
|
||||||
freq = append(lowerSlice, upperSlice...)
|
|
||||||
return freq
|
|
||||||
}
|
|
||||||
|
|
||||||
func chiMergePrintTable(freq []*FrequencyTableEntry) {
|
|
||||||
classes := chiCountClasses(freq)
|
|
||||||
fmt.Printf("Attribute value\t")
|
|
||||||
for k := range classes {
|
|
||||||
fmt.Printf("\t%s", k)
|
|
||||||
}
|
|
||||||
fmt.Printf("\tTotal\n")
|
|
||||||
for _, f := range freq {
|
|
||||||
fmt.Printf("%.2f\t", f.Value)
|
|
||||||
total := 0
|
|
||||||
for k := range classes {
|
|
||||||
fmt.Printf("\t%d", f.Frequency[k])
|
|
||||||
total += f.Frequency[k]
|
|
||||||
}
|
|
||||||
fmt.Printf("\t%d\n", total)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Produces a value mapping table
|
// Produces a value mapping table
|
||||||
// inst: The base.Instances which need discretising
|
// inst: The base.Instances which need discretising
|
||||||
// sig: The significance level (e.g. 0.95)
|
// sig: The significance level (e.g. 0.95)
|
||||||
@ -316,7 +45,7 @@ func chiMergePrintTable(freq []*FrequencyTableEntry) {
|
|||||||
// adjacent rows will be merged
|
// adjacent rows will be merged
|
||||||
// precision: internal number of decimal places to round E value to
|
// precision: internal number of decimal places to round E value to
|
||||||
// (useful for verification)
|
// (useful for verification)
|
||||||
func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
|
func chiMerge(inst base.FixedDataGrid, attr base.Attribute, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
|
||||||
|
|
||||||
// Parameter sanity checking
|
// Parameter sanity checking
|
||||||
if !(2 <= minrows) {
|
if !(2 <= minrows) {
|
||||||
@ -329,12 +58,17 @@ func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows
|
|||||||
sig = 10
|
sig = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that the attribute is numeric
|
||||||
|
_, ok := attr.(*base.FloatAttribute)
|
||||||
|
if !ok {
|
||||||
|
panic("only use Chi-M on numeric stuff")
|
||||||
|
}
|
||||||
|
|
||||||
// Build a frequency table
|
// Build a frequency table
|
||||||
freq := ChiMBuildFrequencyTable(attr, inst)
|
freq := ChiMBuildFrequencyTable(attr, inst)
|
||||||
// Count the number of classes
|
// Count the number of classes
|
||||||
classes := chiCountClasses(freq)
|
classes := chiCountClasses(freq)
|
||||||
for {
|
for {
|
||||||
// chiMergePrintTable(freq) DEBUG
|
|
||||||
if len(freq) <= minrows {
|
if len(freq) <= minrows {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -378,3 +112,72 @@ func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows
|
|||||||
}
|
}
|
||||||
return freq
|
return freq
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChiMergeFilter) Train() error {
|
||||||
|
as := c.getAttributeSpecs()
|
||||||
|
|
||||||
|
for _, a := range as {
|
||||||
|
|
||||||
|
attr := a.GetAttribute()
|
||||||
|
|
||||||
|
// Skip if not set
|
||||||
|
if !c.attrs[attr] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build sort order
|
||||||
|
sortOrder := []base.AttributeSpec{a}
|
||||||
|
|
||||||
|
// Sort
|
||||||
|
sorted, err := base.LazySort(c.train, base.Ascending, sortOrder)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform ChiMerge
|
||||||
|
freq := chiMerge(sorted, attr, c.Significance, c.MinRows, c.MaxRows)
|
||||||
|
c.tables[attr] = freq
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAttributesAfterFiltering gets a list of before/after
|
||||||
|
// Attributes as base.FilteredAttributes
|
||||||
|
func (c *ChiMergeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||||
|
oldAttrs := c.train.AllAttributes()
|
||||||
|
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
||||||
|
for i, a := range oldAttrs {
|
||||||
|
if c.attrs[a] {
|
||||||
|
retAttr := new(base.CategoricalAttribute)
|
||||||
|
retAttr.SetName(a.GetName())
|
||||||
|
for _, k := range c.tables[a] {
|
||||||
|
retAttr.GetSysValFromString(fmt.Sprintf("%f", k.Value))
|
||||||
|
}
|
||||||
|
ret[i] = base.FilteredAttribute{a, retAttr}
|
||||||
|
} else {
|
||||||
|
ret[i] = base.FilteredAttribute{a, a}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChiMergeFilter) Transform(a base.Attribute, field []byte) []byte {
|
||||||
|
// Do we use this Attribute?
|
||||||
|
if !c.attrs[a] {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
// Find the Attribute value in the table
|
||||||
|
table := c.tables[a]
|
||||||
|
dis := 0
|
||||||
|
val := base.UnpackBytesToFloat(field)
|
||||||
|
for j, k := range table {
|
||||||
|
if k.Value < val {
|
||||||
|
dis = j
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return base.PackU64ToBytes(uint64(dis))
|
||||||
|
}
|
||||||
|
14
filters/chimerge_freq.go
Normal file
14
filters/chimerge_freq.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FrequencyTableEntry struct {
|
||||||
|
Value float64
|
||||||
|
Frequency map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *FrequencyTableEntry) String() string {
|
||||||
|
return fmt.Sprintf("%.2f %s", t.Value, t.Frequency)
|
||||||
|
}
|
205
filters/chimerge_funcs.go
Normal file
205
filters/chimerge_funcs.go
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sjwhitworth/golearn/base"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChiMBuildFrequencyTable(attr base.Attribute, inst base.FixedDataGrid) []*FrequencyTableEntry {
|
||||||
|
ret := make([]*FrequencyTableEntry, 0)
|
||||||
|
attribute := attr.(*base.FloatAttribute)
|
||||||
|
|
||||||
|
attrSpec, err := inst.GetAttribute(attr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
attrSpecs := []base.AttributeSpec{attrSpec}
|
||||||
|
|
||||||
|
err = inst.MapOverRows(attrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||||
|
value := row[0]
|
||||||
|
valueConv := attribute.GetFloatFromSysVal(value)
|
||||||
|
class := base.GetClass(inst, rowNo)
|
||||||
|
// Search the frequency table for the value
|
||||||
|
found := false
|
||||||
|
for _, entry := range ret {
|
||||||
|
if entry.Value == valueConv {
|
||||||
|
found = true
|
||||||
|
entry.Frequency[class] += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
newEntry := &FrequencyTableEntry{
|
||||||
|
valueConv,
|
||||||
|
make(map[string]int),
|
||||||
|
}
|
||||||
|
newEntry.Frequency[class] = 1
|
||||||
|
ret = append(ret, newEntry)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func chiSquaredPdf(k float64, x float64) float64 {
|
||||||
|
if x < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
top := math.Pow(x, (k/2)-1) * math.Exp(-x/2)
|
||||||
|
bottom := math.Pow(2, k/2) * math.Gamma(k/2)
|
||||||
|
return top / bottom
|
||||||
|
}
|
||||||
|
|
||||||
|
func chiSquaredPercentile(k int, x float64) float64 {
|
||||||
|
// Implements Yahya et al.'s "A Numerical Procedure
|
||||||
|
// for Computing Chi-Square Percentage Points"
|
||||||
|
// InterStat Journal 01/2007; April 25:page:1-8.
|
||||||
|
steps := 32
|
||||||
|
intervals := 4 * steps
|
||||||
|
w := x / (4.0 * float64(steps))
|
||||||
|
values := make([]float64, intervals+1)
|
||||||
|
for i := 0; i < intervals+1; i++ {
|
||||||
|
c := w * float64(i)
|
||||||
|
v := chiSquaredPdf(float64(k), c)
|
||||||
|
values[i] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
ret1 := values[0] + values[len(values)-1]
|
||||||
|
ret2 := 0.0
|
||||||
|
ret3 := 0.0
|
||||||
|
ret4 := 0.0
|
||||||
|
|
||||||
|
for i := 2; i < intervals-1; i += 4 {
|
||||||
|
ret2 += values[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 4; i < intervals-3; i += 4 {
|
||||||
|
ret3 += values[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < intervals; i += 2 {
|
||||||
|
ret4 += values[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return (2.0 * w / 45) * (7*ret1 + 12*ret2 + 14*ret3 + 32*ret4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func chiCountClasses(entries []*FrequencyTableEntry) map[string]int {
|
||||||
|
classCounter := make(map[string]int)
|
||||||
|
for _, e := range entries {
|
||||||
|
for k := range e.Frequency {
|
||||||
|
classCounter[k] += e.Frequency[k]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return classCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
func chiComputeStatistic(entry1 *FrequencyTableEntry, entry2 *FrequencyTableEntry) float64 {
|
||||||
|
|
||||||
|
// Sum the number of things observed per class
|
||||||
|
classCounter := make(map[string]int)
|
||||||
|
for k := range entry1.Frequency {
|
||||||
|
classCounter[k] += entry1.Frequency[k]
|
||||||
|
}
|
||||||
|
for k := range entry2.Frequency {
|
||||||
|
classCounter[k] += entry2.Frequency[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum the number of things observed per value
|
||||||
|
entryObservations1 := 0
|
||||||
|
entryObservations2 := 0
|
||||||
|
for k := range entry1.Frequency {
|
||||||
|
entryObservations1 += entry1.Frequency[k]
|
||||||
|
}
|
||||||
|
for k := range entry2.Frequency {
|
||||||
|
entryObservations2 += entry2.Frequency[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
totalObservations := entryObservations1 + entryObservations2
|
||||||
|
// Compute the expected values per class
|
||||||
|
expectedClassValues1 := make(map[string]float64)
|
||||||
|
expectedClassValues2 := make(map[string]float64)
|
||||||
|
for k := range classCounter {
|
||||||
|
expectedClassValues1[k] = float64(classCounter[k])
|
||||||
|
expectedClassValues1[k] *= float64(entryObservations1)
|
||||||
|
expectedClassValues1[k] /= float64(totalObservations)
|
||||||
|
}
|
||||||
|
for k := range classCounter {
|
||||||
|
expectedClassValues2[k] = float64(classCounter[k])
|
||||||
|
expectedClassValues2[k] *= float64(entryObservations2)
|
||||||
|
expectedClassValues2[k] /= float64(totalObservations)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute chi-squared value
|
||||||
|
chiSum := 0.0
|
||||||
|
for k := range expectedClassValues1 {
|
||||||
|
numerator := float64(entry1.Frequency[k])
|
||||||
|
numerator -= expectedClassValues1[k]
|
||||||
|
numerator = math.Pow(numerator, 2)
|
||||||
|
denominator := float64(expectedClassValues1[k])
|
||||||
|
if denominator < 0.5 {
|
||||||
|
denominator = 0.5
|
||||||
|
}
|
||||||
|
chiSum += numerator / denominator
|
||||||
|
}
|
||||||
|
for k := range expectedClassValues2 {
|
||||||
|
numerator := float64(entry2.Frequency[k])
|
||||||
|
numerator -= expectedClassValues2[k]
|
||||||
|
numerator = math.Pow(numerator, 2)
|
||||||
|
denominator := float64(expectedClassValues2[k])
|
||||||
|
if denominator < 0.5 {
|
||||||
|
denominator = 0.5
|
||||||
|
}
|
||||||
|
chiSum += numerator / denominator
|
||||||
|
}
|
||||||
|
|
||||||
|
return chiSum
|
||||||
|
}
|
||||||
|
|
||||||
|
func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*FrequencyTableEntry {
|
||||||
|
mergeEntry1 := freq[minIndex]
|
||||||
|
mergeEntry2 := freq[minIndex+1]
|
||||||
|
classCounter := make(map[string]int)
|
||||||
|
for k := range mergeEntry1.Frequency {
|
||||||
|
classCounter[k] += mergeEntry1.Frequency[k]
|
||||||
|
}
|
||||||
|
for k := range mergeEntry2.Frequency {
|
||||||
|
classCounter[k] += mergeEntry2.Frequency[k]
|
||||||
|
}
|
||||||
|
newVal := freq[minIndex].Value
|
||||||
|
newEntry := &FrequencyTableEntry{
|
||||||
|
newVal,
|
||||||
|
classCounter,
|
||||||
|
}
|
||||||
|
lowerSlice := freq
|
||||||
|
upperSlice := freq
|
||||||
|
if minIndex > 0 {
|
||||||
|
lowerSlice = freq[0:minIndex]
|
||||||
|
upperSlice = freq[minIndex+1:]
|
||||||
|
} else {
|
||||||
|
lowerSlice = make([]*FrequencyTableEntry, 0)
|
||||||
|
upperSlice = freq[1:]
|
||||||
|
}
|
||||||
|
upperSlice[0] = newEntry
|
||||||
|
freq = append(lowerSlice, upperSlice...)
|
||||||
|
return freq
|
||||||
|
}
|
||||||
|
|
||||||
|
func chiMergePrintTable(freq []*FrequencyTableEntry) {
|
||||||
|
classes := chiCountClasses(freq)
|
||||||
|
fmt.Printf("Attribute value\t")
|
||||||
|
for k := range classes {
|
||||||
|
fmt.Printf("\t%s", k)
|
||||||
|
}
|
||||||
|
fmt.Printf("\tTotal\n")
|
||||||
|
for _, f := range freq {
|
||||||
|
fmt.Printf("%.2f\t", f.Value)
|
||||||
|
total := 0
|
||||||
|
for k := range classes {
|
||||||
|
fmt.Printf("\t%d", f.Frequency[k])
|
||||||
|
total += f.Frequency[k]
|
||||||
|
}
|
||||||
|
fmt.Printf("\t%d\n", total)
|
||||||
|
}
|
||||||
|
}
|
@ -14,7 +14,7 @@ func TestChiMFreqTable(testEnv *testing.T) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
freq := ChiMBuildFrequencyTable(0, inst)
|
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||||
|
|
||||||
if freq[0].Frequency["c1"] != 1 {
|
if freq[0].Frequency["c1"] != 1 {
|
||||||
testEnv.Error("Wrong frequency")
|
testEnv.Error("Wrong frequency")
|
||||||
@ -32,7 +32,7 @@ func TestChiClassCounter(testEnv *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
freq := ChiMBuildFrequencyTable(0, inst)
|
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||||
classes := chiCountClasses(freq)
|
classes := chiCountClasses(freq)
|
||||||
if classes["c1"] != 27 {
|
if classes["c1"] != 27 {
|
||||||
testEnv.Error(classes)
|
testEnv.Error(classes)
|
||||||
@ -50,7 +50,7 @@ func TestStatisticValues(testEnv *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
freq := ChiMBuildFrequencyTable(0, inst)
|
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||||
chiVal := chiComputeStatistic(freq[5], freq[6])
|
chiVal := chiComputeStatistic(freq[5], freq[6])
|
||||||
if math.Abs(chiVal-1.89) > 0.01 {
|
if math.Abs(chiVal-1.89) > 0.01 {
|
||||||
testEnv.Error(chiVal)
|
testEnv.Error(chiVal)
|
||||||
@ -78,12 +78,15 @@ func TestChiSquareDistValues(testEnv *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChiMerge1(testEnv *testing.T) {
|
func TestChiMerge1(testEnv *testing.T) {
|
||||||
// See Bramer, Principles of Machine Learning
|
|
||||||
|
// Read the data
|
||||||
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
freq := chiMerge(inst, 0, 0.90, 0, inst.Rows)
|
_, rows := inst.Size()
|
||||||
|
|
||||||
|
freq := chiMerge(inst, inst.AllAttributes()[0], 0.90, 0, rows)
|
||||||
if len(freq) != 3 {
|
if len(freq) != 3 {
|
||||||
testEnv.Error("Wrong length")
|
testEnv.Error("Wrong length")
|
||||||
}
|
}
|
||||||
@ -106,10 +109,18 @@ func TestChiMerge2(testEnv *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
attrs := make([]int, 1)
|
|
||||||
attrs[0] = 0
|
// Sort the instances
|
||||||
inst.Sort(base.Ascending, attrs)
|
allAttrs := inst.AllAttributes()
|
||||||
freq := chiMerge(inst, 0, 0.90, 0, inst.Rows)
|
sortAttrSpecs := base.ResolveAllAttributes(inst, allAttrs)[0:1]
|
||||||
|
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform Chi-Merge
|
||||||
|
_, rows := inst.Size()
|
||||||
|
freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows)
|
||||||
if len(freq) != 5 {
|
if len(freq) != 5 {
|
||||||
testEnv.Errorf("Wrong length (%d)", len(freq))
|
testEnv.Errorf("Wrong length (%d)", len(freq))
|
||||||
testEnv.Error(freq)
|
testEnv.Error(freq)
|
||||||
@ -131,6 +142,7 @@ func TestChiMerge2(testEnv *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
func TestChiMerge3(testEnv *testing.T) {
|
func TestChiMerge3(testEnv *testing.T) {
|
||||||
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
||||||
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
||||||
@ -138,12 +150,52 @@ func TestChiMerge3(testEnv *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
attrs := make([]int, 1)
|
|
||||||
attrs[0] = 0
|
insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes()))
|
||||||
inst.Sort(base.Ascending, attrs)
|
if err != nil {
|
||||||
filt := NewChiMergeFilter(inst, 0.90)
|
testEnv.Error(err)
|
||||||
filt.AddAttribute(inst.GetAttr(0))
|
}
|
||||||
filt.Build()
|
filt := NewChiMergeFilter(inst, 0.90)
|
||||||
filt.Run(inst)
|
filt.AddAttribute(inst.AllAttributes()[0])
|
||||||
fmt.Println(inst)
|
filt.Train()
|
||||||
|
instf := base.NewLazilyFilteredInstances(insts, filt)
|
||||||
|
fmt.Println(instf)
|
||||||
|
fmt.Println(instf.String())
|
||||||
|
rowStr := instf.RowString(0)
|
||||||
|
ref := "4.300000 3.00 1.10 0.10 Iris-setosa"
|
||||||
|
if rowStr != ref {
|
||||||
|
panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref))
|
||||||
|
}
|
||||||
|
clsAttrs := instf.AllClassAttributes()
|
||||||
|
if len(clsAttrs) != 1 {
|
||||||
|
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
|
||||||
|
}
|
||||||
|
if clsAttrs[0].GetName() != "Species" {
|
||||||
|
panic("Class Attribute wrong!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
func TestChiMerge4(testEnv *testing.T) {
|
||||||
|
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
||||||
|
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
||||||
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filt := NewChiMergeFilter(inst, 0.90)
|
||||||
|
filt.AddAttribute(inst.AllAttributes()[0])
|
||||||
|
filt.AddAttribute(inst.AllAttributes()[1])
|
||||||
|
filt.Train()
|
||||||
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
|
fmt.Println(instf)
|
||||||
|
fmt.Println(instf.String())
|
||||||
|
clsAttrs := instf.AllClassAttributes()
|
||||||
|
if len(clsAttrs) != 1 {
|
||||||
|
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
|
||||||
|
}
|
||||||
|
if clsAttrs[0].GetName() != "Species" {
|
||||||
|
panic("Class Attribute wrong!")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
62
filters/disc.go
Normal file
62
filters/disc.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
base "github.com/sjwhitworth/golearn/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AbstractDiscretizeFilter struct {
|
||||||
|
attrs map[base.Attribute]bool
|
||||||
|
trained bool
|
||||||
|
train base.FixedDataGrid
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAttribute adds the AttributeSpec of the given attribute `a'
|
||||||
|
// to the AbstractFloatFilter for discretisation.
|
||||||
|
func (d *AbstractDiscretizeFilter) AddAttribute(a base.Attribute) error {
|
||||||
|
if _, ok := a.(*base.FloatAttribute); !ok {
|
||||||
|
return fmt.Errorf("%s is not a FloatAttribute", a)
|
||||||
|
}
|
||||||
|
_, err := d.train.GetAttribute(a)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid attribute")
|
||||||
|
}
|
||||||
|
d.attrs[a] = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAttributesAfterFiltering gets a list of before/after
|
||||||
|
// Attributes as base.FilteredAttributes
|
||||||
|
func (d *AbstractDiscretizeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||||
|
oldAttrs := d.train.AllAttributes()
|
||||||
|
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
||||||
|
for i, a := range oldAttrs {
|
||||||
|
if d.attrs[a] {
|
||||||
|
retAttr := new(base.CategoricalAttribute)
|
||||||
|
retAttr.SetName(a.GetName())
|
||||||
|
ret[i] = base.FilteredAttribute{a, retAttr}
|
||||||
|
} else {
|
||||||
|
ret[i] = base.FilteredAttribute{a, a}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *AbstractDiscretizeFilter) getAttributeSpecs() []base.AttributeSpec {
|
||||||
|
as := make([]base.AttributeSpec, 0)
|
||||||
|
// Set up the AttributeSpecs, and values
|
||||||
|
for attr := range d.attrs {
|
||||||
|
// If for some reason we've un-added it...
|
||||||
|
if !d.attrs[attr] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Get the AttributeSpec for the training set
|
||||||
|
a, err := d.train.GetAttribute(attr)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("Attribute resolution error: %s", err))
|
||||||
|
}
|
||||||
|
// Append to return set
|
||||||
|
as = append(as, a)
|
||||||
|
}
|
||||||
|
return as
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user