mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Added Predict function
Added predict function along with its test. Current interface is the same of the KNN example. In other words, only the class string is returned from the PredictOne function.
This commit is contained in:
parent
86b18fe1c9
commit
90458d92ed
@ -40,18 +40,24 @@ type BernoulliNBClassifier struct {
|
||||
base.BaseEstimator
|
||||
// Logarithm of each class prior
|
||||
logClassPrior map[string]float64
|
||||
// Log of conditional probability for each term. This vector should be
|
||||
// accessed in the following way: p(f|c) = logCondProb[c][f].
|
||||
// Conditional probability for each term. This vector should be
|
||||
// accessed in the following way: p(f|c) = condProb[c][f].
|
||||
// Logarithm is used in order to avoid underflow.
|
||||
logCondProb map[string][]float64
|
||||
condProb map[string][]float64
|
||||
// Number of instances in each class. This is necessary in order to
|
||||
// calculate the laplace smooth value during the Predict step.
|
||||
classInstances map[string]int
|
||||
// Number of features in the training set
|
||||
features int
|
||||
}
|
||||
|
||||
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
|
||||
// is the number of possible labels in the classification task.
|
||||
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
||||
nb := BernoulliNBClassifier{}
|
||||
nb.logCondProb = make(map[string][]float64)
|
||||
nb.condProb = make(map[string][]float64)
|
||||
nb.logClassPrior = make(map[string]float64)
|
||||
nb.features = 0
|
||||
return &nb
|
||||
}
|
||||
|
||||
@ -59,8 +65,14 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
||||
// necessary for calculating prior probability and p(f_i)
|
||||
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
|
||||
// Number of features in this training set
|
||||
nb.features = 0
|
||||
if X.Rows > 0 {
|
||||
nb.features = len(X.GetRowVectorWithoutClass(0))
|
||||
}
|
||||
|
||||
// Number of instances in class
|
||||
classInstances := make(map[string]int)
|
||||
nb.classInstances = make(map[string]int)
|
||||
|
||||
// Number of documents with given term (by class)
|
||||
docsContainingTerm := make(map[string][]int)
|
||||
@ -70,14 +82,16 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
// version is used.
|
||||
for r := 0; r < X.Rows; r++ {
|
||||
class := X.GetClass(r)
|
||||
docVector := X.GetRowVectorWithoutClass(r)
|
||||
|
||||
// increment number of instances in class
|
||||
t, ok := classInstances[class]
|
||||
t, ok := nb.classInstances[class]
|
||||
if !ok { t = 0 }
|
||||
classInstances[class] = t + 1
|
||||
nb.classInstances[class] = t + 1
|
||||
|
||||
for feat := 0; feat < X.Cols; feat++ {
|
||||
v := X.Get(r, feat)
|
||||
|
||||
for feat := 0; feat < len(docVector); feat++ {
|
||||
v := docVector[feat]
|
||||
// In Bernoulli Naive Bayes the presence and absence of
|
||||
// features are considered. All non-zero values are
|
||||
// treated as presence.
|
||||
@ -86,7 +100,7 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
// given label.
|
||||
t, ok := docsContainingTerm[class]
|
||||
if !ok {
|
||||
t = make([]int, X.Cols)
|
||||
t = make([]int, nb.features)
|
||||
docsContainingTerm[class] = t
|
||||
}
|
||||
t[feat] += 1
|
||||
@ -95,20 +109,77 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
}
|
||||
|
||||
// Pre-calculate conditional probabilities for each class
|
||||
for c, _ := range classInstances {
|
||||
nb.logClassPrior[c] = math.Log((float64(classInstances[c]))/float64(X.Rows))
|
||||
nb.logCondProb[c] = make([]float64, X.Cols)
|
||||
for feat := 0; feat < X.Cols; feat++ {
|
||||
for c, _ := range nb.classInstances {
|
||||
nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows))
|
||||
nb.condProb[c] = make([]float64, nb.features)
|
||||
for feat := 0; feat < nb.features; feat++ {
|
||||
classTerms, _ := docsContainingTerm[c]
|
||||
numDocs := classTerms[feat]
|
||||
docsInClass, _ := classInstances[c]
|
||||
docsInClass, _ := nb.classInstances[c]
|
||||
|
||||
classLogCondProb, _ := nb.logCondProb[c]
|
||||
classCondProb, _ := nb.condProb[c]
|
||||
// Calculate conditional probability with laplace smoothing
|
||||
classLogCondProb[feat] = math.Log(float64(numDocs + 1) / float64(docsInClass + 1))
|
||||
classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use trained model to predict test vector's class. The following
|
||||
// operation is used in order to score each class:
|
||||
//
|
||||
// classScore = log(p(c)) + \sum_{f}{log(p(f|c))}
|
||||
//
|
||||
// PredictOne returns the string that represents the predicted class.
|
||||
//
|
||||
// IMPORTANT: PredictOne panics if Fit was not called or if the
|
||||
// document vector and train matrix have a different number of columns.
|
||||
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
|
||||
if nb.features == 0 {
|
||||
panic("Fit should be called before predicting")
|
||||
}
|
||||
|
||||
if len(vector) != nb.features {
|
||||
panic("Different dimensions in Train and Test sets")
|
||||
}
|
||||
|
||||
// Currently only the predicted class is returned.
|
||||
bestScore := -math.MaxFloat64
|
||||
bestClass := ""
|
||||
|
||||
for class, prior := range nb.logClassPrior {
|
||||
classScore := prior
|
||||
for f := 0; f < nb.features; f++ {
|
||||
if vector[f] > 0 {
|
||||
// Test document has feature c
|
||||
classScore += math.Log(nb.condProb[class][f])
|
||||
} else {
|
||||
if nb.condProb[class][f] == 1.0 {
|
||||
// special case when prob = 1.0, consider laplace
|
||||
// smooth
|
||||
classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1))
|
||||
} else {
|
||||
classScore += math.Log(1.0 - nb.condProb[class][f])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if classScore > bestScore {
|
||||
bestScore = classScore
|
||||
bestClass = class
|
||||
}
|
||||
}
|
||||
|
||||
return bestClass
|
||||
}
|
||||
|
||||
// Predict is just a wrapper for the PredictOne function.
|
||||
//
|
||||
// IMPORTANT: Predict panics if Fit was not called or if the
|
||||
// document vector and train matrix have a different number of columns.
|
||||
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
|
||||
ret := what.GeneratePredictionVector()
|
||||
for i := 0; i < what.Rows; i++ {
|
||||
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -7,7 +7,18 @@ import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestFit(t *testing.T) {
|
||||
func TestNoFit(t *testing.T) {
|
||||
Convey("Given an empty BernoulliNaiveBayes", t, func() {
|
||||
nb := NewBernoulliNBClassifier()
|
||||
|
||||
Convey("PredictOne should panic if Fit was not called", func() {
|
||||
testDoc := []float64{0.0, 1.0}
|
||||
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSimple(t *testing.T) {
|
||||
Convey("Given a simple training data", t, func() {
|
||||
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
|
||||
if err1 != nil {
|
||||
@ -17,32 +28,72 @@ func TestFit(t *testing.T) {
|
||||
nb := NewBernoulliNBClassifier()
|
||||
nb.Fit(trainingData)
|
||||
|
||||
Convey("All log(prior) should be correctly calculated", func() {
|
||||
logPriorBlue := nb.logClassPrior["blue"]
|
||||
logPriorRed := nb.logClassPrior["red"]
|
||||
Convey("Check if Fit is working as expected", func() {
|
||||
Convey("All log(prior) should be correctly calculated", func() {
|
||||
logPriorBlue := nb.logClassPrior["blue"]
|
||||
logPriorRed := nb.logClassPrior["red"]
|
||||
|
||||
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
|
||||
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
|
||||
So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
|
||||
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
|
||||
})
|
||||
|
||||
Convey("'red' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.condProb["red"][0]
|
||||
logCondProbTok1 := nb.condProb["red"][1]
|
||||
logCondProbTok2 := nb.condProb["red"][2]
|
||||
|
||||
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
|
||||
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
|
||||
})
|
||||
|
||||
Convey("'blue' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.condProb["blue"][0]
|
||||
logCondProbTok1 := nb.condProb["blue"][1]
|
||||
logCondProbTok2 := nb.condProb["blue"][2]
|
||||
|
||||
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
|
||||
})
|
||||
})
|
||||
|
||||
Convey("'red' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.logCondProb["red"][0]
|
||||
logCondProbTok1 := nb.logCondProb["red"][1]
|
||||
logCondProbTok2 := nb.logCondProb["red"][2]
|
||||
Convey("PredictOne should work as expected", func() {
|
||||
Convey("Using a document with different number of cols should panic", func() {
|
||||
testDoc := []float64{0.0, 2.0}
|
||||
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
|
||||
})
|
||||
|
||||
So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0))
|
||||
So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0/3.0))
|
||||
So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0))
|
||||
Convey("Token 1 should be a good predictor of the blue class", func() {
|
||||
testDoc := []float64{0.0, 123.0, 0.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
|
||||
|
||||
testDoc = []float64{120.0, 123.0, 0.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
|
||||
})
|
||||
|
||||
Convey("Token 2 should be a good predictor of the red class", func() {
|
||||
testDoc := []float64{0.0, 0.0, 120.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "red")
|
||||
|
||||
testDoc = []float64{10.0, 0.0, 120.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
|
||||
Convey("'blue' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.logCondProb["blue"][0]
|
||||
logCondProbTok1 := nb.logCondProb["blue"][1]
|
||||
logCondProbTok2 := nb.logCondProb["blue"][2]
|
||||
Convey("Predict should work as expected", func() {
|
||||
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
predictions := nb.Predict(testData)
|
||||
|
||||
So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0))
|
||||
So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0))
|
||||
So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0/3.0))
|
||||
Convey("All simple predicitions should be correct", func() {
|
||||
So(predictions.GetClass(0), ShouldEqual, "blue")
|
||||
So(predictions.GetClass(1), ShouldEqual, "red")
|
||||
So(predictions.GetClass(2), ShouldEqual, "blue")
|
||||
So(predictions.GetClass(3), ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
4
naive/test/simple_test.csv
Normal file
4
naive/test/simple_test.csv
Normal file
@ -0,0 +1,4 @@
|
||||
0,12,0,blue
|
||||
0,0,645,red
|
||||
9,213,0,blue
|
||||
21,0,987,red
|
|
Loading…
x
Reference in New Issue
Block a user