2014-07-02 15:48:35 +01:00
|
|
|
package linear_models
|
2014-05-06 12:55:58 +08:00
|
|
|
|
|
|
|
import (
|
2020-09-06 10:01:07 +01:00
|
|
|
"testing"
|
|
|
|
|
2014-07-02 15:48:35 +01:00
|
|
|
"github.com/sjwhitworth/golearn/base"
|
2014-05-06 12:55:58 +08:00
|
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestLogisticRegression(t *testing.T) {
|
|
|
|
Convey("Given labels, a classifier and data", t, func() {
|
2014-08-02 16:22:15 +01:00
|
|
|
// Load data
|
2014-07-02 15:48:35 +01:00
|
|
|
X, err := base.ParseCSVToInstances("train.csv", false)
|
|
|
|
So(err, ShouldEqual, nil)
|
|
|
|
Y, err := base.ParseCSVToInstances("test.csv", false)
|
|
|
|
So(err, ShouldEqual, nil)
|
2014-08-02 16:22:15 +01:00
|
|
|
|
|
|
|
// Setup the problem
|
2014-08-22 09:18:01 +00:00
|
|
|
lr, err := NewLogisticRegression("l2", 1.0, 1e-6)
|
2014-08-22 13:16:11 +00:00
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
2014-07-02 15:48:35 +01:00
|
|
|
lr.Fit(X)
|
2014-05-06 12:55:58 +08:00
|
|
|
|
|
|
|
Convey("When predicting the label of first vector", func() {
|
2014-10-04 17:57:39 +01:00
|
|
|
Z, err := lr.Predict(Y)
|
|
|
|
So(err, ShouldEqual, nil)
|
2014-05-06 12:55:58 +08:00
|
|
|
Convey("The result should be 1", func() {
|
2020-09-06 10:01:07 +01:00
|
|
|
So(Z.RowString(0), ShouldEqual, "1")
|
2014-05-06 12:55:58 +08:00
|
|
|
})
|
|
|
|
})
|
|
|
|
Convey("When predicting the label of second vector", func() {
|
2014-10-04 17:57:39 +01:00
|
|
|
Z, err := lr.Predict(Y)
|
|
|
|
So(err, ShouldEqual, nil)
|
2014-05-06 12:55:58 +08:00
|
|
|
Convey("The result should be -1", func() {
|
2020-09-06 10:01:07 +01:00
|
|
|
So(Z.RowString(1), ShouldEqual, "0")
|
2014-05-06 12:55:58 +08:00
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|