mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Bernoulli Naive Bayes: first draft
This is the first draft of the bernoulli naive bayes implementation. It is missing the Fit function tests and the Predict function.
This commit is contained in:
parent
9f3d9eaa64
commit
0035dd184e
85
naive/bernoulli_nb.go
Normal file
85
naive/bernoulli_nb.go
Normal file
@ -0,0 +1,85 @@
|
||||
package naive
|
||||
|
||||
import (
|
||||
"github.com/gonum/matrix/mat64"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
// A Bernoulli Naive Bayes Classifier. Naive Bayes classifiers assumes
|
||||
// that features probabilities are independent. In order to classify an
|
||||
// instance, it is calculated the probability that it was generated by
|
||||
// each known class, that is, for each class C, the following
|
||||
// probability is calculated.
|
||||
//
|
||||
// p(C|F1, F2, F3... Fn)
|
||||
//
|
||||
// Being F1, F2... Fn the instance features. Using the bayes theorem
|
||||
// this can be written as:
|
||||
//
|
||||
// \frac{p(C) \times p(F1, F2... Fn|C)}{p(F1, F2... Fn)}
|
||||
//
|
||||
// In the Bernoulli Naive Bayes features are considered independent
|
||||
// booleans, this means that the likelihood of a document given a class
|
||||
// C is given by:
|
||||
//
|
||||
// p(F1, F2... Fn) =
|
||||
// \prod_{i=1}^{n}{[F_i \times p(f_i|C)) + (1-F_i)(1 - p(f_i|C)))]}
|
||||
//
|
||||
// where
|
||||
// - F_i equals to 1 if feature is present in vector and zero
|
||||
// otherwise
|
||||
// - p(f_i|C) the probability of class C generating the feature
|
||||
// f_i
|
||||
//
|
||||
// For more information:
|
||||
//
|
||||
// C.D. Manning, P. Raghavan and H. Schuetze (2008). Introduction to
|
||||
// Information Retrieval. Cambridge University Press, pp. 234-265.
|
||||
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
||||
type BernoulliNBClassifier struct {
|
||||
base.BaseEstimator
|
||||
// Number of instances in each class. Used for calculating the prior
|
||||
// probability and p(f_i|C)
|
||||
classInstances []int
|
||||
}
|
||||
|
||||
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
|
||||
// is the number of possible labels in the classification task.
|
||||
func NewBernoulliNBClassifier(classes int) *BernoulliNBClassifier {
|
||||
nb := BernoulliNBClassifier{}
|
||||
nb.classInstances = make([]int, classes)
|
||||
return &nb
|
||||
}
|
||||
|
||||
// Fill data matrix with Bernoulli Naive Bayes model. All values
|
||||
// necessary for calculating prior probability and p(f_i
|
||||
func (nb *BernoulliNBClassifier) Fit(X *mat64.Dense, y []int) {
|
||||
instances, features := X.Dims()
|
||||
if instances != len(y) {
|
||||
panic(mat64.ErrShape)
|
||||
}
|
||||
|
||||
nb.Data = mat64.NewDense(len(nb.classInstances), features, nil)
|
||||
|
||||
for r := 0; r < instances; r++ {
|
||||
// Get label of this instance. This should be a value between
|
||||
// zero and nb.classes.
|
||||
label := y[r]
|
||||
nb.classInstances[label]++
|
||||
|
||||
for c := 0; c < features; c++ {
|
||||
v := X.At(r, c)
|
||||
// In Bernoulli Naive Bayes the presence and absence of
|
||||
// features are considered. All non-zero values are
|
||||
// treated as presence.
|
||||
if v > 0 {
|
||||
// Update number of times this feature appeared within
|
||||
// given label.
|
||||
nb.Data.Set(label, c, nb.Data.At(label, c) + 1.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
24
naive/bernoulli_nb_test.go
Normal file
24
naive/bernoulli_nb_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package naive
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
// Test if panic is correctly called when matrices with different
|
||||
// dimensions are used as arguments.
|
||||
func TestFitPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatalf("invalid matrix dim did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
nb := NewBernoulliNBClassifier(2)
|
||||
|
||||
X := mat64.NewDense(10, 20, nil)
|
||||
// simulating user mistake, one extra label
|
||||
y := make([]int, 11)
|
||||
|
||||
nb.Fit(X, y)
|
||||
}
|
2
naive/naive.go
Normal file
2
naive/naive.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package naive implements...
|
||||
package naive
|
Loading…
x
Reference in New Issue
Block a user