1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Added Predict function

Added predict function along with its test. Current interface is the
same of the KNN example. In other words, only the class string is
returned from the PredictOne function.
This commit is contained in:
Thiago Cardoso 2014-05-20 22:59:03 -03:00
parent 86b18fe1c9
commit 90458d92ed
3 changed files with 163 additions and 37 deletions

View File

@ -40,18 +40,24 @@ type BernoulliNBClassifier struct {
base.BaseEstimator
// Logarithm of each class prior
logClassPrior map[string]float64
// Log of conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = logCondProb[c][f].
// Conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = condProb[c][f].
// Logarithm is used in order to avoid underflow.
logCondProb map[string][]float64
condProb map[string][]float64
// Number of instances in each class. This is necessary in order to
// calculate the laplace smooth value during the Predict step.
classInstances map[string]int
// Number of features in the training set
features int
}
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
// is the number of possible labels in the classification task.
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
nb := BernoulliNBClassifier{}
nb.logCondProb = make(map[string][]float64)
nb.condProb = make(map[string][]float64)
nb.logClassPrior = make(map[string]float64)
nb.features = 0
return &nb
}
@ -59,8 +65,14 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// Number of features in this training set
nb.features = 0
if X.Rows > 0 {
nb.features = len(X.GetRowVectorWithoutClass(0))
}
// Number of instances in class
classInstances := make(map[string]int)
nb.classInstances = make(map[string]int)
// Number of documents with given term (by class)
docsContainingTerm := make(map[string][]int)
@ -70,14 +82,16 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// version is used.
for r := 0; r < X.Rows; r++ {
class := X.GetClass(r)
docVector := X.GetRowVectorWithoutClass(r)
// increment number of instances in class
t, ok := classInstances[class]
t, ok := nb.classInstances[class]
if !ok { t = 0 }
classInstances[class] = t + 1
nb.classInstances[class] = t + 1
for feat := 0; feat < X.Cols; feat++ {
v := X.Get(r, feat)
for feat := 0; feat < len(docVector); feat++ {
v := docVector[feat]
// In Bernoulli Naive Bayes the presence and absence of
// features are considered. All non-zero values are
// treated as presence.
@ -86,7 +100,7 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// given label.
t, ok := docsContainingTerm[class]
if !ok {
t = make([]int, X.Cols)
t = make([]int, nb.features)
docsContainingTerm[class] = t
}
t[feat] += 1
@ -95,20 +109,77 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
}
// Pre-calculate conditional probabilities for each class
for c, _ := range classInstances {
nb.logClassPrior[c] = math.Log((float64(classInstances[c]))/float64(X.Rows))
nb.logCondProb[c] = make([]float64, X.Cols)
for feat := 0; feat < X.Cols; feat++ {
for c, _ := range nb.classInstances {
nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows))
nb.condProb[c] = make([]float64, nb.features)
for feat := 0; feat < nb.features; feat++ {
classTerms, _ := docsContainingTerm[c]
numDocs := classTerms[feat]
docsInClass, _ := classInstances[c]
docsInClass, _ := nb.classInstances[c]
classLogCondProb, _ := nb.logCondProb[c]
classCondProb, _ := nb.condProb[c]
// Calculate conditional probability with laplace smoothing
classLogCondProb[feat] = math.Log(float64(numDocs + 1) / float64(docsInClass + 1))
classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1)
}
}
}
// Use trained model to predict test vector's class. The following
// operation is used in order to score each class:
//
// classScore = log(p(c)) + \sum_{f}{log(p(f|c))}
//
// PredictOne returns the string that represents the predicted class.
//
// IMPORTANT: PredictOne panics if Fit was not called or if the
// document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
if nb.features == 0 {
panic("Fit should be called before predicting")
}
if len(vector) != nb.features {
panic("Different dimensions in Train and Test sets")
}
// Currently only the predicted class is returned.
bestScore := -math.MaxFloat64
bestClass := ""
for class, prior := range nb.logClassPrior {
classScore := prior
for f := 0; f < nb.features; f++ {
if vector[f] > 0 {
// Test document has feature c
classScore += math.Log(nb.condProb[class][f])
} else {
if nb.condProb[class][f] == 1.0 {
// special case when prob = 1.0, consider laplace
// smooth
classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1))
} else {
classScore += math.Log(1.0 - nb.condProb[class][f])
}
}
}
if classScore > bestScore {
bestScore = classScore
bestClass = class
}
}
return bestClass
}
// Predict is just a wrapper for the PredictOne function.
//
// IMPORTANT: Predict panics if Fit was not called or if the
// document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
ret := what.GeneratePredictionVector()
for i := 0; i < what.Rows; i++ {
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
}
return ret
}

View File

@ -7,7 +7,18 @@ import (
. "github.com/smartystreets/goconvey/convey"
)
func TestFit(t *testing.T) {
func TestNoFit(t *testing.T) {
Convey("Given an empty BernoulliNaiveBayes", t, func() {
nb := NewBernoulliNBClassifier()
Convey("PredictOne should panic if Fit was not called", func() {
testDoc := []float64{0.0, 1.0}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
})
}
func TestSimple(t *testing.T) {
Convey("Given a simple training data", t, func() {
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
if err1 != nil {
@ -17,32 +28,72 @@ func TestFit(t *testing.T) {
nb := NewBernoulliNBClassifier()
nb.Fit(trainingData)
Convey("All log(prior) should be correctly calculated", func() {
logPriorBlue := nb.logClassPrior["blue"]
logPriorRed := nb.logClassPrior["red"]
Convey("Check if Fit is working as expected", func() {
Convey("All log(prior) should be correctly calculated", func() {
logPriorBlue := nb.logClassPrior["blue"]
logPriorRed := nb.logClassPrior["red"]
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
})
Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["red"][0]
logCondProbTok1 := nb.condProb["red"][1]
logCondProbTok2 := nb.condProb["red"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
})
Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["blue"][0]
logCondProbTok1 := nb.condProb["blue"][1]
logCondProbTok2 := nb.condProb["blue"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
})
})
Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.logCondProb["red"][0]
logCondProbTok1 := nb.logCondProb["red"][1]
logCondProbTok2 := nb.logCondProb["red"][2]
Convey("PredictOne should work as expected", func() {
Convey("Using a document with different number of cols should panic", func() {
testDoc := []float64{0.0, 2.0}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0))
So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0/3.0))
So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0))
Convey("Token 1 should be a good predictor of the blue class", func() {
testDoc := []float64{0.0, 123.0, 0.0}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
testDoc = []float64{120.0, 123.0, 0.0}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
})
Convey("Token 2 should be a good predictor of the red class", func() {
testDoc := []float64{0.0, 0.0, 120.0}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
testDoc = []float64{10.0, 0.0, 120.0}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
})
})
Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.logCondProb["blue"][0]
logCondProbTok1 := nb.logCondProb["blue"][1]
logCondProbTok2 := nb.logCondProb["blue"][2]
Convey("Predict should work as expected", func() {
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
if err != nil {
t.Error(err)
}
predictions := nb.Predict(testData)
So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0))
So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0))
So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0/3.0))
Convey("All simple predicitions should be correct", func() {
So(predictions.GetClass(0), ShouldEqual, "blue")
So(predictions.GetClass(1), ShouldEqual, "red")
So(predictions.GetClass(2), ShouldEqual, "blue")
So(predictions.GetClass(3), ShouldEqual, "red")
})
})
})
}

View File

@ -0,0 +1,4 @@
0,12,0,blue
0,0,645,red
9,213,0,blue
21,0,987,red
1 0 12 0 blue
2 0 0 645 red
3 9 213 0 blue
4 21 0 987 red