1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/meta/one_v_all.go
Richard Townsend 981d43f1dd Adds support for multi-class linear SVMs.
This patch
  * Adds a one-vs-all meta classifier into meta/
  * Adds a LinearSVC (essentially the same as LogisticRegression
    but with different libsvm parameters) to linear_models/
  * Adds a MultiLinearSVC into ensemble/ for predicting
    CategoricalAttribute  classes with the LinearSVC
  * Adds a new example dataset based on classifying article headlines.

The example dataset is drawn from WikiNews, and consists of an average,
min and max Word2Vec representation of article headlines from three
categories. The Word2Vec model was computed offline using gensim.
2014-10-05 11:15:41 +01:00

173 lines
4.6 KiB
Go

package meta
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
)
// OneVsAllModel replaces class Attributes with numeric versions
// and trains n wrapped classifiers. The actual class is chosen
// by whichever is most confident. Only one CategoricalAttribute
// class variable is supported.
type OneVsAllModel struct {
NewClassifierFunction func() base.Classifier
filters []*oneVsAllFilter
classifiers []base.Classifier
maxClassVal uint64
}
// NewOneVsAllModel creates a new OneVsAllModel. The argument
// must be a function which returns a base.Classifier ready for training.
func NewOneVsAllModel(f func() base.Classifier) *OneVsAllModel {
return &OneVsAllModel{
f,
nil,
nil,
0,
}
}
func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Attribute]base.Attribute {
attrs := from.AllAttributes()
classAttrs := from.AllClassAttributes()
if len(classAttrs) != 1 {
panic("Only 1 class Attribute is supported!")
}
ret := make(map[base.Attribute]base.Attribute)
for _, a := range attrs {
ret[a] = a
for _, b := range classAttrs {
if a.Equals(b) {
cur := base.NewFloatAttribute(b.GetName())
ret[a] = cur
}
}
}
return ret
}
// Fit creates n filtered datasets (where n is the number of values
// a CategoricalAttribute can take) and uses them to train the
// underlying classifiers.
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
var classAttr *base.CategoricalAttribute
// Do some validation
classAttrs := using.AllClassAttributes()
for _, a := range classAttrs {
if c, ok := a.(*base.CategoricalAttribute); !ok {
panic("Unsupported ClassAttribute type")
} else {
classAttr = c
}
}
attrs := m.generateAttributes(using)
// Find the highest stored value
val := uint64(0)
for _, s := range classAttr.GetValues() {
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
if cur > val {
val = cur
}
}
if val == 0 {
panic("Must have more than one class!")
}
m.maxClassVal = val
// Create individual filtered instances for training
filters := make([]*oneVsAllFilter, val+1)
classifiers := make([]base.Classifier, val+1)
for i := uint64(0); i <= val; i++ {
f := &oneVsAllFilter{
attrs,
classAttr,
i,
}
filters[i] = f
classifiers[i] = m.NewClassifierFunction()
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
}
m.filters = filters
m.classifiers = classifiers
}
// Predict issues predictions. Each class-specific classifier is expected
// to output a value between 0 (indicating that a given instance is not
// a given class) and 1 (indicating that the given instance is definitely
// that class). For each instance, the class with the highest value is chosen.
// The result is undefined if several underlying models output the same value.
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
ret := base.GeneratePredictionVector(what)
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
specs := make([]base.AttributeSpec, m.maxClassVal+1)
for i := uint64(0); i <= m.maxClassVal; i++ {
f := m.filters[i]
c := base.NewLazilyFilteredInstances(what, f)
p, err := m.classifiers[i].Predict(c)
if err != nil {
return nil, err
}
vecs[i] = p
specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
}
_, rows := ret.Size()
spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
for i := 0; i < rows; i++ {
class := uint64(0)
best := 0.0
for j := uint64(0); j <= m.maxClassVal; j++ {
val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
if val > best {
class = j
best = val
}
}
ret.Set(spec, i, base.PackU64ToBytes(class))
}
return ret, nil
}
//
// Filter implementation
//
type oneVsAllFilter struct {
attrs map[base.Attribute]base.Attribute
classAttr base.Attribute
classAttrVal uint64
}
func (f *oneVsAllFilter) AddAttribute(a base.Attribute) error {
return fmt.Errorf("Not supported")
}
func (f *oneVsAllFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
ret := make([]base.FilteredAttribute, len(f.attrs))
cnt := 0
for i := range f.attrs {
ret[cnt] = base.FilteredAttribute{i, f.attrs[i]}
cnt++
}
return ret
}
func (f *oneVsAllFilter) String() string {
return "oneVsAllFilter"
}
func (f *oneVsAllFilter) Transform(old, to base.Attribute, seq []byte) []byte {
if !old.Equals(f.classAttr) {
return seq
}
val := base.UnpackBytesToU64(seq)
if val == f.classAttrVal {
return base.PackFloatToBytes(1.0)
}
return base.PackFloatToBytes(0.0)
}
func (f *oneVsAllFilter) Train() error {
return fmt.Errorf("Unsupported")
}