diff --git a/naive/bernoulli_nb.go b/naive/bernoulli_nb.go index 83ad6c9..b1f24af 100644 --- a/naive/bernoulli_nb.go +++ b/naive/bernoulli_nb.go @@ -38,8 +38,6 @@ import ( // http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html type BernoulliNBClassifier struct { base.BaseEstimator - // Logarithm of each class prior - logClassPrior map[string]float64 // Conditional probability for each term. This vector should be // accessed in the following way: p(f|c) = condProb[c][f]. // Logarithm is used in order to avoid underflow. @@ -47,6 +45,8 @@ type BernoulliNBClassifier struct { // Number of instances in each class. This is necessary in order to // calculate the laplace smooth value during the Predict step. classInstances map[string]int + // Number of instances used in training. + trainingInstances int // Number of features in the training set features int } @@ -56,8 +56,8 @@ type BernoulliNBClassifier struct { func NewBernoulliNBClassifier() *BernoulliNBClassifier { nb := BernoulliNBClassifier{} nb.condProb = make(map[string][]float64) - nb.logClassPrior = make(map[string]float64) nb.features = 0 + nb.trainingInstances = 0 return &nb } @@ -65,7 +65,8 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier { // necessary for calculating prior probability and p(f_i) func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { - // Number of features in this training set + // Number of features and instances in this training set + nb.trainingInstances = X.Rows nb.features = 0 if X.Rows > 0 { nb.features = len(X.GetRowVectorWithoutClass(0)) @@ -86,8 +87,8 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { // increment number of instances in class t, ok := nb.classInstances[class] - if !ok { t = 0 } - nb.classInstances[class] = t + 1 + if !ok { t = 0 } + nb.classInstances[class] = t + 1 for feat := 0; feat < len(docVector); feat++ { @@ -110,7 +111,6 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { // Pre-calculate conditional probabilities for each class for c, _ := range nb.classInstances { - nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows)) nb.condProb[c] = make([]float64, nb.features) for feat := 0; feat < nb.features; feat++ { classTerms, _ := docsContainingTerm[c] @@ -146,8 +146,9 @@ func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string { bestScore := -math.MaxFloat64 bestClass := "" - for class, prior := range nb.logClassPrior { - classScore := prior + for class, classCount := range nb.classInstances { + // Init classScore with log(prior) + classScore := math.Log((float64(classCount))/float64(nb.trainingInstances)) for f := 0; f < nb.features; f++ { if vector[f] > 0 { // Test document has feature c diff --git a/naive/bernoulli_nb_test.go b/naive/bernoulli_nb_test.go index e99df1f..bdd3702 100644 --- a/naive/bernoulli_nb_test.go +++ b/naive/bernoulli_nb_test.go @@ -1,7 +1,6 @@ package naive import ( - "math" "github.com/sjwhitworth/golearn/base" "testing" . "github.com/smartystreets/goconvey/convey" @@ -29,12 +28,10 @@ func TestSimple(t *testing.T) { nb.Fit(trainingData) Convey("Check if Fit is working as expected", func() { - Convey("All log(prior) should be correctly calculated", func() { - logPriorBlue := nb.logClassPrior["blue"] - logPriorRed := nb.logClassPrior["red"] - - So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5)) - So(logPriorRed, ShouldAlmostEqual, math.Log(0.5)) + Convey("All data needed for prior should be correctly calculated", func() { + So(nb.classInstances["blue"], ShouldEqual, 2) + So(nb.classInstances["red"], ShouldEqual, 2) + So(nb.trainingInstances, ShouldEqual, 4) }) Convey("'red' conditional probabilities should be correct", func() {