mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-30 13:48:57 +08:00
Removed class prior pre-calculation
Since the number of instances in each class are stored, there is no need to keep the pre-calculated priors.
This commit is contained in:
parent
0cf6d258e6
commit
94ef107fc8
@ -38,8 +38,6 @@ import (
|
|||||||
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
||||||
type BernoulliNBClassifier struct {
|
type BernoulliNBClassifier struct {
|
||||||
base.BaseEstimator
|
base.BaseEstimator
|
||||||
// Logarithm of each class prior
|
|
||||||
logClassPrior map[string]float64
|
|
||||||
// Conditional probability for each term. This vector should be
|
// Conditional probability for each term. This vector should be
|
||||||
// accessed in the following way: p(f|c) = condProb[c][f].
|
// accessed in the following way: p(f|c) = condProb[c][f].
|
||||||
// Logarithm is used in order to avoid underflow.
|
// Logarithm is used in order to avoid underflow.
|
||||||
@ -47,6 +45,8 @@ type BernoulliNBClassifier struct {
|
|||||||
// Number of instances in each class. This is necessary in order to
|
// Number of instances in each class. This is necessary in order to
|
||||||
// calculate the laplace smooth value during the Predict step.
|
// calculate the laplace smooth value during the Predict step.
|
||||||
classInstances map[string]int
|
classInstances map[string]int
|
||||||
|
// Number of instances used in training.
|
||||||
|
trainingInstances int
|
||||||
// Number of features in the training set
|
// Number of features in the training set
|
||||||
features int
|
features int
|
||||||
}
|
}
|
||||||
@ -56,8 +56,8 @@ type BernoulliNBClassifier struct {
|
|||||||
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
||||||
nb := BernoulliNBClassifier{}
|
nb := BernoulliNBClassifier{}
|
||||||
nb.condProb = make(map[string][]float64)
|
nb.condProb = make(map[string][]float64)
|
||||||
nb.logClassPrior = make(map[string]float64)
|
|
||||||
nb.features = 0
|
nb.features = 0
|
||||||
|
nb.trainingInstances = 0
|
||||||
return &nb
|
return &nb
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,7 +65,8 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
|||||||
// necessary for calculating prior probability and p(f_i)
|
// necessary for calculating prior probability and p(f_i)
|
||||||
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||||
|
|
||||||
// Number of features in this training set
|
// Number of features and instances in this training set
|
||||||
|
nb.trainingInstances = X.Rows
|
||||||
nb.features = 0
|
nb.features = 0
|
||||||
if X.Rows > 0 {
|
if X.Rows > 0 {
|
||||||
nb.features = len(X.GetRowVectorWithoutClass(0))
|
nb.features = len(X.GetRowVectorWithoutClass(0))
|
||||||
@ -110,7 +111,6 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
|||||||
|
|
||||||
// Pre-calculate conditional probabilities for each class
|
// Pre-calculate conditional probabilities for each class
|
||||||
for c, _ := range nb.classInstances {
|
for c, _ := range nb.classInstances {
|
||||||
nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows))
|
|
||||||
nb.condProb[c] = make([]float64, nb.features)
|
nb.condProb[c] = make([]float64, nb.features)
|
||||||
for feat := 0; feat < nb.features; feat++ {
|
for feat := 0; feat < nb.features; feat++ {
|
||||||
classTerms, _ := docsContainingTerm[c]
|
classTerms, _ := docsContainingTerm[c]
|
||||||
@ -146,8 +146,9 @@ func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
|
|||||||
bestScore := -math.MaxFloat64
|
bestScore := -math.MaxFloat64
|
||||||
bestClass := ""
|
bestClass := ""
|
||||||
|
|
||||||
for class, prior := range nb.logClassPrior {
|
for class, classCount := range nb.classInstances {
|
||||||
classScore := prior
|
// Init classScore with log(prior)
|
||||||
|
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances))
|
||||||
for f := 0; f < nb.features; f++ {
|
for f := 0; f < nb.features; f++ {
|
||||||
if vector[f] > 0 {
|
if vector[f] > 0 {
|
||||||
// Test document has feature c
|
// Test document has feature c
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package naive
|
package naive
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"testing"
|
"testing"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
@ -29,12 +28,10 @@ func TestSimple(t *testing.T) {
|
|||||||
nb.Fit(trainingData)
|
nb.Fit(trainingData)
|
||||||
|
|
||||||
Convey("Check if Fit is working as expected", func() {
|
Convey("Check if Fit is working as expected", func() {
|
||||||
Convey("All log(prior) should be correctly calculated", func() {
|
Convey("All data needed for prior should be correctly calculated", func() {
|
||||||
logPriorBlue := nb.logClassPrior["blue"]
|
So(nb.classInstances["blue"], ShouldEqual, 2)
|
||||||
logPriorRed := nb.logClassPrior["red"]
|
So(nb.classInstances["red"], ShouldEqual, 2)
|
||||||
|
So(nb.trainingInstances, ShouldEqual, 4)
|
||||||
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
|
|
||||||
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("'red' conditional probabilities should be correct", func() {
|
Convey("'red' conditional probabilities should be correct", func() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user