From 90458d92ed895c86baf4415e254af9dc573dd4fa Mon Sep 17 00:00:00 2001 From: Thiago Cardoso Date: Tue, 20 May 2014 22:59:03 -0300 Subject: [PATCH] Added Predict function Added predict function along with its test. Current interface is the same of the KNN example. In other words, only the class string is returned from the PredictOne function. --- naive/bernoulli_nb.go | 105 +++++++++++++++++++++++++++++++------ naive/bernoulli_nb_test.go | 91 +++++++++++++++++++++++++------- naive/test/simple_test.csv | 4 ++ 3 files changed, 163 insertions(+), 37 deletions(-) create mode 100644 naive/test/simple_test.csv diff --git a/naive/bernoulli_nb.go b/naive/bernoulli_nb.go index 794dc90..83ad6c9 100644 --- a/naive/bernoulli_nb.go +++ b/naive/bernoulli_nb.go @@ -40,18 +40,24 @@ type BernoulliNBClassifier struct { base.BaseEstimator // Logarithm of each class prior logClassPrior map[string]float64 - // Log of conditional probability for each term. This vector should be - // accessed in the following way: p(f|c) = logCondProb[c][f]. + // Conditional probability for each term. This vector should be + // accessed in the following way: p(f|c) = condProb[c][f]. // Logarithm is used in order to avoid underflow. - logCondProb map[string][]float64 + condProb map[string][]float64 + // Number of instances in each class. This is necessary in order to + // calculate the laplace smooth value during the Predict step. + classInstances map[string]int + // Number of features in the training set + features int } // Create a new Bernoulli Naive Bayes Classifier. The argument 'classes' // is the number of possible labels in the classification task. func NewBernoulliNBClassifier() *BernoulliNBClassifier { nb := BernoulliNBClassifier{} - nb.logCondProb = make(map[string][]float64) + nb.condProb = make(map[string][]float64) nb.logClassPrior = make(map[string]float64) + nb.features = 0 return &nb } @@ -59,8 +65,14 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier { // necessary for calculating prior probability and p(f_i) func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { + // Number of features in this training set + nb.features = 0 + if X.Rows > 0 { + nb.features = len(X.GetRowVectorWithoutClass(0)) + } + // Number of instances in class - classInstances := make(map[string]int) + nb.classInstances = make(map[string]int) // Number of documents with given term (by class) docsContainingTerm := make(map[string][]int) @@ -70,14 +82,16 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { // version is used. for r := 0; r < X.Rows; r++ { class := X.GetClass(r) + docVector := X.GetRowVectorWithoutClass(r) // increment number of instances in class - t, ok := classInstances[class] + t, ok := nb.classInstances[class] if !ok { t = 0 } - classInstances[class] = t + 1 + nb.classInstances[class] = t + 1 - for feat := 0; feat < X.Cols; feat++ { - v := X.Get(r, feat) + + for feat := 0; feat < len(docVector); feat++ { + v := docVector[feat] // In Bernoulli Naive Bayes the presence and absence of // features are considered. All non-zero values are // treated as presence. @@ -86,7 +100,7 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { // given label. t, ok := docsContainingTerm[class] if !ok { - t = make([]int, X.Cols) + t = make([]int, nb.features) docsContainingTerm[class] = t } t[feat] += 1 @@ -95,20 +109,77 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { } // Pre-calculate conditional probabilities for each class - for c, _ := range classInstances { - nb.logClassPrior[c] = math.Log((float64(classInstances[c]))/float64(X.Rows)) - nb.logCondProb[c] = make([]float64, X.Cols) - for feat := 0; feat < X.Cols; feat++ { + for c, _ := range nb.classInstances { + nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows)) + nb.condProb[c] = make([]float64, nb.features) + for feat := 0; feat < nb.features; feat++ { classTerms, _ := docsContainingTerm[c] numDocs := classTerms[feat] - docsInClass, _ := classInstances[c] + docsInClass, _ := nb.classInstances[c] - classLogCondProb, _ := nb.logCondProb[c] + classCondProb, _ := nb.condProb[c] // Calculate conditional probability with laplace smoothing - classLogCondProb[feat] = math.Log(float64(numDocs + 1) / float64(docsInClass + 1)) + classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1) } } } +// Use trained model to predict test vector's class. The following +// operation is used in order to score each class: +// +// classScore = log(p(c)) + \sum_{f}{log(p(f|c))} +// +// PredictOne returns the string that represents the predicted class. +// +// IMPORTANT: PredictOne panics if Fit was not called or if the +// document vector and train matrix have a different number of columns. +func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string { + if nb.features == 0 { + panic("Fit should be called before predicting") + } + if len(vector) != nb.features { + panic("Different dimensions in Train and Test sets") + } + // Currently only the predicted class is returned. + bestScore := -math.MaxFloat64 + bestClass := "" + + for class, prior := range nb.logClassPrior { + classScore := prior + for f := 0; f < nb.features; f++ { + if vector[f] > 0 { + // Test document has feature c + classScore += math.Log(nb.condProb[class][f]) + } else { + if nb.condProb[class][f] == 1.0 { + // special case when prob = 1.0, consider laplace + // smooth + classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1)) + } else { + classScore += math.Log(1.0 - nb.condProb[class][f]) + } + } + } + + if classScore > bestScore { + bestScore = classScore + bestClass = class + } + } + + return bestClass +} + +// Predict is just a wrapper for the PredictOne function. +// +// IMPORTANT: Predict panics if Fit was not called or if the +// document vector and train matrix have a different number of columns. +func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances { + ret := what.GeneratePredictionVector() + for i := 0; i < what.Rows; i++ { + ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i))) + } + return ret +} diff --git a/naive/bernoulli_nb_test.go b/naive/bernoulli_nb_test.go index 69d86ad..e99df1f 100644 --- a/naive/bernoulli_nb_test.go +++ b/naive/bernoulli_nb_test.go @@ -7,7 +7,18 @@ import ( . "github.com/smartystreets/goconvey/convey" ) -func TestFit(t *testing.T) { +func TestNoFit(t *testing.T) { + Convey("Given an empty BernoulliNaiveBayes", t, func() { + nb := NewBernoulliNBClassifier() + + Convey("PredictOne should panic if Fit was not called", func() { + testDoc := []float64{0.0, 1.0} + So(func() { nb.PredictOne(testDoc) }, ShouldPanic) + }) + }) +} + +func TestSimple(t *testing.T) { Convey("Given a simple training data", t, func() { trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false) if err1 != nil { @@ -17,32 +28,72 @@ func TestFit(t *testing.T) { nb := NewBernoulliNBClassifier() nb.Fit(trainingData) - Convey("All log(prior) should be correctly calculated", func() { - logPriorBlue := nb.logClassPrior["blue"] - logPriorRed := nb.logClassPrior["red"] + Convey("Check if Fit is working as expected", func() { + Convey("All log(prior) should be correctly calculated", func() { + logPriorBlue := nb.logClassPrior["blue"] + logPriorRed := nb.logClassPrior["red"] - So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5)) - So(logPriorRed, ShouldAlmostEqual, math.Log(0.5)) + So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5)) + So(logPriorRed, ShouldAlmostEqual, math.Log(0.5)) + }) + + Convey("'red' conditional probabilities should be correct", func() { + logCondProbTok0 := nb.condProb["red"][0] + logCondProbTok1 := nb.condProb["red"][1] + logCondProbTok2 := nb.condProb["red"][2] + + So(logCondProbTok0, ShouldAlmostEqual, 1.0) + So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0) + So(logCondProbTok2, ShouldAlmostEqual, 1.0) + }) + + Convey("'blue' conditional probabilities should be correct", func() { + logCondProbTok0 := nb.condProb["blue"][0] + logCondProbTok1 := nb.condProb["blue"][1] + logCondProbTok2 := nb.condProb["blue"][2] + + So(logCondProbTok0, ShouldAlmostEqual, 1.0) + So(logCondProbTok1, ShouldAlmostEqual, 1.0) + So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0) + }) }) - Convey("'red' conditional probabilities should be correct", func() { - logCondProbTok0 := nb.logCondProb["red"][0] - logCondProbTok1 := nb.logCondProb["red"][1] - logCondProbTok2 := nb.logCondProb["red"][2] + Convey("PredictOne should work as expected", func() { + Convey("Using a document with different number of cols should panic", func() { + testDoc := []float64{0.0, 2.0} + So(func() { nb.PredictOne(testDoc) }, ShouldPanic) + }) - So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0)) - So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0/3.0)) - So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0)) + Convey("Token 1 should be a good predictor of the blue class", func() { + testDoc := []float64{0.0, 123.0, 0.0} + So(nb.PredictOne(testDoc), ShouldEqual, "blue") + + testDoc = []float64{120.0, 123.0, 0.0} + So(nb.PredictOne(testDoc), ShouldEqual, "blue") + }) + + Convey("Token 2 should be a good predictor of the red class", func() { + testDoc := []float64{0.0, 0.0, 120.0} + So(nb.PredictOne(testDoc), ShouldEqual, "red") + + testDoc = []float64{10.0, 0.0, 120.0} + So(nb.PredictOne(testDoc), ShouldEqual, "red") + }) }) - Convey("'blue' conditional probabilities should be correct", func() { - logCondProbTok0 := nb.logCondProb["blue"][0] - logCondProbTok1 := nb.logCondProb["blue"][1] - logCondProbTok2 := nb.logCondProb["blue"][2] + Convey("Predict should work as expected", func() { + testData, err := base.ParseCSVToInstances("test/simple_test.csv", false) + if err != nil { + t.Error(err) + } + predictions := nb.Predict(testData) - So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0)) - So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0)) - So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0/3.0)) + Convey("All simple predicitions should be correct", func() { + So(predictions.GetClass(0), ShouldEqual, "blue") + So(predictions.GetClass(1), ShouldEqual, "red") + So(predictions.GetClass(2), ShouldEqual, "blue") + So(predictions.GetClass(3), ShouldEqual, "red") + }) }) }) } diff --git a/naive/test/simple_test.csv b/naive/test/simple_test.csv new file mode 100644 index 0000000..8109b16 --- /dev/null +++ b/naive/test/simple_test.csv @@ -0,0 +1,4 @@ +0,12,0,blue +0,0,645,red +9,213,0,blue +21,0,987,red