1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/filters/binning.go
2014-05-13 22:45:52 +01:00

122 lines
3.1 KiB
Go

package filters
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"math"
)
// BinningFilter does equal-width binning for numeric
// Attributes (aka "histogram binning")
type BinningFilter struct {
Attributes []int
Instances *base.Instances
BinCount int
MinVals map[int]float64
MaxVals map[int]float64
trained bool
}
// NewBinningFilter creates a BinningFilter structure
// with some helpful default initialisations.
func NewBinningFilter(inst *base.Instances, bins int) BinningFilter {
return BinningFilter{
make([]int, 0),
inst,
bins,
make(map[int]float64),
make(map[int]float64),
false,
}
}
// AddAttribute adds the index of the given attribute `a'
// to the BinningFilter for discretisation.
func (b *BinningFilter) AddAttribute(a base.Attribute) {
attrIndex := b.Instances.GetAttrIndex(a)
if attrIndex == -1 {
panic("invalid attribute")
}
b.Attributes = append(b.Attributes, attrIndex)
}
// AddAllNumericAttributes adds every suitable attribute
// to the BinningFilter for discretiation
func (b *BinningFilter) AddAllNumericAttributes() {
for i := 0; i < b.Instances.Cols; i++ {
if i == b.Instances.ClassIndex {
continue
}
attr := b.Instances.GetAttr(i)
if attr.GetType() != base.Float64Type {
continue
}
b.Attributes = append(b.Attributes, i)
}
}
// Build computes and stores the bin values
// for the training instances.
func (b *BinningFilter) Build() {
for _, attr := range b.Attributes {
maxVal := math.Inf(-1)
minVal := math.Inf(1)
for i := 0; i < b.Instances.Rows; i++ {
val := b.Instances.Get(i, attr)
if val > maxVal {
maxVal = val
}
if val < minVal {
minVal = val
}
}
b.MaxVals[attr] = maxVal
b.MinVals[attr] = minVal
b.trained = true
}
}
// Run applies a trained BinningFilter to a set of Instances,
// discretising any numeric attributes added.
//
// IMPORTANT: Run discretises in-place, so make sure to take
// a copy if the original instances are still needed
//
// IMPORTANT: This function panic()s if the filter has not been
// trained. Call Build() before running this function
//
// IMPORTANT: Call Build() after adding any additional attributes.
// Otherwise, the training structure will be out of date from
// the values expected and could cause a panic.
func (b *BinningFilter) Run(on *base.Instances) {
if !b.trained {
panic("Call Build() beforehand")
}
for attr := range b.Attributes {
minVal := b.MinVals[attr]
maxVal := b.MaxVals[attr]
disc := 0
// Casts to float32 to replicate a floating point precision error
delta := float32(maxVal - minVal)
delta /= float32(b.BinCount)
for i := 0; i < on.Rows; i++ {
val := on.Get(i, attr)
if val <= minVal {
disc = 0
} else {
disc = int(math.Floor(float64(float32(val-minVal) / delta)))
if disc >= b.BinCount {
disc = b.BinCount - 1
}
}
on.Set(i, attr, float64(disc))
}
newAttribute := new(base.CategoricalAttribute)
newAttribute.SetName(on.GetAttr(attr).GetName())
for i := 0; i < b.BinCount; i++ {
newAttribute.GetSysValFromString(fmt.Sprintf("%d", i))
}
on.ReplaceAttr(attr, newAttribute)
}
}