1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/base/util_instances.go
Stephen Whitworth 7ea42ac80b Merge pull request #101 from Sentimentron/arff-staging
ARFF import/export, CSV export, lossless serialisation
2014-11-21 13:53:43 +00:00

535 lines
14 KiB
Go

package base
import (
"fmt"
"math/rand"
)
// This file contains utility functions relating to efficiently
// generating predictions and instantiating DataGrid implementations.
// GeneratePredictionVector selects the class Attributes from a given
// FixedDataGrid and returns something which can hold the predictions.
func GeneratePredictionVector(from FixedDataGrid) UpdatableDataGrid {
classAttrs := from.AllClassAttributes()
_, rowCount := from.Size()
ret := NewDenseInstances()
for _, a := range classAttrs {
ret.AddAttribute(a)
ret.AddClassAttribute(a)
}
ret.Extend(rowCount)
return ret
}
// CopyDenseInstancesStructure returns a new DenseInstances
// with identical structure (layout, Attributes) to the original
func CopyDenseInstances(template *DenseInstances, templateAttrs []Attribute) *DenseInstances {
instances := NewDenseInstances()
templateAgs := template.AllAttributeGroups()
for ag := range templateAgs {
agTemplate := templateAgs[ag]
if _, ok := agTemplate.(*BinaryAttributeGroup); ok {
instances.CreateAttributeGroup(ag, 0)
} else {
instances.CreateAttributeGroup(ag, 8)
}
}
for _, a := range templateAttrs {
s, err := template.GetAttribute(a)
if err != nil {
panic(err)
}
if ag, ok := template.agRevMap[s.pond]; !ok {
panic(ag)
} else {
_, err := instances.AddAttributeToAttributeGroup(a, ag)
if err != nil {
panic(err)
}
}
}
return instances
}
// GetClass is a shortcut for returning the string value of the current
// class on a given row.
//
// IMPORTANT: GetClass will panic if the number of ClassAttributes is
// set to anything other than one.
func GetClass(from DataGrid, row int) string {
// Get the Attribute
classAttrs := from.AllClassAttributes()
if len(classAttrs) > 1 {
panic("More than one class defined")
} else if len(classAttrs) == 0 {
panic("No class defined!")
}
classAttr := classAttrs[0]
// Fetch and convert the class value
classAttrSpec, err := from.GetAttribute(classAttr)
if err != nil {
panic(fmt.Errorf("Can't resolve class Attribute %s", err))
}
classVal := from.Get(classAttrSpec, row)
if classVal == nil {
panic("Class values shouldn't be missing")
}
return classAttr.GetStringFromSysVal(classVal)
}
// SetClass is a shortcut for updating the given class of a row.
//
// IMPORTANT: SetClass will panic if the number of class Attributes
// is anything other than one.
func SetClass(at UpdatableDataGrid, row int, class string) {
// Get the Attribute
classAttrs := at.AllClassAttributes()
if len(classAttrs) > 1 {
panic("More than one class defined")
} else if len(classAttrs) == 0 {
panic("No class Attributes are defined")
}
classAttr := classAttrs[0]
// Fetch and convert the class value
classAttrSpec, err := at.GetAttribute(classAttr)
if err != nil {
panic(fmt.Errorf("Can't resolve class Attribute %s", err))
}
classBytes := classAttr.GetSysValFromString(class)
at.Set(classAttrSpec, row, classBytes)
}
// GetAttributeByName returns an Attribute matching a given name.
// Returns nil if one doesn't exist.
func GetAttributeByName(inst FixedDataGrid, name string) Attribute {
for _, a := range inst.AllAttributes() {
if a.GetName() == name {
return a
}
}
return nil
}
// GetClassDistributionByBinaryFloatValue returns the count of each row
// which has a float value close to 0.0 or 1.0.
func GetClassDistributionByBinaryFloatValue(inst FixedDataGrid) []int {
// Get the class variable
attrs := inst.AllClassAttributes()
if len(attrs) != 1 {
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
}
if _, ok := attrs[0].(*FloatAttribute); !ok {
panic(fmt.Errorf("Class Attribute must be FloatAttribute (is %s)", attrs[0]))
}
// Get the number of class values
ret := make([]int, 2)
// Map through everything
specs := ResolveAttributes(inst, attrs)
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
index := UnpackBytesToFloat(vals[0])
if index > 0.5 {
ret[1]++
} else {
ret[0]++
}
return false, nil
})
return ret
}
// GetClassDistributionByIntegerVal returns a vector containing
// the count of each class vector (indexed by the class' system
// integer representation)
func GetClassDistributionByCategoricalValue(inst FixedDataGrid) []int {
var classAttr *CategoricalAttribute
var ok bool
// Get the class variable
attrs := inst.AllClassAttributes()
if len(attrs) != 1 {
panic(fmt.Errorf("Wrong number of class variables (has %d, should be 1)", len(attrs)))
}
if classAttr, ok = attrs[0].(*CategoricalAttribute); !ok {
panic(fmt.Errorf("Class Attribute must be a CategoricalAttribute (is %s)", attrs[0]))
}
// Get the number of class values
classLen := len(classAttr.GetValues())
ret := make([]int, classLen)
// Map through everything
specs := ResolveAttributes(inst, attrs)
inst.MapOverRows(specs, func(vals [][]byte, row int) (bool, error) {
index := UnpackBytesToU64(vals[0])
ret[int(index)]++
return false, nil
})
return ret
}
// GetClassDistribution returns a map containing the count of each
// class type (indexed by the class' string representation).
func GetClassDistribution(inst FixedDataGrid) map[string]int {
ret := make(map[string]int)
_, rows := inst.Size()
for i := 0; i < rows; i++ {
cls := GetClass(inst, i)
ret[cls]++
}
return ret
}
// GetClassDistributionAfterThreshold returns the class distribution
// after a speculative split on a given Attribute using a threshold.
func GetClassDistributionAfterThreshold(inst FixedDataGrid, at Attribute, val float64) map[string]map[string]int {
ret := make(map[string]map[string]int)
// Find the attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid attribute %s (%s)", at, err))
}
// Validate
if _, ok := at.(*FloatAttribute); !ok {
panic(fmt.Sprintf("Must be numeric!"))
}
_, rows := inst.Size()
for i := 0; i < rows; i++ {
splitVal := UnpackBytesToFloat(inst.Get(attrSpec, i)) > val
splitVar := "0"
if splitVal {
splitVar = "1"
}
classVar := GetClass(inst, i)
if _, ok := ret[splitVar]; !ok {
ret[splitVar] = make(map[string]int)
i--
continue
}
ret[splitVar][classVar]++
}
return ret
}
// GetClassDistributionAfterSplit returns the class distribution
// after a speculative split on a given Attribute.
func GetClassDistributionAfterSplit(inst FixedDataGrid, at Attribute) map[string]map[string]int {
ret := make(map[string]map[string]int)
// Find the attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid attribute %s (%s)", at, err))
}
_, rows := inst.Size()
for i := 0; i < rows; i++ {
splitVar := at.GetStringFromSysVal(inst.Get(attrSpec, i))
classVar := GetClass(inst, i)
if _, ok := ret[splitVar]; !ok {
ret[splitVar] = make(map[string]int)
i--
continue
}
ret[splitVar][classVar]++
}
return ret
}
// DecomposeOnNumericAttributeThreshold divides the instance set depending on the
// value of a given numeric Attribute, constructs child instances, and returns
// them in a map keyed on whether that row had a higher value than the threshold
// or not.
//
// IMPORTANT: calls panic() if the AttributeSpec of at cannot be determined, or if
// the Attribute is not numeric.
func DecomposeOnNumericAttributeThreshold(inst FixedDataGrid, at Attribute, val float64) map[string]FixedDataGrid {
// Verify
if _, ok := at.(*FloatAttribute); !ok {
panic("Invalid argument")
}
// Find the Attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid Attribute index %s", at))
}
// Construct the new Attribute set
newAttrs := make([]Attribute, 0)
for _, a := range inst.AllAttributes() {
if a.Equals(at) {
continue
}
newAttrs = append(newAttrs, a)
}
// Create the return map
ret := make(map[string]FixedDataGrid)
// Create the return row mapping
rowMaps := make(map[string][]int)
// Build full Attribute set
fullAttrSpec := ResolveAttributes(inst, newAttrs)
fullAttrSpec = append(fullAttrSpec, attrSpec)
// Decompose
inst.MapOverRows(fullAttrSpec, func(row [][]byte, rowNo int) (bool, error) {
// Find the output instance set
targetBytes := row[len(row)-1]
targetVal := UnpackBytesToFloat(targetBytes)
val := targetVal > val
targetSet := "0"
if val {
targetSet = "1"
}
rowMap := rowMaps[targetSet]
rowMaps[targetSet] = append(rowMap, rowNo)
return true, nil
})
for a := range rowMaps {
ret[a] = NewInstancesViewFromVisible(inst, rowMaps[a], newAttrs)
}
return ret
}
// DecomposeOnAttributeValues divides the instance set depending on the
// value of a given Attribute, constructs child instances, and returns
// them in a map keyed on the string value of that Attribute.
//
// IMPORTANT: calls panic() if the AttributeSpec of at cannot be determined.
func DecomposeOnAttributeValues(inst FixedDataGrid, at Attribute) map[string]FixedDataGrid {
// Find the Attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid Attribute index %s", at))
}
// Construct the new Attribute set
newAttrs := make([]Attribute, 0)
for _, a := range inst.AllAttributes() {
if a.Equals(at) {
continue
}
newAttrs = append(newAttrs, a)
}
// Create the return map
ret := make(map[string]FixedDataGrid)
// Create the return row mapping
rowMaps := make(map[string][]int)
// Build full Attribute set
fullAttrSpec := ResolveAttributes(inst, newAttrs)
fullAttrSpec = append(fullAttrSpec, attrSpec)
// Decompose
inst.MapOverRows(fullAttrSpec, func(row [][]byte, rowNo int) (bool, error) {
// Find the output instance set
targetBytes := row[len(row)-1]
targetAttr := fullAttrSpec[len(fullAttrSpec)-1].attr
targetSet := targetAttr.GetStringFromSysVal(targetBytes)
if _, ok := rowMaps[targetSet]; !ok {
rowMaps[targetSet] = make([]int, 0)
}
rowMap := rowMaps[targetSet]
rowMaps[targetSet] = append(rowMap, rowNo)
return true, nil
})
for a := range rowMaps {
ret[a] = NewInstancesViewFromVisible(inst, rowMaps[a], newAttrs)
}
return ret
}
// InstancesTrainTestSplit takes a given Instances (src) and a train-test fraction
// (prop) and returns an array of two new Instances, one containing approximately
// that fraction and the other containing what's left.
//
// IMPORTANT: this function is only meaningful when prop is between 0.0 and 1.0.
// Using any other values may result in odd behaviour.
func InstancesTrainTestSplit(src FixedDataGrid, prop float64) (FixedDataGrid, FixedDataGrid) {
trainingRows := make([]int, 0)
testingRows := make([]int, 0)
src = Shuffle(src)
// Create the return structure
_, rows := src.Size()
for i := 0; i < rows; i++ {
trainOrTest := rand.Intn(101)
if trainOrTest > int(100*prop) {
trainingRows = append(trainingRows, i)
} else {
testingRows = append(testingRows, i)
}
}
allAttrs := src.AllAttributes()
return NewInstancesViewFromVisible(src, trainingRows, allAttrs), NewInstancesViewFromVisible(src, testingRows, allAttrs)
}
// LazyShuffle randomizes the row order without re-ordering the rows
// via an InstancesView.
func LazyShuffle(from FixedDataGrid) FixedDataGrid {
_, rows := from.Size()
rowMap := make(map[int]int)
for i := 0; i < rows; i++ {
j := rand.Intn(i + 1)
rowMap[i] = j
rowMap[j] = i
}
return NewInstancesViewFromRows(from, rowMap)
}
// Shuffle randomizes the row order either in place (if DenseInstances)
// or using LazyShuffle.
func Shuffle(from FixedDataGrid) FixedDataGrid {
_, rows := from.Size()
if inst, ok := from.(*DenseInstances); ok {
for i := 0; i < rows; i++ {
j := rand.Intn(i + 1)
inst.swapRows(i, j)
}
return inst
} else {
return LazyShuffle(from)
}
}
// SampleWithReplacement returns a new FixedDataGrid containing
// an equal number of random rows drawn from the original FixedDataGrid
//
// IMPORTANT: There's a high chance of seeing duplicate rows
// whenever size is close to the row count.
func SampleWithReplacement(from FixedDataGrid, size int) FixedDataGrid {
rowMap := make(map[int]int)
_, rows := from.Size()
for i := 0; i < size; i++ {
srcRow := rand.Intn(rows)
rowMap[i] = srcRow
}
return NewInstancesViewFromRows(from, rowMap)
}
// CheckCompatible checks whether two DataGrids have the same Attributes
// and if they do, it returns them.
func CheckCompatible(s1 FixedDataGrid, s2 FixedDataGrid) []Attribute {
s1Attrs := s1.AllAttributes()
s2Attrs := s2.AllAttributes()
interAttrs := AttributeIntersect(s1Attrs, s2Attrs)
if len(interAttrs) != len(s1Attrs) {
return nil
} else if len(interAttrs) != len(s2Attrs) {
return nil
}
return interAttrs
}
// CheckStrictlyCompatible checks whether two DenseInstances have
// AttributeGroups with the same Attributes, in the same order,
// enabling optimisations.
func CheckStrictlyCompatible(s1 FixedDataGrid, s2 FixedDataGrid) bool {
// Cast
d1, ok1 := s1.(*DenseInstances)
d2, ok2 := s2.(*DenseInstances)
if !ok1 || !ok2 {
return false
}
// Retrieve AttributeGroups
d1ags := d1.AllAttributeGroups()
d2ags := d2.AllAttributeGroups()
// Check everything in d1 is in d2
for a := range d1ags {
_, ok := d2ags[a]
if !ok {
return false
}
}
// Check everything in d2 is in d1
for a := range d2ags {
_, ok := d1ags[a]
if !ok {
return false
}
}
// Check that everything has the same number
// of equivalent Attributes, in the same order
for a := range d1ags {
ag1 := d1ags[a]
ag2 := d2ags[a]
a1 := ag1.Attributes()
a2 := ag2.Attributes()
for i := range a1 {
at1 := a1[i]
at2 := a2[i]
if !at1.Equals(at2) {
return false
}
}
}
return true
}
// InstancesAreEqual checks whether a given Instance set is exactly
// the same as another (i.e. has the same size and values).
func InstancesAreEqual(inst, other FixedDataGrid) bool {
_, rows := inst.Size()
for _, a := range inst.AllAttributes() {
as1, err := inst.GetAttribute(a)
if err != nil {
panic(err) // That indicates some kind of error
}
as2, err := inst.GetAttribute(a)
if err != nil {
return false // Obviously has different Attributes
}
if !as1.GetAttribute().Equals(as2.GetAttribute()) {
return false
}
for i := 0; i < rows; i++ {
b1 := inst.Get(as1, i)
b2 := inst.Get(as2, i)
if !byteSeqEqual(b1, b2) {
return false
}
}
}
return true
}