mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00

This patch * Adds a one-vs-all meta classifier into meta/ * Adds a LinearSVC (essentially the same as LogisticRegression but with different libsvm parameters) to linear_models/ * Adds a MultiLinearSVC into ensemble/ for predicting CategoricalAttribute classes with the LinearSVC * Adds a new example dataset based on classifying article headlines. The example dataset is drawn from WikiNews, and consists of an average, min and max Word2Vec representation of article headlines from three categories. The Word2Vec model was computed offline using gensim.
70 lines
1.8 KiB
Go
70 lines
1.8 KiB
Go
package linear_models
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
type LogisticRegression struct {
|
|
param *Parameter
|
|
model *Model
|
|
}
|
|
|
|
func NewLogisticRegression(penalty string, C float64, eps float64) (*LogisticRegression, error) {
|
|
solver_type := 0
|
|
if penalty == "l2" {
|
|
solver_type = L2R_LR
|
|
} else if penalty == "l1" {
|
|
solver_type = L1R_LR
|
|
} else {
|
|
return nil, errors.New(fmt.Sprintf("Invalid penalty '%s'", penalty))
|
|
}
|
|
|
|
lr := LogisticRegression{}
|
|
lr.param = NewParameter(solver_type, C, eps)
|
|
lr.model = nil
|
|
return &lr, nil
|
|
}
|
|
|
|
func (lr *LogisticRegression) Fit(X base.FixedDataGrid) error {
|
|
problemVec := convertInstancesToProblemVec(X)
|
|
labelVec := convertInstancesToLabelVec(X)
|
|
prob := NewProblem(problemVec, labelVec, 0)
|
|
lr.model = Train(prob, lr.param)
|
|
return nil
|
|
}
|
|
|
|
func (lr *LogisticRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
|
|
|
// Only support 1 class Attribute
|
|
classAttrs := X.AllClassAttributes()
|
|
if len(classAttrs) != 1 {
|
|
panic(fmt.Sprintf("%d Wrong number of classes", len(classAttrs)))
|
|
}
|
|
// Generate return structure
|
|
ret := base.GeneratePredictionVector(X)
|
|
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
|
|
// Retrieve numeric non-class Attributes
|
|
numericAttrs := base.NonClassFloatAttributes(X)
|
|
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
|
|
|
// Allocate row storage
|
|
row := make([]float64, len(numericAttrSpecs))
|
|
X.MapOverRows(numericAttrSpecs, func(rowBytes [][]byte, rowNo int) (bool, error) {
|
|
for i, r := range rowBytes {
|
|
row[i] = base.UnpackBytesToFloat(r)
|
|
}
|
|
val := Predict(lr.model, row)
|
|
vals := base.PackFloatToBytes(val)
|
|
ret.Set(classAttrSpecs[0], rowNo, vals)
|
|
return true, nil
|
|
})
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (lr *LogisticRegression) String() string {
|
|
return "LogisticRegression"
|
|
}
|