mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
152 lines
4.9 KiB
Go
152 lines
4.9 KiB
Go
package filters
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
// BinaryConvertFilters convert a given DataGrid into one which
|
|
// only contains BinaryAttributes.
|
|
//
|
|
// FloatAttributes are discretised into either 0 (if the value is 0)
|
|
// or 1 (if the value is not 0).
|
|
//
|
|
// CategoricalAttributes are discretised into one or more new
|
|
// BinaryAttributes.
|
|
type BinaryConvertFilter struct {
|
|
attrs []base.Attribute
|
|
converted []base.FilteredAttribute
|
|
twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes
|
|
nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute
|
|
}
|
|
|
|
// NewBinaryConvertFilter creates a blank BinaryConvertFilter
|
|
func NewBinaryConvertFilter() *BinaryConvertFilter {
|
|
ret := &BinaryConvertFilter{
|
|
make([]base.Attribute, 0),
|
|
make([]base.FilteredAttribute, 0),
|
|
make(map[base.Attribute]bool),
|
|
make(map[base.Attribute]map[uint64]base.Attribute),
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// AddAttribute adds a new Attribute to this Filter
|
|
func (b *BinaryConvertFilter) AddAttribute(a base.Attribute) error {
|
|
b.attrs = append(b.attrs, a)
|
|
return nil
|
|
}
|
|
|
|
// GetAttributesAfterFiltering returns the Attributes previously computed via Train()
|
|
func (b *BinaryConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
|
return b.converted
|
|
}
|
|
|
|
// String gets a human-readable string
|
|
func (b *BinaryConvertFilter) String() string {
|
|
return fmt.Sprintf("BinaryConvertFilter(%d Attribute(s))", len(b.attrs))
|
|
}
|
|
|
|
// Transform converts the given byte sequence using the old Attribute into the new
|
|
// byte sequence.
|
|
//
|
|
// If the old Attribute has a categorical value of at most two items, then a zero or
|
|
// non-zero byte sequence is returned.
|
|
//
|
|
// If the old Attribute has a categorical value of at most n-items, then a non-zero
|
|
// or zero byte sequence is returned based on the value of the new Attribute passed in.
|
|
//
|
|
// If the old Attribute is a float, it's value's unpacked and we check for non-zeroness
|
|
//
|
|
// If the old Attribute is a BinaryAttribute, just return the input
|
|
func (b *BinaryConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte {
|
|
ret := make([]byte, 1)
|
|
// Check for CategoricalAttribute
|
|
if _, ok := a.(*base.CategoricalAttribute); ok {
|
|
// Unpack byte value
|
|
val := base.UnpackBytesToU64(attrBytes)
|
|
// If it's a two-valued one, check for non-zero
|
|
if b.twoValuedCategoricalAttributes[a] {
|
|
if val > 0 {
|
|
ret[0] = 1
|
|
} else {
|
|
ret[0] = 0
|
|
}
|
|
} else if an, ok := b.nValuedCategoricalAttributeMap[a]; ok {
|
|
// If it's an n-valued one, check the new Attribute maps onto
|
|
// the unpacked value
|
|
if af, ok := an[val]; ok {
|
|
if af.Equals(n) {
|
|
ret[0] = 1
|
|
} else {
|
|
ret[0] = 0
|
|
}
|
|
} else {
|
|
panic("Categorical value not defined!")
|
|
}
|
|
} else {
|
|
panic(fmt.Sprintf("Not a recognised Attribute %v", a))
|
|
}
|
|
} else if _, ok := a.(*base.BinaryAttribute); ok {
|
|
// Binary: just return the original value
|
|
ret = attrBytes
|
|
} else if _, ok := a.(*base.FloatAttribute); ok {
|
|
// Float: check for non-zero
|
|
val := base.UnpackBytesToFloat(attrBytes)
|
|
if val > 0 {
|
|
ret[0] = 1
|
|
} else {
|
|
ret[0] = 0
|
|
}
|
|
} else {
|
|
panic(fmt.Sprintf("Unrecognised Attribute: %v", a))
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// Train converts the FloatAttributesinto equivalently named BinaryAttributes,
|
|
// leaves BinaryAttributes unmodified and processes
|
|
// CategoricalAttributes as follows.
|
|
//
|
|
// If the CategoricalAttribute has two values, one of them is
|
|
// designated 0 and the other 1, and a single identically-named
|
|
// binary Attribute is returned.
|
|
//
|
|
// If the CategoricalAttribute has more than two (n) values, the Filter
|
|
// generates n BinaryAttributes and sets each of them if the value's observed.
|
|
func (b *BinaryConvertFilter) Train() error {
|
|
for _, a := range b.attrs {
|
|
if ac, ok := a.(*base.CategoricalAttribute); ok {
|
|
vals := ac.GetValues()
|
|
if len(vals) <= 2 {
|
|
nAttr := base.NewBinaryAttribute(ac.GetName())
|
|
fAttr := base.FilteredAttribute{ac, nAttr}
|
|
b.converted = append(b.converted, fAttr)
|
|
b.twoValuedCategoricalAttributes[a] = true
|
|
} else {
|
|
if _, ok := b.nValuedCategoricalAttributeMap[a]; !ok {
|
|
b.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute)
|
|
}
|
|
for i := uint64(0); i < uint64(len(vals)); i++ {
|
|
v := vals[i]
|
|
newName := fmt.Sprintf("%s_%s", ac.GetName(), v)
|
|
newAttr := base.NewBinaryAttribute(newName)
|
|
fAttr := base.FilteredAttribute{ac, newAttr}
|
|
b.converted = append(b.converted, fAttr)
|
|
b.nValuedCategoricalAttributeMap[a][i] = newAttr
|
|
}
|
|
}
|
|
} else if ab, ok := a.(*base.BinaryAttribute); ok {
|
|
fAttr := base.FilteredAttribute{ab, ab}
|
|
b.converted = append(b.converted, fAttr)
|
|
} else if af, ok := a.(*base.FloatAttribute); ok {
|
|
newAttr := base.NewBinaryAttribute(af.GetName())
|
|
fAttr := base.FilteredAttribute{af, newAttr}
|
|
b.converted = append(b.converted, fAttr)
|
|
} else {
|
|
return fmt.Errorf("Unsupported Attribute type: %v", a)
|
|
}
|
|
}
|
|
return nil
|
|
}
|