1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-30 13:48:57 +08:00

Removed class prior pre-calculation

Since the number of instances in each class are stored, there is no need
to keep the pre-calculated priors.
This commit is contained in:
Thiago Cardoso 2014-06-08 00:01:42 -03:00
parent 0cf6d258e6
commit 94ef107fc8
2 changed files with 14 additions and 16 deletions

View File

@ -38,8 +38,6 @@ import (
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html // http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
type BernoulliNBClassifier struct { type BernoulliNBClassifier struct {
base.BaseEstimator base.BaseEstimator
// Logarithm of each class prior
logClassPrior map[string]float64
// Conditional probability for each term. This vector should be // Conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = condProb[c][f]. // accessed in the following way: p(f|c) = condProb[c][f].
// Logarithm is used in order to avoid underflow. // Logarithm is used in order to avoid underflow.
@ -47,6 +45,8 @@ type BernoulliNBClassifier struct {
// Number of instances in each class. This is necessary in order to // Number of instances in each class. This is necessary in order to
// calculate the laplace smooth value during the Predict step. // calculate the laplace smooth value during the Predict step.
classInstances map[string]int classInstances map[string]int
// Number of instances used in training.
trainingInstances int
// Number of features in the training set // Number of features in the training set
features int features int
} }
@ -56,8 +56,8 @@ type BernoulliNBClassifier struct {
func NewBernoulliNBClassifier() *BernoulliNBClassifier { func NewBernoulliNBClassifier() *BernoulliNBClassifier {
nb := BernoulliNBClassifier{} nb := BernoulliNBClassifier{}
nb.condProb = make(map[string][]float64) nb.condProb = make(map[string][]float64)
nb.logClassPrior = make(map[string]float64)
nb.features = 0 nb.features = 0
nb.trainingInstances = 0
return &nb return &nb
} }
@ -65,7 +65,8 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
// necessary for calculating prior probability and p(f_i) // necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// Number of features in this training set // Number of features and instances in this training set
nb.trainingInstances = X.Rows
nb.features = 0 nb.features = 0
if X.Rows > 0 { if X.Rows > 0 {
nb.features = len(X.GetRowVectorWithoutClass(0)) nb.features = len(X.GetRowVectorWithoutClass(0))
@ -110,7 +111,6 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// Pre-calculate conditional probabilities for each class // Pre-calculate conditional probabilities for each class
for c, _ := range nb.classInstances { for c, _ := range nb.classInstances {
nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows))
nb.condProb[c] = make([]float64, nb.features) nb.condProb[c] = make([]float64, nb.features)
for feat := 0; feat < nb.features; feat++ { for feat := 0; feat < nb.features; feat++ {
classTerms, _ := docsContainingTerm[c] classTerms, _ := docsContainingTerm[c]
@ -146,8 +146,9 @@ func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
bestScore := -math.MaxFloat64 bestScore := -math.MaxFloat64
bestClass := "" bestClass := ""
for class, prior := range nb.logClassPrior { for class, classCount := range nb.classInstances {
classScore := prior // Init classScore with log(prior)
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances))
for f := 0; f < nb.features; f++ { for f := 0; f < nb.features; f++ {
if vector[f] > 0 { if vector[f] > 0 {
// Test document has feature c // Test document has feature c

View File

@ -1,7 +1,6 @@
package naive package naive
import ( import (
"math"
"github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/base"
"testing" "testing"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
@ -29,12 +28,10 @@ func TestSimple(t *testing.T) {
nb.Fit(trainingData) nb.Fit(trainingData)
Convey("Check if Fit is working as expected", func() { Convey("Check if Fit is working as expected", func() {
Convey("All log(prior) should be correctly calculated", func() { Convey("All data needed for prior should be correctly calculated", func() {
logPriorBlue := nb.logClassPrior["blue"] So(nb.classInstances["blue"], ShouldEqual, 2)
logPriorRed := nb.logClassPrior["red"] So(nb.classInstances["red"], ShouldEqual, 2)
So(nb.trainingInstances, ShouldEqual, 4)
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
}) })
Convey("'red' conditional probabilities should be correct", func() { Convey("'red' conditional probabilities should be correct", func() {