diff --git a/knn/knn.go b/knn/knn.go index 38c9e7e..c003a54 100644 --- a/knn/knn.go +++ b/knn/knn.go @@ -124,6 +124,12 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid { } } + // If every Attribute is a FloatAttribute, then we remove the last one + // because that is the Attribute we are trying to predict. + if len(allNumericAttrs) == len(allAttrs) { + allNumericAttrs = allNumericAttrs[:len(allNumericAttrs)-1] + } + // Generate return vector ret := base.GeneratePredictionVector(what)