1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Support for individual class weightings

This commit is contained in:
Richard Townsend 2014-10-25 15:14:13 +01:00
parent 056ccef9b6
commit 8fe06e7332
7 changed files with 301 additions and 28 deletions

View File

@ -89,6 +89,69 @@ func GetAttributeByName(inst FixedDataGrid, name string) Attribute {
return nil
}
// GetClassDistributionByBinaryFloatValue returns the count of each row
// which has a float value close to 0.0 or 1.0.
func GetClassDistributionByBinaryFloatValue(inst FixedDataGrid) []int {
// Get the class variable
attrs := inst.AllClassAttributes()
if len(attrs) != 1 {
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
}
if _, ok := attrs[0].(*FloatAttribute); !ok {
panic(fmt.Errorf("Class Attribute must be FloatAttribute (is %s)", attrs[0]))
}
// Get the number of class values
ret := make([]int, 2)
// Map through everything
specs := ResolveAttributes(inst, attrs)
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
index := UnpackBytesToFloat(vals[0])
if index > 0.5 {
ret[1]++
} else {
ret[0]++
}
return false, nil
})
return ret
}
// GetClassDistributionByIntegerVal returns a vector containing
// the count of each class vector (indexed by the class' system
// integer representation)
func GetClassDistributionByCategoricalValue(inst FixedDataGrid) []int {
var classAttr *CategoricalAttribute
var ok bool
// Get the class variable
attrs := inst.AllClassAttributes()
if len(attrs) != 1 {
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
}
if classAttr, ok = attrs[0].(*CategoricalAttribute); !ok {
panic(fmt.Errorf("Class Attribute must be a CategoricalAttribute (is %s)", attrs[0]))
}
// Get the number of class values
classLen := len(classAttr.GetValues())
ret := make([]int, classLen)
// Map through everything
specs := ResolveAttributes(inst, attrs)
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
index := UnpackBytesToU64(vals[0])
ret[int(index)]++
return false, nil
})
return ret
}
// GetClassDistribution returns a map containing the count of each
// class type (indexed by the class' string representation).
func GetClassDistribution(inst FixedDataGrid) map[string]int {

View File

@ -18,14 +18,38 @@ type MultiLinearSVC struct {
// whether the system solves the dual or primal SVM form, true should be used
// in most cases. C is the penalty term, normally 1.0. eps is the convergence
// term, typically 1e-4.
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64) *MultiLinearSVC {
classifierFunc := func() base.Classifier {
ret, err := linear_models.NewLinearSVC(loss, penalty, dual, C, eps)
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64, weights map[string]float64) *MultiLinearSVC {
// Set up the training parameters
params := &linear_models.LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
panic(err)
}
// Classifier creation function
classifierFunc := func(cls string) base.Classifier {
var weightVec []float64
newParams := params.Copy()
if weights != nil {
weightVec = make([]float64, 2)
for i := range weights {
if i != cls {
weightVec[0] += weights[i]
} else {
weightVec[1] = weights[i]
}
}
}
newParams.ClassWeights = weightVec
ret, err := linear_models.NewLinearSVCFromParams(newParams)
if err != nil {
panic(err)
}
return ret
}
// Return me...
return &MultiLinearSVC{
meta.NewOneVsAllModel(classifierFunc),
}

View File

@ -1,27 +1,49 @@
package ensemble
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestMultiSVM(t *testing.T) {
func TestMultiSVMUnweighted(t *testing.T) {
Convey("Loading data...", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
So(err, ShouldBeNil)
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4)
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4, nil)
m.Fit(X)
Convey("Predictions should work...", func() {
predictions, err := m.Predict(Y)
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
So(err, ShouldEqual, nil)
fmt.Println(evaluation.GetSummary(cf))
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
})
})
}
func TestMultiSVMWeighted(t *testing.T) {
Convey("Loading data...", t, func() {
weights := make(map[string]float64)
weights["Finance"] = 0.1739
weights["Tech"] = 0.0750
weights["Politics"] = 0.4928
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
So(err, ShouldBeNil)
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
m := NewMultiLinearSVC("l1", "l2", true, 0.62, 1e-4, weights)
m.Fit(X)
Convey("Predictions should work...", func() {
predictions, err := m.Predict(Y)
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
So(err, ShouldEqual, nil)
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
})
})
}

View File

@ -1,54 +1,186 @@
package linear_models
import "C"
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"unsafe"
)
type LinearSVC struct {
param *Parameter
model *Model
// LinearSVCParams represnts all available LinearSVC options.
//
// SolverKind: can be linear_models.L2_L1LOSS_SVC_DUAL,
// L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L1R_L2LOSS_SVC.
// It must be set via SetKindFromStrings.
//
// ClassWeights describes how each class is weighted, and can
// be used in class-imabalanced scenarios. If this is nil, then
// all classes will be weighted the same unless WeightClassesAutomatically
// is True.
//
// C is a float64 represnenting the misclassification penalty.
//
// Eps is a float64 convergence threshold.
//
// Dual indicates whether the solution is primary or dual.
type LinearSVCParams struct {
SolverType int
ClassWeights []float64
C float64
Eps float64
WeightClassesAutomatically bool
Dual bool
}
func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {
solver_type := 0
// Copy return s a copy of these parameters
func (p *LinearSVCParams) Copy() *LinearSVCParams {
ret := &LinearSVCParams{
p.SolverType,
nil,
p.C,
p.Eps,
p.WeightClassesAutomatically,
p.Dual,
}
if p.ClassWeights != nil {
ret.ClassWeights = make([]float64, len(p.ClassWeights))
copy(ret.ClassWeights, p.ClassWeights)
}
return ret
}
// SetKindFromStrings configures the solver kind from strings.
// Penalty and Loss parameters can either be l1 or l2.
func (p *LinearSVCParams) SetKindFromStrings(loss, penalty string) error {
var ret error
p.SolverType = 0
// Loss validation
if loss == "l1" {
} else if loss == "l2" {
} else {
return fmt.Errorf("loss must be \"l1\" or \"l2\"")
}
// Penalty validation
if penalty == "l2" {
if loss == "l1" {
if dual {
solver_type = L2R_L1LOSS_SVC_DUAL
if !p.Dual {
ret = fmt.Errorf("Important: changed to dual form")
}
p.SolverType = L2R_L1LOSS_SVC_DUAL
p.Dual = true
} else {
if dual {
solver_type = L2R_L2LOSS_SVC_DUAL
if p.Dual {
p.SolverType = L2R_L2LOSS_SVC_DUAL
} else {
solver_type = L2R_L2LOSS_SVC
p.SolverType = L2R_L2LOSS_SVC
}
}
} else if penalty == "l1" {
if loss == "l2" {
if !dual {
solver_type = L1R_L2LOSS_SVC
if p.Dual {
ret = fmt.Errorf("Important: changed to primary form")
}
p.Dual = false
p.SolverType = L1R_L2LOSS_SVC
} else {
return fmt.Errorf("Must have L2 loss with L1 penalty")
}
}
if solver_type == 0 {
panic("Parameter combination")
} else {
return fmt.Errorf("Penalty must be \"l1\" or \"l2\"")
}
// Final validation
if p.SolverType == 0 {
return fmt.Errorf("Invalid parameter combination")
}
return ret
}
// convertToNativeFormat converts the LinearSVCParams given into a format
// for liblinear.
func (p *LinearSVCParams) convertToNativeFormat() *Parameter {
return NewParameter(p.SolverType, p.C, p.Eps)
}
// LinearSVC represents a linear support-vector classifier.
type LinearSVC struct {
param *Parameter
model *Model
Param *LinearSVCParams
}
// NewLinearSVC creates a new support classifier.
//
// loss and penalty: see LinearSVCParams#SetKindFromString
//
// dual: see LinearSVCParams
//
// eps: see LinearSVCParams
//
// C: see LinearSVCParams
func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {
// Convert and check parameters
params := &LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
return nil, err
}
return NewLinearSVCFromParams(params)
}
// NewLinearSVCFromParams constructs a LinearSVC from the given LinearSVCParams structure.
func NewLinearSVCFromParams(params *LinearSVCParams) (*LinearSVC, error) {
// Construct model
lr := LinearSVC{}
lr.param = NewParameter(solver_type, C, eps)
lr.param = params.convertToNativeFormat()
lr.Param = params
lr.model = nil
return &lr, nil
}
// Fit automatically weights the class vector (if configured to do so)
// converts the FixedDataGrid into the right format and trains the model.
func (lr *LinearSVC) Fit(X base.FixedDataGrid) error {
var weightVec []float64
var weightClasses []C.int
// Creates the class weighting
if lr.Param.ClassWeights == nil {
if lr.Param.WeightClassesAutomatically {
weightVec = generateClassWeightVectorFromDist(X)
} else {
weightVec = generateClassWeightVectorFromFixed(X)
}
} else {
weightVec = lr.Param.ClassWeights
}
weightClasses = make([]C.int, len(weightVec))
for i := range weightVec {
weightClasses[i] = C.int(i)
}
// Convert the problem
problemVec := convertInstancesToProblemVec(X)
labelVec := convertInstancesToLabelVec(X)
// Train
prob := NewProblem(problemVec, labelVec, 0)
lr.param.c_param.nr_weight = C.int(len(weightVec))
lr.param.c_param.weight_label = &(weightClasses[0])
lr.param.c_param.weight = (*C.double)(unsafe.Pointer(&weightVec[0]))
// lr.param.weights = (*C.double)unsafe.Pointer(&(weightVec[0]));
lr.model = Train(prob, lr.param)
return nil
}
// Predict issues predictions from a trained LinearSVC.
func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
// Only support 1 class Attribute
@ -59,6 +191,7 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
// Generate return structure
ret := base.GeneratePredictionVector(X)
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
// Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(X)
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
@ -78,6 +211,7 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
return ret, nil
}
// String return a humaan-readable version.
func (lr *LinearSVC) String() string {
return "LogisticSVC"
}

View File

@ -5,6 +5,35 @@ import (
"github.com/sjwhitworth/golearn/base"
)
func generateClassWeightVectorFromDist(X base.FixedDataGrid) []float64 {
classDist := base.GetClassDistributionByBinaryFloatValue(X)
ret := make([]float64, len(classDist))
for i, c := range classDist {
if c == 0 {
ret[i] = 1.0
} else {
ret[i] = 1.0 / float64(c)
}
}
return ret
}
func generateClassWeightVectorFromFixed(X base.FixedDataGrid) []float64 {
classAttrs := X.AllClassAttributes()
if len(classAttrs) != 1 {
panic("Wrong number of class Attributes")
}
if _, ok := classAttrs[0].(*base.FloatAttribute); ok {
ret := make([]float64, 2)
for i := range ret {
ret[i] = 1.0
}
return ret
} else {
panic("Must be a FloatAttribute")
}
}
func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
// Allocate problem array
_, rows := X.Size()

View File

@ -10,7 +10,7 @@ import (
// by whichever is most confident. Only one CategoricalAttribute
// class variable is supported.
type OneVsAllModel struct {
NewClassifierFunction func() base.Classifier
NewClassifierFunction func(string) base.Classifier
filters []*oneVsAllFilter
classifiers []base.Classifier
maxClassVal uint64
@ -18,7 +18,7 @@ type OneVsAllModel struct {
// NewOneVsAllModel creates a new OneVsAllModel. The argument
// must be a function which returns a base.Classifier ready for training.
func NewOneVsAllModel(f func() base.Classifier) *OneVsAllModel {
func NewOneVsAllModel(f func(string) base.Classifier) *OneVsAllModel {
return &OneVsAllModel{
f,
nil,
@ -64,7 +64,8 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
// Find the highest stored value
val := uint64(0)
for _, s := range classAttr.GetValues() {
classVals := classAttr.GetValues()
for _, s := range classVals {
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
if cur > val {
val = cur
@ -85,7 +86,7 @@ func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
i,
}
filters[i] = f
classifiers[i] = m.NewClassifierFunction()
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
}

View File

@ -11,7 +11,7 @@ import (
func TestOneVsAllModel(t *testing.T) {
classifierFunc := func() base.Classifier {
classifierFunc := func(c string) base.Classifier {
m, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4)
if err != nil {
panic(err)