mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-30 13:48:57 +08:00

This patch * Adds a one-vs-all meta classifier into meta/ * Adds a LinearSVC (essentially the same as LogisticRegression but with different libsvm parameters) to linear_models/ * Adds a MultiLinearSVC into ensemble/ for predicting CategoricalAttribute classes with the LinearSVC * Adds a new example dataset based on classifying article headlines. The example dataset is drawn from WikiNews, and consists of an average, min and max Word2Vec representation of article headlines from three categories. The Word2Vec model was computed offline using gensim.
39 lines
978 B
Go
39 lines
978 B
Go
package linear_models
|
|
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"testing"
|
|
)
|
|
|
|
func TestLogisticRegression(t *testing.T) {
|
|
Convey("Given labels, a classifier and data", t, func() {
|
|
// Load data
|
|
X, err := base.ParseCSVToInstances("train.csv", false)
|
|
So(err, ShouldEqual, nil)
|
|
Y, err := base.ParseCSVToInstances("test.csv", false)
|
|
So(err, ShouldEqual, nil)
|
|
|
|
// Setup the problem
|
|
lr, err := NewLogisticRegression("l2", 1.0, 1e-6)
|
|
So(err, ShouldBeNil)
|
|
|
|
lr.Fit(X)
|
|
|
|
Convey("When predicting the label of first vector", func() {
|
|
Z, err := lr.Predict(Y)
|
|
So(err, ShouldEqual, nil)
|
|
Convey("The result should be 1", func() {
|
|
So(Z.RowString(0), ShouldEqual, "1.00")
|
|
})
|
|
})
|
|
Convey("When predicting the label of second vector", func() {
|
|
Z, err := lr.Predict(Y)
|
|
So(err, ShouldEqual, nil)
|
|
Convey("The result should be -1", func() {
|
|
So(Z.RowString(1), ShouldEqual, "-1.00")
|
|
})
|
|
})
|
|
})
|
|
}
|