mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
155 lines
4.1 KiB
Go
155 lines
4.1 KiB
Go
package meta
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
// OneVsAllModel replaces class Attributes with numeric versions
|
|
// and trains n wrapped classifiers. The actual class is chosen
|
|
// by whichever is most confident. Only one CategoricalAttribute
|
|
// class variable is supported.
|
|
type OneVsAllModel struct {
|
|
NewClassifierFunction func(string) base.Classifier
|
|
filters []*oneVsAllFilter
|
|
classifiers []base.Classifier
|
|
maxClassVal uint64
|
|
}
|
|
|
|
// NewOneVsAllModel creates a new OneVsAllModel. The argument
|
|
// must be a function which returns a base.Classifier ready for training.
|
|
func NewOneVsAllModel(f func(string) base.Classifier) *OneVsAllModel {
|
|
return &OneVsAllModel{
|
|
f,
|
|
nil,
|
|
nil,
|
|
0,
|
|
}
|
|
}
|
|
|
|
// Fit creates n filtered datasets (where n is the number of values
|
|
// a CategoricalAttribute can take) and uses them to train the
|
|
// underlying classifiers.
|
|
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
|
var classAttr *base.CategoricalAttribute
|
|
// Do some validation
|
|
classAttrs := using.AllClassAttributes()
|
|
for _, a := range classAttrs {
|
|
if c, ok := a.(*base.CategoricalAttribute); !ok {
|
|
panic("Unsupported ClassAttribute type")
|
|
} else {
|
|
classAttr = c
|
|
}
|
|
}
|
|
attrs := m.generateAttributes(using)
|
|
|
|
// Find the highest stored value
|
|
val := uint64(0)
|
|
classVals := classAttr.GetValues()
|
|
for _, s := range classVals {
|
|
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
|
|
if cur > val {
|
|
val = cur
|
|
}
|
|
}
|
|
if val == 0 {
|
|
panic("Must have more than one class!")
|
|
}
|
|
m.maxClassVal = val
|
|
|
|
// Create individual filtered instances for training
|
|
filters := make([]*oneVsAllFilter, val+1)
|
|
classifiers := make([]base.Classifier, val+1)
|
|
for i := uint64(0); i <= val; i++ {
|
|
f := &oneVsAllFilter{
|
|
attrs,
|
|
classAttr,
|
|
i,
|
|
}
|
|
filters[i] = f
|
|
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
|
|
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
|
|
}
|
|
|
|
m.filters = filters
|
|
m.classifiers = classifiers
|
|
}
|
|
|
|
// Predict issues predictions. Each class-specific classifier is expected
|
|
// to output a value between 0 (indicating that a given instance is not
|
|
// a given class) and 1 (indicating that the given instance is definitely
|
|
// that class). For each instance, the class with the highest value is chosen.
|
|
// The result is undefined if several underlying models output the same value.
|
|
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
|
ret := base.GeneratePredictionVector(what)
|
|
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
|
|
specs := make([]base.AttributeSpec, m.maxClassVal+1)
|
|
for i := uint64(0); i <= m.maxClassVal; i++ {
|
|
f := m.filters[i]
|
|
c := base.NewLazilyFilteredInstances(what, f)
|
|
p, err := m.classifiers[i].Predict(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vecs[i] = p
|
|
specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
|
|
}
|
|
_, rows := ret.Size()
|
|
spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
|
|
for i := 0; i < rows; i++ {
|
|
class := uint64(0)
|
|
best := 0.0
|
|
for j := uint64(0); j <= m.maxClassVal; j++ {
|
|
val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
|
|
if val > best {
|
|
class = j
|
|
best = val
|
|
}
|
|
}
|
|
ret.Set(spec, i, base.PackU64ToBytes(class))
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
//
|
|
// Filter implementation
|
|
//
|
|
type oneVsAllFilter struct {
|
|
attrs map[base.Attribute]base.Attribute
|
|
classAttr base.Attribute
|
|
classAttrVal uint64
|
|
}
|
|
|
|
func (f *oneVsAllFilter) AddAttribute(a base.Attribute) error {
|
|
return fmt.Errorf("Not supported")
|
|
}
|
|
|
|
func (f *oneVsAllFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
|
ret := make([]base.FilteredAttribute, len(f.attrs))
|
|
cnt := 0
|
|
for i := range f.attrs {
|
|
ret[cnt] = base.FilteredAttribute{i, f.attrs[i]}
|
|
cnt++
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (f *oneVsAllFilter) String() string {
|
|
return "oneVsAllFilter"
|
|
}
|
|
|
|
func (f *oneVsAllFilter) Transform(old, to base.Attribute, seq []byte) []byte {
|
|
if !old.Equals(f.classAttr) {
|
|
return seq
|
|
}
|
|
val := base.UnpackBytesToU64(seq)
|
|
if val == f.classAttrVal {
|
|
return base.PackFloatToBytes(1.0)
|
|
}
|
|
return base.PackFloatToBytes(0.0)
|
|
}
|
|
|
|
func (f *oneVsAllFilter) Train() error {
|
|
return fmt.Errorf("Unsupported")
|
|
}
|