mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Support for individual class weightings
This commit is contained in:
parent
056ccef9b6
commit
8fe06e7332
@ -89,6 +89,69 @@ func GetAttributeByName(inst FixedDataGrid, name string) Attribute {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClassDistributionByBinaryFloatValue returns the count of each row
|
||||
// which has a float value close to 0.0 or 1.0.
|
||||
func GetClassDistributionByBinaryFloatValue(inst FixedDataGrid) []int {
|
||||
|
||||
// Get the class variable
|
||||
attrs := inst.AllClassAttributes()
|
||||
if len(attrs) != 1 {
|
||||
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
|
||||
}
|
||||
if _, ok := attrs[0].(*FloatAttribute); !ok {
|
||||
panic(fmt.Errorf("Class Attribute must be FloatAttribute (is %s)", attrs[0]))
|
||||
}
|
||||
|
||||
// Get the number of class values
|
||||
ret := make([]int, 2)
|
||||
|
||||
// Map through everything
|
||||
specs := ResolveAttributes(inst, attrs)
|
||||
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
|
||||
index := UnpackBytesToFloat(vals[0])
|
||||
if index > 0.5 {
|
||||
ret[1]++
|
||||
} else {
|
||||
ret[0]++
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetClassDistributionByIntegerVal returns a vector containing
|
||||
// the count of each class vector (indexed by the class' system
|
||||
// integer representation)
|
||||
func GetClassDistributionByCategoricalValue(inst FixedDataGrid) []int {
|
||||
|
||||
var classAttr *CategoricalAttribute
|
||||
var ok bool
|
||||
// Get the class variable
|
||||
attrs := inst.AllClassAttributes()
|
||||
if len(attrs) != 1 {
|
||||
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
|
||||
}
|
||||
if classAttr, ok = attrs[0].(*CategoricalAttribute); !ok {
|
||||
panic(fmt.Errorf("Class Attribute must be a CategoricalAttribute (is %s)", attrs[0]))
|
||||
}
|
||||
|
||||
// Get the number of class values
|
||||
classLen := len(classAttr.GetValues())
|
||||
ret := make([]int, classLen)
|
||||
|
||||
// Map through everything
|
||||
specs := ResolveAttributes(inst, attrs)
|
||||
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
|
||||
index := UnpackBytesToU64(vals[0])
|
||||
ret[int(index)]++
|
||||
return false, nil
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetClassDistribution returns a map containing the count of each
|
||||
// class type (indexed by the class' string representation).
|
||||
func GetClassDistribution(inst FixedDataGrid) map[string]int {
|
||||
|
@ -18,14 +18,38 @@ type MultiLinearSVC struct {
|
||||
// whether the system solves the dual or primal SVM form, true should be used
|
||||
// in most cases. C is the penalty term, normally 1.0. eps is the convergence
|
||||
// term, typically 1e-4.
|
||||
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64) *MultiLinearSVC {
|
||||
classifierFunc := func() base.Classifier {
|
||||
ret, err := linear_models.NewLinearSVC(loss, penalty, dual, C, eps)
|
||||
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64, weights map[string]float64) *MultiLinearSVC {
|
||||
// Set up the training parameters
|
||||
params := &linear_models.LinearSVCParams{0, nil, C, eps, false, dual}
|
||||
err := params.SetKindFromStrings(loss, penalty)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Classifier creation function
|
||||
classifierFunc := func(cls string) base.Classifier {
|
||||
var weightVec []float64
|
||||
newParams := params.Copy()
|
||||
if weights != nil {
|
||||
weightVec = make([]float64, 2)
|
||||
for i := range weights {
|
||||
if i != cls {
|
||||
weightVec[0] += weights[i]
|
||||
} else {
|
||||
weightVec[1] = weights[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
newParams.ClassWeights = weightVec
|
||||
|
||||
ret, err := linear_models.NewLinearSVCFromParams(newParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Return me...
|
||||
return &MultiLinearSVC{
|
||||
meta.NewOneVsAllModel(classifierFunc),
|
||||
}
|
||||
|
@ -1,27 +1,49 @@
|
||||
package ensemble
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMultiSVM(t *testing.T) {
|
||||
func TestMultiSVMUnweighted(t *testing.T) {
|
||||
Convey("Loading data...", t, func() {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
|
||||
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4)
|
||||
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4, nil)
|
||||
m.Fit(X)
|
||||
|
||||
Convey("Predictions should work...", func() {
|
||||
predictions, err := m.Predict(Y)
|
||||
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
||||
So(err, ShouldEqual, nil)
|
||||
fmt.Println(evaluation.GetSummary(cf))
|
||||
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSVMWeighted(t *testing.T) {
|
||||
Convey("Loading data...", t, func() {
|
||||
weights := make(map[string]float64)
|
||||
weights["Finance"] = 0.1739
|
||||
weights["Tech"] = 0.0750
|
||||
weights["Politics"] = 0.4928
|
||||
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
|
||||
m := NewMultiLinearSVC("l1", "l2", true, 0.62, 1e-4, weights)
|
||||
m.Fit(X)
|
||||
|
||||
Convey("Predictions should work...", func() {
|
||||
predictions, err := m.Predict(Y)
|
||||
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -1,54 +1,186 @@
|
||||
package linear_models
|
||||
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type LinearSVC struct {
|
||||
param *Parameter
|
||||
model *Model
|
||||
// LinearSVCParams represnts all available LinearSVC options.
|
||||
//
|
||||
// SolverKind: can be linear_models.L2_L1LOSS_SVC_DUAL,
|
||||
// L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L1R_L2LOSS_SVC.
|
||||
// It must be set via SetKindFromStrings.
|
||||
//
|
||||
// ClassWeights describes how each class is weighted, and can
|
||||
// be used in class-imabalanced scenarios. If this is nil, then
|
||||
// all classes will be weighted the same unless WeightClassesAutomatically
|
||||
// is True.
|
||||
//
|
||||
// C is a float64 represnenting the misclassification penalty.
|
||||
//
|
||||
// Eps is a float64 convergence threshold.
|
||||
//
|
||||
// Dual indicates whether the solution is primary or dual.
|
||||
type LinearSVCParams struct {
|
||||
SolverType int
|
||||
ClassWeights []float64
|
||||
C float64
|
||||
Eps float64
|
||||
WeightClassesAutomatically bool
|
||||
Dual bool
|
||||
}
|
||||
|
||||
func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {
|
||||
solver_type := 0
|
||||
// Copy return s a copy of these parameters
|
||||
func (p *LinearSVCParams) Copy() *LinearSVCParams {
|
||||
ret := &LinearSVCParams{
|
||||
p.SolverType,
|
||||
nil,
|
||||
p.C,
|
||||
p.Eps,
|
||||
p.WeightClassesAutomatically,
|
||||
p.Dual,
|
||||
}
|
||||
if p.ClassWeights != nil {
|
||||
ret.ClassWeights = make([]float64, len(p.ClassWeights))
|
||||
copy(ret.ClassWeights, p.ClassWeights)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// SetKindFromStrings configures the solver kind from strings.
|
||||
// Penalty and Loss parameters can either be l1 or l2.
|
||||
func (p *LinearSVCParams) SetKindFromStrings(loss, penalty string) error {
|
||||
var ret error
|
||||
p.SolverType = 0
|
||||
// Loss validation
|
||||
if loss == "l1" {
|
||||
} else if loss == "l2" {
|
||||
} else {
|
||||
return fmt.Errorf("loss must be \"l1\" or \"l2\"")
|
||||
}
|
||||
|
||||
// Penalty validation
|
||||
if penalty == "l2" {
|
||||
if loss == "l1" {
|
||||
if dual {
|
||||
solver_type = L2R_L1LOSS_SVC_DUAL
|
||||
if !p.Dual {
|
||||
ret = fmt.Errorf("Important: changed to dual form")
|
||||
}
|
||||
p.SolverType = L2R_L1LOSS_SVC_DUAL
|
||||
p.Dual = true
|
||||
} else {
|
||||
if dual {
|
||||
solver_type = L2R_L2LOSS_SVC_DUAL
|
||||
if p.Dual {
|
||||
p.SolverType = L2R_L2LOSS_SVC_DUAL
|
||||
} else {
|
||||
solver_type = L2R_L2LOSS_SVC
|
||||
p.SolverType = L2R_L2LOSS_SVC
|
||||
}
|
||||
}
|
||||
} else if penalty == "l1" {
|
||||
if loss == "l2" {
|
||||
if !dual {
|
||||
solver_type = L1R_L2LOSS_SVC
|
||||
if p.Dual {
|
||||
ret = fmt.Errorf("Important: changed to primary form")
|
||||
}
|
||||
p.Dual = false
|
||||
p.SolverType = L1R_L2LOSS_SVC
|
||||
} else {
|
||||
return fmt.Errorf("Must have L2 loss with L1 penalty")
|
||||
}
|
||||
}
|
||||
if solver_type == 0 {
|
||||
panic("Parameter combination")
|
||||
} else {
|
||||
return fmt.Errorf("Penalty must be \"l1\" or \"l2\"")
|
||||
}
|
||||
|
||||
// Final validation
|
||||
if p.SolverType == 0 {
|
||||
return fmt.Errorf("Invalid parameter combination")
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// convertToNativeFormat converts the LinearSVCParams given into a format
|
||||
// for liblinear.
|
||||
func (p *LinearSVCParams) convertToNativeFormat() *Parameter {
|
||||
return NewParameter(p.SolverType, p.C, p.Eps)
|
||||
}
|
||||
|
||||
// LinearSVC represents a linear support-vector classifier.
|
||||
type LinearSVC struct {
|
||||
param *Parameter
|
||||
model *Model
|
||||
Param *LinearSVCParams
|
||||
}
|
||||
|
||||
// NewLinearSVC creates a new support classifier.
|
||||
//
|
||||
// loss and penalty: see LinearSVCParams#SetKindFromString
|
||||
//
|
||||
// dual: see LinearSVCParams
|
||||
//
|
||||
// eps: see LinearSVCParams
|
||||
//
|
||||
// C: see LinearSVCParams
|
||||
func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {
|
||||
|
||||
// Convert and check parameters
|
||||
params := &LinearSVCParams{0, nil, C, eps, false, dual}
|
||||
err := params.SetKindFromStrings(loss, penalty)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewLinearSVCFromParams(params)
|
||||
}
|
||||
|
||||
// NewLinearSVCFromParams constructs a LinearSVC from the given LinearSVCParams structure.
|
||||
func NewLinearSVCFromParams(params *LinearSVCParams) (*LinearSVC, error) {
|
||||
// Construct model
|
||||
lr := LinearSVC{}
|
||||
lr.param = NewParameter(solver_type, C, eps)
|
||||
lr.param = params.convertToNativeFormat()
|
||||
lr.Param = params
|
||||
lr.model = nil
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
// Fit automatically weights the class vector (if configured to do so)
|
||||
// converts the FixedDataGrid into the right format and trains the model.
|
||||
func (lr *LinearSVC) Fit(X base.FixedDataGrid) error {
|
||||
|
||||
var weightVec []float64
|
||||
var weightClasses []C.int
|
||||
|
||||
// Creates the class weighting
|
||||
if lr.Param.ClassWeights == nil {
|
||||
if lr.Param.WeightClassesAutomatically {
|
||||
weightVec = generateClassWeightVectorFromDist(X)
|
||||
} else {
|
||||
weightVec = generateClassWeightVectorFromFixed(X)
|
||||
}
|
||||
} else {
|
||||
weightVec = lr.Param.ClassWeights
|
||||
}
|
||||
|
||||
weightClasses = make([]C.int, len(weightVec))
|
||||
for i := range weightVec {
|
||||
weightClasses[i] = C.int(i)
|
||||
}
|
||||
|
||||
// Convert the problem
|
||||
problemVec := convertInstancesToProblemVec(X)
|
||||
labelVec := convertInstancesToLabelVec(X)
|
||||
|
||||
// Train
|
||||
prob := NewProblem(problemVec, labelVec, 0)
|
||||
lr.param.c_param.nr_weight = C.int(len(weightVec))
|
||||
lr.param.c_param.weight_label = &(weightClasses[0])
|
||||
lr.param.c_param.weight = (*C.double)(unsafe.Pointer(&weightVec[0]))
|
||||
|
||||
// lr.param.weights = (*C.double)unsafe.Pointer(&(weightVec[0]));
|
||||
lr.model = Train(prob, lr.param)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Predict issues predictions from a trained LinearSVC.
|
||||
func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
|
||||
// Only support 1 class Attribute
|
||||
@ -59,6 +191,7 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
// Generate return structure
|
||||
ret := base.GeneratePredictionVector(X)
|
||||
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
|
||||
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
@ -78,6 +211,7 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// String return a humaan-readable version.
|
||||
func (lr *LinearSVC) String() string {
|
||||
return "LogisticSVC"
|
||||
}
|
||||
|
@ -5,6 +5,35 @@ import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
func generateClassWeightVectorFromDist(X base.FixedDataGrid) []float64 {
|
||||
classDist := base.GetClassDistributionByBinaryFloatValue(X)
|
||||
ret := make([]float64, len(classDist))
|
||||
for i, c := range classDist {
|
||||
if c == 0 {
|
||||
ret[i] = 1.0
|
||||
} else {
|
||||
ret[i] = 1.0 / float64(c)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateClassWeightVectorFromFixed(X base.FixedDataGrid) []float64 {
|
||||
classAttrs := X.AllClassAttributes()
|
||||
if len(classAttrs) != 1 {
|
||||
panic("Wrong number of class Attributes")
|
||||
}
|
||||
if _, ok := classAttrs[0].(*base.FloatAttribute); ok {
|
||||
ret := make([]float64, 2)
|
||||
for i := range ret {
|
||||
ret[i] = 1.0
|
||||
}
|
||||
return ret
|
||||
} else {
|
||||
panic("Must be a FloatAttribute")
|
||||
}
|
||||
}
|
||||
|
||||
func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
||||
// Allocate problem array
|
||||
_, rows := X.Size()
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
// by whichever is most confident. Only one CategoricalAttribute
|
||||
// class variable is supported.
|
||||
type OneVsAllModel struct {
|
||||
NewClassifierFunction func() base.Classifier
|
||||
NewClassifierFunction func(string) base.Classifier
|
||||
filters []*oneVsAllFilter
|
||||
classifiers []base.Classifier
|
||||
maxClassVal uint64
|
||||
@ -18,7 +18,7 @@ type OneVsAllModel struct {
|
||||
|
||||
// NewOneVsAllModel creates a new OneVsAllModel. The argument
|
||||
// must be a function which returns a base.Classifier ready for training.
|
||||
func NewOneVsAllModel(f func() base.Classifier) *OneVsAllModel {
|
||||
func NewOneVsAllModel(f func(string) base.Classifier) *OneVsAllModel {
|
||||
return &OneVsAllModel{
|
||||
f,
|
||||
nil,
|
||||
@ -64,7 +64,8 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
||||
|
||||
// Find the highest stored value
|
||||
val := uint64(0)
|
||||
for _, s := range classAttr.GetValues() {
|
||||
classVals := classAttr.GetValues()
|
||||
for _, s := range classVals {
|
||||
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
|
||||
if cur > val {
|
||||
val = cur
|
||||
@ -85,7 +86,7 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
||||
i,
|
||||
}
|
||||
filters[i] = f
|
||||
classifiers[i] = m.NewClassifierFunction()
|
||||
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
|
||||
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
|
||||
func TestOneVsAllModel(t *testing.T) {
|
||||
|
||||
classifierFunc := func() base.Classifier {
|
||||
classifierFunc := func(c string) base.Classifier {
|
||||
m, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
Loading…
x
Reference in New Issue
Block a user