1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

Removed class prior pre-calculation

Since the number of instances in each class are stored, there is no need
to keep the pre-calculated priors.
This commit is contained in:
Thiago Cardoso 2014-06-08 00:01:42 -03:00
parent 0cf6d258e6
commit 94ef107fc8
2 changed files with 14 additions and 16 deletions

View File

@ -38,8 +38,6 @@ import (
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
type BernoulliNBClassifier struct {
base.BaseEstimator
// Logarithm of each class prior
logClassPrior map[string]float64
// Conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = condProb[c][f].
// Logarithm is used in order to avoid underflow.
@ -47,6 +45,8 @@ type BernoulliNBClassifier struct {
// Number of instances in each class. This is necessary in order to
// calculate the laplace smooth value during the Predict step.
classInstances map[string]int
// Number of instances used in training.
trainingInstances int
// Number of features in the training set
features int
}
@ -56,8 +56,8 @@ type BernoulliNBClassifier struct {
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
nb := BernoulliNBClassifier{}
nb.condProb = make(map[string][]float64)
nb.logClassPrior = make(map[string]float64)
nb.features = 0
nb.trainingInstances = 0
return &nb
}
@ -65,7 +65,8 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// Number of features in this training set
// Number of features and instances in this training set
nb.trainingInstances = X.Rows
nb.features = 0
if X.Rows > 0 {
nb.features = len(X.GetRowVectorWithoutClass(0))
@ -110,7 +111,6 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// Pre-calculate conditional probabilities for each class
for c, _ := range nb.classInstances {
nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows))
nb.condProb[c] = make([]float64, nb.features)
for feat := 0; feat < nb.features; feat++ {
classTerms, _ := docsContainingTerm[c]
@ -146,8 +146,9 @@ func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
bestScore := -math.MaxFloat64
bestClass := ""
for class, prior := range nb.logClassPrior {
classScore := prior
for class, classCount := range nb.classInstances {
// Init classScore with log(prior)
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances))
for f := 0; f < nb.features; f++ {
if vector[f] > 0 {
// Test document has feature c

View File

@ -1,7 +1,6 @@
package naive
import (
"math"
"github.com/sjwhitworth/golearn/base"
"testing"
. "github.com/smartystreets/goconvey/convey"
@ -29,12 +28,10 @@ func TestSimple(t *testing.T) {
nb.Fit(trainingData)
Convey("Check if Fit is working as expected", func() {
Convey("All log(prior) should be correctly calculated", func() {
logPriorBlue := nb.logClassPrior["blue"]
logPriorRed := nb.logClassPrior["red"]
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
Convey("All data needed for prior should be correctly calculated", func() {
So(nb.classInstances["blue"], ShouldEqual, 2)
So(nb.classInstances["red"], ShouldEqual, 2)
So(nb.trainingInstances, ShouldEqual, 4)
})
Convey("'red' conditional probabilities should be correct", func() {