1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/filters/binning.go
2014-08-22 07:21:24 +00:00

123 lines
3.1 KiB
Go

package filters
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"math"
)
// BinningFilter does equal-width binning for numeric
// Attributes (aka "histogram binning")
type BinningFilter struct {
AbstractDiscretizeFilter
bins int
minVals map[base.Attribute]float64
maxVals map[base.Attribute]float64
}
// NewBinningFilter creates a BinningFilter structure
// with some helpful default initialisations.
func NewBinningFilter(d base.FixedDataGrid, bins int) *BinningFilter {
return &BinningFilter{
AbstractDiscretizeFilter{
make(map[base.Attribute]bool),
false,
d,
},
bins,
make(map[base.Attribute]float64),
make(map[base.Attribute]float64),
}
}
func (b *BinningFilter) String() string {
return fmt.Sprintf("BinningFilter(%d Attribute(s), %d bin(s)", b.attrs, b.bins)
}
// Train computes and stores the bin values
// for the training instances.
func (b *BinningFilter) Train() error {
as := b.getAttributeSpecs()
// Set up the AttributeSpecs, and values
for attr := range b.attrs {
if !b.attrs[attr] {
continue
}
b.minVals[attr] = float64(math.Inf(1))
b.maxVals[attr] = float64(math.Inf(-1))
}
err := b.train.MapOverRows(as, func(row [][]byte, rowNo int) (bool, error) {
for i, a := range row {
attr := as[i].GetAttribute()
attrf := attr.(*base.FloatAttribute)
val := float64(attrf.GetFloatFromSysVal(a))
if val > b.maxVals[attr] {
b.maxVals[attr] = val
}
if val < b.minVals[attr] {
b.minVals[attr] = val
}
}
return true, nil
})
if err != nil {
return fmt.Errorf("Training error: %s", err)
}
b.trained = true
return nil
}
// Transform takes an Attribute and byte sequence and returns
// the transformed byte sequence.
func (b *BinningFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
if !b.attrs[a] {
return field
}
af, ok := a.(*base.FloatAttribute)
if !ok {
panic("Attribute is the wrong type")
}
minVal := b.minVals[a]
maxVal := b.maxVals[a]
disc := 0
// Casts to float64 to replicate a floating point precision error
delta := float64(maxVal-minVal) / float64(b.bins)
val := float64(af.GetFloatFromSysVal(field))
if val <= minVal {
disc = 0
} else {
disc = int(math.Floor(float64(float64(val-minVal)/delta + 0.0001)))
}
return base.PackU64ToBytes(uint64(disc))
}
// GetAttributesAfterFiltering gets a list of before/after
// Attributes as base.FilteredAttributes
func (b *BinningFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
oldAttrs := b.train.AllAttributes()
ret := make([]base.FilteredAttribute, len(oldAttrs))
for i, a := range oldAttrs {
if b.attrs[a] {
retAttr := new(base.CategoricalAttribute)
minVal := b.minVals[a]
maxVal := b.maxVals[a]
delta := float64(maxVal-minVal) / float64(b.bins)
retAttr.SetName(a.GetName())
for i := 0; i <= b.bins; i++ {
floatVal := float64(i)*delta + minVal
fmtStr := fmt.Sprintf("%%.%df", a.(*base.FloatAttribute).Precision)
binVal := fmt.Sprintf(fmtStr, floatVal)
retAttr.GetSysValFromString(binVal)
}
ret[i] = base.FilteredAttribute{a, retAttr}
} else {
ret[i] = base.FilteredAttribute{a, a}
}
}
return ret
}