mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
support PredictProba
This commit is contained in:
parent
51d7b7d262
commit
f56fce1a43
83
trees/id3.go
83
trees/id3.go
@ -277,6 +277,89 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
return predictions, nil
|
||||
}
|
||||
|
||||
type ClassProba struct {
|
||||
probability float64
|
||||
classValue string
|
||||
}
|
||||
|
||||
type ClassesProba []ClassProba
|
||||
|
||||
func (o ClassesProba) Len() int {
|
||||
return len(o)
|
||||
}
|
||||
func (o ClassesProba) Swap(i, j int) {
|
||||
o[i], o[j] = o[j], o[i]
|
||||
}
|
||||
func (o ClassesProba) Less(i, j int) bool {
|
||||
return o[i].probability < o[j].probability
|
||||
}
|
||||
|
||||
// Predict class probabilities of the input samples what, returns a sorted array (by probability) of classes, and another array representing it's probabilities
|
||||
func (t *ID3DecisionTree) PredictProba(what base.FixedDataGrid) (ClassesProba, error) {
|
||||
d := t.Root
|
||||
predictions := base.GeneratePredictionVector(what)
|
||||
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
|
||||
predAttrSpecs := base.ResolveAttributes(what, predAttrs)
|
||||
|
||||
var results ClassesProba
|
||||
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
cur := d
|
||||
for {
|
||||
if cur.Children == nil {
|
||||
totalDist := 0
|
||||
for _,dist:= range cur.ClassDist {
|
||||
totalDist += dist
|
||||
}
|
||||
for class,dist:= range cur.ClassDist {
|
||||
classProba := ClassProba{classValue:class, probability: float64(dist/totalDist)}
|
||||
results = append(results,classProba)
|
||||
}
|
||||
sort.Sort(results)
|
||||
break
|
||||
} else {
|
||||
splitVal := cur.SplitRule.SplitVal
|
||||
at := cur.SplitRule.SplitAttr
|
||||
ats, err := what.GetAttribute(at)
|
||||
if err != nil {
|
||||
//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
|
||||
//break
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var classVar string
|
||||
if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok {
|
||||
// If it's a numeric Attribute (e.g. FloatAttribute) check that
|
||||
// the value of the current node is greater than the old one
|
||||
classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo))
|
||||
if classVal > splitVal {
|
||||
classVar = "1"
|
||||
} else {
|
||||
classVar = "0"
|
||||
}
|
||||
} else {
|
||||
classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
|
||||
}
|
||||
if next, ok := cur.Children[classVar]; ok {
|
||||
cur = next
|
||||
} else {
|
||||
// Suspicious of this
|
||||
var bestChild string
|
||||
for c := range cur.Children {
|
||||
bestChild = c
|
||||
if c > classVar {
|
||||
break
|
||||
}
|
||||
}
|
||||
cur = cur.Children[bestChild]
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
return results, nil
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// ID3 Tree type
|
||||
//
|
||||
|
Loading…
x
Reference in New Issue
Block a user