1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/ensemble/multisvc.go
2014-10-30 23:28:26 +00:00

73 lines
2.3 KiB
Go

package ensemble
import (
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/linear_models"
"github.com/sjwhitworth/golearn/meta"
)
// MultiLinearSVC implements a multi-class Support Vector Classifier using a one-vs-all
// voting scheme. Only one CategoricalAttribute class is supported.
type MultiLinearSVC struct {
m *meta.OneVsAllModel
}
// NewMultiLinearSVC creates a new MultiLinearSVC using the OneVsAllModel.
// The loss and penalty arguments can be "l1" or "l2". Typical values are
// "l1" for the loss and "l2" for the penalty. The dual parameter controls
// whether the system solves the dual or primal SVM form, true should be used
// in most cases. C is the penalty term, normally 1.0. eps is the convergence
// term, typically 1e-4.
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64, weights map[string]float64) *MultiLinearSVC {
// Set up the training parameters
params := &linear_models.LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
panic(err)
}
// Classifier creation function
classifierFunc := func(cls string) base.Classifier {
var weightVec []float64
newParams := params.Copy()
if weights != nil {
weightVec = make([]float64, 2)
for i := range weights {
if i != cls {
weightVec[0] += weights[i]
} else {
weightVec[1] = weights[i]
}
}
}
newParams.ClassWeights = weightVec
ret, err := linear_models.NewLinearSVCFromParams(newParams)
if err != nil {
panic(err)
}
return ret
}
// Return me...
return &MultiLinearSVC{
meta.NewOneVsAllModel(classifierFunc),
}
}
// Fit builds the MultiLinearSVC by building n (where n is the number of values
// the singular CategoricalAttribute can take) seperate one-vs-rest models.
func (m *MultiLinearSVC) Fit(instances base.FixedDataGrid) error {
m.m.Fit(instances)
return nil
}
// Predict issues predictions from the MultiLinearSVC. Each underlying LinearSVC is
// used to predict whether an instance takes on a class or some other class, and the
// model which definitively reports a given class is the one chosen. The result is
// undefined if all underlying models predict that the instance originates from some
// other class.
func (m *MultiLinearSVC) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
return m.m.Predict(from)
}