mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
base: Cleaned up duplicate Attribute resolution functions
This commit is contained in:
parent
ff97065261
commit
47341b2869
@ -129,7 +129,7 @@ func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGri
|
|||||||
|
|
||||||
rowCounter := 0
|
rowCounter := 0
|
||||||
|
|
||||||
specs := ResolveAllAttributes(u, u.AllAttributes())
|
specs := ResolveAttributes(u, u.AllAttributes())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
record, err := reader.Read()
|
record, err := reader.Read()
|
||||||
|
@ -379,7 +379,7 @@ func (inst *DenseInstances) Size() (int, int) {
|
|||||||
|
|
||||||
// swapRows swaps over rows i and j
|
// swapRows swaps over rows i and j
|
||||||
func (inst *DenseInstances) swapRows(i, j int) {
|
func (inst *DenseInstances) swapRows(i, j int) {
|
||||||
as := GetAllAttributeSpecs(inst)
|
as := ResolveAllAttributes(inst)
|
||||||
for _, a := range as {
|
for _, a := range as {
|
||||||
v1 := inst.Get(a, i)
|
v1 := inst.Get(a, i)
|
||||||
v2 := inst.Get(a, j)
|
v2 := inst.Get(a, j)
|
||||||
@ -424,7 +424,7 @@ func (inst *DenseInstances) String() string {
|
|||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
|
|
||||||
// Get all Attribute information
|
// Get all Attribute information
|
||||||
as := GetAllAttributeSpecs(inst)
|
as := ResolveAllAttributes(inst)
|
||||||
|
|
||||||
// Print header
|
// Print header
|
||||||
cols, rows := inst.Size()
|
cols, rows := inst.Size()
|
||||||
|
@ -153,7 +153,7 @@ func (l *LazilyFilteredInstances) MapOverRows(asv []AttributeSpec, mapFunc func(
|
|||||||
func (l *LazilyFilteredInstances) RowString(row int) string {
|
func (l *LazilyFilteredInstances) RowString(row int) string {
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
|
|
||||||
as := GetAllAttributeSpecs(l) // Retrieve all Attribute data
|
as := ResolveAllAttributes(l) // Retrieve all Attribute data
|
||||||
first := true // Decide whether to prefix
|
first := true // Decide whether to prefix
|
||||||
|
|
||||||
for _, a := range as {
|
for _, a := range as {
|
||||||
@ -188,7 +188,7 @@ func (l *LazilyFilteredInstances) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get all Attribute information
|
// Get all Attribute information
|
||||||
as := GetAllAttributeSpecs(l)
|
as := ResolveAllAttributes(l)
|
||||||
|
|
||||||
// Print header
|
// Print header
|
||||||
buffer.WriteString("Lazily filtered instances using ")
|
buffer.WriteString("Lazily filtered instances using ")
|
||||||
|
@ -17,8 +17,8 @@ func TestLazySortDesc(testEnv *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
as1 := GetAllAttributeSpecs(inst1)
|
as1 := ResolveAllAttributes(inst1)
|
||||||
as2 := GetAllAttributeSpecs(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
|
|
||||||
if isSortedDesc(inst1, as1[0]) {
|
if isSortedDesc(inst1, as1[0]) {
|
||||||
testEnv.Error("Can't test descending sort order")
|
testEnv.Error("Can't test descending sort order")
|
||||||
@ -44,7 +44,7 @@ func TestLazySortDesc(testEnv *testing.T) {
|
|||||||
|
|
||||||
func TestLazySortAsc(testEnv *testing.T) {
|
func TestLazySortAsc(testEnv *testing.T) {
|
||||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
as1 := GetAllAttributeSpecs(inst)
|
as1 := ResolveAllAttributes(inst)
|
||||||
if isSortedAsc(inst, as1[0]) {
|
if isSortedAsc(inst, as1[0]) {
|
||||||
testEnv.Error("Can't test ascending sort on something ascending already")
|
testEnv.Error("Can't test ascending sort on something ascending already")
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ func TestLazySortAsc(testEnv *testing.T) {
|
|||||||
testEnv.Error(err)
|
testEnv.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
as2 := GetAllAttributeSpecs(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
if !isSortedAsc(inst2, as2[0]) {
|
if !isSortedAsc(inst2, as2[0]) {
|
||||||
testEnv.Error("This file should be sorted in ascending order")
|
testEnv.Error("This file should be sorted in ascending order")
|
||||||
}
|
}
|
||||||
|
@ -44,8 +44,8 @@ func TestSortDesc(testEnv *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
as1 := GetAllAttributeSpecs(inst1)
|
as1 := ResolveAllAttributes(inst1)
|
||||||
as2 := GetAllAttributeSpecs(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
|
|
||||||
if isSortedDesc(inst1, as1[0]) {
|
if isSortedDesc(inst1, as1[0]) {
|
||||||
testEnv.Error("Can't test descending sort order")
|
testEnv.Error("Can't test descending sort order")
|
||||||
@ -71,7 +71,7 @@ func TestSortDesc(testEnv *testing.T) {
|
|||||||
|
|
||||||
func TestSortAsc(testEnv *testing.T) {
|
func TestSortAsc(testEnv *testing.T) {
|
||||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
as1 := GetAllAttributeSpecs(inst)
|
as1 := ResolveAllAttributes(inst)
|
||||||
if isSortedAsc(inst, as1[0]) {
|
if isSortedAsc(inst, as1[0]) {
|
||||||
testEnv.Error("Can't test ascending sort on something ascending already")
|
testEnv.Error("Can't test ascending sort on something ascending already")
|
||||||
}
|
}
|
||||||
@ -90,7 +90,7 @@ func TestSortAsc(testEnv *testing.T) {
|
|||||||
testEnv.Error(err)
|
testEnv.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
as2 := GetAllAttributeSpecs(inst2)
|
as2 := ResolveAllAttributes(inst2)
|
||||||
if !isSortedAsc(inst2, as2[0]) {
|
if !isSortedAsc(inst2, as2[0]) {
|
||||||
testEnv.Error("This file should be sorted in ascending order")
|
testEnv.Error("This file should be sorted in ascending order")
|
||||||
}
|
}
|
||||||
|
@ -38,9 +38,9 @@ func NonClassAttributes(d DataGrid) []Attribute {
|
|||||||
return AttributeDifferenceReferences(allAttrs, classAttrs)
|
return AttributeDifferenceReferences(allAttrs, classAttrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveAllAttributes returns AttributeSpecs describing
|
// ResolveAttributes returns AttributeSpecs describing
|
||||||
// all of the Attributes.
|
// all of the Attributes.
|
||||||
func ResolveAllAttributes(d DataGrid, attrs []Attribute) []AttributeSpec {
|
func ResolveAttributes(d DataGrid, attrs []Attribute) []AttributeSpec {
|
||||||
ret := make([]AttributeSpec, len(attrs))
|
ret := make([]AttributeSpec, len(attrs))
|
||||||
for i, a := range attrs {
|
for i, a := range attrs {
|
||||||
spec, err := d.GetAttribute(a)
|
spec, err := d.GetAttribute(a)
|
||||||
@ -52,25 +52,9 @@ func ResolveAllAttributes(d DataGrid, attrs []Attribute) []AttributeSpec {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllAttributeSpecs retrieves every Attribute specification
|
// ResolveAllAttributes returns every AttributeSpec
|
||||||
// from a given DataGrid. Useful in conjunction with MapOverRows.
|
func ResolveAllAttributes(d DataGrid) []AttributeSpec {
|
||||||
func GetAllAttributeSpecs(from DataGrid) []AttributeSpec {
|
return ResolveAttributes(d, d.AllAttributes())
|
||||||
attrs := from.AllAttributes()
|
|
||||||
return GetSomeAttributeSpecs(from, attrs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSomeAttributeSpecs returns a subset of Attribute specifications
|
|
||||||
// from a given DataGrid.
|
|
||||||
func GetSomeAttributeSpecs(from DataGrid, attrs []Attribute) []AttributeSpec {
|
|
||||||
ret := make([]AttributeSpec, len(attrs))
|
|
||||||
for i, a := range attrs {
|
|
||||||
as, err := from.GetAttribute(a)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ret[i] = as
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAttrSet(a []Attribute) map[Attribute]bool {
|
func buildAttrSet(a []Attribute) map[Attribute]bool {
|
||||||
|
@ -144,7 +144,7 @@ func DecomposeOnAttributeValues(inst FixedDataGrid, at Attribute) map[string]Fix
|
|||||||
rowMaps := make(map[string][]int)
|
rowMaps := make(map[string][]int)
|
||||||
|
|
||||||
// Build full Attribute set
|
// Build full Attribute set
|
||||||
fullAttrSpec := ResolveAllAttributes(inst, newAttrs)
|
fullAttrSpec := ResolveAttributes(inst, newAttrs)
|
||||||
fullAttrSpec = append(fullAttrSpec, attrSpec)
|
fullAttrSpec = append(fullAttrSpec, attrSpec)
|
||||||
|
|
||||||
// Decompose
|
// Decompose
|
||||||
|
@ -78,7 +78,7 @@ func NewInstancesViewFromRows(src FixedDataGrid, rows map[int]int) *InstancesVie
|
|||||||
func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribute) *InstancesView {
|
func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribute) *InstancesView {
|
||||||
ret := &InstancesView{
|
ret := &InstancesView{
|
||||||
src,
|
src,
|
||||||
GetSomeAttributeSpecs(src, attrs),
|
ResolveAttributes(src, attrs),
|
||||||
make(map[int]int),
|
make(map[int]int),
|
||||||
make(map[Attribute]bool),
|
make(map[Attribute]bool),
|
||||||
true,
|
true,
|
||||||
@ -99,7 +99,7 @@ func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribut
|
|||||||
func NewInstancesViewFromAttrs(src FixedDataGrid, attrs []Attribute) *InstancesView {
|
func NewInstancesViewFromAttrs(src FixedDataGrid, attrs []Attribute) *InstancesView {
|
||||||
ret := &InstancesView{
|
ret := &InstancesView{
|
||||||
src,
|
src,
|
||||||
GetSomeAttributeSpecs(src, attrs),
|
ResolveAttributes(src, attrs),
|
||||||
nil,
|
nil,
|
||||||
make(map[Attribute]bool),
|
make(map[Attribute]bool),
|
||||||
false,
|
false,
|
||||||
@ -252,7 +252,7 @@ func (v *InstancesView) String() string {
|
|||||||
maxRows := 30
|
maxRows := 30
|
||||||
|
|
||||||
// Get all Attribute information
|
// Get all Attribute information
|
||||||
as := GetAllAttributeSpecs(v)
|
as := ResolveAllAttributes(v)
|
||||||
|
|
||||||
// Print header
|
// Print header
|
||||||
cols, rows := v.Size()
|
cols, rows := v.Size()
|
||||||
@ -305,7 +305,7 @@ func (v *InstancesView) String() string {
|
|||||||
// RowString returns a string representation of a given row.
|
// RowString returns a string representation of a given row.
|
||||||
func (v *InstancesView) RowString(row int) string {
|
func (v *InstancesView) RowString(row int) string {
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
as := GetAllAttributeSpecs(v)
|
as := ResolveAllAttributes(v)
|
||||||
first := true
|
first := true
|
||||||
for _, a := range as {
|
for _, a := range as {
|
||||||
val := v.Get(a, row)
|
val := v.Get(a, row)
|
||||||
|
@ -46,7 +46,7 @@ func main() {
|
|||||||
// for doing so is not very sophisticated.
|
// for doing so is not very sophisticated.
|
||||||
|
|
||||||
// First, have to resolve Attribute Specifications
|
// First, have to resolve Attribute Specifications
|
||||||
as := base.ResolveAllAttributes(rawData, rawData.AllAttributes())
|
as := base.ResolveAttributes(rawData, rawData.AllAttributes())
|
||||||
|
|
||||||
// Attribute Specifications describe where a given column lives
|
// Attribute Specifications describe where a given column lives
|
||||||
rawData.Set(as[0], 0, as[0].GetAttribute().GetSysValFromString("1.00"))
|
rawData.Set(as[0], 0, as[0].GetAttribute().GetSysValFromString("1.00"))
|
||||||
|
@ -112,7 +112,7 @@ func TestChiMerge2(testEnv *testing.T) {
|
|||||||
|
|
||||||
// Sort the instances
|
// Sort the instances
|
||||||
allAttrs := inst.AllAttributes()
|
allAttrs := inst.AllAttributes()
|
||||||
sortAttrSpecs := base.ResolveAllAttributes(inst, allAttrs)[0:1]
|
sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
|
||||||
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
|
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -65,8 +65,8 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
|||||||
ret := base.GeneratePredictionVector(what)
|
ret := base.GeneratePredictionVector(what)
|
||||||
|
|
||||||
// Resolve Attribute specifications for both
|
// Resolve Attribute specifications for both
|
||||||
whatAttrSpecs := base.ResolveAllAttributes(what, allNumericAttrs)
|
whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
|
||||||
trainAttrSpecs := base.ResolveAllAttributes(KNN.TrainingData, allNumericAttrs)
|
trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)
|
||||||
|
|
||||||
// Reserve storage for most the most similar items
|
// Reserve storage for most the most similar items
|
||||||
distances := make(map[int]float64)
|
distances := make(map[int]float64)
|
||||||
|
@ -34,7 +34,7 @@ func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
|||||||
|
|
||||||
// Retrieve numeric non-class Attributes
|
// Retrieve numeric non-class Attributes
|
||||||
numericAttrs := base.NonClassFloatAttributes(X)
|
numericAttrs := base.NonClassFloatAttributes(X)
|
||||||
numericAttrSpecs := base.ResolveAllAttributes(X, numericAttrs)
|
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||||
|
|
||||||
// Convert each row
|
// Convert each row
|
||||||
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||||
@ -66,7 +66,7 @@ func convertInstancesToLabelVec(X base.FixedDataGrid) []float64 {
|
|||||||
_, rows := X.Size()
|
_, rows := X.Size()
|
||||||
labelVec := make([]float64, rows)
|
labelVec := make([]float64, rows)
|
||||||
// Resolve class Attribute specification
|
// Resolve class Attribute specification
|
||||||
classAttrSpecs := base.ResolveAllAttributes(X, classAttrs)
|
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
||||||
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||||
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
|
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
|
||||||
return true, nil
|
return true, nil
|
||||||
@ -90,10 +90,10 @@ func (lr *LogisticRegression) Predict(X base.FixedDataGrid) base.FixedDataGrid {
|
|||||||
}
|
}
|
||||||
// Generate return structure
|
// Generate return structure
|
||||||
ret := base.GeneratePredictionVector(X)
|
ret := base.GeneratePredictionVector(X)
|
||||||
classAttrSpecs := base.ResolveAllAttributes(ret, classAttrs)
|
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
|
||||||
// Retrieve numeric non-class Attributes
|
// Retrieve numeric non-class Attributes
|
||||||
numericAttrs := base.NonClassFloatAttributes(X)
|
numericAttrs := base.NonClassFloatAttributes(X)
|
||||||
numericAttrSpecs := base.ResolveAllAttributes(X, numericAttrs)
|
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||||
|
|
||||||
// Allocate row storage
|
// Allocate row storage
|
||||||
row := make([]float64, len(numericAttrSpecs))
|
row := make([]float64, len(numericAttrSpecs))
|
||||||
|
@ -112,7 +112,7 @@ func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
|||||||
for { // Need to resolve the voting problem
|
for { // Need to resolve the voting problem
|
||||||
incoming, ok := <-votes
|
incoming, ok := <-votes
|
||||||
if ok {
|
if ok {
|
||||||
cSpecs := base.ResolveAllAttributes(incoming, incoming.AllClassAttributes())
|
cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
|
||||||
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
|
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
|
||||||
// Check if we've seen this class before...
|
// Check if we've seen this class before...
|
||||||
if _, ok := voting[predRow]; !ok {
|
if _, ok := voting[predRow]; !ok {
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package naive
|
package naive
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
base "github.com/sjwhitworth/golearn/base"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Bernoulli Naive Bayes Classifier. Naive Bayes classifiers assumes
|
// A Bernoulli Naive Bayes Classifier. Naive Bayes classifiers assumes
|
||||||
@ -37,91 +37,103 @@ import (
|
|||||||
// Information Retrieval. Cambridge University Press, pp. 234-265.
|
// Information Retrieval. Cambridge University Press, pp. 234-265.
|
||||||
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
||||||
type BernoulliNBClassifier struct {
|
type BernoulliNBClassifier struct {
|
||||||
base.BaseEstimator
|
base.BaseEstimator
|
||||||
// Conditional probability for each term. This vector should be
|
// Conditional probability for each term. This vector should be
|
||||||
// accessed in the following way: p(f|c) = condProb[c][f].
|
// accessed in the following way: p(f|c) = condProb[c][f].
|
||||||
// Logarithm is used in order to avoid underflow.
|
// Logarithm is used in order to avoid underflow.
|
||||||
condProb map[string][]float64
|
condProb map[string][]float64
|
||||||
// Number of instances in each class. This is necessary in order to
|
// Number of instances in each class. This is necessary in order to
|
||||||
// calculate the laplace smooth value during the Predict step.
|
// calculate the laplace smooth value during the Predict step.
|
||||||
classInstances map[string]int
|
classInstances map[string]int
|
||||||
// Number of instances used in training.
|
// Number of instances used in training.
|
||||||
trainingInstances int
|
trainingInstances int
|
||||||
// Number of features in the training set
|
// Number of features in the training set
|
||||||
features int
|
features int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
|
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
|
||||||
// is the number of possible labels in the classification task.
|
// is the number of possible labels in the classification task.
|
||||||
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
||||||
nb := BernoulliNBClassifier{}
|
nb := BernoulliNBClassifier{}
|
||||||
nb.condProb = make(map[string][]float64)
|
nb.condProb = make(map[string][]float64)
|
||||||
nb.features = 0
|
nb.features = 0
|
||||||
nb.trainingInstances = 0
|
nb.trainingInstances = 0
|
||||||
return &nb
|
return &nb
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill data matrix with Bernoulli Naive Bayes model. All values
|
// Fill data matrix with Bernoulli Naive Bayes model. All values
|
||||||
// necessary for calculating prior probability and p(f_i)
|
// necessary for calculating prior probability and p(f_i)
|
||||||
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
|
||||||
|
|
||||||
// Number of features and instances in this training set
|
// Check that all Attributes are binary
|
||||||
nb.trainingInstances = X.Rows
|
classAttrs := X.AllClassAttributes()
|
||||||
nb.features = 0
|
allAttrs := X.AllAttributes()
|
||||||
if X.Rows > 0 {
|
featAttrs := base.AttributeDifferenceReference(allAttrs, classAttrs)
|
||||||
nb.features = len(X.GetRowVectorWithoutClass(0))
|
for i := range featAttrs {
|
||||||
}
|
if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
|
||||||
|
panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
featAttrSpecs := base.ResolveAllAttributes(featAttrs, X)
|
||||||
|
|
||||||
// Number of instances in class
|
// Check that only one classAttribute is defined
|
||||||
nb.classInstances = make(map[string]int)
|
if len(classAttrs) > 0 {
|
||||||
|
panic("Only one class Attribute can be used")
|
||||||
|
}
|
||||||
|
|
||||||
// Number of documents with given term (by class)
|
// Number of features and instances in this training set
|
||||||
docsContainingTerm := make(map[string][]int)
|
nb.features, nb.trainingInstances() = X.Size()
|
||||||
|
|
||||||
// This algorithm could be vectorized after binarizing the data
|
// Number of instances in class
|
||||||
// matrix. Since mat64 doesn't have this function, a iterative
|
nb.classInstances = make(map[string]int)
|
||||||
// version is used.
|
|
||||||
for r := 0; r < X.Rows; r++ {
|
|
||||||
class := X.GetClass(r)
|
|
||||||
docVector := X.GetRowVectorWithoutClass(r)
|
|
||||||
|
|
||||||
// increment number of instances in class
|
// Number of documents with given term (by class)
|
||||||
t, ok := nb.classInstances[class]
|
docsContainingTerm := make(map[string][]int)
|
||||||
if !ok { t = 0 }
|
|
||||||
nb.classInstances[class] = t + 1
|
|
||||||
|
|
||||||
|
// This algorithm could be vectorized after binarizing the data
|
||||||
|
// matrix. Since mat64 doesn't have this function, a iterative
|
||||||
|
// version is used.
|
||||||
|
X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
|
||||||
|
class := base.GetClass(X, r)
|
||||||
|
|
||||||
for feat := 0; feat < len(docVector); feat++ {
|
// increment number of instances in class
|
||||||
v := docVector[feat]
|
t, ok := nb.classInstances[class]
|
||||||
// In Bernoulli Naive Bayes the presence and absence of
|
if !ok {
|
||||||
// features are considered. All non-zero values are
|
t = 0
|
||||||
// treated as presence.
|
}
|
||||||
if v > 0 {
|
nb.classInstances[class] = t + 1
|
||||||
// Update number of times this feature appeared within
|
|
||||||
// given label.
|
|
||||||
t, ok := docsContainingTerm[class]
|
|
||||||
if !ok {
|
|
||||||
t = make([]int, nb.features)
|
|
||||||
docsContainingTerm[class] = t
|
|
||||||
}
|
|
||||||
t[feat] += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-calculate conditional probabilities for each class
|
for feat := 0; feat < len(docVector); feat++ {
|
||||||
for c, _ := range nb.classInstances {
|
v := docVector[feat]
|
||||||
nb.condProb[c] = make([]float64, nb.features)
|
// In Bernoulli Naive Bayes the presence and absence of
|
||||||
for feat := 0; feat < nb.features; feat++ {
|
// features are considered. All non-zero values are
|
||||||
classTerms, _ := docsContainingTerm[c]
|
// treated as presence.
|
||||||
numDocs := classTerms[feat]
|
if v[0] > 0 {
|
||||||
docsInClass, _ := nb.classInstances[c]
|
// Update number of times this feature appeared within
|
||||||
|
// given label.
|
||||||
|
t, ok := docsContainingTerm[class]
|
||||||
|
if !ok {
|
||||||
|
t = make([]int, nb.features)
|
||||||
|
docsContainingTerm[class] = t
|
||||||
|
}
|
||||||
|
t[feat] += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
classCondProb, _ := nb.condProb[c]
|
// Pre-calculate conditional probabilities for each class
|
||||||
// Calculate conditional probability with laplace smoothing
|
for c, _ := range nb.classInstances {
|
||||||
classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1)
|
nb.condProb[c] = make([]float64, nb.features)
|
||||||
}
|
for feat := 0; feat < nb.features; feat++ {
|
||||||
}
|
classTerms, _ := docsContainingTerm[c]
|
||||||
|
numDocs := classTerms[feat]
|
||||||
|
docsInClass, _ := nb.classInstances[c]
|
||||||
|
|
||||||
|
classCondProb, _ := nb.condProb[c]
|
||||||
|
// Calculate conditional probability with laplace smoothing
|
||||||
|
classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use trained model to predict test vector's class. The following
|
// Use trained model to predict test vector's class. The following
|
||||||
@ -134,43 +146,43 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
|||||||
// IMPORTANT: PredictOne panics if Fit was not called or if the
|
// IMPORTANT: PredictOne panics if Fit was not called or if the
|
||||||
// document vector and train matrix have a different number of columns.
|
// document vector and train matrix have a different number of columns.
|
||||||
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
|
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
|
||||||
if nb.features == 0 {
|
if nb.features == 0 {
|
||||||
panic("Fit should be called before predicting")
|
panic("Fit should be called before predicting")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(vector) != nb.features {
|
if len(vector) != nb.features {
|
||||||
panic("Different dimensions in Train and Test sets")
|
panic("Different dimensions in Train and Test sets")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Currently only the predicted class is returned.
|
// Currently only the predicted class is returned.
|
||||||
bestScore := -math.MaxFloat64
|
bestScore := -math.MaxFloat64
|
||||||
bestClass := ""
|
bestClass := ""
|
||||||
|
|
||||||
for class, classCount := range nb.classInstances {
|
for class, classCount := range nb.classInstances {
|
||||||
// Init classScore with log(prior)
|
// Init classScore with log(prior)
|
||||||
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances))
|
classScore := math.Log((float64(classCount)) / float64(nb.trainingInstances))
|
||||||
for f := 0; f < nb.features; f++ {
|
for f := 0; f < nb.features; f++ {
|
||||||
if vector[f] > 0 {
|
if vector[f] > 0 {
|
||||||
// Test document has feature c
|
// Test document has feature c
|
||||||
classScore += math.Log(nb.condProb[class][f])
|
classScore += math.Log(nb.condProb[class][f])
|
||||||
} else {
|
} else {
|
||||||
if nb.condProb[class][f] == 1.0 {
|
if nb.condProb[class][f] == 1.0 {
|
||||||
// special case when prob = 1.0, consider laplace
|
// special case when prob = 1.0, consider laplace
|
||||||
// smooth
|
// smooth
|
||||||
classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1))
|
classScore += math.Log(1.0 / float64(nb.classInstances[class]+1))
|
||||||
} else {
|
} else {
|
||||||
classScore += math.Log(1.0 - nb.condProb[class][f])
|
classScore += math.Log(1.0 - nb.condProb[class][f])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if classScore > bestScore {
|
if classScore > bestScore {
|
||||||
bestScore = classScore
|
bestScore = classScore
|
||||||
bestClass = class
|
bestClass = class
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return bestClass
|
return bestClass
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict is just a wrapper for the PredictOne function.
|
// Predict is just a wrapper for the PredictOne function.
|
||||||
@ -178,9 +190,9 @@ func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
|
|||||||
// IMPORTANT: Predict panics if Fit was not called or if the
|
// IMPORTANT: Predict panics if Fit was not called or if the
|
||||||
// document vector and train matrix have a different number of columns.
|
// document vector and train matrix have a different number of columns.
|
||||||
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
|
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
|
||||||
ret := what.GeneratePredictionVector()
|
ret := what.GeneratePredictionVector()
|
||||||
for i := 0; i < what.Rows; i++ {
|
for i := 0; i < what.Rows; i++ {
|
||||||
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
|
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
@ -203,7 +203,7 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
|
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
|
||||||
predAttrSpecs := base.ResolveAllAttributes(what, predAttrs)
|
predAttrSpecs := base.ResolveAttributes(what, predAttrs)
|
||||||
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||||
cur := d
|
cur := d
|
||||||
for {
|
for {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user