1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00
golearn/filters/chimerge.go
2014-08-22 07:21:24 +00:00

189 lines
4.7 KiB
Go

package filters
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"math"
)
// ChiMergeFilter implements supervised discretisation
// by merging successive numeric intervals if the difference
// in their class distribution is not statistically signficant.
// See Bramer, "Principles of Data Mining", 2nd Edition
// pp 105--115
type ChiMergeFilter struct {
AbstractDiscretizeFilter
tables map[base.Attribute][]*FrequencyTableEntry
Significance float64
MinRows int
MaxRows int
}
// NewChiMergeFilter creates a ChiMergeFilter with some helpful intialisations.
func NewChiMergeFilter(d base.FixedDataGrid, significance float64) *ChiMergeFilter {
_, rows := d.Size()
return &ChiMergeFilter{
AbstractDiscretizeFilter{
make(map[base.Attribute]bool),
false,
d,
},
make(map[base.Attribute][]*FrequencyTableEntry),
significance,
2,
rows,
}
}
// Train computes and stores the
// Produces a value mapping table
// inst: The base.Instances which need discretising
// sig: The significance level (e.g. 0.95)
// minrows: The minimum number of rows required in the frequency table
// maxrows: The maximum number of rows allowed in the frequency table
// If the number of rows is above this, statistically signficant
// adjacent rows will be merged
// precision: internal number of decimal places to round E value to
// (useful for verification)
func chiMerge(inst base.FixedDataGrid, attr base.Attribute, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
// Parameter sanity checking
if !(2 <= minrows) {
minrows = 2
}
if !(minrows < maxrows) {
maxrows = minrows + 1
}
if sig == 0 {
sig = 10
}
// Check that the attribute is numeric
_, ok := attr.(*base.FloatAttribute)
if !ok {
panic("only use Chi-M on numeric stuff")
}
// Build a frequency table
freq := ChiMBuildFrequencyTable(attr, inst)
// Count the number of classes
classes := chiCountClasses(freq)
for {
if len(freq) <= minrows {
break
}
minChiVal := math.Inf(1)
// There may be more than one index to merge
minChiIndexes := make([]int, 0)
for i := 0; i < len(freq)-1; i++ {
chiVal := chiComputeStatistic(freq[i], freq[i+1])
if chiVal < minChiVal {
minChiVal = chiVal
minChiIndexes = make([]int, 0)
}
if chiVal == minChiVal {
minChiIndexes = append(minChiIndexes, i)
}
}
// Only merge if:
// We're above the maximum number of rows
// OR the chiVal is significant
// AS LONG AS we're above the minimum row count
merge := false
if len(freq) > maxrows {
merge = true
}
// Compute the degress of freedom |classes - 1| * |rows - 1|
degsOfFree := len(classes) - 1
sigVal := chiSquaredPercentile(degsOfFree, minChiVal)
if sigVal < sig {
merge = true
}
// If we don't need to merge, then break
if !merge {
break
}
// Otherwise merge the rows i, i+1 by taking
// The higher of the two things as the value
// Combining the class frequencies
for i, v := range minChiIndexes {
freq = chiMergeMergeZipAdjacent(freq, v-i)
}
}
return freq
}
func (c *ChiMergeFilter) Train() error {
as := c.getAttributeSpecs()
for _, a := range as {
attr := a.GetAttribute()
// Skip if not set
if !c.attrs[attr] {
continue
}
// Build sort order
sortOrder := []base.AttributeSpec{a}
// Sort
sorted, err := base.LazySort(c.train, base.Ascending, sortOrder)
if err != nil {
panic(err)
}
// Perform ChiMerge
freq := chiMerge(sorted, attr, c.Significance, c.MinRows, c.MaxRows)
c.tables[attr] = freq
}
return nil
}
// GetAttributesAfterFiltering gets a list of before/after
// Attributes as base.FilteredAttributes
func (c *ChiMergeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
oldAttrs := c.train.AllAttributes()
ret := make([]base.FilteredAttribute, len(oldAttrs))
for i, a := range oldAttrs {
if c.attrs[a] {
retAttr := new(base.CategoricalAttribute)
retAttr.SetName(a.GetName())
for _, k := range c.tables[a] {
retAttr.GetSysValFromString(fmt.Sprintf("%f", k.Value))
}
ret[i] = base.FilteredAttribute{a, retAttr}
} else {
ret[i] = base.FilteredAttribute{a, a}
}
}
return ret
}
// Transform returns the byte sequence after discretisation
func (c *ChiMergeFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
// Do we use this Attribute?
if !c.attrs[a] {
return field
}
// Find the Attribute value in the table
table := c.tables[a]
dis := 0
val := base.UnpackBytesToFloat(field)
for j, k := range table {
if k.Value < val {
dis = j
continue
}
break
}
return base.PackU64ToBytes(uint64(dis))
}
func (c *ChiMergeFilter) String() string {
return fmt.Sprintf("ChiMergeFilter(%d Attributes, %.2f Significance)", len(c.tables), c.Significance)
}