diff --git a/filters/binning.go b/filters/binning.go new file mode 100644 index 0000000..157896e --- /dev/null +++ b/filters/binning.go @@ -0,0 +1,121 @@ +package filters + +import ( + "fmt" + base "github.com/sjwhitworth/golearn/base" + "math" +) + +// BinningFilter does equal-width binning for numeric +// Attributes (aka "histogram binning") +type BinningFilter struct { + Attributes []int + Instances *base.Instances + BinCount int + MinVals map[int]float64 + MaxVals map[int]float64 + trained bool +} + +// NewBinningFilter creates a BinningFilter structure +// with some helpful default initialisations. +func NewBinningFilter(inst *base.Instances, bins int) BinningFilter { + return BinningFilter{ + make([]int, 0), + inst, + bins, + make(map[int]float64), + make(map[int]float64), + false, + } +} + +// AddAttribute adds the index of the given attribute `a' +// to the BinningFilter for discretisation. +func (b *BinningFilter) AddAttribute(a base.Attribute) { + attrIndex := b.Instances.GetAttrIndex(a) + if attrIndex == -1 { + panic("invalid attribute") + } + b.Attributes = append(b.Attributes, attrIndex) +} + +// AddAllNumericAttributes adds every suitable attribute +// to the BinningFilter for discretiation +func (b *BinningFilter) AddAllNumericAttributes() { + for i := 0; i < b.Instances.Cols; i++ { + if i == b.Instances.ClassIndex { + continue + } + attr := b.Instances.GetAttr(i) + if attr.GetType() != base.Float64Type { + continue + } + b.Attributes = append(b.Attributes, i) + } +} + +// Build computes and stores the bin values +// for the training instances. +func (b *BinningFilter) Build() { + for _, attr := range b.Attributes { + maxVal := math.Inf(-1) + minVal := math.Inf(1) + for i := 0; i < b.Instances.Rows; i++ { + val := b.Instances.Get(i, attr) + if val > maxVal { + maxVal = val + } + if val < minVal { + minVal = val + } + } + b.MaxVals[attr] = maxVal + b.MinVals[attr] = minVal + b.trained = true + } +} + +// Run applies a trained BinningFilter to a set of Instances, +// discretising any numeric attributes added. +// +// IMPORTANT: Run discretises in-place, so make sure to take +// a copy if the original instances are still needed +// +// IMPORTANT: This function panic()s if the filter has not been +// trained. Call Build() before running this function +// +// IMPORTANT: Call Build() after adding any additional attributes. +// Otherwise, the training structure will be out of date from +// the values expected and could cause a panic. +func (b *BinningFilter) Run(on *base.Instances) { + if !b.trained { + panic("Call Build() beforehand") + } + for attr := range b.Attributes { + minVal := b.MinVals[attr] + maxVal := b.MaxVals[attr] + disc := 0 + // Casts to float32 to replicate a floating point precision error + delta := float32(maxVal - minVal) + delta /= float32(b.BinCount) + for i := 0; i < on.Rows; i++ { + val := on.Get(i, attr) + if val <= minVal { + disc = 0 + } else { + disc = int(math.Floor(float64(float32(val-minVal) / delta))) + if disc >= b.BinCount { + disc = b.BinCount - 1 + } + } + on.Set(i, attr, float64(disc)) + } + newAttribute := new(base.CategoricalAttribute) + newAttribute.SetName(on.GetAttr(attr).GetName()) + for i := 0; i < b.BinCount; i++ { + newAttribute.GetSysValFromString(fmt.Sprintf("%d", i)) + } + on.ReplaceAttr(attr, newAttribute) + } +} diff --git a/filters/binning_test.go b/filters/binning_test.go new file mode 100644 index 0000000..ad6a01c --- /dev/null +++ b/filters/binning_test.go @@ -0,0 +1,28 @@ +package filters + +import ( + base "github.com/sjwhitworth/golearn/base" + "math" + "testing" +) + +func TestBinning(testEnv *testing.T) { + inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true) + inst3, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + filt := NewBinningFilter(inst1, 10) + filt.AddAttribute(inst1.GetAttr(0)) + filt.Build() + filt.Run(inst1) + for i := 0; i < inst1.Rows; i++ { + val1 := inst1.Get(i, 0) + val2 := inst2.Get(i, 0) + val3 := inst3.Get(i, 0) + if math.Abs(val1-val2) >= 1 { + testEnv.Error(val1, val2, val3, i) + } + } +} diff --git a/filters/chimerge.go b/filters/chimerge.go new file mode 100644 index 0000000..5339add --- /dev/null +++ b/filters/chimerge.go @@ -0,0 +1,365 @@ +package filters + +import ( + "fmt" + base "github.com/sjwhitworth/golearn/base" + "math" +) + +// ChiMergeFilter implements supervised discretisation +// by merging successive numeric intervals if the difference +// in their class distribution is not statistically signficant. +// See Bramer, "Principles of Data Mining", 2nd Edition +// pp 105--115 +type ChiMergeFilter struct { + Attributes []int + Instances *base.Instances + Tables map[int][]*FrequencyTableEntry + Significance float64 + MinRows int + MaxRows int + _Trained bool +} + +// Create a ChiMergeFilter with some helpful intialisations. +func NewChiMergeFilter(inst *base.Instances, significance float64) ChiMergeFilter { + return ChiMergeFilter{ + make([]int, 0), + inst, + make(map[int][]*FrequencyTableEntry), + significance, + 0, + 0, + false, + } +} + +// Build trains a ChiMergeFilter on the ChiMergeFilter.Instances given +func (c *ChiMergeFilter) Build() { + for _, attr := range c.Attributes { + tab := chiMerge(c.Instances, attr, c.Significance, c.MinRows, c.MaxRows) + c.Tables[attr] = tab + c._Trained = true + } +} + +// Run discretises the set of Instances `on' +// +// IMPORTANT: ChiMergeFilter discretises in place. +func (c *ChiMergeFilter) Run(on *base.Instances) { + if !c._Trained { + panic("Call Build() beforehand") + } + for attr := range c.Tables { + table := c.Tables[attr] + for i := 0; i < on.Rows; i++ { + val := on.Get(i, attr) + dis := 0 + for j, k := range table { + if k.Value < val { + dis = j + continue + } + break + } + on.Set(i, attr, float64(dis)) + } + newAttribute := new(base.CategoricalAttribute) + newAttribute.SetName(on.GetAttr(attr).GetName()) + for _, k := range table { + newAttribute.GetSysValFromString(fmt.Sprintf("%f", k.Value)) + } + on.ReplaceAttr(attr, newAttribute) + } +} + +// AddAttribute add a given numeric Attribute `attr' to the +// filter. +// +// IMPORTANT: This function panic()s if it can't locate the +// attribute in the Instances set. +func (c *ChiMergeFilter) AddAttribute(attr base.Attribute) { + if attr.GetType() != base.Float64Type { + panic("ChiMerge only works on Float64Attributes") + } + attrIndex := c.Instances.GetAttrIndex(attr) + if attrIndex == -1 { + panic("Invalid attribute!") + } + c.Attributes = append(c.Attributes, attrIndex) +} + +type FrequencyTableEntry struct { + Value float64 + Frequency map[string]int +} + +func (t *FrequencyTableEntry) String() string { + return fmt.Sprintf("%.2f %s", t.Value, t.Frequency) +} + +func ChiMBuildFrequencyTable(attr int, inst *base.Instances) []*FrequencyTableEntry { + ret := make([]*FrequencyTableEntry, 0) + var attribute *base.FloatAttribute + attribute, ok := inst.GetAttr(attr).(*base.FloatAttribute) + if !ok { + panic("only use Chi-M on numeric stuff") + } + for i := 0; i < inst.Rows; i++ { + value := inst.Get(i, attr) + valueConv := attribute.GetUsrVal(value) + class := inst.GetClass(i) + // Search the frequency table for the value + found := false + for _, entry := range ret { + if entry.Value == valueConv { + found = true + entry.Frequency[class] += 1 + } + } + if !found { + newEntry := &FrequencyTableEntry{ + valueConv, + make(map[string]int), + } + newEntry.Frequency[class] = 1 + ret = append(ret, newEntry) + } + } + + return ret +} + +func chiSquaredPdf(k float64, x float64) float64 { + if x < 0 { + return 0 + } + top := math.Pow(x, (k/2)-1) * math.Exp(-x/2) + bottom := math.Pow(2, k/2) * math.Gamma(k/2) + return top / bottom +} + +func chiSquaredPercentile(k int, x float64) float64 { + // Implements Yahya et al.'s "A Numerical Procedure + // for Computing Chi-Square Percentage Points" + // InterStat Journal 01/2007; April 25:page:1-8. + steps := 32 + intervals := 4 * steps + w := x / (4.0 * float64(steps)) + values := make([]float64, intervals+1) + for i := 0; i < intervals+1; i++ { + c := w * float64(i) + v := chiSquaredPdf(float64(k), c) + values[i] = v + } + + ret1 := values[0] + values[len(values)-1] + ret2 := 0.0 + ret3 := 0.0 + ret4 := 0.0 + + for i := 2; i < intervals-1; i += 4 { + ret2 += values[i] + } + + for i := 4; i < intervals-3; i += 4 { + ret3 += values[i] + } + + for i := 1; i < intervals; i += 2 { + ret4 += values[i] + } + + return (2.0 * w / 45) * (7*ret1 + 12*ret2 + 14*ret3 + 32*ret4) +} + +func chiCountClasses(entries []*FrequencyTableEntry) map[string]int { + classCounter := make(map[string]int) + for _, e := range entries { + for k := range e.Frequency { + classCounter[k] += e.Frequency[k] + } + } + return classCounter +} + +func chiComputeStatistic(entry1 *FrequencyTableEntry, entry2 *FrequencyTableEntry) float64 { + + // Sum the number of things observed per class + classCounter := make(map[string]int) + for k := range entry1.Frequency { + classCounter[k] += entry1.Frequency[k] + } + for k := range entry2.Frequency { + classCounter[k] += entry2.Frequency[k] + } + + // Sum the number of things observed per value + entryObservations1 := 0 + entryObservations2 := 0 + for k := range entry1.Frequency { + entryObservations1 += entry1.Frequency[k] + } + for k := range entry2.Frequency { + entryObservations2 += entry2.Frequency[k] + } + + totalObservations := entryObservations1 + entryObservations2 + // Compute the expected values per class + expectedClassValues1 := make(map[string]float64) + expectedClassValues2 := make(map[string]float64) + for k := range classCounter { + expectedClassValues1[k] = float64(classCounter[k]) + expectedClassValues1[k] *= float64(entryObservations1) + expectedClassValues1[k] /= float64(totalObservations) + } + for k := range classCounter { + expectedClassValues2[k] = float64(classCounter[k]) + expectedClassValues2[k] *= float64(entryObservations2) + expectedClassValues2[k] /= float64(totalObservations) + } + + // Compute chi-squared value + chiSum := 0.0 + for k := range expectedClassValues1 { + numerator := float64(entry1.Frequency[k]) + numerator -= expectedClassValues1[k] + numerator = math.Pow(numerator, 2) + denominator := float64(expectedClassValues1[k]) + if denominator < 0.5 { + denominator = 0.5 + } + chiSum += numerator / denominator + } + for k := range expectedClassValues2 { + numerator := float64(entry2.Frequency[k]) + numerator -= expectedClassValues2[k] + numerator = math.Pow(numerator, 2) + denominator := float64(expectedClassValues2[k]) + if denominator < 0.5 { + denominator = 0.5 + } + chiSum += numerator / denominator + } + + return chiSum +} + +func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*FrequencyTableEntry { + mergeEntry1 := freq[minIndex] + mergeEntry2 := freq[minIndex+1] + classCounter := make(map[string]int) + for k := range mergeEntry1.Frequency { + classCounter[k] += mergeEntry1.Frequency[k] + } + for k := range mergeEntry2.Frequency { + classCounter[k] += mergeEntry2.Frequency[k] + } + newVal := freq[minIndex].Value + newEntry := &FrequencyTableEntry{ + newVal, + classCounter, + } + lowerSlice := freq + upperSlice := freq + if minIndex > 0 { + lowerSlice = freq[0:minIndex] + upperSlice = freq[minIndex+1:] + } else { + lowerSlice = make([]*FrequencyTableEntry, 0) + upperSlice = freq[1:] + } + upperSlice[0] = newEntry + freq = append(lowerSlice, upperSlice...) + return freq +} + +func chiMergePrintTable(freq []*FrequencyTableEntry) { + classes := chiCountClasses(freq) + fmt.Printf("Attribute value\t") + for k := range classes { + fmt.Printf("\t%s", k) + } + fmt.Printf("\tTotal\n") + for _, f := range freq { + fmt.Printf("%.2f\t", f.Value) + total := 0 + for k := range classes { + fmt.Printf("\t%d", f.Frequency[k]) + total += f.Frequency[k] + } + fmt.Printf("\t%d\n", total) + } +} + +// Produces a value mapping table +// inst: The base.Instances which need discretising +// sig: The significance level (e.g. 0.95) +// minrows: The minimum number of rows required in the frequency table +// maxrows: The maximum number of rows allowed in the frequency table +// If the number of rows is above this, statistically signficant +// adjacent rows will be merged +// precision: internal number of decimal places to round E value to +// (useful for verification) +func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows int) []*FrequencyTableEntry { + + // Parameter sanity checking + if !(2 <= minrows) { + minrows = 2 + } + if !(minrows < maxrows) { + maxrows = minrows + 1 + } + if sig == 0 { + sig = 10 + } + + // Build a frequency table + freq := ChiMBuildFrequencyTable(attr, inst) + // Count the number of classes + classes := chiCountClasses(freq) + for { + // chiMergePrintTable(freq) DEBUG + if len(freq) <= minrows { + break + } + minChiVal := math.Inf(1) + // There may be more than one index to merge + minChiIndexes := make([]int, 0) + for i := 0; i < len(freq)-1; i++ { + chiVal := chiComputeStatistic(freq[i], freq[i+1]) + if chiVal < minChiVal { + minChiVal = chiVal + minChiIndexes = make([]int, 0) + } + if chiVal == minChiVal { + minChiIndexes = append(minChiIndexes, i) + } + } + // Only merge if: + // We're above the maximum number of rows + // OR the chiVal is significant + // AS LONG AS we're above the minimum row count + merge := false + if len(freq) > maxrows { + merge = true + } + // Compute the degress of freedom |classes - 1| * |rows - 1| + degsOfFree := len(classes) - 1 + sigVal := chiSquaredPercentile(degsOfFree, minChiVal) + if sigVal < sig { + merge = true + } + // If we don't need to merge, then break + if !merge { + break + } + // Otherwise merge the rows i, i+1 by taking + // The higher of the two things as the value + // Combining the class frequencies + for i, v := range minChiIndexes { + freq = chiMergeMergeZipAdjacent(freq, v-i) + } + } + return freq +} diff --git a/filters/chimerge_test.go b/filters/chimerge_test.go new file mode 100644 index 0000000..0f49404 --- /dev/null +++ b/filters/chimerge_test.go @@ -0,0 +1,149 @@ +package filters + +import ( + "fmt" + base "github.com/sjwhitworth/golearn/base" + "math" + "testing" +) + +func TestChiMFreqTable(testEnv *testing.T) { + + inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) + if err != nil { + panic(err) + } + + freq := ChiMBuildFrequencyTable(0, inst) + + if freq[0].Frequency["c1"] != 1 { + testEnv.Error("Wrong frequency") + } + if freq[0].Frequency["c3"] != 4 { + testEnv.Error("Wrong frequency %s", freq[1]) + } + if freq[10].Frequency["c2"] != 1 { + testEnv.Error("Wrong frequency") + } +} + +func TestChiClassCounter(testEnv *testing.T) { + inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) + if err != nil { + panic(err) + } + freq := ChiMBuildFrequencyTable(0, inst) + classes := chiCountClasses(freq) + if classes["c1"] != 27 { + testEnv.Error(classes) + } + if classes["c2"] != 12 { + testEnv.Error(classes) + } + if classes["c3"] != 21 { + testEnv.Error(classes) + } +} + +func TestStatisticValues(testEnv *testing.T) { + inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) + if err != nil { + panic(err) + } + freq := ChiMBuildFrequencyTable(0, inst) + chiVal := chiComputeStatistic(freq[5], freq[6]) + if math.Abs(chiVal-1.89) > 0.01 { + testEnv.Error(chiVal) + } + + chiVal = chiComputeStatistic(freq[1], freq[2]) + if math.Abs(chiVal-1.08) > 0.01 { + testEnv.Error(chiVal) + } +} + +func TestChiSquareDistValues(testEnv *testing.T) { + chiVal1 := chiSquaredPercentile(2, 4.61) + chiVal2 := chiSquaredPercentile(3, 7.82) + chiVal3 := chiSquaredPercentile(4, 13.28) + if math.Abs(chiVal1-0.90) > 0.001 { + testEnv.Error(chiVal1) + } + if math.Abs(chiVal2-0.95) > 0.001 { + testEnv.Error(chiVal2) + } + if math.Abs(chiVal3-0.99) > 0.001 { + testEnv.Error(chiVal3) + } +} + +func TestChiMerge1(testEnv *testing.T) { + // See Bramer, Principles of Machine Learning + inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) + if err != nil { + panic(err) + } + freq := chiMerge(inst, 0, 0.90, 0, inst.Rows) + if len(freq) != 3 { + testEnv.Error("Wrong length") + } + if freq[0].Value != 1.3 { + testEnv.Error(freq[0]) + } + if freq[1].Value != 56.2 { + testEnv.Error(freq[1]) + } + if freq[2].Value != 87.1 { + testEnv.Error(freq[2]) + } +} + +func TestChiMerge2(testEnv *testing.T) { + // + // See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf + // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 + inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + attrs := make([]int, 1) + attrs[0] = 0 + inst.Sort(base.Ascending, attrs) + freq := chiMerge(inst, 0, 0.90, 0, inst.Rows) + if len(freq) != 5 { + testEnv.Error("Wrong length (%d)", len(freq)) + testEnv.Error(freq) + } + if freq[0].Value != 4.3 { + testEnv.Error(freq[0]) + } + if freq[1].Value != 5.5 { + testEnv.Error(freq[1]) + } + if freq[2].Value != 5.8 { + testEnv.Error(freq[2]) + } + if freq[3].Value != 6.3 { + testEnv.Error(freq[3]) + } + if freq[4].Value != 7.1 { + testEnv.Error(freq[4]) + } +} + +func TestChiMerge3(testEnv *testing.T) { + // See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf + // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 + inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + attrs := make([]int, 1) + attrs[0] = 0 + inst.Sort(base.Ascending, attrs) + filt := NewChiMergeFilter(inst, 0.90) + filt.AddAttribute(inst.GetAttr(0)) + filt.Build() + filt.Run(inst) + fmt.Println(inst) +}