1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

base: Cleaned up duplicate Attribute resolution functions

This commit is contained in:
Richard Townsend 2014-08-03 12:31:26 +01:00
parent ff97065261
commit 47341b2869
15 changed files with 151 additions and 155 deletions

View File

@ -129,7 +129,7 @@ func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGri
rowCounter := 0 rowCounter := 0
specs := ResolveAllAttributes(u, u.AllAttributes()) specs := ResolveAttributes(u, u.AllAttributes())
for { for {
record, err := reader.Read() record, err := reader.Read()

View File

@ -379,7 +379,7 @@ func (inst *DenseInstances) Size() (int, int) {
// swapRows swaps over rows i and j // swapRows swaps over rows i and j
func (inst *DenseInstances) swapRows(i, j int) { func (inst *DenseInstances) swapRows(i, j int) {
as := GetAllAttributeSpecs(inst) as := ResolveAllAttributes(inst)
for _, a := range as { for _, a := range as {
v1 := inst.Get(a, i) v1 := inst.Get(a, i)
v2 := inst.Get(a, j) v2 := inst.Get(a, j)
@ -424,7 +424,7 @@ func (inst *DenseInstances) String() string {
var buffer bytes.Buffer var buffer bytes.Buffer
// Get all Attribute information // Get all Attribute information
as := GetAllAttributeSpecs(inst) as := ResolveAllAttributes(inst)
// Print header // Print header
cols, rows := inst.Size() cols, rows := inst.Size()

View File

@ -153,7 +153,7 @@ func (l *LazilyFilteredInstances) MapOverRows(asv []AttributeSpec, mapFunc func(
func (l *LazilyFilteredInstances) RowString(row int) string { func (l *LazilyFilteredInstances) RowString(row int) string {
var buffer bytes.Buffer var buffer bytes.Buffer
as := GetAllAttributeSpecs(l) // Retrieve all Attribute data as := ResolveAllAttributes(l) // Retrieve all Attribute data
first := true // Decide whether to prefix first := true // Decide whether to prefix
for _, a := range as { for _, a := range as {
@ -188,7 +188,7 @@ func (l *LazilyFilteredInstances) String() string {
} }
// Get all Attribute information // Get all Attribute information
as := GetAllAttributeSpecs(l) as := ResolveAllAttributes(l)
// Print header // Print header
buffer.WriteString("Lazily filtered instances using ") buffer.WriteString("Lazily filtered instances using ")

View File

@ -17,8 +17,8 @@ func TestLazySortDesc(testEnv *testing.T) {
return return
} }
as1 := GetAllAttributeSpecs(inst1) as1 := ResolveAllAttributes(inst1)
as2 := GetAllAttributeSpecs(inst2) as2 := ResolveAllAttributes(inst2)
if isSortedDesc(inst1, as1[0]) { if isSortedDesc(inst1, as1[0]) {
testEnv.Error("Can't test descending sort order") testEnv.Error("Can't test descending sort order")
@ -44,7 +44,7 @@ func TestLazySortDesc(testEnv *testing.T) {
func TestLazySortAsc(testEnv *testing.T) { func TestLazySortAsc(testEnv *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
as1 := GetAllAttributeSpecs(inst) as1 := ResolveAllAttributes(inst)
if isSortedAsc(inst, as1[0]) { if isSortedAsc(inst, as1[0]) {
testEnv.Error("Can't test ascending sort on something ascending already") testEnv.Error("Can't test ascending sort on something ascending already")
} }
@ -67,7 +67,7 @@ func TestLazySortAsc(testEnv *testing.T) {
testEnv.Error(err) testEnv.Error(err)
return return
} }
as2 := GetAllAttributeSpecs(inst2) as2 := ResolveAllAttributes(inst2)
if !isSortedAsc(inst2, as2[0]) { if !isSortedAsc(inst2, as2[0]) {
testEnv.Error("This file should be sorted in ascending order") testEnv.Error("This file should be sorted in ascending order")
} }

View File

@ -44,8 +44,8 @@ func TestSortDesc(testEnv *testing.T) {
return return
} }
as1 := GetAllAttributeSpecs(inst1) as1 := ResolveAllAttributes(inst1)
as2 := GetAllAttributeSpecs(inst2) as2 := ResolveAllAttributes(inst2)
if isSortedDesc(inst1, as1[0]) { if isSortedDesc(inst1, as1[0]) {
testEnv.Error("Can't test descending sort order") testEnv.Error("Can't test descending sort order")
@ -71,7 +71,7 @@ func TestSortDesc(testEnv *testing.T) {
func TestSortAsc(testEnv *testing.T) { func TestSortAsc(testEnv *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
as1 := GetAllAttributeSpecs(inst) as1 := ResolveAllAttributes(inst)
if isSortedAsc(inst, as1[0]) { if isSortedAsc(inst, as1[0]) {
testEnv.Error("Can't test ascending sort on something ascending already") testEnv.Error("Can't test ascending sort on something ascending already")
} }
@ -90,7 +90,7 @@ func TestSortAsc(testEnv *testing.T) {
testEnv.Error(err) testEnv.Error(err)
return return
} }
as2 := GetAllAttributeSpecs(inst2) as2 := ResolveAllAttributes(inst2)
if !isSortedAsc(inst2, as2[0]) { if !isSortedAsc(inst2, as2[0]) {
testEnv.Error("This file should be sorted in ascending order") testEnv.Error("This file should be sorted in ascending order")
} }

View File

@ -38,9 +38,9 @@ func NonClassAttributes(d DataGrid) []Attribute {
return AttributeDifferenceReferences(allAttrs, classAttrs) return AttributeDifferenceReferences(allAttrs, classAttrs)
} }
// ResolveAllAttributes returns AttributeSpecs describing // ResolveAttributes returns AttributeSpecs describing
// all of the Attributes. // all of the Attributes.
func ResolveAllAttributes(d DataGrid, attrs []Attribute) []AttributeSpec { func ResolveAttributes(d DataGrid, attrs []Attribute) []AttributeSpec {
ret := make([]AttributeSpec, len(attrs)) ret := make([]AttributeSpec, len(attrs))
for i, a := range attrs { for i, a := range attrs {
spec, err := d.GetAttribute(a) spec, err := d.GetAttribute(a)
@ -52,25 +52,9 @@ func ResolveAllAttributes(d DataGrid, attrs []Attribute) []AttributeSpec {
return ret return ret
} }
// GetAllAttributeSpecs retrieves every Attribute specification // ResolveAllAttributes returns every AttributeSpec
// from a given DataGrid. Useful in conjunction with MapOverRows. func ResolveAllAttributes(d DataGrid) []AttributeSpec {
func GetAllAttributeSpecs(from DataGrid) []AttributeSpec { return ResolveAttributes(d, d.AllAttributes())
attrs := from.AllAttributes()
return GetSomeAttributeSpecs(from, attrs)
}
// GetSomeAttributeSpecs returns a subset of Attribute specifications
// from a given DataGrid.
func GetSomeAttributeSpecs(from DataGrid, attrs []Attribute) []AttributeSpec {
ret := make([]AttributeSpec, len(attrs))
for i, a := range attrs {
as, err := from.GetAttribute(a)
if err != nil {
panic(err)
}
ret[i] = as
}
return ret
} }
func buildAttrSet(a []Attribute) map[Attribute]bool { func buildAttrSet(a []Attribute) map[Attribute]bool {

View File

@ -144,7 +144,7 @@ func DecomposeOnAttributeValues(inst FixedDataGrid, at Attribute) map[string]Fix
rowMaps := make(map[string][]int) rowMaps := make(map[string][]int)
// Build full Attribute set // Build full Attribute set
fullAttrSpec := ResolveAllAttributes(inst, newAttrs) fullAttrSpec := ResolveAttributes(inst, newAttrs)
fullAttrSpec = append(fullAttrSpec, attrSpec) fullAttrSpec = append(fullAttrSpec, attrSpec)
// Decompose // Decompose

View File

@ -78,7 +78,7 @@ func NewInstancesViewFromRows(src FixedDataGrid, rows map[int]int) *InstancesVie
func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribute) *InstancesView { func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribute) *InstancesView {
ret := &InstancesView{ ret := &InstancesView{
src, src,
GetSomeAttributeSpecs(src, attrs), ResolveAttributes(src, attrs),
make(map[int]int), make(map[int]int),
make(map[Attribute]bool), make(map[Attribute]bool),
true, true,
@ -99,7 +99,7 @@ func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribut
func NewInstancesViewFromAttrs(src FixedDataGrid, attrs []Attribute) *InstancesView { func NewInstancesViewFromAttrs(src FixedDataGrid, attrs []Attribute) *InstancesView {
ret := &InstancesView{ ret := &InstancesView{
src, src,
GetSomeAttributeSpecs(src, attrs), ResolveAttributes(src, attrs),
nil, nil,
make(map[Attribute]bool), make(map[Attribute]bool),
false, false,
@ -252,7 +252,7 @@ func (v *InstancesView) String() string {
maxRows := 30 maxRows := 30
// Get all Attribute information // Get all Attribute information
as := GetAllAttributeSpecs(v) as := ResolveAllAttributes(v)
// Print header // Print header
cols, rows := v.Size() cols, rows := v.Size()
@ -305,7 +305,7 @@ func (v *InstancesView) String() string {
// RowString returns a string representation of a given row. // RowString returns a string representation of a given row.
func (v *InstancesView) RowString(row int) string { func (v *InstancesView) RowString(row int) string {
var buffer bytes.Buffer var buffer bytes.Buffer
as := GetAllAttributeSpecs(v) as := ResolveAllAttributes(v)
first := true first := true
for _, a := range as { for _, a := range as {
val := v.Get(a, row) val := v.Get(a, row)

View File

@ -46,7 +46,7 @@ func main() {
// for doing so is not very sophisticated. // for doing so is not very sophisticated.
// First, have to resolve Attribute Specifications // First, have to resolve Attribute Specifications
as := base.ResolveAllAttributes(rawData, rawData.AllAttributes()) as := base.ResolveAttributes(rawData, rawData.AllAttributes())
// Attribute Specifications describe where a given column lives // Attribute Specifications describe where a given column lives
rawData.Set(as[0], 0, as[0].GetAttribute().GetSysValFromString("1.00")) rawData.Set(as[0], 0, as[0].GetAttribute().GetSysValFromString("1.00"))

View File

@ -112,7 +112,7 @@ func TestChiMerge2(testEnv *testing.T) {
// Sort the instances // Sort the instances
allAttrs := inst.AllAttributes() allAttrs := inst.AllAttributes()
sortAttrSpecs := base.ResolveAllAttributes(inst, allAttrs)[0:1] sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs) instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -65,8 +65,8 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
ret := base.GeneratePredictionVector(what) ret := base.GeneratePredictionVector(what)
// Resolve Attribute specifications for both // Resolve Attribute specifications for both
whatAttrSpecs := base.ResolveAllAttributes(what, allNumericAttrs) whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
trainAttrSpecs := base.ResolveAllAttributes(KNN.TrainingData, allNumericAttrs) trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)
// Reserve storage for most the most similar items // Reserve storage for most the most similar items
distances := make(map[int]float64) distances := make(map[int]float64)

View File

@ -34,7 +34,7 @@ func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
// Retrieve numeric non-class Attributes // Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(X) numericAttrs := base.NonClassFloatAttributes(X)
numericAttrSpecs := base.ResolveAllAttributes(X, numericAttrs) numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
// Convert each row // Convert each row
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) { X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
@ -66,7 +66,7 @@ func convertInstancesToLabelVec(X base.FixedDataGrid) []float64 {
_, rows := X.Size() _, rows := X.Size()
labelVec := make([]float64, rows) labelVec := make([]float64, rows)
// Resolve class Attribute specification // Resolve class Attribute specification
classAttrSpecs := base.ResolveAllAttributes(X, classAttrs) classAttrSpecs := base.ResolveAttributes(X, classAttrs)
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) { X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
labelVec[rowNo] = base.UnpackBytesToFloat(row[0]) labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
return true, nil return true, nil
@ -90,10 +90,10 @@ func (lr *LogisticRegression) Predict(X base.FixedDataGrid) base.FixedDataGrid {
} }
// Generate return structure // Generate return structure
ret := base.GeneratePredictionVector(X) ret := base.GeneratePredictionVector(X)
classAttrSpecs := base.ResolveAllAttributes(ret, classAttrs) classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
// Retrieve numeric non-class Attributes // Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(X) numericAttrs := base.NonClassFloatAttributes(X)
numericAttrSpecs := base.ResolveAllAttributes(X, numericAttrs) numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
// Allocate row storage // Allocate row storage
row := make([]float64, len(numericAttrSpecs)) row := make([]float64, len(numericAttrSpecs))

View File

@ -112,7 +112,7 @@ func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
for { // Need to resolve the voting problem for { // Need to resolve the voting problem
incoming, ok := <-votes incoming, ok := <-votes
if ok { if ok {
cSpecs := base.ResolveAllAttributes(incoming, incoming.AllClassAttributes()) cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) { incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
// Check if we've seen this class before... // Check if we've seen this class before...
if _, ok := voting[predRow]; !ok { if _, ok := voting[predRow]; !ok {

View File

@ -1,8 +1,8 @@
package naive package naive
import ( import (
"math" base "github.com/sjwhitworth/golearn/base"
base "github.com/sjwhitworth/golearn/base" "math"
) )
// A Bernoulli Naive Bayes Classifier. Naive Bayes classifiers assumes // A Bernoulli Naive Bayes Classifier. Naive Bayes classifiers assumes
@ -37,91 +37,103 @@ import (
// Information Retrieval. Cambridge University Press, pp. 234-265. // Information Retrieval. Cambridge University Press, pp. 234-265.
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html // http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
type BernoulliNBClassifier struct { type BernoulliNBClassifier struct {
base.BaseEstimator base.BaseEstimator
// Conditional probability for each term. This vector should be // Conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = condProb[c][f]. // accessed in the following way: p(f|c) = condProb[c][f].
// Logarithm is used in order to avoid underflow. // Logarithm is used in order to avoid underflow.
condProb map[string][]float64 condProb map[string][]float64
// Number of instances in each class. This is necessary in order to // Number of instances in each class. This is necessary in order to
// calculate the laplace smooth value during the Predict step. // calculate the laplace smooth value during the Predict step.
classInstances map[string]int classInstances map[string]int
// Number of instances used in training. // Number of instances used in training.
trainingInstances int trainingInstances int
// Number of features in the training set // Number of features in the training set
features int features int
} }
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes' // Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
// is the number of possible labels in the classification task. // is the number of possible labels in the classification task.
func NewBernoulliNBClassifier() *BernoulliNBClassifier { func NewBernoulliNBClassifier() *BernoulliNBClassifier {
nb := BernoulliNBClassifier{} nb := BernoulliNBClassifier{}
nb.condProb = make(map[string][]float64) nb.condProb = make(map[string][]float64)
nb.features = 0 nb.features = 0
nb.trainingInstances = 0 nb.trainingInstances = 0
return &nb return &nb
} }
// Fill data matrix with Bernoulli Naive Bayes model. All values // Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i) // necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) { func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
// Number of features and instances in this training set // Check that all Attributes are binary
nb.trainingInstances = X.Rows classAttrs := X.AllClassAttributes()
nb.features = 0 allAttrs := X.AllAttributes()
if X.Rows > 0 { featAttrs := base.AttributeDifferenceReference(allAttrs, classAttrs)
nb.features = len(X.GetRowVectorWithoutClass(0)) for i := range featAttrs {
} if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
}
}
featAttrSpecs := base.ResolveAllAttributes(featAttrs, X)
// Number of instances in class // Check that only one classAttribute is defined
nb.classInstances = make(map[string]int) if len(classAttrs) > 0 {
panic("Only one class Attribute can be used")
}
// Number of documents with given term (by class) // Number of features and instances in this training set
docsContainingTerm := make(map[string][]int) nb.features, nb.trainingInstances() = X.Size()
// This algorithm could be vectorized after binarizing the data // Number of instances in class
// matrix. Since mat64 doesn't have this function, a iterative nb.classInstances = make(map[string]int)
// version is used.
for r := 0; r < X.Rows; r++ {
class := X.GetClass(r)
docVector := X.GetRowVectorWithoutClass(r)
// increment number of instances in class // Number of documents with given term (by class)
t, ok := nb.classInstances[class] docsContainingTerm := make(map[string][]int)
if !ok { t = 0 }
nb.classInstances[class] = t + 1
// This algorithm could be vectorized after binarizing the data
// matrix. Since mat64 doesn't have this function, a iterative
// version is used.
X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
class := base.GetClass(X, r)
for feat := 0; feat < len(docVector); feat++ { // increment number of instances in class
v := docVector[feat] t, ok := nb.classInstances[class]
// In Bernoulli Naive Bayes the presence and absence of if !ok {
// features are considered. All non-zero values are t = 0
// treated as presence. }
if v > 0 { nb.classInstances[class] = t + 1
// Update number of times this feature appeared within
// given label.
t, ok := docsContainingTerm[class]
if !ok {
t = make([]int, nb.features)
docsContainingTerm[class] = t
}
t[feat] += 1
}
}
}
// Pre-calculate conditional probabilities for each class for feat := 0; feat < len(docVector); feat++ {
for c, _ := range nb.classInstances { v := docVector[feat]
nb.condProb[c] = make([]float64, nb.features) // In Bernoulli Naive Bayes the presence and absence of
for feat := 0; feat < nb.features; feat++ { // features are considered. All non-zero values are
classTerms, _ := docsContainingTerm[c] // treated as presence.
numDocs := classTerms[feat] if v[0] > 0 {
docsInClass, _ := nb.classInstances[c] // Update number of times this feature appeared within
// given label.
t, ok := docsContainingTerm[class]
if !ok {
t = make([]int, nb.features)
docsContainingTerm[class] = t
}
t[feat] += 1
}
}
})
classCondProb, _ := nb.condProb[c] // Pre-calculate conditional probabilities for each class
// Calculate conditional probability with laplace smoothing for c, _ := range nb.classInstances {
classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1) nb.condProb[c] = make([]float64, nb.features)
} for feat := 0; feat < nb.features; feat++ {
} classTerms, _ := docsContainingTerm[c]
numDocs := classTerms[feat]
docsInClass, _ := nb.classInstances[c]
classCondProb, _ := nb.condProb[c]
// Calculate conditional probability with laplace smoothing
classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
}
}
} }
// Use trained model to predict test vector's class. The following // Use trained model to predict test vector's class. The following
@ -134,43 +146,43 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// IMPORTANT: PredictOne panics if Fit was not called or if the // IMPORTANT: PredictOne panics if Fit was not called or if the
// document vector and train matrix have a different number of columns. // document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string { func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
if nb.features == 0 { if nb.features == 0 {
panic("Fit should be called before predicting") panic("Fit should be called before predicting")
} }
if len(vector) != nb.features { if len(vector) != nb.features {
panic("Different dimensions in Train and Test sets") panic("Different dimensions in Train and Test sets")
} }
// Currently only the predicted class is returned. // Currently only the predicted class is returned.
bestScore := -math.MaxFloat64 bestScore := -math.MaxFloat64
bestClass := "" bestClass := ""
for class, classCount := range nb.classInstances { for class, classCount := range nb.classInstances {
// Init classScore with log(prior) // Init classScore with log(prior)
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances)) classScore := math.Log((float64(classCount)) / float64(nb.trainingInstances))
for f := 0; f < nb.features; f++ { for f := 0; f < nb.features; f++ {
if vector[f] > 0 { if vector[f] > 0 {
// Test document has feature c // Test document has feature c
classScore += math.Log(nb.condProb[class][f]) classScore += math.Log(nb.condProb[class][f])
} else { } else {
if nb.condProb[class][f] == 1.0 { if nb.condProb[class][f] == 1.0 {
// special case when prob = 1.0, consider laplace // special case when prob = 1.0, consider laplace
// smooth // smooth
classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1)) classScore += math.Log(1.0 / float64(nb.classInstances[class]+1))
} else { } else {
classScore += math.Log(1.0 - nb.condProb[class][f]) classScore += math.Log(1.0 - nb.condProb[class][f])
} }
} }
} }
if classScore > bestScore { if classScore > bestScore {
bestScore = classScore bestScore = classScore
bestClass = class bestClass = class
} }
} }
return bestClass return bestClass
} }
// Predict is just a wrapper for the PredictOne function. // Predict is just a wrapper for the PredictOne function.
@ -178,9 +190,9 @@ func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
// IMPORTANT: Predict panics if Fit was not called or if the // IMPORTANT: Predict panics if Fit was not called or if the
// document vector and train matrix have a different number of columns. // document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances { func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
ret := what.GeneratePredictionVector() ret := what.GeneratePredictionVector()
for i := 0; i < what.Rows; i++ { for i := 0; i < what.Rows; i++ {
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i))) ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
} }
return ret return ret
} }

View File

@ -203,7 +203,7 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
panic(err) panic(err)
} }
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes()) predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
predAttrSpecs := base.ResolveAllAttributes(what, predAttrs) predAttrSpecs := base.ResolveAttributes(what, predAttrs)
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) { what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
cur := d cur := d
for { for {