mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00

This patch * Adds a one-vs-all meta classifier into meta/ * Adds a LinearSVC (essentially the same as LogisticRegression but with different libsvm parameters) to linear_models/ * Adds a MultiLinearSVC into ensemble/ for predicting CategoricalAttribute classes with the LinearSVC * Adds a new example dataset based on classifying article headlines. The example dataset is drawn from WikiNews, and consists of an average, min and max Word2Vec representation of article headlines from three categories. The Word2Vec model was computed offline using gensim.
28 lines
701 B
Go
28 lines
701 B
Go
package ensemble
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"testing"
|
|
)
|
|
|
|
func TestMultiSVM(t *testing.T) {
|
|
Convey("Loading data...", t, func() {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
|
|
So(err, ShouldBeNil)
|
|
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
|
|
|
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4)
|
|
m.Fit(X)
|
|
|
|
Convey("Predictions should work...", func() {
|
|
predictions, err := m.Predict(Y)
|
|
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
|
So(err, ShouldEqual, nil)
|
|
fmt.Println(evaluation.GetSummary(cf))
|
|
})
|
|
})
|
|
}
|