1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-30 13:48:57 +08:00

Replace Panics with error returns to BernoulliNBClassifier Fit method to satisfy base.Classifier interface

This commit is contained in:
Justin Judd 2019-07-17 11:44:53 +09:00
parent c3cae572f4
commit bffc4a52e6

View File

@ -182,7 +182,7 @@ func NewBernoulliNBClassifier() *BernoulliNBClassifier {
// Fill data matrix with Bernoulli Naive Bayes model. All values // Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i) // necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) { func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) error {
// Check that all Attributes are binary // Check that all Attributes are binary
classAttrs := X.AllClassAttributes() classAttrs := X.AllClassAttributes()
@ -190,14 +190,14 @@ func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
featAttrs := base.AttributeDifference(allAttrs, classAttrs) featAttrs := base.AttributeDifference(allAttrs, classAttrs)
for i := range featAttrs { for i := range featAttrs {
if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok { if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i])) return fmt.Errorf("%v: Should be BinaryAttribute", featAttrs[i])
} }
} }
featAttrSpecs := base.ResolveAttributes(X, featAttrs) featAttrSpecs := base.ResolveAttributes(X, featAttrs)
// Check that only one classAttribute is defined // Check that only one classAttribute is defined
if len(classAttrs) != 1 { if len(classAttrs) != 1 {
panic("Only one class Attribute can be used") return fmt.Errorf("Only one class Attribute can be used")
} }
// Number of features and instances in this training set // Number of features and instances in this training set
@ -258,6 +258,7 @@ func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
} }
nb.fitOn = base.NewStructuralCopy(X) nb.fitOn = base.NewStructuralCopy(X)
return nil
} }
// Use trained model to predict test vector's class. The following // Use trained model to predict test vector's class. The following