mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Removed class prior pre-calculation
Since the number of instances in each class are stored, there is no need to keep the pre-calculated priors.
This commit is contained in:
parent
0cf6d258e6
commit
94ef107fc8
@ -38,8 +38,6 @@ import (
|
||||
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
||||
type BernoulliNBClassifier struct {
|
||||
base.BaseEstimator
|
||||
// Logarithm of each class prior
|
||||
logClassPrior map[string]float64
|
||||
// Conditional probability for each term. This vector should be
|
||||
// accessed in the following way: p(f|c) = condProb[c][f].
|
||||
// Logarithm is used in order to avoid underflow.
|
||||
@ -47,6 +45,8 @@ type BernoulliNBClassifier struct {
|
||||
// Number of instances in each class. This is necessary in order to
|
||||
// calculate the laplace smooth value during the Predict step.
|
||||
classInstances map[string]int
|
||||
// Number of instances used in training.
|
||||
trainingInstances int
|
||||
// Number of features in the training set
|
||||
features int
|
||||
}
|
||||
@ -56,8 +56,8 @@ type BernoulliNBClassifier struct {
|
||||
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
||||
nb := BernoulliNBClassifier{}
|
||||
nb.condProb = make(map[string][]float64)
|
||||
nb.logClassPrior = make(map[string]float64)
|
||||
nb.features = 0
|
||||
nb.trainingInstances = 0
|
||||
return &nb
|
||||
}
|
||||
|
||||
@ -65,7 +65,8 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
||||
// necessary for calculating prior probability and p(f_i)
|
||||
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
|
||||
// Number of features in this training set
|
||||
// Number of features and instances in this training set
|
||||
nb.trainingInstances = X.Rows
|
||||
nb.features = 0
|
||||
if X.Rows > 0 {
|
||||
nb.features = len(X.GetRowVectorWithoutClass(0))
|
||||
@ -110,7 +111,6 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
|
||||
// Pre-calculate conditional probabilities for each class
|
||||
for c, _ := range nb.classInstances {
|
||||
nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows))
|
||||
nb.condProb[c] = make([]float64, nb.features)
|
||||
for feat := 0; feat < nb.features; feat++ {
|
||||
classTerms, _ := docsContainingTerm[c]
|
||||
@ -146,8 +146,9 @@ func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
|
||||
bestScore := -math.MaxFloat64
|
||||
bestClass := ""
|
||||
|
||||
for class, prior := range nb.logClassPrior {
|
||||
classScore := prior
|
||||
for class, classCount := range nb.classInstances {
|
||||
// Init classScore with log(prior)
|
||||
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances))
|
||||
for f := 0; f < nb.features; f++ {
|
||||
if vector[f] > 0 {
|
||||
// Test document has feature c
|
||||
|
@ -1,7 +1,6 @@
|
||||
package naive
|
||||
|
||||
import (
|
||||
"math"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"testing"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
@ -29,12 +28,10 @@ func TestSimple(t *testing.T) {
|
||||
nb.Fit(trainingData)
|
||||
|
||||
Convey("Check if Fit is working as expected", func() {
|
||||
Convey("All log(prior) should be correctly calculated", func() {
|
||||
logPriorBlue := nb.logClassPrior["blue"]
|
||||
logPriorRed := nb.logClassPrior["red"]
|
||||
|
||||
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
|
||||
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
|
||||
Convey("All data needed for prior should be correctly calculated", func() {
|
||||
So(nb.classInstances["blue"], ShouldEqual, 2)
|
||||
So(nb.classInstances["red"], ShouldEqual, 2)
|
||||
So(nb.trainingInstances, ShouldEqual, 4)
|
||||
})
|
||||
|
||||
Convey("'red' conditional probabilities should be correct", func() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user