2014-05-11 21:00:28 -03:00
|
|
|
package naive
|
|
|
|
|
|
|
|
import (
|
2014-05-18 23:23:51 -03:00
|
|
|
"math"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
2014-05-11 21:00:28 -03:00
|
|
|
"testing"
|
2014-05-18 23:23:51 -03:00
|
|
|
. "github.com/smartystreets/goconvey/convey"
|
2014-05-11 21:00:28 -03:00
|
|
|
)
|
|
|
|
|
2014-05-18 23:23:51 -03:00
|
|
|
func TestFit(t *testing.T) {
|
|
|
|
Convey("Given a simple training data", t, func() {
|
|
|
|
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
|
|
|
|
if err1 != nil {
|
|
|
|
t.Error(err1)
|
2014-05-11 21:00:28 -03:00
|
|
|
}
|
|
|
|
|
2014-05-18 23:23:51 -03:00
|
|
|
nb := NewBernoulliNBClassifier()
|
|
|
|
nb.Fit(trainingData)
|
2014-05-11 21:00:28 -03:00
|
|
|
|
2014-05-18 23:23:51 -03:00
|
|
|
Convey("All log(prior) should be correctly calculated", func() {
|
|
|
|
logPriorBlue := nb.logClassPrior["blue"]
|
|
|
|
logPriorRed := nb.logClassPrior["red"]
|
2014-05-11 21:00:28 -03:00
|
|
|
|
2014-05-18 23:23:51 -03:00
|
|
|
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
|
|
|
|
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
|
|
|
|
})
|
|
|
|
|
|
|
|
Convey("'red' conditional probabilities should be correct", func() {
|
|
|
|
})
|
|
|
|
Convey("'blue' conditional probabilities should be correct", func() {
|
|
|
|
})
|
|
|
|
})
|
2014-05-11 21:00:28 -03:00
|
|
|
}
|