mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
Adds support for multi-class linear SVMs.
This patch * Adds a one-vs-all meta classifier into meta/ * Adds a LinearSVC (essentially the same as LogisticRegression but with different libsvm parameters) to linear_models/ * Adds a MultiLinearSVC into ensemble/ for predicting CategoricalAttribute classes with the LinearSVC * Adds a new example dataset based on classifying article headlines. The example dataset is drawn from WikiNews, and consists of an average, min and max Word2Vec representation of article headlines from three categories. The Word2Vec model was computed offline using gensim.
This commit is contained in:
parent
0e4d04af52
commit
981d43f1dd
48
ensemble/multisvc.go
Normal file
48
ensemble/multisvc.go
Normal file
@ -0,0 +1,48 @@
|
||||
package ensemble
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/linear_models"
|
||||
"github.com/sjwhitworth/golearn/meta"
|
||||
)
|
||||
|
||||
// MultiLinearSVC implements a multi-class Support Vector Classifier using a one-vs-all
|
||||
// voting scheme. Only one CategoricalAttribute class is supported.
|
||||
type MultiLinearSVC struct {
|
||||
m *meta.OneVsAllModel
|
||||
}
|
||||
|
||||
// NewMultiLinearSVC creates a new MultiLinearSVC using the OneVsAllModel.
|
||||
// The loss and penalty arguments can be "l1" or "l2". Typical values are
|
||||
// "l1" for the loss and "l2" for the penalty. The dual parameter controls
|
||||
// whether the system solves the dual or primal SVM form, true should be used
|
||||
// in most cases. C is the penalty term, normally 1.0. eps is the convergence
|
||||
// term, typically 1e-4.
|
||||
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64) *MultiLinearSVC {
|
||||
classifierFunc := func() base.Classifier {
|
||||
ret, err := linear_models.NewLinearSVC(loss, penalty, dual, C, eps)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
return &MultiLinearSVC{
|
||||
meta.NewOneVsAllModel(classifierFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// Fit builds the MultiLinearSVC by building n (where n is the number of values
|
||||
// the singular CategoricalAttribute can take) seperate one-vs-rest models.
|
||||
func (m *MultiLinearSVC) Fit(instances base.FixedDataGrid) error {
|
||||
m.m.Fit(instances)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Predict issues predictions from the MultiLinearSVC. Each underlying LinearSVC is
|
||||
// used to predict whether an instance takes on a class or some other class, and the
|
||||
// model which definitively reports a given class is the one chosen. The result is
|
||||
// undefined if all underlying models predict that the instance originates from some
|
||||
// other class.
|
||||
func (m *MultiLinearSVC) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
return m.m.Predict(from)
|
||||
}
|
27
ensemble/multisvc_test.go
Normal file
27
ensemble/multisvc_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package ensemble
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMultiSVM(t *testing.T) {
|
||||
Convey("Loading data...", t, func() {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
|
||||
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4)
|
||||
m.Fit(X)
|
||||
|
||||
Convey("Predictions should work...", func() {
|
||||
predictions, err := m.Predict(Y)
|
||||
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
||||
So(err, ShouldEqual, nil)
|
||||
fmt.Println(evaluation.GetSummary(cf))
|
||||
})
|
||||
})
|
||||
}
|
9714
examples/datasets/articles.csv
Normal file
9714
examples/datasets/articles.csv
Normal file
File diff suppressed because one or more lines are too long
@ -21,13 +21,15 @@ func TestLogisticRegression(t *testing.T) {
|
||||
lr.Fit(X)
|
||||
|
||||
Convey("When predicting the label of first vector", func() {
|
||||
Z := lr.Predict(Y)
|
||||
Z, err := lr.Predict(Y)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("The result should be 1", func() {
|
||||
So(Z.RowString(0), ShouldEqual, "1.00")
|
||||
})
|
||||
})
|
||||
Convey("When predicting the label of second vector", func() {
|
||||
Z := lr.Predict(Y)
|
||||
Z, err := lr.Predict(Y)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("The result should be -1", func() {
|
||||
So(Z.RowString(1), ShouldEqual, "-1.00")
|
||||
})
|
||||
|
83
linear_models/linearsvc.go
Normal file
83
linear_models/linearsvc.go
Normal file
@ -0,0 +1,83 @@
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
type LinearSVC struct {
|
||||
param *Parameter
|
||||
model *Model
|
||||
}
|
||||
|
||||
func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {
|
||||
solver_type := 0
|
||||
if penalty == "l2" {
|
||||
if loss == "l1" {
|
||||
if dual {
|
||||
solver_type = L2R_L1LOSS_SVC_DUAL
|
||||
}
|
||||
} else {
|
||||
if dual {
|
||||
solver_type = L2R_L2LOSS_SVC_DUAL
|
||||
} else {
|
||||
solver_type = L2R_L2LOSS_SVC
|
||||
}
|
||||
}
|
||||
} else if penalty == "l1" {
|
||||
if loss == "l2" {
|
||||
if !dual {
|
||||
solver_type = L1R_L2LOSS_SVC
|
||||
}
|
||||
}
|
||||
}
|
||||
if solver_type == 0 {
|
||||
panic("Parameter combination")
|
||||
}
|
||||
|
||||
lr := LinearSVC{}
|
||||
lr.param = NewParameter(solver_type, C, eps)
|
||||
lr.model = nil
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
func (lr *LinearSVC) Fit(X base.FixedDataGrid) error {
|
||||
problemVec := convertInstancesToProblemVec(X)
|
||||
labelVec := convertInstancesToLabelVec(X)
|
||||
prob := NewProblem(problemVec, labelVec, 0)
|
||||
lr.model = Train(prob, lr.param)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
|
||||
// Only support 1 class Attribute
|
||||
classAttrs := X.AllClassAttributes()
|
||||
if len(classAttrs) != 1 {
|
||||
panic(fmt.Sprintf("%d Wrong number of classes", len(classAttrs)))
|
||||
}
|
||||
// Generate return structure
|
||||
ret := base.GeneratePredictionVector(X)
|
||||
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
|
||||
// Allocate row storage
|
||||
row := make([]float64, len(numericAttrSpecs))
|
||||
X.MapOverRows(numericAttrSpecs, func(rowBytes [][]byte, rowNo int) (bool, error) {
|
||||
for i, r := range rowBytes {
|
||||
row[i] = base.UnpackBytesToFloat(r)
|
||||
}
|
||||
val := Predict(lr.model, row)
|
||||
vals := base.PackFloatToBytes(val)
|
||||
ret.Set(classAttrSpecs[0], rowNo, vals)
|
||||
return true, nil
|
||||
})
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (lr *LinearSVC) String() string {
|
||||
return "LogisticSVC"
|
||||
}
|
@ -27,61 +27,15 @@ func NewLogisticRegression(penalty string, C float64, eps float64) (*LogisticReg
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
||||
// Allocate problem array
|
||||
_, rows := X.Size()
|
||||
problemVec := make([][]float64, rows)
|
||||
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
|
||||
// Convert each row
|
||||
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
// Allocate a new row
|
||||
probRow := make([]float64, len(numericAttrSpecs))
|
||||
// Read out the row
|
||||
for i, _ := range numericAttrSpecs {
|
||||
probRow[i] = base.UnpackBytesToFloat(row[i])
|
||||
}
|
||||
// Add the row
|
||||
problemVec[rowNo] = probRow
|
||||
return true, nil
|
||||
})
|
||||
return problemVec
|
||||
}
|
||||
|
||||
func convertInstancesToLabelVec(X base.FixedDataGrid) []float64 {
|
||||
// Get the class Attributes
|
||||
classAttrs := X.AllClassAttributes()
|
||||
// Only support 1 class Attribute
|
||||
if len(classAttrs) != 1 {
|
||||
panic(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
|
||||
}
|
||||
// ClassAttribute must be numeric
|
||||
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
|
||||
panic(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
|
||||
}
|
||||
// Allocate return structure
|
||||
_, rows := X.Size()
|
||||
labelVec := make([]float64, rows)
|
||||
// Resolve class Attribute specification
|
||||
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
||||
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
|
||||
return true, nil
|
||||
})
|
||||
return labelVec
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) Fit(X base.FixedDataGrid) {
|
||||
func (lr *LogisticRegression) Fit(X base.FixedDataGrid) error {
|
||||
problemVec := convertInstancesToProblemVec(X)
|
||||
labelVec := convertInstancesToLabelVec(X)
|
||||
prob := NewProblem(problemVec, labelVec, 0)
|
||||
lr.model = Train(prob, lr.param)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) Predict(X base.FixedDataGrid) base.FixedDataGrid {
|
||||
func (lr *LogisticRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
|
||||
// Only support 1 class Attribute
|
||||
classAttrs := X.AllClassAttributes()
|
||||
@ -107,5 +61,9 @@ func (lr *LogisticRegression) Predict(X base.FixedDataGrid) base.FixedDataGrid {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
return ret
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) String() string {
|
||||
return "LogisticRegression"
|
||||
}
|
||||
|
53
linear_models/util.go
Normal file
53
linear_models/util.go
Normal file
@ -0,0 +1,53 @@
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
||||
// Allocate problem array
|
||||
_, rows := X.Size()
|
||||
problemVec := make([][]float64, rows)
|
||||
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
|
||||
// Convert each row
|
||||
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
// Allocate a new row
|
||||
probRow := make([]float64, len(numericAttrSpecs))
|
||||
// Read out the row
|
||||
for i, _ := range numericAttrSpecs {
|
||||
probRow[i] = base.UnpackBytesToFloat(row[i])
|
||||
}
|
||||
// Add the row
|
||||
problemVec[rowNo] = probRow
|
||||
return true, nil
|
||||
})
|
||||
return problemVec
|
||||
}
|
||||
|
||||
func convertInstancesToLabelVec(X base.FixedDataGrid) []float64 {
|
||||
// Get the class Attributes
|
||||
classAttrs := X.AllClassAttributes()
|
||||
// Only support 1 class Attribute
|
||||
if len(classAttrs) != 1 {
|
||||
panic(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
|
||||
}
|
||||
// ClassAttribute must be numeric
|
||||
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
|
||||
panic(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
|
||||
}
|
||||
// Allocate return structure
|
||||
_, rows := X.Size()
|
||||
labelVec := make([]float64, rows)
|
||||
// Resolve class Attribute specification
|
||||
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
||||
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
|
||||
return true, nil
|
||||
})
|
||||
return labelVec
|
||||
}
|
172
meta/one_v_all.go
Normal file
172
meta/one_v_all.go
Normal file
@ -0,0 +1,172 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
// OneVsAllModel replaces class Attributes with numeric versions
|
||||
// and trains n wrapped classifiers. The actual class is chosen
|
||||
// by whichever is most confident. Only one CategoricalAttribute
|
||||
// class variable is supported.
|
||||
type OneVsAllModel struct {
|
||||
NewClassifierFunction func() base.Classifier
|
||||
filters []*oneVsAllFilter
|
||||
classifiers []base.Classifier
|
||||
maxClassVal uint64
|
||||
}
|
||||
|
||||
// NewOneVsAllModel creates a new OneVsAllModel. The argument
|
||||
// must be a function which returns a base.Classifier ready for training.
|
||||
func NewOneVsAllModel(f func() base.Classifier) *OneVsAllModel {
|
||||
return &OneVsAllModel{
|
||||
f,
|
||||
nil,
|
||||
nil,
|
||||
0,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Attribute]base.Attribute {
|
||||
attrs := from.AllAttributes()
|
||||
classAttrs := from.AllClassAttributes()
|
||||
if len(classAttrs) != 1 {
|
||||
panic("Only 1 class Attribute is supported!")
|
||||
}
|
||||
ret := make(map[base.Attribute]base.Attribute)
|
||||
for _, a := range attrs {
|
||||
ret[a] = a
|
||||
for _, b := range classAttrs {
|
||||
if a.Equals(b) {
|
||||
cur := base.NewFloatAttribute(b.GetName())
|
||||
ret[a] = cur
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Fit creates n filtered datasets (where n is the number of values
|
||||
// a CategoricalAttribute can take) and uses them to train the
|
||||
// underlying classifiers.
|
||||
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
||||
var classAttr *base.CategoricalAttribute
|
||||
// Do some validation
|
||||
classAttrs := using.AllClassAttributes()
|
||||
for _, a := range classAttrs {
|
||||
if c, ok := a.(*base.CategoricalAttribute); !ok {
|
||||
panic("Unsupported ClassAttribute type")
|
||||
} else {
|
||||
classAttr = c
|
||||
}
|
||||
}
|
||||
attrs := m.generateAttributes(using)
|
||||
|
||||
// Find the highest stored value
|
||||
val := uint64(0)
|
||||
for _, s := range classAttr.GetValues() {
|
||||
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
|
||||
if cur > val {
|
||||
val = cur
|
||||
}
|
||||
}
|
||||
if val == 0 {
|
||||
panic("Must have more than one class!")
|
||||
}
|
||||
m.maxClassVal = val
|
||||
|
||||
// Create individual filtered instances for training
|
||||
filters := make([]*oneVsAllFilter, val+1)
|
||||
classifiers := make([]base.Classifier, val+1)
|
||||
for i := uint64(0); i <= val; i++ {
|
||||
f := &oneVsAllFilter{
|
||||
attrs,
|
||||
classAttr,
|
||||
i,
|
||||
}
|
||||
filters[i] = f
|
||||
classifiers[i] = m.NewClassifierFunction()
|
||||
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
|
||||
}
|
||||
|
||||
m.filters = filters
|
||||
m.classifiers = classifiers
|
||||
}
|
||||
|
||||
// Predict issues predictions. Each class-specific classifier is expected
|
||||
// to output a value between 0 (indicating that a given instance is not
|
||||
// a given class) and 1 (indicating that the given instance is definitely
|
||||
// that class). For each instance, the class with the highest value is chosen.
|
||||
// The result is undefined if several underlying models output the same value.
|
||||
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
ret := base.GeneratePredictionVector(what)
|
||||
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
|
||||
specs := make([]base.AttributeSpec, m.maxClassVal+1)
|
||||
for i := uint64(0); i <= m.maxClassVal; i++ {
|
||||
f := m.filters[i]
|
||||
c := base.NewLazilyFilteredInstances(what, f)
|
||||
p, err := m.classifiers[i].Predict(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vecs[i] = p
|
||||
specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
|
||||
}
|
||||
_, rows := ret.Size()
|
||||
spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
|
||||
for i := 0; i < rows; i++ {
|
||||
class := uint64(0)
|
||||
best := 0.0
|
||||
for j := uint64(0); j <= m.maxClassVal; j++ {
|
||||
val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
|
||||
if val > best {
|
||||
class = j
|
||||
best = val
|
||||
}
|
||||
}
|
||||
ret.Set(spec, i, base.PackU64ToBytes(class))
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
//
|
||||
// Filter implementation
|
||||
//
|
||||
type oneVsAllFilter struct {
|
||||
attrs map[base.Attribute]base.Attribute
|
||||
classAttr base.Attribute
|
||||
classAttrVal uint64
|
||||
}
|
||||
|
||||
func (f *oneVsAllFilter) AddAttribute(a base.Attribute) error {
|
||||
return fmt.Errorf("Not supported")
|
||||
}
|
||||
|
||||
func (f *oneVsAllFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||
ret := make([]base.FilteredAttribute, len(f.attrs))
|
||||
cnt := 0
|
||||
for i := range f.attrs {
|
||||
ret[cnt] = base.FilteredAttribute{i, f.attrs[i]}
|
||||
cnt++
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (f *oneVsAllFilter) String() string {
|
||||
return "oneVsAllFilter"
|
||||
}
|
||||
|
||||
func (f *oneVsAllFilter) Transform(old, to base.Attribute, seq []byte) []byte {
|
||||
if !old.Equals(f.classAttr) {
|
||||
return seq
|
||||
}
|
||||
val := base.UnpackBytesToU64(seq)
|
||||
if val == f.classAttrVal {
|
||||
return base.PackFloatToBytes(1.0)
|
||||
}
|
||||
return base.PackFloatToBytes(0.0)
|
||||
}
|
||||
|
||||
func (f *oneVsAllFilter) Train() error {
|
||||
return fmt.Errorf("Unsupported")
|
||||
}
|
49
meta/one_v_all_test.go
Normal file
49
meta/one_v_all_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/linear_models"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOneVsAllModel(t *testing.T) {
|
||||
|
||||
classifierFunc := func() base.Classifier {
|
||||
m, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
Convey("Given data", t, func() {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
|
||||
m := NewOneVsAllModel(classifierFunc)
|
||||
m.Fit(X)
|
||||
|
||||
Convey("The maximum class index should be 2", func() {
|
||||
So(m.maxClassVal, ShouldEqual, 2)
|
||||
})
|
||||
|
||||
Convey("There should be three of everything...", func() {
|
||||
So(len(m.filters), ShouldEqual, 3)
|
||||
So(len(m.classifiers), ShouldEqual, 3)
|
||||
})
|
||||
|
||||
Convey("Predictions should work...", func() {
|
||||
predictions, err := m.Predict(Y)
|
||||
So(err, ShouldEqual, nil)
|
||||
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
||||
So(err, ShouldEqual, nil)
|
||||
fmt.Println(evaluation.GetAccuracy(cf))
|
||||
fmt.Println(evaluation.GetSummary(cf))
|
||||
})
|
||||
})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user