mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00

This patch * Adds a one-vs-all meta classifier into meta/ * Adds a LinearSVC (essentially the same as LogisticRegression but with different libsvm parameters) to linear_models/ * Adds a MultiLinearSVC into ensemble/ for predicting CategoricalAttribute classes with the LinearSVC * Adds a new example dataset based on classifying article headlines. The example dataset is drawn from WikiNews, and consists of an average, min and max Word2Vec representation of article headlines from three categories. The Word2Vec model was computed offline using gensim.
173 lines
4.6 KiB
Go
173 lines
4.6 KiB
Go
package meta
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
// OneVsAllModel replaces class Attributes with numeric versions
|
|
// and trains n wrapped classifiers. The actual class is chosen
|
|
// by whichever is most confident. Only one CategoricalAttribute
|
|
// class variable is supported.
|
|
type OneVsAllModel struct {
|
|
NewClassifierFunction func() base.Classifier
|
|
filters []*oneVsAllFilter
|
|
classifiers []base.Classifier
|
|
maxClassVal uint64
|
|
}
|
|
|
|
// NewOneVsAllModel creates a new OneVsAllModel. The argument
|
|
// must be a function which returns a base.Classifier ready for training.
|
|
func NewOneVsAllModel(f func() base.Classifier) *OneVsAllModel {
|
|
return &OneVsAllModel{
|
|
f,
|
|
nil,
|
|
nil,
|
|
0,
|
|
}
|
|
}
|
|
|
|
func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Attribute]base.Attribute {
|
|
attrs := from.AllAttributes()
|
|
classAttrs := from.AllClassAttributes()
|
|
if len(classAttrs) != 1 {
|
|
panic("Only 1 class Attribute is supported!")
|
|
}
|
|
ret := make(map[base.Attribute]base.Attribute)
|
|
for _, a := range attrs {
|
|
ret[a] = a
|
|
for _, b := range classAttrs {
|
|
if a.Equals(b) {
|
|
cur := base.NewFloatAttribute(b.GetName())
|
|
ret[a] = cur
|
|
}
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// Fit creates n filtered datasets (where n is the number of values
|
|
// a CategoricalAttribute can take) and uses them to train the
|
|
// underlying classifiers.
|
|
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
|
var classAttr *base.CategoricalAttribute
|
|
// Do some validation
|
|
classAttrs := using.AllClassAttributes()
|
|
for _, a := range classAttrs {
|
|
if c, ok := a.(*base.CategoricalAttribute); !ok {
|
|
panic("Unsupported ClassAttribute type")
|
|
} else {
|
|
classAttr = c
|
|
}
|
|
}
|
|
attrs := m.generateAttributes(using)
|
|
|
|
// Find the highest stored value
|
|
val := uint64(0)
|
|
for _, s := range classAttr.GetValues() {
|
|
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
|
|
if cur > val {
|
|
val = cur
|
|
}
|
|
}
|
|
if val == 0 {
|
|
panic("Must have more than one class!")
|
|
}
|
|
m.maxClassVal = val
|
|
|
|
// Create individual filtered instances for training
|
|
filters := make([]*oneVsAllFilter, val+1)
|
|
classifiers := make([]base.Classifier, val+1)
|
|
for i := uint64(0); i <= val; i++ {
|
|
f := &oneVsAllFilter{
|
|
attrs,
|
|
classAttr,
|
|
i,
|
|
}
|
|
filters[i] = f
|
|
classifiers[i] = m.NewClassifierFunction()
|
|
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
|
|
}
|
|
|
|
m.filters = filters
|
|
m.classifiers = classifiers
|
|
}
|
|
|
|
// Predict issues predictions. Each class-specific classifier is expected
|
|
// to output a value between 0 (indicating that a given instance is not
|
|
// a given class) and 1 (indicating that the given instance is definitely
|
|
// that class). For each instance, the class with the highest value is chosen.
|
|
// The result is undefined if several underlying models output the same value.
|
|
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
|
ret := base.GeneratePredictionVector(what)
|
|
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
|
|
specs := make([]base.AttributeSpec, m.maxClassVal+1)
|
|
for i := uint64(0); i <= m.maxClassVal; i++ {
|
|
f := m.filters[i]
|
|
c := base.NewLazilyFilteredInstances(what, f)
|
|
p, err := m.classifiers[i].Predict(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vecs[i] = p
|
|
specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
|
|
}
|
|
_, rows := ret.Size()
|
|
spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
|
|
for i := 0; i < rows; i++ {
|
|
class := uint64(0)
|
|
best := 0.0
|
|
for j := uint64(0); j <= m.maxClassVal; j++ {
|
|
val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
|
|
if val > best {
|
|
class = j
|
|
best = val
|
|
}
|
|
}
|
|
ret.Set(spec, i, base.PackU64ToBytes(class))
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
//
|
|
// Filter implementation
|
|
//
|
|
type oneVsAllFilter struct {
|
|
attrs map[base.Attribute]base.Attribute
|
|
classAttr base.Attribute
|
|
classAttrVal uint64
|
|
}
|
|
|
|
func (f *oneVsAllFilter) AddAttribute(a base.Attribute) error {
|
|
return fmt.Errorf("Not supported")
|
|
}
|
|
|
|
func (f *oneVsAllFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
|
ret := make([]base.FilteredAttribute, len(f.attrs))
|
|
cnt := 0
|
|
for i := range f.attrs {
|
|
ret[cnt] = base.FilteredAttribute{i, f.attrs[i]}
|
|
cnt++
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (f *oneVsAllFilter) String() string {
|
|
return "oneVsAllFilter"
|
|
}
|
|
|
|
func (f *oneVsAllFilter) Transform(old, to base.Attribute, seq []byte) []byte {
|
|
if !old.Equals(f.classAttr) {
|
|
return seq
|
|
}
|
|
val := base.UnpackBytesToU64(seq)
|
|
if val == f.classAttrVal {
|
|
return base.PackFloatToBytes(1.0)
|
|
}
|
|
return base.PackFloatToBytes(0.0)
|
|
}
|
|
|
|
func (f *oneVsAllFilter) Train() error {
|
|
return fmt.Errorf("Unsupported")
|
|
}
|