1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

support PredictProba

This commit is contained in:
meirwahnon 2017-07-17 14:48:38 +03:00
parent 51d7b7d262
commit f56fce1a43

View File

@ -277,6 +277,89 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
return predictions, nil
}
type ClassProba struct {
probability float64
classValue string
}
type ClassesProba []ClassProba
func (o ClassesProba) Len() int {
return len(o)
}
func (o ClassesProba) Swap(i, j int) {
o[i], o[j] = o[j], o[i]
}
func (o ClassesProba) Less(i, j int) bool {
return o[i].probability < o[j].probability
}
// Predict class probabilities of the input samples what, returns a sorted array (by probability) of classes, and another array representing it's probabilities
func (t *ID3DecisionTree) PredictProba(what base.FixedDataGrid) (ClassesProba, error) {
d := t.Root
predictions := base.GeneratePredictionVector(what)
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
predAttrSpecs := base.ResolveAttributes(what, predAttrs)
var results ClassesProba
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
cur := d
for {
if cur.Children == nil {
totalDist := 0
for _,dist:= range cur.ClassDist {
totalDist += dist
}
for class,dist:= range cur.ClassDist {
classProba := ClassProba{classValue:class, probability: float64(dist/totalDist)}
results = append(results,classProba)
}
sort.Sort(results)
break
} else {
splitVal := cur.SplitRule.SplitVal
at := cur.SplitRule.SplitAttr
ats, err := what.GetAttribute(at)
if err != nil {
//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
//break
panic(err)
}
var classVar string
if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok {
// If it's a numeric Attribute (e.g. FloatAttribute) check that
// the value of the current node is greater than the old one
classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo))
if classVal > splitVal {
classVar = "1"
} else {
classVar = "0"
}
} else {
classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
}
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
// Suspicious of this
var bestChild string
for c := range cur.Children {
bestChild = c
if c > classVar {
break
}
}
cur = cur.Children[bestChild]
}
}
}
return true, nil
})
return results, nil
}
//
// ID3 Tree type
//