mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
123 lines
3.1 KiB
Go
123 lines
3.1 KiB
Go
package filters
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"math"
|
|
)
|
|
|
|
// BinningFilter does equal-width binning for numeric
|
|
// Attributes (aka "histogram binning")
|
|
type BinningFilter struct {
|
|
AbstractDiscretizeFilter
|
|
bins int
|
|
minVals map[base.Attribute]float64
|
|
maxVals map[base.Attribute]float64
|
|
}
|
|
|
|
// NewBinningFilter creates a BinningFilter structure
|
|
// with some helpful default initialisations.
|
|
func NewBinningFilter(d base.FixedDataGrid, bins int) *BinningFilter {
|
|
return &BinningFilter{
|
|
AbstractDiscretizeFilter{
|
|
make(map[base.Attribute]bool),
|
|
false,
|
|
d,
|
|
},
|
|
bins,
|
|
make(map[base.Attribute]float64),
|
|
make(map[base.Attribute]float64),
|
|
}
|
|
}
|
|
|
|
func (b *BinningFilter) String() string {
|
|
return fmt.Sprintf("BinningFilter(%d Attribute(s), %d bin(s)", b.attrs, b.bins)
|
|
}
|
|
|
|
// Train computes and stores the bin values
|
|
// for the training instances.
|
|
func (b *BinningFilter) Train() error {
|
|
|
|
as := b.getAttributeSpecs()
|
|
// Set up the AttributeSpecs, and values
|
|
for attr := range b.attrs {
|
|
if !b.attrs[attr] {
|
|
continue
|
|
}
|
|
b.minVals[attr] = float64(math.Inf(1))
|
|
b.maxVals[attr] = float64(math.Inf(-1))
|
|
}
|
|
|
|
err := b.train.MapOverRows(as, func(row [][]byte, rowNo int) (bool, error) {
|
|
for i, a := range row {
|
|
attr := as[i].GetAttribute()
|
|
attrf := attr.(*base.FloatAttribute)
|
|
val := float64(attrf.GetFloatFromSysVal(a))
|
|
if val > b.maxVals[attr] {
|
|
b.maxVals[attr] = val
|
|
}
|
|
if val < b.minVals[attr] {
|
|
b.minVals[attr] = val
|
|
}
|
|
}
|
|
return true, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("Training error: %s", err)
|
|
}
|
|
b.trained = true
|
|
return nil
|
|
}
|
|
|
|
// Transform takes an Attribute and byte sequence and returns
|
|
// the transformed byte sequence.
|
|
func (b *BinningFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
|
|
|
|
if !b.attrs[a] {
|
|
return field
|
|
}
|
|
af, ok := a.(*base.FloatAttribute)
|
|
if !ok {
|
|
panic("Attribute is the wrong type")
|
|
}
|
|
minVal := b.minVals[a]
|
|
maxVal := b.maxVals[a]
|
|
disc := 0
|
|
// Casts to float64 to replicate a floating point precision error
|
|
delta := float64(maxVal-minVal) / float64(b.bins)
|
|
val := float64(af.GetFloatFromSysVal(field))
|
|
if val <= minVal {
|
|
disc = 0
|
|
} else {
|
|
disc = int(math.Floor(float64(float64(val-minVal)/delta + 0.0001)))
|
|
}
|
|
return base.PackU64ToBytes(uint64(disc))
|
|
}
|
|
|
|
// GetAttributesAfterFiltering gets a list of before/after
|
|
// Attributes as base.FilteredAttributes
|
|
func (b *BinningFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
|
oldAttrs := b.train.AllAttributes()
|
|
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
|
for i, a := range oldAttrs {
|
|
if b.attrs[a] {
|
|
retAttr := new(base.CategoricalAttribute)
|
|
minVal := b.minVals[a]
|
|
maxVal := b.maxVals[a]
|
|
delta := float64(maxVal-minVal) / float64(b.bins)
|
|
retAttr.SetName(a.GetName())
|
|
for i := 0; i <= b.bins; i++ {
|
|
floatVal := float64(i)*delta + minVal
|
|
fmtStr := fmt.Sprintf("%%.%df", a.(*base.FloatAttribute).Precision)
|
|
binVal := fmt.Sprintf(fmtStr, floatVal)
|
|
retAttr.GetSysValFromString(binVal)
|
|
}
|
|
ret[i] = base.FilteredAttribute{a, retAttr}
|
|
} else {
|
|
ret[i] = base.FilteredAttribute{a, a}
|
|
}
|
|
}
|
|
return ret
|
|
}
|