1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00
golearn/ensemble/multisvc.go
2017-09-10 20:35:34 +01:00

177 lines
4.8 KiB
Go

package ensemble
import (
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/linear_models"
"github.com/sjwhitworth/golearn/meta"
"fmt"
)
// MultiLinearSVC implements a multi-class Support Vector Classifier using a one-vs-all
// voting scheme. Only one CategoricalAttribute class is supported.
type MultiLinearSVC struct {
m *meta.OneVsAllModel
parameters *linear_models.LinearSVCParams
weights map[string]float64
}
// NewMultiLinearSVC creates a new MultiLinearSVC using the OneVsAllModel.
// The loss and penalty arguments can be "l1" or "l2". Typical values are
// "l1" for the loss and "l2" for the penalty. The dual parameter controls
// whether the system solves the dual or primal SVM form, true should be used
// in most cases. C is the penalty term, normally 1.0. eps is the convergence
// term, typically 1e-4.
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64, weights map[string]float64) *MultiLinearSVC {
// Set up the training parameters
params := &linear_models.LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
panic(err)
}
// Return me...
ret := &MultiLinearSVC{
parameters: params,
weights: weights,
}
ret.initializeOneVsAllModel()
return ret
}
func (m *MultiLinearSVC) initializeOneVsAllModel() {
// Classifier creation function
classifierFunc := func(cls string) base.Classifier {
var weightVec []float64
newParams := m.parameters.Copy()
if m.weights != nil {
weightVec = make([]float64, 2)
for i := range m.weights {
if i != cls {
weightVec[0] += m.weights[i]
} else {
weightVec[1] = m.weights[i]
}
}
}
newParams.ClassWeights = weightVec
ret, err := linear_models.NewLinearSVCFromParams(newParams)
if err != nil {
panic(err)
}
return ret
}
m.m = meta.NewOneVsAllModel(classifierFunc)
}
// Fit builds the MultiLinearSVC by building n (where n is the number of values
// the singular CategoricalAttribute can take) seperate one-vs-rest models.
func (m *MultiLinearSVC) Fit(instances base.FixedDataGrid) error {
m.m.Fit(instances)
return nil
}
// Predict issues predictions from the MultiLinearSVC. Each underlying LinearSVC is
// used to predict whether an instance takes on a class or some other class, and the
// model which definitively reports a given class is the one chosen. The result is
// undefined if all underlying models predict that the instance originates from some
// other class.
func (m *MultiLinearSVC) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
return m.m.Predict(from)
}
func (m *MultiLinearSVC) GetClassifierMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "MultiLinearSVC",
ClassifierVersion: "1",
ClassifierMetadata: nil,
}
}
func (m *MultiLinearSVC) Save(filePath string) error {
metadata := m.GetClassifierMetadata()
serializer, err := base.CreateSerializedClassifierStub(filePath, metadata)
if err != nil {
return err
}
err = m.SaveWithPrefix(serializer, "")
if err != nil {
return fmt.Errorf("Unable to Save(): %v", err)
}
serializer.Close()
return err
}
func (m *MultiLinearSVC) SaveWithPrefix(serializer *base.ClassifierSerializer, prefix string) error {
p := func(fName string) string {
return fmt.Sprintf("%s/%s", prefix, fName)
}
// Write out the linear parameters
err := serializer.WriteJSONForKey(p("params"), m.parameters)
if err != nil {
return fmt.Errorf("Unable to marshal parameters: %v", err)
}
// Write out the weights
err = serializer.WriteJSONForKey(p("weights"), m.weights)
if err != nil {
return fmt.Errorf("Unable to write weights: %v", err)
}
// Serialize the model
err = m.m.SaveWithPrefix(serializer, p("one-vs-all"))
return err
}
func (m *MultiLinearSVC) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "MultiLinearSVC",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}
}
func (m *MultiLinearSVC) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
err = m.LoadWithPrefix(reader, "")
if err != nil {
return err
}
return nil
}
func (m *MultiLinearSVC) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
p := func(fName string) string {
return fmt.Sprintf("%s/%s", prefix, fName)
}
err := reader.GetJSONForKey(p("params"), &m.parameters)
if err != nil {
return fmt.Errorf("Can't load parameters: %v", err)
}
err = reader.GetJSONForKey(p("weights"), &m.weights)
if err != nil {
return fmt.Errorf("Can't load parameters: %v", err)
}
m.initializeOneVsAllModel()
// Load the model
err = m.m.LoadWithPrefix(reader, p("one-vs-all"))
if err != nil {
return err
}
return nil
}