1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-05-03 22:17:14 +08:00
golearn/naive/bernoulli_nb_test.go

34 lines
948 B
Go
Raw Normal View History

package naive
import (
"math"
"github.com/sjwhitworth/golearn/base"
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestFit(t *testing.T) {
Convey("Given a simple training data", t, func() {
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
if err1 != nil {
t.Error(err1)
}
nb := NewBernoulliNBClassifier()
nb.Fit(trainingData)
Convey("All log(prior) should be correctly calculated", func() {
logPriorBlue := nb.logClassPrior["blue"]
logPriorRed := nb.logClassPrior["red"]
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
})
Convey("'red' conditional probabilities should be correct", func() {
})
Convey("'blue' conditional probabilities should be correct", func() {
})
})
}