diff --git a/naive/bernoulli_nb.go b/naive/bernoulli_nb.go index 794dc90..83ad6c9 100644 --- a/naive/bernoulli_nb.go +++ b/naive/bernoulli_nb.go @@ -40,18 +40,24 @@ type BernoulliNBClassifier struct { base.BaseEstimator // Logarithm of each class prior logClassPrior map[string]float64 - // Log of conditional probability for each term. This vector should be - // accessed in the following way: p(f|c) = logCondProb[c][f]. + // Conditional probability for each term. This vector should be + // accessed in the following way: p(f|c) = condProb[c][f]. // Logarithm is used in order to avoid underflow. - logCondProb map[string][]float64 + condProb map[string][]float64 + // Number of instances in each class. This is necessary in order to + // calculate the laplace smooth value during the Predict step. + classInstances map[string]int + // Number of features in the training set + features int } // Create a new Bernoulli Naive Bayes Classifier. The argument 'classes' // is the number of possible labels in the classification task. func NewBernoulliNBClassifier() *BernoulliNBClassifier { nb := BernoulliNBClassifier{} - nb.logCondProb = make(map[string][]float64) + nb.condProb = make(map[string][]float64) nb.logClassPrior = make(map[string]float64) + nb.features = 0 return &nb } @@ -59,8 +65,14 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier { // necessary for calculating prior probability and p(f_i) func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { + // Number of features in this training set + nb.features = 0 + if X.Rows > 0 { + nb.features = len(X.GetRowVectorWithoutClass(0)) + } + // Number of instances in class - classInstances := make(map[string]int) + nb.classInstances = make(map[string]int) // Number of documents with given term (by class) docsContainingTerm := make(map[string][]int) @@ -70,14 +82,16 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { // version is used. for r := 0; r < X.Rows; r++ { class := X.GetClass(r) + docVector := X.GetRowVectorWithoutClass(r) // increment number of instances in class - t, ok := classInstances[class] + t, ok := nb.classInstances[class] if !ok { t = 0 } - classInstances[class] = t + 1 + nb.classInstances[class] = t + 1 - for feat := 0; feat < X.Cols; feat++ { - v := X.Get(r, feat) + + for feat := 0; feat < len(docVector); feat++ { + v := docVector[feat] // In Bernoulli Naive Bayes the presence and absence of // features are considered. All non-zero values are // treated as presence. @@ -86,7 +100,7 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { // given label. t, ok := docsContainingTerm[class] if !ok { - t = make([]int, X.Cols) + t = make([]int, nb.features) docsContainingTerm[class] = t } t[feat] += 1 @@ -95,20 +109,77 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { } // Pre-calculate conditional probabilities for each class - for c, _ := range classInstances { - nb.logClassPrior[c] = math.Log((float64(classInstances[c]))/float64(X.Rows)) - nb.logCondProb[c] = make([]float64, X.Cols) - for feat := 0; feat < X.Cols; feat++ { + for c, _ := range nb.classInstances { + nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows)) + nb.condProb[c] = make([]float64, nb.features) + for feat := 0; feat < nb.features; feat++ { classTerms, _ := docsContainingTerm[c] numDocs := classTerms[feat] - docsInClass, _ := classInstances[c] + docsInClass, _ := nb.classInstances[c] - classLogCondProb, _ := nb.logCondProb[c] + classCondProb, _ := nb.condProb[c] // Calculate conditional probability with laplace smoothing - classLogCondProb[feat] = math.Log(float64(numDocs + 1) / float64(docsInClass + 1)) + classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1) } } } +// Use trained model to predict test vector's class. The following +// operation is used in order to score each class: +// +// classScore = log(p(c)) + \sum_{f}{log(p(f|c))} +// +// PredictOne returns the string that represents the predicted class. +// +// IMPORTANT: PredictOne panics if Fit was not called or if the +// document vector and train matrix have a different number of columns. +func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string { + if nb.features == 0 { + panic("Fit should be called before predicting") + } + if len(vector) != nb.features { + panic("Different dimensions in Train and Test sets") + } + // Currently only the predicted class is returned. + bestScore := -math.MaxFloat64 + bestClass := "" + + for class, prior := range nb.logClassPrior { + classScore := prior + for f := 0; f < nb.features; f++ { + if vector[f] > 0 { + // Test document has feature c + classScore += math.Log(nb.condProb[class][f]) + } else { + if nb.condProb[class][f] == 1.0 { + // special case when prob = 1.0, consider laplace + // smooth + classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1)) + } else { + classScore += math.Log(1.0 - nb.condProb[class][f]) + } + } + } + + if classScore > bestScore { + bestScore = classScore + bestClass = class + } + } + + return bestClass +} + +// Predict is just a wrapper for the PredictOne function. +// +// IMPORTANT: Predict panics if Fit was not called or if the +// document vector and train matrix have a different number of columns. +func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances { + ret := what.GeneratePredictionVector() + for i := 0; i < what.Rows; i++ { + ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i))) + } + return ret +} diff --git a/naive/bernoulli_nb_test.go b/naive/bernoulli_nb_test.go index 69d86ad..e99df1f 100644 --- a/naive/bernoulli_nb_test.go +++ b/naive/bernoulli_nb_test.go @@ -7,7 +7,18 @@ import ( . "github.com/smartystreets/goconvey/convey" ) -func TestFit(t *testing.T) { +func TestNoFit(t *testing.T) { + Convey("Given an empty BernoulliNaiveBayes", t, func() { + nb := NewBernoulliNBClassifier() + + Convey("PredictOne should panic if Fit was not called", func() { + testDoc := []float64{0.0, 1.0} + So(func() { nb.PredictOne(testDoc) }, ShouldPanic) + }) + }) +} + +func TestSimple(t *testing.T) { Convey("Given a simple training data", t, func() { trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false) if err1 != nil { @@ -17,32 +28,72 @@ func TestFit(t *testing.T) { nb := NewBernoulliNBClassifier() nb.Fit(trainingData) - Convey("All log(prior) should be correctly calculated", func() { - logPriorBlue := nb.logClassPrior["blue"] - logPriorRed := nb.logClassPrior["red"] + Convey("Check if Fit is working as expected", func() { + Convey("All log(prior) should be correctly calculated", func() { + logPriorBlue := nb.logClassPrior["blue"] + logPriorRed := nb.logClassPrior["red"] - So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5)) - So(logPriorRed, ShouldAlmostEqual, math.Log(0.5)) + So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5)) + So(logPriorRed, ShouldAlmostEqual, math.Log(0.5)) + }) + + Convey("'red' conditional probabilities should be correct", func() { + logCondProbTok0 := nb.condProb["red"][0] + logCondProbTok1 := nb.condProb["red"][1] + logCondProbTok2 := nb.condProb["red"][2] + + So(logCondProbTok0, ShouldAlmostEqual, 1.0) + So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0) + So(logCondProbTok2, ShouldAlmostEqual, 1.0) + }) + + Convey("'blue' conditional probabilities should be correct", func() { + logCondProbTok0 := nb.condProb["blue"][0] + logCondProbTok1 := nb.condProb["blue"][1] + logCondProbTok2 := nb.condProb["blue"][2] + + So(logCondProbTok0, ShouldAlmostEqual, 1.0) + So(logCondProbTok1, ShouldAlmostEqual, 1.0) + So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0) + }) }) - Convey("'red' conditional probabilities should be correct", func() { - logCondProbTok0 := nb.logCondProb["red"][0] - logCondProbTok1 := nb.logCondProb["red"][1] - logCondProbTok2 := nb.logCondProb["red"][2] + Convey("PredictOne should work as expected", func() { + Convey("Using a document with different number of cols should panic", func() { + testDoc := []float64{0.0, 2.0} + So(func() { nb.PredictOne(testDoc) }, ShouldPanic) + }) - So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0)) - So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0/3.0)) - So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0)) + Convey("Token 1 should be a good predictor of the blue class", func() { + testDoc := []float64{0.0, 123.0, 0.0} + So(nb.PredictOne(testDoc), ShouldEqual, "blue") + + testDoc = []float64{120.0, 123.0, 0.0} + So(nb.PredictOne(testDoc), ShouldEqual, "blue") + }) + + Convey("Token 2 should be a good predictor of the red class", func() { + testDoc := []float64{0.0, 0.0, 120.0} + So(nb.PredictOne(testDoc), ShouldEqual, "red") + + testDoc = []float64{10.0, 0.0, 120.0} + So(nb.PredictOne(testDoc), ShouldEqual, "red") + }) }) - Convey("'blue' conditional probabilities should be correct", func() { - logCondProbTok0 := nb.logCondProb["blue"][0] - logCondProbTok1 := nb.logCondProb["blue"][1] - logCondProbTok2 := nb.logCondProb["blue"][2] + Convey("Predict should work as expected", func() { + testData, err := base.ParseCSVToInstances("test/simple_test.csv", false) + if err != nil { + t.Error(err) + } + predictions := nb.Predict(testData) - So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0)) - So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0)) - So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0/3.0)) + Convey("All simple predicitions should be correct", func() { + So(predictions.GetClass(0), ShouldEqual, "blue") + So(predictions.GetClass(1), ShouldEqual, "red") + So(predictions.GetClass(2), ShouldEqual, "blue") + So(predictions.GetClass(3), ShouldEqual, "red") + }) }) }) } diff --git a/naive/test/simple_test.csv b/naive/test/simple_test.csv new file mode 100644 index 0000000..8109b16 --- /dev/null +++ b/naive/test/simple_test.csv @@ -0,0 +1,4 @@ +0,12,0,blue +0,0,645,red +9,213,0,blue +21,0,987,red