1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/filters/binary.go
Richard Townsend ff97065261 base: BinaryConvertFilter, Transform()
Transform now takes a new Attribute so BinaryConvertFilter
 can work correctly
2014-08-03 15:17:18 +01:00

152 lines
4.9 KiB
Go

package filters
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
)
// BinaryConvertFilters convert a given DataGrid into one which
// only contains BinaryAttributes.
//
// FloatAttributes are discretised into either 0 (if the value is 0)
// or 1 (if the value is not 0).
//
// CategoricalAttributes are discretised into one or more new
// BinaryAttributes.
type BinaryConvertFilter struct {
attrs []base.Attribute
converted []base.FilteredAttribute
twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes
nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute
}
// NewBinaryConvertFilter creates a blank BinaryConvertFilter
func NewBinaryConvertFilter() *BinaryConvertFilter {
ret := &BinaryConvertFilter{
make([]base.Attribute, 0),
make([]base.FilteredAttribute, 0),
make(map[base.Attribute]bool),
make(map[base.Attribute]map[uint64]base.Attribute),
}
return ret
}
// AddAttribute adds a new Attribute to this Filter
func (b *BinaryConvertFilter) AddAttribute(a base.Attribute) error {
b.attrs = append(b.attrs, a)
return nil
}
// GetAttributesAfterFiltering returns the Attributes previously computed via Train()
func (b *BinaryConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
return b.converted
}
// String gets a human-readable string
func (b *BinaryConvertFilter) String() string {
return fmt.Sprintf("BinaryConvertFilter(%d Attribute(s))", len(b.attrs))
}
// Transform converts the given byte sequence using the old Attribute into the new
// byte sequence.
//
// If the old Attribute has a categorical value of at most two items, then a zero or
// non-zero byte sequence is returned.
//
// If the old Attribute has a categorical value of at most n-items, then a non-zero
// or zero byte sequence is returned based on the value of the new Attribute passed in.
//
// If the old Attribute is a float, it's value's unpacked and we check for non-zeroness
//
// If the old Attribute is a BinaryAttribute, just return the input
func (b *BinaryConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte {
ret := make([]byte, 1)
// Check for CategoricalAttribute
if _, ok := a.(*base.CategoricalAttribute); ok {
// Unpack byte value
val := base.UnpackBytesToU64(attrBytes)
// If it's a two-valued one, check for non-zero
if b.twoValuedCategoricalAttributes[a] {
if val > 0 {
ret[0] = 1
} else {
ret[0] = 0
}
} else if an, ok := b.nValuedCategoricalAttributeMap[a]; ok {
// If it's an n-valued one, check the new Attribute maps onto
// the unpacked value
if af, ok := an[val]; ok {
if af.Equals(n) {
ret[0] = 1
} else {
ret[0] = 0
}
} else {
panic("Categorical value not defined!")
}
} else {
panic(fmt.Sprintf("Not a recognised Attribute %v", a))
}
} else if _, ok := a.(*base.BinaryAttribute); ok {
// Binary: just return the original value
ret = attrBytes
} else if _, ok := a.(*base.FloatAttribute); ok {
// Float: check for non-zero
val := base.UnpackBytesToFloat(attrBytes)
if val > 0 {
ret[0] = 1
} else {
ret[0] = 0
}
} else {
panic(fmt.Sprintf("Unrecognised Attribute: %v", a))
}
return ret
}
// Train converts the FloatAttributesinto equivalently named BinaryAttributes,
// leaves BinaryAttributes unmodified and processes
// CategoricalAttributes as follows.
//
// If the CategoricalAttribute has two values, one of them is
// designated 0 and the other 1, and a single identically-named
// binary Attribute is returned.
//
// If the CategoricalAttribute has more than two (n) values, the Filter
// generates n BinaryAttributes and sets each of them if the value's observed.
func (b *BinaryConvertFilter) Train() error {
for _, a := range b.attrs {
if ac, ok := a.(*base.CategoricalAttribute); ok {
vals := ac.GetValues()
if len(vals) <= 2 {
nAttr := base.NewBinaryAttribute(ac.GetName())
fAttr := base.FilteredAttribute{ac, nAttr}
b.converted = append(b.converted, fAttr)
b.twoValuedCategoricalAttributes[a] = true
} else {
if _, ok := b.nValuedCategoricalAttributeMap[a]; !ok {
b.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute)
}
for i := uint64(0); i < uint64(len(vals)); i++ {
v := vals[i]
newName := fmt.Sprintf("%s_%s", ac.GetName(), v)
newAttr := base.NewBinaryAttribute(newName)
fAttr := base.FilteredAttribute{ac, newAttr}
b.converted = append(b.converted, fAttr)
b.nValuedCategoricalAttributeMap[a][i] = newAttr
}
}
} else if ab, ok := a.(*base.BinaryAttribute); ok {
fAttr := base.FilteredAttribute{ab, ab}
b.converted = append(b.converted, fAttr)
} else if af, ok := a.(*base.FloatAttribute); ok {
newAttr := base.NewBinaryAttribute(af.GetName())
fAttr := base.FilteredAttribute{af, newAttr}
b.converted = append(b.converted, fAttr)
} else {
return fmt.Errorf("Unsupported Attribute type: %v", a)
}
}
return nil
}