mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
189 lines
4.7 KiB
Go
189 lines
4.7 KiB
Go
package filters
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"math"
|
|
)
|
|
|
|
// ChiMergeFilter implements supervised discretisation
|
|
// by merging successive numeric intervals if the difference
|
|
// in their class distribution is not statistically signficant.
|
|
// See Bramer, "Principles of Data Mining", 2nd Edition
|
|
// pp 105--115
|
|
type ChiMergeFilter struct {
|
|
AbstractDiscretizeFilter
|
|
tables map[base.Attribute][]*FrequencyTableEntry
|
|
Significance float64
|
|
MinRows int
|
|
MaxRows int
|
|
}
|
|
|
|
// NewChiMergeFilter creates a ChiMergeFilter with some helpful intialisations.
|
|
func NewChiMergeFilter(d base.FixedDataGrid, significance float64) *ChiMergeFilter {
|
|
_, rows := d.Size()
|
|
return &ChiMergeFilter{
|
|
AbstractDiscretizeFilter{
|
|
make(map[base.Attribute]bool),
|
|
false,
|
|
d,
|
|
},
|
|
make(map[base.Attribute][]*FrequencyTableEntry),
|
|
significance,
|
|
2,
|
|
rows,
|
|
}
|
|
}
|
|
|
|
// Train computes and stores the
|
|
// Produces a value mapping table
|
|
// inst: The base.Instances which need discretising
|
|
// sig: The significance level (e.g. 0.95)
|
|
// minrows: The minimum number of rows required in the frequency table
|
|
// maxrows: The maximum number of rows allowed in the frequency table
|
|
// If the number of rows is above this, statistically signficant
|
|
// adjacent rows will be merged
|
|
// precision: internal number of decimal places to round E value to
|
|
// (useful for verification)
|
|
func chiMerge(inst base.FixedDataGrid, attr base.Attribute, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
|
|
|
|
// Parameter sanity checking
|
|
if !(2 <= minrows) {
|
|
minrows = 2
|
|
}
|
|
if !(minrows < maxrows) {
|
|
maxrows = minrows + 1
|
|
}
|
|
if sig == 0 {
|
|
sig = 10
|
|
}
|
|
|
|
// Check that the attribute is numeric
|
|
_, ok := attr.(*base.FloatAttribute)
|
|
if !ok {
|
|
panic("only use Chi-M on numeric stuff")
|
|
}
|
|
|
|
// Build a frequency table
|
|
freq := ChiMBuildFrequencyTable(attr, inst)
|
|
// Count the number of classes
|
|
classes := chiCountClasses(freq)
|
|
for {
|
|
if len(freq) <= minrows {
|
|
break
|
|
}
|
|
minChiVal := math.Inf(1)
|
|
// There may be more than one index to merge
|
|
minChiIndexes := make([]int, 0)
|
|
for i := 0; i < len(freq)-1; i++ {
|
|
chiVal := chiComputeStatistic(freq[i], freq[i+1])
|
|
if chiVal < minChiVal {
|
|
minChiVal = chiVal
|
|
minChiIndexes = make([]int, 0)
|
|
}
|
|
if chiVal == minChiVal {
|
|
minChiIndexes = append(minChiIndexes, i)
|
|
}
|
|
}
|
|
// Only merge if:
|
|
// We're above the maximum number of rows
|
|
// OR the chiVal is significant
|
|
// AS LONG AS we're above the minimum row count
|
|
merge := false
|
|
if len(freq) > maxrows {
|
|
merge = true
|
|
}
|
|
// Compute the degress of freedom |classes - 1| * |rows - 1|
|
|
degsOfFree := len(classes) - 1
|
|
sigVal := chiSquaredPercentile(degsOfFree, minChiVal)
|
|
if sigVal < sig {
|
|
merge = true
|
|
}
|
|
// If we don't need to merge, then break
|
|
if !merge {
|
|
break
|
|
}
|
|
// Otherwise merge the rows i, i+1 by taking
|
|
// The higher of the two things as the value
|
|
// Combining the class frequencies
|
|
for i, v := range minChiIndexes {
|
|
freq = chiMergeMergeZipAdjacent(freq, v-i)
|
|
}
|
|
}
|
|
return freq
|
|
}
|
|
|
|
func (c *ChiMergeFilter) Train() error {
|
|
as := c.getAttributeSpecs()
|
|
|
|
for _, a := range as {
|
|
|
|
attr := a.GetAttribute()
|
|
|
|
// Skip if not set
|
|
if !c.attrs[attr] {
|
|
continue
|
|
}
|
|
|
|
// Build sort order
|
|
sortOrder := []base.AttributeSpec{a}
|
|
|
|
// Sort
|
|
sorted, err := base.LazySort(c.train, base.Ascending, sortOrder)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Perform ChiMerge
|
|
freq := chiMerge(sorted, attr, c.Significance, c.MinRows, c.MaxRows)
|
|
c.tables[attr] = freq
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAttributesAfterFiltering gets a list of before/after
|
|
// Attributes as base.FilteredAttributes
|
|
func (c *ChiMergeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
|
oldAttrs := c.train.AllAttributes()
|
|
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
|
for i, a := range oldAttrs {
|
|
if c.attrs[a] {
|
|
retAttr := new(base.CategoricalAttribute)
|
|
retAttr.SetName(a.GetName())
|
|
for _, k := range c.tables[a] {
|
|
retAttr.GetSysValFromString(fmt.Sprintf("%f", k.Value))
|
|
}
|
|
ret[i] = base.FilteredAttribute{a, retAttr}
|
|
} else {
|
|
ret[i] = base.FilteredAttribute{a, a}
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// Transform returns the byte sequence after discretisation
|
|
func (c *ChiMergeFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
|
|
// Do we use this Attribute?
|
|
if !c.attrs[a] {
|
|
return field
|
|
}
|
|
// Find the Attribute value in the table
|
|
table := c.tables[a]
|
|
dis := 0
|
|
val := base.UnpackBytesToFloat(field)
|
|
for j, k := range table {
|
|
if k.Value < val {
|
|
dis = j
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
|
|
return base.PackU64ToBytes(uint64(dis))
|
|
}
|
|
|
|
func (c *ChiMergeFilter) String() string {
|
|
return fmt.Sprintf("ChiMergeFilter(%d Attributes, %.2f Significance)", len(c.tables), c.Significance)
|
|
}
|