mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-30 13:48:57 +08:00
Replace Panics with error returns to BernoulliNBClassifier Fit method to satisfy base.Classifier interface
This commit is contained in:
parent
c3cae572f4
commit
bffc4a52e6
@ -182,7 +182,7 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
|||||||
|
|
||||||
// Fill data matrix with Bernoulli Naive Bayes model. All values
|
// Fill data matrix with Bernoulli Naive Bayes model. All values
|
||||||
// necessary for calculating prior probability and p(f_i)
|
// necessary for calculating prior probability and p(f_i)
|
||||||
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
|
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) error {
|
||||||
|
|
||||||
// Check that all Attributes are binary
|
// Check that all Attributes are binary
|
||||||
classAttrs := X.AllClassAttributes()
|
classAttrs := X.AllClassAttributes()
|
||||||
@ -190,14 +190,14 @@ func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
|
|||||||
featAttrs := base.AttributeDifference(allAttrs, classAttrs)
|
featAttrs := base.AttributeDifference(allAttrs, classAttrs)
|
||||||
for i := range featAttrs {
|
for i := range featAttrs {
|
||||||
if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
|
if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
|
||||||
panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
|
return fmt.Errorf("%v: Should be BinaryAttribute", featAttrs[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
featAttrSpecs := base.ResolveAttributes(X, featAttrs)
|
featAttrSpecs := base.ResolveAttributes(X, featAttrs)
|
||||||
|
|
||||||
// Check that only one classAttribute is defined
|
// Check that only one classAttribute is defined
|
||||||
if len(classAttrs) != 1 {
|
if len(classAttrs) != 1 {
|
||||||
panic("Only one class Attribute can be used")
|
return fmt.Errorf("Only one class Attribute can be used")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Number of features and instances in this training set
|
// Number of features and instances in this training set
|
||||||
@ -258,6 +258,7 @@ func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nb.fitOn = base.NewStructuralCopy(X)
|
nb.fitOn = base.NewStructuralCopy(X)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use trained model to predict test vector's class. The following
|
// Use trained model to predict test vector's class. The following
|
||||||
|
Loading…
x
Reference in New Issue
Block a user