mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
commit
76ef9ede34
@ -1,16 +1,14 @@
|
||||
package base
|
||||
|
||||
import "fmt"
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
// CategoricalType is for Attributes which represent values distinctly.
|
||||
CategoricalType = iota
|
||||
// Float64Type should be replaced with a FractionalNumeric type [DEPRECATED].
|
||||
Float64Type
|
||||
BinaryType
|
||||
)
|
||||
|
||||
// Attribute Attributes disambiguate columns of the feature matrix and declare their types.
|
||||
// Attributes disambiguate columns of the feature matrix and declare their types.
|
||||
type Attribute interface {
|
||||
// Returns the general characterstics of this Attribute .
|
||||
// to avoid the overhead of casting
|
||||
@ -25,12 +23,12 @@ type Attribute interface {
|
||||
// representation. For example, a CategoricalAttribute with values
|
||||
// ["iris-setosa", "iris-virginica"] would return the float64
|
||||
// representation of 0 when given "iris-setosa".
|
||||
GetSysValFromString(string) float64
|
||||
GetSysValFromString(string) []byte
|
||||
// Converts a given value from a system representation into a human
|
||||
// representation. For example, a CategoricalAttribute with values
|
||||
// ["iris-setosa", "iris-viriginica"] might return "iris-setosa"
|
||||
// when given 0.0 as the argument.
|
||||
GetStringFromSysVal(float64) string
|
||||
GetStringFromSysVal([]byte) string
|
||||
// Tests for equality with another Attribute. Other Attributes are
|
||||
// considered equal if:
|
||||
// * They have the same type (i.e. FloatAttribute <> CategoricalAttribute)
|
||||
@ -38,230 +36,8 @@ type Attribute interface {
|
||||
// * If applicable, they have the same categorical values (though not
|
||||
// necessarily in the same order).
|
||||
Equals(Attribute) bool
|
||||
}
|
||||
|
||||
// FloatAttribute is an implementation which stores floating point
|
||||
// representations of numbers.
|
||||
type FloatAttribute struct {
|
||||
Name string
|
||||
Precision int
|
||||
}
|
||||
|
||||
// NewFloatAttribute returns a new FloatAttribute with a default
|
||||
// precision of 2 decimal places
|
||||
func NewFloatAttribute() *FloatAttribute {
|
||||
return &FloatAttribute{"", 2}
|
||||
}
|
||||
|
||||
// Equals tests a FloatAttribute for equality with another Attribute.
|
||||
//
|
||||
// Returns false if the other Attribute has a different name
|
||||
// or if the other Attribute is not a FloatAttribute.
|
||||
func (Attr *FloatAttribute) Equals(other Attribute) bool {
|
||||
// Check whether this FloatAttribute is equal to another
|
||||
_, ok := other.(*FloatAttribute)
|
||||
if !ok {
|
||||
// Not the same type, so can't be equal
|
||||
return false
|
||||
}
|
||||
if Attr.GetName() != other.GetName() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetName returns this FloatAttribute's human-readable name.
|
||||
func (Attr *FloatAttribute) GetName() string {
|
||||
return Attr.Name
|
||||
}
|
||||
|
||||
// SetName sets this FloatAttribute's human-readable name.
|
||||
func (Attr *FloatAttribute) SetName(name string) {
|
||||
Attr.Name = name
|
||||
}
|
||||
|
||||
// GetType returns Float64Type.
|
||||
func (Attr *FloatAttribute) GetType() int {
|
||||
return Float64Type
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this Attribute.
|
||||
// e.g. "FloatAttribute(Sepal Width)"
|
||||
func (Attr *FloatAttribute) String() string {
|
||||
return fmt.Sprintf("FloatAttribute(%s)", Attr.Name)
|
||||
}
|
||||
|
||||
// CheckSysValFromString confirms whether a given rawVal can
|
||||
// be converted into a valid system representation.
|
||||
func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) (float64, error) {
|
||||
f, err := strconv.ParseFloat(rawVal, 64)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// GetSysValFromString parses the given rawVal string to a float64 and returns it.
|
||||
//
|
||||
// float64 happens to be a 1-to-1 mapping to the system representation.
|
||||
// IMPORTANT: This function panic()s if rawVal is not a valid float.
|
||||
// Use CheckSysValFromString to confirm.
|
||||
func (Attr *FloatAttribute) GetSysValFromString(rawVal string) float64 {
|
||||
f, err := strconv.ParseFloat(rawVal, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// GetStringFromSysVal converts a given system value to to a string with two decimal
|
||||
// places of precision [TODO: revise this and allow more precision].
|
||||
func (Attr *FloatAttribute) GetStringFromSysVal(rawVal float64) string {
|
||||
formatString := fmt.Sprintf("%%.%df", Attr.Precision)
|
||||
return fmt.Sprintf(formatString, rawVal)
|
||||
}
|
||||
|
||||
// GetSysVal returns the system representation of userVal.
|
||||
//
|
||||
// Because FloatAttribute represents float64 types, this
|
||||
// just returns its argument.
|
||||
func (Attr *FloatAttribute) GetSysVal(userVal float64) float64 {
|
||||
return userVal
|
||||
}
|
||||
|
||||
// GetUsrVal returns the user representation of sysVal.
|
||||
//
|
||||
// Because FloatAttribute represents float64 types, this
|
||||
// just returns its argument.
|
||||
func (Attr *FloatAttribute) GetUsrVal(sysVal float64) float64 {
|
||||
return sysVal
|
||||
}
|
||||
|
||||
// CategoricalAttribute is an Attribute implementation
|
||||
// which stores discrete string values
|
||||
// - useful for representing classes.
|
||||
type CategoricalAttribute struct {
|
||||
Name string
|
||||
values []string
|
||||
}
|
||||
|
||||
func NewCategoricalAttribute() *CategoricalAttribute {
|
||||
return &CategoricalAttribute{
|
||||
"",
|
||||
make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetName returns the human-readable name assigned to this attribute.
|
||||
func (Attr *CategoricalAttribute) GetName() string {
|
||||
return Attr.Name
|
||||
}
|
||||
|
||||
// SetName sets the human-readable name on this attribute.
|
||||
func (Attr *CategoricalAttribute) SetName(name string) {
|
||||
Attr.Name = name
|
||||
}
|
||||
|
||||
// GetType returns CategoricalType to avoid casting overhead.
|
||||
func (Attr *CategoricalAttribute) GetType() int {
|
||||
return CategoricalType
|
||||
}
|
||||
|
||||
// GetSysVal returns the system representation of userVal as an index into the Values slice
|
||||
// If the userVal can't be found, it returns -1.
|
||||
func (Attr *CategoricalAttribute) GetSysVal(userVal string) float64 {
|
||||
for idx, val := range Attr.values {
|
||||
if val == userVal {
|
||||
return float64(idx)
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetUsrVal returns a human-readable representation of the given sysVal.
|
||||
//
|
||||
// IMPORTANT: this function doesn't check the boundaries of the array.
|
||||
func (Attr *CategoricalAttribute) GetUsrVal(sysVal float64) string {
|
||||
idx := int(sysVal)
|
||||
return Attr.values[idx]
|
||||
}
|
||||
|
||||
// GetSysValFromString returns the system representation of rawVal
|
||||
// as an index into the Values slice. If rawVal is not inside
|
||||
// the Values slice, it is appended.
|
||||
//
|
||||
// IMPORTANT: If no system representation yet exists, this functions adds it.
|
||||
// If you need to determine whether rawVal exists: use GetSysVal and check
|
||||
// for a -1 return value.
|
||||
//
|
||||
// Example: if the CategoricalAttribute contains the values ["iris-setosa",
|
||||
// "iris-virginica"] and "iris-versicolor" is provided as the argument,
|
||||
// the Values slide becomes ["iris-setosa", "iris-virginica", "iris-versicolor"]
|
||||
// and 2.00 is returned as the system representation.
|
||||
func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) float64 {
|
||||
// Match in raw values
|
||||
catIndex := -1
|
||||
for i, s := range Attr.values {
|
||||
if s == rawVal {
|
||||
catIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if catIndex == -1 {
|
||||
Attr.values = append(Attr.values, rawVal)
|
||||
catIndex = len(Attr.values) - 1
|
||||
}
|
||||
return float64(catIndex)
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this Attribute.
|
||||
//
|
||||
// Returns a string containing the list of human-readable values this
|
||||
// CategoricalAttribute can take.
|
||||
func (Attr *CategoricalAttribute) String() string {
|
||||
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
|
||||
}
|
||||
|
||||
// GetStringFromSysVal returns a human-readable value from the given system-representation
|
||||
// value val.
|
||||
//
|
||||
// IMPORTANT: This function calls panic() if the value is greater than
|
||||
// the length of the array.
|
||||
// TODO: Return a user-configurable default instead.
|
||||
func (Attr *CategoricalAttribute) GetStringFromSysVal(val float64) string {
|
||||
convVal := int(val)
|
||||
if convVal >= len(Attr.values) {
|
||||
panic(fmt.Sprintf("Out of range: %d in %d", convVal, len(Attr.values)))
|
||||
}
|
||||
return Attr.values[convVal]
|
||||
}
|
||||
|
||||
// Equals checks equality against another Attribute.
|
||||
//
|
||||
// Two CategoricalAttributes are considered equal if they contain
|
||||
// the same values and have the same name. Otherwise, this function
|
||||
// returns false.
|
||||
func (Attr *CategoricalAttribute) Equals(other Attribute) bool {
|
||||
attribute, ok := other.(*CategoricalAttribute)
|
||||
if !ok {
|
||||
// Not the same type, so can't be equal
|
||||
return false
|
||||
}
|
||||
if Attr.GetName() != attribute.GetName() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check that this CategoricalAttribute has the same
|
||||
// values as the other, in the same order
|
||||
if len(attribute.values) != len(Attr.values) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, a := range Attr.values {
|
||||
if a != attribute.values[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
// Tests whether two Attributes can be represented in the same pond
|
||||
// i.e. they're the same size, and their byte order makes them meaningful
|
||||
// when considered together
|
||||
Compatable(Attribute) bool
|
||||
}
|
||||
|
69
base/attributes_test.go
Normal file
69
base/attributes_test.go
Normal file
@ -0,0 +1,69 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFloatAttributeSysVal(t *testing.T) {
|
||||
Convey("Given some float", t, func() {
|
||||
x := "1.21"
|
||||
attr := NewFloatAttribute()
|
||||
Convey("When the float gets packed", func() {
|
||||
packed := attr.GetSysValFromString(x)
|
||||
Convey("And then unpacked", func() {
|
||||
unpacked := attr.GetStringFromSysVal(packed)
|
||||
Convey("The unpacked version should be the same", func() {
|
||||
So(unpacked, ShouldEqual, x)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCategoricalAttributeVal(t *testing.T) {
|
||||
attr := NewCategoricalAttribute()
|
||||
Convey("Given some string", t, func() {
|
||||
x := "hello world!"
|
||||
Convey("When the string gets converted", func() {
|
||||
packed := attr.GetSysValFromString(x)
|
||||
Convey("And then unconverted", func() {
|
||||
unpacked := attr.GetStringFromSysVal(packed)
|
||||
Convey("The unpacked version should be the same", func() {
|
||||
So(unpacked, ShouldEqual, x)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
Convey("Given some second string", t, func() {
|
||||
x := "hello world 1!"
|
||||
Convey("When the string gets converted", func() {
|
||||
packed := attr.GetSysValFromString(x)
|
||||
So(packed[0], ShouldEqual, 0x1)
|
||||
Convey("And then unconverted", func() {
|
||||
unpacked := attr.GetStringFromSysVal(packed)
|
||||
Convey("The unpacked version should be the same", func() {
|
||||
So(unpacked, ShouldEqual, x)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBinaryAttribute(t *testing.T) {
|
||||
attr := new(BinaryAttribute)
|
||||
Convey("Given some binary Attribute", t, func() {
|
||||
Convey("SetName, GetName should be equal", func() {
|
||||
attr.SetName("Hello")
|
||||
So(attr.GetName(), ShouldEqual, "Hello")
|
||||
})
|
||||
Convey("Non-zero values should equal 1", func() {
|
||||
sysVal := attr.GetSysValFromString("1")
|
||||
So(sysVal[0], ShouldEqual, 1)
|
||||
})
|
||||
Convey("Zero values should equal 0", func() {
|
||||
sysVal := attr.GetSysValFromString("0")
|
||||
So(sysVal[0], ShouldEqual, 0)
|
||||
})
|
||||
})
|
||||
}
|
78
base/binary.go
Normal file
78
base/binary.go
Normal file
@ -0,0 +1,78 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// BinaryAttributes can only represent 1 or 0.
|
||||
type BinaryAttribute struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// NewBinaryAttribute creates a BinaryAttribute with the given name
|
||||
func NewBinaryAttribute(name string) *BinaryAttribute {
|
||||
return &BinaryAttribute{
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName returns the name of this Attribute.
|
||||
func (b *BinaryAttribute) GetName() string {
|
||||
return b.Name
|
||||
}
|
||||
|
||||
// SetName sets the name of this Attribute.
|
||||
func (b *BinaryAttribute) SetName(name string) {
|
||||
b.Name = name
|
||||
}
|
||||
|
||||
// GetType returns BinaryType.
|
||||
func (b *BinaryAttribute) GetType() int {
|
||||
return BinaryType
|
||||
}
|
||||
|
||||
// GetSysValFromString returns either 1 or 0 in a single byte.
|
||||
func (b *BinaryAttribute) GetSysValFromString(userVal string) []byte {
|
||||
f, err := strconv.ParseFloat(userVal, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret := make([]byte, 1)
|
||||
if f > 0 {
|
||||
ret[0] = 1
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetStringFromSysVal returns either 1 or 0.
|
||||
func (b *BinaryAttribute) GetStringFromSysVal(val []byte) string {
|
||||
if val[0] > 0 {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
|
||||
// Equals checks for equality with another BinaryAttribute.
|
||||
func (b *BinaryAttribute) Equals(other Attribute) bool {
|
||||
if a, ok := other.(*BinaryAttribute); !ok {
|
||||
return false
|
||||
} else {
|
||||
return a.Name == b.Name
|
||||
}
|
||||
}
|
||||
|
||||
// Compatable checks whether this Attribute can be represented
|
||||
// in the same pond as another.
|
||||
func (b *BinaryAttribute) Compatable(other Attribute) bool {
|
||||
if _, ok := other.(*BinaryAttribute); !ok {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a human-redable representation.
|
||||
func (b *BinaryAttribute) String() string {
|
||||
return fmt.Sprintf("BinaryAttribute(%s)", b.Name)
|
||||
}
|
165
base/categorical.go
Normal file
165
base/categorical.go
Normal file
@ -0,0 +1,165 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// CategoricalAttribute is an Attribute implementation
|
||||
// which stores discrete string values
|
||||
// - useful for representing classes.
|
||||
type CategoricalAttribute struct {
|
||||
Name string
|
||||
values []string
|
||||
}
|
||||
|
||||
// NewCategoricalAttribute creates a blank CategoricalAttribute.
|
||||
func NewCategoricalAttribute() *CategoricalAttribute {
|
||||
return &CategoricalAttribute{
|
||||
"",
|
||||
make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetValues returns all the values currently defined
|
||||
func (Attr *CategoricalAttribute) GetValues() []string {
|
||||
return Attr.values
|
||||
}
|
||||
|
||||
// GetName returns the human-readable name assigned to this attribute.
|
||||
func (Attr *CategoricalAttribute) GetName() string {
|
||||
return Attr.Name
|
||||
}
|
||||
|
||||
// SetName sets the human-readable name on this attribute.
|
||||
func (Attr *CategoricalAttribute) SetName(name string) {
|
||||
Attr.Name = name
|
||||
}
|
||||
|
||||
// GetType returns CategoricalType to avoid casting overhead.
|
||||
func (Attr *CategoricalAttribute) GetType() int {
|
||||
return CategoricalType
|
||||
}
|
||||
|
||||
// GetSysVal returns the system representation of userVal as an index into the Values slice
|
||||
// If the userVal can't be found, it returns nothing.
|
||||
func (Attr *CategoricalAttribute) GetSysVal(userVal string) []byte {
|
||||
for idx, val := range Attr.values {
|
||||
if val == userVal {
|
||||
return PackU64ToBytes(uint64(idx))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUsrVal returns a human-readable representation of the given sysVal.
|
||||
//
|
||||
// IMPORTANT: this function doesn't check the boundaries of the array.
|
||||
func (Attr *CategoricalAttribute) GetUsrVal(sysVal []byte) string {
|
||||
idx := UnpackBytesToU64(sysVal)
|
||||
return Attr.values[idx]
|
||||
}
|
||||
|
||||
// GetSysValFromString returns the system representation of rawVal
|
||||
// as an index into the Values slice. If rawVal is not inside
|
||||
// the Values slice, it is appended.
|
||||
//
|
||||
// IMPORTANT: If no system representation yet exists, this functions adds it.
|
||||
// If you need to determine whether rawVal exists: use GetSysVal and check
|
||||
// for a zero-length return value.
|
||||
//
|
||||
// Example: if the CategoricalAttribute contains the values ["iris-setosa",
|
||||
// "iris-virginica"] and "iris-versicolor" is provided as the argument,
|
||||
// the Values slide becomes ["iris-setosa", "iris-virginica", "iris-versicolor"]
|
||||
// and 2.00 is returned as the system representation.
|
||||
func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) []byte {
|
||||
// Match in raw values
|
||||
catIndex := -1
|
||||
for i, s := range Attr.values {
|
||||
if s == rawVal {
|
||||
catIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if catIndex == -1 {
|
||||
Attr.values = append(Attr.values, rawVal)
|
||||
catIndex = len(Attr.values) - 1
|
||||
}
|
||||
|
||||
ret := PackU64ToBytes(uint64(catIndex))
|
||||
return ret
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this Attribute.
|
||||
//
|
||||
// Returns a string containing the list of human-readable values this
|
||||
// CategoricalAttribute can take.
|
||||
func (Attr *CategoricalAttribute) String() string {
|
||||
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
|
||||
}
|
||||
|
||||
// GetStringFromSysVal returns a human-readable value from the given system-representation
|
||||
// value val.
|
||||
//
|
||||
// IMPORTANT: This function calls panic() if the value is greater than
|
||||
// the length of the array.
|
||||
// TODO: Return a user-configurable default instead.
|
||||
func (Attr *CategoricalAttribute) GetStringFromSysVal(rawVal []byte) string {
|
||||
convVal := int(UnpackBytesToU64(rawVal))
|
||||
if convVal >= len(Attr.values) {
|
||||
panic(fmt.Sprintf("Out of range: %d in %d (%s)", convVal, len(Attr.values), Attr))
|
||||
}
|
||||
return Attr.values[convVal]
|
||||
}
|
||||
|
||||
// Equals checks equality against another Attribute.
|
||||
//
|
||||
// Two CategoricalAttributes are considered equal if they contain
|
||||
// the same values and have the same name. Otherwise, this function
|
||||
// returns false.
|
||||
func (Attr *CategoricalAttribute) Equals(other Attribute) bool {
|
||||
attribute, ok := other.(*CategoricalAttribute)
|
||||
if !ok {
|
||||
// Not the same type, so can't be equal
|
||||
return false
|
||||
}
|
||||
if Attr.GetName() != attribute.GetName() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check that this CategoricalAttribute has the same
|
||||
// values as the other, in the same order
|
||||
if len(attribute.values) != len(Attr.values) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, a := range Attr.values {
|
||||
if a != attribute.values[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Compatable checks that this CategoricalAttribute has the same
|
||||
// values as another, in the same order.
|
||||
func (Attr *CategoricalAttribute) Compatable(other Attribute) bool {
|
||||
attribute, ok := other.(*CategoricalAttribute)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check that this CategoricalAttribute has the same
|
||||
// values as the other, in the same order
|
||||
if len(attribute.values) != len(Attr.values) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, a := range Attr.values {
|
||||
if a != attribute.values[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
@ -10,17 +10,17 @@ type Classifier interface {
|
||||
// and constructs a new set of Instances of equivalent
|
||||
// length with only the class Attribute and fills it in
|
||||
// with predictions.
|
||||
Predict(*Instances) *Instances
|
||||
Predict(FixedDataGrid) FixedDataGrid
|
||||
// Takes a set of instances and updates the Classifier's
|
||||
// internal structures to enable prediction
|
||||
Fit(*Instances)
|
||||
Fit(FixedDataGrid)
|
||||
// Why not make every classifier return a nice-looking string?
|
||||
String() string
|
||||
}
|
||||
|
||||
// BaseClassifier stores options common to every classifier.
|
||||
type BaseClassifier struct {
|
||||
TrainingData *Instances
|
||||
TrainingData *DataGrid
|
||||
}
|
||||
|
||||
type BaseRegressor struct {
|
||||
|
98
base/csv.go
98
base/csv.go
@ -77,28 +77,32 @@ func ParseCSVSniffAttributeNames(filepath string, hasHeaders bool) []string {
|
||||
// The type of a given attribute is determined by looking at the first data row
|
||||
// of the CSV.
|
||||
func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
|
||||
var attrs []Attribute
|
||||
// Open file
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer file.Close()
|
||||
// Create the CSV reader
|
||||
reader := csv.NewReader(file)
|
||||
attrs := make([]Attribute, 0)
|
||||
if hasHeaders {
|
||||
// Skip the headers
|
||||
_, err := reader.Read()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
// Read the first line of the file
|
||||
columns, err := reader.Read()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, entry := range columns {
|
||||
// Match the Attribute type with regular expressions
|
||||
entry = strings.Trim(entry, " ")
|
||||
matched, err := regexp.MatchString("^[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?$", entry)
|
||||
//fmt.Println(entry, matched)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -112,30 +116,8 @@ func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
|
||||
return attrs
|
||||
}
|
||||
|
||||
// ParseCSVToInstances reads the CSV file given by filepath and returns
|
||||
// the read Instances.
|
||||
func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *Instances, err error) {
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var ok bool
|
||||
if err, ok = r.(error); !ok {
|
||||
err = fmt.Errorf("golearn: ParseCSVToInstances: %v", r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Read the number of rows in the file
|
||||
rowCount := ParseCSVGetRows(filepath)
|
||||
if hasHeaders {
|
||||
rowCount--
|
||||
}
|
||||
|
||||
// Read the row headers
|
||||
attrs := ParseCSVGetAttributes(filepath, hasHeaders)
|
||||
|
||||
// Allocate the Instances to return
|
||||
instances = NewInstances(attrs, rowCount)
|
||||
// ParseCSVBuildInstances updates an [[#UpdatableDataGrid]] from a filepath in place
|
||||
func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGrid) {
|
||||
|
||||
// Read the input
|
||||
file, err := os.Open(filepath)
|
||||
@ -146,6 +128,9 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *Instances
|
||||
reader := csv.NewReader(file)
|
||||
|
||||
rowCounter := 0
|
||||
|
||||
specs := ResolveAttributes(u, u.AllAttributes())
|
||||
|
||||
for {
|
||||
record, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
@ -159,13 +144,68 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *Instances
|
||||
continue
|
||||
}
|
||||
}
|
||||
for i := range attrs {
|
||||
instances.SetAttrStr(rowCounter, i, strings.Trim(record[i], " "))
|
||||
for i, v := range record {
|
||||
u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(v))
|
||||
}
|
||||
rowCounter++
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ParseCSVToInstances reads the CSV file given by filepath and returns
|
||||
// the read Instances.
|
||||
func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInstances, err error) {
|
||||
|
||||
// Read the number of rows in the file
|
||||
rowCount := ParseCSVGetRows(filepath)
|
||||
if hasHeaders {
|
||||
rowCount--
|
||||
}
|
||||
|
||||
// Read the row headers
|
||||
attrs := ParseCSVGetAttributes(filepath, hasHeaders)
|
||||
specs := make([]AttributeSpec, len(attrs))
|
||||
// Allocate the Instances to return
|
||||
instances = NewDenseInstances()
|
||||
for i, a := range attrs {
|
||||
spec := instances.AddAttribute(a)
|
||||
specs[i] = spec
|
||||
}
|
||||
instances.Extend(rowCount)
|
||||
|
||||
// Read the input
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer file.Close()
|
||||
reader := csv.NewReader(file)
|
||||
|
||||
rowCounter := 0
|
||||
|
||||
for {
|
||||
record, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if rowCounter == 0 {
|
||||
if hasHeaders {
|
||||
hasHeaders = false
|
||||
continue
|
||||
}
|
||||
}
|
||||
for i, v := range record {
|
||||
v = strings.Trim(v, " ")
|
||||
instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v))
|
||||
}
|
||||
rowCounter++
|
||||
}
|
||||
|
||||
instances.AddClassAttribute(attrs[len(attrs)-1])
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
//ParseCSV parses a CSV file and returns the number of columns and rows, the headers, the labels associated with
|
||||
|
@ -1,6 +1,8 @@
|
||||
package base
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseCSVGetRows(testEnv *testing.T) {
|
||||
lineCount := ParseCSVGetRows("../examples/datasets/iris.csv")
|
||||
@ -76,9 +78,9 @@ func TestReadInstances(testEnv *testing.T) {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
row1 := inst.RowStr(0)
|
||||
row2 := inst.RowStr(50)
|
||||
row3 := inst.RowStr(100)
|
||||
row1 := inst.RowString(0)
|
||||
row2 := inst.RowString(50)
|
||||
row3 := inst.RowString(100)
|
||||
|
||||
if row1 != "5.10 3.50 1.40 0.20 Iris-setosa" {
|
||||
testEnv.Error(row1)
|
||||
@ -97,10 +99,11 @@ func TestReadAwkwardInsatnces(testEnv *testing.T) {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
if inst.GetAttr(0).GetType() != Float64Type {
|
||||
attrs := inst.AllAttributes()
|
||||
if attrs[0].GetType() != Float64Type {
|
||||
testEnv.Error("Should be float!")
|
||||
}
|
||||
if inst.GetAttr(1).GetType() != CategoricalType {
|
||||
if attrs[1].GetType() != CategoricalType {
|
||||
testEnv.Error("Should be discrete!")
|
||||
}
|
||||
}
|
||||
|
51
base/data.go
Normal file
51
base/data.go
Normal file
@ -0,0 +1,51 @@
|
||||
package base
|
||||
|
||||
// SortDirection specifies sorting direction...
|
||||
type SortDirection int
|
||||
|
||||
const (
|
||||
// Descending says that Instances should be sorted high to low...
|
||||
Descending SortDirection = 1
|
||||
// Ascending states that Instances should be sorted low to high...
|
||||
Ascending SortDirection = 2
|
||||
)
|
||||
|
||||
// DataGrid implementations represent data addressable by rows and columns.
|
||||
type DataGrid interface {
|
||||
// Retrieves a given Attribute's specification
|
||||
GetAttribute(Attribute) (AttributeSpec, error)
|
||||
// Retrieves details of every Attribute
|
||||
AllAttributes() []Attribute
|
||||
// Marks an Attribute as a class Attribute
|
||||
AddClassAttribute(Attribute) error
|
||||
// Unmarks an Attribute as a class Attribute
|
||||
RemoveClassAttribute(Attribute) error
|
||||
// Returns details of all class Attributes
|
||||
AllClassAttributes() []Attribute
|
||||
// Gets the bytes at a given position or nil
|
||||
Get(AttributeSpec, int) []byte
|
||||
// Convenience function for iteration.
|
||||
MapOverRows([]AttributeSpec, func([][]byte, int) (bool, error)) error
|
||||
}
|
||||
|
||||
// FixedDataGrid implementations have a size known in advance and implement
|
||||
// all of the functionality offered by DataGrid implementations.
|
||||
type FixedDataGrid interface {
|
||||
DataGrid
|
||||
// Returns a string representation of a given row
|
||||
RowString(int) string
|
||||
// Returns the number of Attributes and rows currently allocated
|
||||
Size() (int, int)
|
||||
}
|
||||
|
||||
// UpdatableDataGrid implementations can be changed in addition to implementing
|
||||
// all of the functionality offered by FixedDataGrid implementations.
|
||||
type UpdatableDataGrid interface {
|
||||
FixedDataGrid
|
||||
// Sets a given Attribute and row to a byte sequence.
|
||||
Set(AttributeSpec, int, []byte)
|
||||
// Adds an Attribute to the grid.
|
||||
AddAttribute(Attribute) AttributeSpec
|
||||
// Allocates additional room to hold a number of rows
|
||||
Extend(int) error
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
package base
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDecomp(testEnv *testing.T) {
|
||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
decomp := inst.DecomposeOnAttributeValues(inst.GetAttr(0))
|
||||
|
||||
row0 := decomp["0.00"].RowStr(0)
|
||||
row1 := decomp["1.00"].RowStr(0)
|
||||
/* row2 := decomp["2.00"].RowStr(0)
|
||||
row3 := decomp["3.00"].RowStr(0)
|
||||
row4 := decomp["4.00"].RowStr(0)
|
||||
row5 := decomp["5.00"].RowStr(0)
|
||||
row6 := decomp["6.00"].RowStr(0)
|
||||
row7 := decomp["7.00"].RowStr(0)*/
|
||||
row8 := decomp["8.00"].RowStr(0)
|
||||
// row9 := decomp["9.00"].RowStr(0)
|
||||
|
||||
if row0 != "3.10 1.50 0.20 Iris-setosa" {
|
||||
testEnv.Error(row0)
|
||||
}
|
||||
if row1 != "3.00 1.40 0.20 Iris-setosa" {
|
||||
testEnv.Error(row1)
|
||||
}
|
||||
if row8 != "2.90 6.30 1.80 Iris-virginica" {
|
||||
testEnv.Error(row8)
|
||||
}
|
||||
}
|
476
base/dense.go
Normal file
476
base/dense.go
Normal file
@ -0,0 +1,476 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base/edf"
|
||||
"math"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DenseInstances stores each Attribute value explicitly
|
||||
// in a large grid.
|
||||
type DenseInstances struct {
|
||||
storage *edf.EdfFile
|
||||
pondMap map[string]int
|
||||
ponds []*Pond
|
||||
lock sync.Mutex
|
||||
fixed bool
|
||||
classAttrs map[AttributeSpec]bool
|
||||
maxRow int
|
||||
attributes []Attribute
|
||||
}
|
||||
|
||||
// NewDenseInstances generates a new DenseInstances set
|
||||
// with an anonymous EDF mapping and default settings.
|
||||
func NewDenseInstances() *DenseInstances {
|
||||
storage, err := edf.EdfAnonMap()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &DenseInstances{
|
||||
storage,
|
||||
make(map[string]int),
|
||||
make([]*Pond, 0),
|
||||
sync.Mutex{},
|
||||
false,
|
||||
make(map[AttributeSpec]bool),
|
||||
0,
|
||||
make([]Attribute, 0),
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Pond functions
|
||||
//
|
||||
|
||||
// createPond adds a new Pond to this set of Instances
|
||||
// IMPORTANT: do not call unless you've acquired the lock
|
||||
func (inst *DenseInstances) createPond(name string, size int) {
|
||||
if inst.fixed {
|
||||
panic("Can't add additional Attributes")
|
||||
}
|
||||
|
||||
// Resolve or create thread
|
||||
threads, err := inst.storage.GetThreads()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ok := false
|
||||
for i := range threads {
|
||||
if threads[i] == name {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if ok {
|
||||
panic("Can't create pond: pond thread already exists")
|
||||
}
|
||||
|
||||
// Write the pool's thread into the file
|
||||
thread := edf.NewThread(inst.storage, name)
|
||||
err = inst.storage.WriteThread(thread)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Can't write thread: %s", err))
|
||||
}
|
||||
|
||||
// Create the pond information
|
||||
pond := new(Pond)
|
||||
pond.threadNo = thread.GetId()
|
||||
pond.parent = inst
|
||||
pond.attributes = make([]Attribute, 0)
|
||||
pond.size = size
|
||||
pond.alloc = make([][]byte, 0)
|
||||
// Store within instances
|
||||
inst.pondMap[name] = len(inst.ponds)
|
||||
inst.ponds = append(inst.ponds, pond)
|
||||
}
|
||||
|
||||
// CreatePond adds a new Pond to this set of instances
|
||||
// with a given name. If the size is 0, a bit-pond is added
|
||||
// if the size of not 0, then the size of each pond attribute
|
||||
// is set to that number of bytes.
|
||||
func (inst *DenseInstances) CreatePond(name string, size int) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var ok bool
|
||||
if err, ok = r.(error); !ok {
|
||||
err = fmt.Errorf("CreatePond: %v (not created)", r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
inst.createPond(name, size)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPond returns a reference to a Pond of a given name /
|
||||
func (inst *DenseInstances) GetPond(name string) (*Pond, error) {
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
// Check if the pond exists
|
||||
if id, ok := inst.pondMap[name]; !ok {
|
||||
return nil, fmt.Errorf("Pond '%s' doesn't exist", name)
|
||||
} else {
|
||||
// Return the pond
|
||||
return inst.ponds[id], nil
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Attribute creation and handling
|
||||
//
|
||||
|
||||
// AddAttribute adds an Attribute to this set of DenseInstances
|
||||
// Creates a default Pond for it if a suitable one doesn't exist.
|
||||
// Returns an AttributeSpec for subsequent Set() calls.
|
||||
//
|
||||
// IMPORTANT: will panic if storage has been allocated via Extend.
|
||||
func (inst *DenseInstances) AddAttribute(a Attribute) AttributeSpec {
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
if inst.fixed {
|
||||
panic("Can't add additional Attributes")
|
||||
}
|
||||
|
||||
// Generate a default Pond name
|
||||
pond := "FLOAT"
|
||||
if _, ok := a.(*CategoricalAttribute); ok {
|
||||
pond = "CAT"
|
||||
} else if _, ok := a.(*FloatAttribute); ok {
|
||||
pond = "FLOAT"
|
||||
} else {
|
||||
panic("Unrecognised Attribute type")
|
||||
}
|
||||
|
||||
// Create the pond if it doesn't exist
|
||||
if _, ok := inst.pondMap[pond]; !ok {
|
||||
inst.createPond(pond, 8)
|
||||
}
|
||||
id := inst.pondMap[pond]
|
||||
p := inst.ponds[id]
|
||||
p.attributes = append(p.attributes, a)
|
||||
inst.attributes = append(inst.attributes, a)
|
||||
return AttributeSpec{id, len(p.attributes) - 1, a}
|
||||
}
|
||||
|
||||
// AddAttributeToPond adds an Attribute to a given pond
|
||||
func (inst *DenseInstances) AddAttributeToPond(newAttribute Attribute, pond string) (AttributeSpec, error) {
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
// Check if the pond exists
|
||||
if _, ok := inst.pondMap[pond]; !ok {
|
||||
return AttributeSpec{-1, 0, nil}, fmt.Errorf("Pond '%s' doesn't exist. Call CreatePond() first", pond)
|
||||
}
|
||||
|
||||
id := inst.pondMap[pond]
|
||||
p := inst.ponds[id]
|
||||
for i, a := range p.attributes {
|
||||
if !a.Compatable(newAttribute) {
|
||||
return AttributeSpec{-1, 0, nil}, fmt.Errorf("Attribute %s is not compatable with %s in pond '%s' (position %d)", newAttribute, a, pond, i)
|
||||
}
|
||||
}
|
||||
|
||||
p.attributes = append(p.attributes, newAttribute)
|
||||
inst.attributes = append(inst.attributes, newAttribute)
|
||||
return AttributeSpec{id, len(p.attributes) - 1, newAttribute}, nil
|
||||
}
|
||||
|
||||
// GetAttribute returns an Attribute equal to the argument.
|
||||
//
|
||||
// TODO: Write a function to pre-compute this once we've allocated
|
||||
// TODO: Write a utility function which retrieves all AttributeSpecs for
|
||||
// a given instance set.
|
||||
func (inst *DenseInstances) GetAttribute(get Attribute) (AttributeSpec, error) {
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
for i, p := range inst.ponds {
|
||||
for j, a := range p.attributes {
|
||||
if a.Equals(get) {
|
||||
return AttributeSpec{i, j, a}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AttributeSpec{-1, 0, nil}, fmt.Errorf("Couldn't resolve %s", get)
|
||||
}
|
||||
|
||||
// AllAttributes returns a slice of all Attributes.
|
||||
func (inst *DenseInstances) AllAttributes() []Attribute {
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
ret := make([]Attribute, 0)
|
||||
for _, p := range inst.ponds {
|
||||
for _, a := range p.attributes {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddClassAttribute sets an Attribute to be a class Attribute.
|
||||
func (inst *DenseInstances) AddClassAttribute(a Attribute) error {
|
||||
|
||||
as, err := inst.GetAttribute(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
inst.classAttrs[as] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveClassAttribute removes an Attribute from the set of class Attributes.
|
||||
func (inst *DenseInstances) RemoveClassAttribute(a Attribute) error {
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
as, err := inst.GetAttribute(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
inst.classAttrs[as] = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllClassAttributes returns a slice of Attributes which have
|
||||
// been designated class Attributes.
|
||||
func (inst *DenseInstances) AllClassAttributes() []Attribute {
|
||||
var ret []Attribute
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
for a := range inst.classAttrs {
|
||||
if inst.classAttrs[a] {
|
||||
ret = append(ret, a.attr)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
//
|
||||
// Allocation functions
|
||||
//
|
||||
|
||||
// Extend extends this set of Instances to store rows additional rows.
|
||||
// It's recommended to set rows to something quite large.
|
||||
//
|
||||
// IMPORTANT: panics if the allocation fails
|
||||
func (inst *DenseInstances) Extend(rows int) error {
|
||||
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
// Get the size of each page
|
||||
pageSize := inst.storage.GetPageSize()
|
||||
|
||||
for pondName := range inst.ponds {
|
||||
p := inst.ponds[pondName]
|
||||
|
||||
// Compute pond row storage requirements
|
||||
rowSize := p.RowSize()
|
||||
|
||||
// How many rows can we store per page?
|
||||
rowsPerPage := float64(pageSize) / float64(rowSize)
|
||||
|
||||
// How many pages?
|
||||
pagesNeeded := uint32(math.Ceil(float64(rows) / rowsPerPage))
|
||||
|
||||
// Allocate those pages
|
||||
r, err := inst.storage.AllocPages(pagesNeeded, p.threadNo)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Allocation error: %s (rowSize %d, pageSize %d, rowsPerPage %.2f, tried to allocate %d page(s) and extend by %d row(s))", err, rowSize, pageSize, rowsPerPage, pagesNeeded, rows))
|
||||
}
|
||||
// Resolve and assign those pages
|
||||
byteBlock := inst.storage.ResolveRange(r)
|
||||
for _, block := range byteBlock {
|
||||
p.alloc = append(p.alloc, block)
|
||||
}
|
||||
}
|
||||
inst.fixed = true
|
||||
inst.maxRow += rows
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set sets a particular Attribute (given as an AttributeSpec) on a particular
|
||||
// row to a particular value.
|
||||
//
|
||||
// AttributeSpecs can be obtained using GetAttribute() or AddAttribute().
|
||||
//
|
||||
// IMPORTANT: Will panic() if the AttributeSpec isn't valid
|
||||
//
|
||||
// IMPORTANT: Will panic() if the row is too large
|
||||
//
|
||||
// IMPORTANT: Will panic() if the val is not the right length
|
||||
func (inst *DenseInstances) Set(a AttributeSpec, row int, val []byte) {
|
||||
inst.ponds[a.pond].set(a.position, row, val)
|
||||
}
|
||||
|
||||
// Get gets a particular Attribute (given as an AttributeSpec) on a particular
|
||||
// row.
|
||||
// AttributeSpecs can be obtained using GetAttribute() or AddAttribute()
|
||||
func (inst *DenseInstances) Get(a AttributeSpec, row int) []byte {
|
||||
return inst.ponds[a.pond].get(a.position, row)
|
||||
}
|
||||
|
||||
// RowString returns a string representation of a given row.
|
||||
func (inst *DenseInstances) RowString(row int) string {
|
||||
var buffer bytes.Buffer
|
||||
first := true
|
||||
for name := range inst.ponds {
|
||||
if first {
|
||||
first = false
|
||||
} else {
|
||||
buffer.WriteString(" ")
|
||||
}
|
||||
p := inst.ponds[name]
|
||||
p.appendToRowBuf(row, &buffer)
|
||||
}
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
//
|
||||
// Row handling functions
|
||||
//
|
||||
|
||||
func (inst *DenseInstances) allocateRowVector(asv []AttributeSpec) [][]byte {
|
||||
ret := make([][]byte, len(asv))
|
||||
for i, as := range asv {
|
||||
p := inst.ponds[as.pond]
|
||||
ret[i] = make([]byte, p.size)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// MapOverRows passes each row map into a function.
|
||||
// First argument is a list of AttributeSpec in the order
|
||||
// they're needed in for the function. The second is the function
|
||||
// to call on each row.
|
||||
func (inst *DenseInstances) MapOverRows(asv []AttributeSpec, mapFunc func([][]byte, int) (bool, error)) error {
|
||||
rowBuf := make([][]byte, len(asv))
|
||||
for i := 0; i < inst.maxRow; i++ {
|
||||
for j, as := range asv {
|
||||
p := inst.ponds[as.pond]
|
||||
rowBuf[j] = p.get(as.position, i)
|
||||
}
|
||||
ok, err := mapFunc(rowBuf, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size returns the number of Attributes as the first return value
|
||||
// and the maximum allocated row as the second value.
|
||||
func (inst *DenseInstances) Size() (int, int) {
|
||||
return len(inst.AllAttributes()), inst.maxRow
|
||||
}
|
||||
|
||||
// swapRows swaps over rows i and j
|
||||
func (inst *DenseInstances) swapRows(i, j int) {
|
||||
as := ResolveAllAttributes(inst)
|
||||
for _, a := range as {
|
||||
v1 := inst.Get(a, i)
|
||||
v2 := inst.Get(a, j)
|
||||
v3 := make([]byte, len(v2))
|
||||
copy(v3, v2)
|
||||
inst.Set(a, j, v1)
|
||||
inst.Set(a, i, v3)
|
||||
}
|
||||
}
|
||||
|
||||
// Equal checks whether a given Instance set is exactly the same
|
||||
// as another: same size and same values (as determined by the Attributes)
|
||||
//
|
||||
// IMPORTANT: does not explicitly check if the Attributes are considered equal.
|
||||
func (inst *DenseInstances) Equal(other DataGrid) bool {
|
||||
|
||||
_, rows := inst.Size()
|
||||
|
||||
for _, a := range inst.AllAttributes() {
|
||||
as1, err := inst.GetAttribute(a)
|
||||
if err != nil {
|
||||
panic(err) // That indicates some kind of error
|
||||
}
|
||||
as2, err := inst.GetAttribute(a)
|
||||
if err != nil {
|
||||
return false // Obviously has different Attributes
|
||||
}
|
||||
for i := 0; i < rows; i++ {
|
||||
b1 := inst.Get(as1, i)
|
||||
b2 := inst.Get(as2, i)
|
||||
if !byteSeqEqual(b1, b2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this dataset.
|
||||
func (inst *DenseInstances) String() string {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
// Get all Attribute information
|
||||
as := ResolveAllAttributes(inst)
|
||||
|
||||
// Print header
|
||||
cols, rows := inst.Size()
|
||||
buffer.WriteString("Instances with ")
|
||||
buffer.WriteString(fmt.Sprintf("%d row(s) ", rows))
|
||||
buffer.WriteString(fmt.Sprintf("%d attribute(s)\n", cols))
|
||||
buffer.WriteString(fmt.Sprintf("Attributes: \n"))
|
||||
|
||||
for _, a := range as {
|
||||
prefix := "\t"
|
||||
if inst.classAttrs[a] {
|
||||
prefix = "*\t"
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a.attr))
|
||||
}
|
||||
|
||||
buffer.WriteString("\nData:\n")
|
||||
maxRows := 30
|
||||
if rows < maxRows {
|
||||
maxRows = rows
|
||||
}
|
||||
|
||||
for i := 0; i < maxRows; i++ {
|
||||
buffer.WriteString("\t")
|
||||
for _, a := range as {
|
||||
val := inst.Get(a, i)
|
||||
buffer.WriteString(fmt.Sprintf("%s ", a.attr.GetStringFromSysVal(val)))
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
|
||||
missingRows := rows - maxRows
|
||||
if missingRows != 0 {
|
||||
buffer.WriteString(fmt.Sprintf("\t...\n%d row(s) undisplayed", missingRows))
|
||||
} else {
|
||||
buffer.WriteString("All rows displayed")
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
261
base/edf/alloc.go
Normal file
261
base/edf/alloc.go
Normal file
@ -0,0 +1,261 @@
|
||||
package edf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ContentEntry structs are stored in ContentEntry blocks
|
||||
// which always at block 2.
|
||||
type ContentEntry struct {
|
||||
// Which thread this entry is assigned to
|
||||
Thread uint32
|
||||
// Which page this block starts at
|
||||
Start uint32
|
||||
// The page up to and including which the block ends
|
||||
End uint32
|
||||
}
|
||||
|
||||
func (e *EdfFile) extend(additionalPages uint32) error {
|
||||
fileInfo, err := e.f.Stat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
newSize := uint64(fileInfo.Size())/e.pageSize + uint64(additionalPages)
|
||||
return e.truncate(int64(newSize))
|
||||
}
|
||||
|
||||
func (e *EdfFile) getFreeMapSize() uint64 {
|
||||
if e.f != nil {
|
||||
fileInfo, err := e.f.Stat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return uint64(fileInfo.Size()) / e.pageSize
|
||||
}
|
||||
return uint64(EDF_SIZE) / e.pageSize
|
||||
}
|
||||
|
||||
// FixedAlloc allocates a |bytesRequested| chunk of pages
|
||||
// on the FIXED thread.
|
||||
func (e *EdfFile) FixedAlloc(bytesRequested uint32) (EdfRange, error) {
|
||||
pageSize := uint32(e.pageSize)
|
||||
return e.AllocPages((pageSize*bytesRequested+pageSize/2)/pageSize, 2)
|
||||
}
|
||||
|
||||
func (e *EdfFile) getContiguousOffset(pagesRequested uint32) (uint32, error) {
|
||||
// Create the free bitmap
|
||||
bitmap := make([]bool, e.getFreeMapSize())
|
||||
for i := 0; i < 4; i++ {
|
||||
bitmap[i] = true
|
||||
}
|
||||
// Traverse the contents table and build a free bitmap
|
||||
block := uint64(2)
|
||||
for {
|
||||
// Get the range for this block
|
||||
r := e.GetPageRange(block, block)
|
||||
if r.Start.Segment != r.End.Segment {
|
||||
return 0, fmt.Errorf("Contents block split across segments")
|
||||
}
|
||||
bytes := e.m[r.Start.Segment]
|
||||
bytes = bytes[r.Start.Byte : r.End.Byte+1]
|
||||
// Get the address of the next contents block
|
||||
block = uint64FromBytes(bytes)
|
||||
if block != 0 {
|
||||
// No point in checking this block for free space
|
||||
continue
|
||||
}
|
||||
bytes = bytes[8:]
|
||||
// Look for a blank entry in the table
|
||||
for i := 0; i < len(bytes); i += 12 {
|
||||
threadID := uint32FromBytes(bytes[i:])
|
||||
if threadID == 0 {
|
||||
continue
|
||||
}
|
||||
start := uint32FromBytes(bytes[i+4:])
|
||||
end := uint32FromBytes(bytes[i+8:])
|
||||
for j := start; j <= end; j++ {
|
||||
if int(j) >= len(bitmap) {
|
||||
break
|
||||
}
|
||||
bitmap[j] = true
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
// Look through the freemap and find a good spot
|
||||
for i := 0; i < len(bitmap); i++ {
|
||||
if bitmap[i] {
|
||||
continue
|
||||
}
|
||||
for j := i; j < len(bitmap); j++ {
|
||||
if !bitmap[j] {
|
||||
diff := j - 1 - i
|
||||
if diff > int(pagesRequested) {
|
||||
return uint32(i), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// addNewContentsBlock adds a new contents block in the next available space
|
||||
func (e *EdfFile) addNewContentsBlock() error {
|
||||
|
||||
var toc ContentEntry
|
||||
|
||||
// Find the next available offset
|
||||
startBlock, err := e.getContiguousOffset(1)
|
||||
if startBlock == 0 && err == nil {
|
||||
// Increase the size of the file if necessary
|
||||
e.extend(uint32(e.pageSize))
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Traverse the contents blocks looking for one with a blank NEXT pointer
|
||||
block := uint64(2)
|
||||
for {
|
||||
// Get the range for this block
|
||||
r := e.GetPageRange(block, block)
|
||||
if r.Start.Segment != r.End.Segment {
|
||||
return fmt.Errorf("Contents block split across segments")
|
||||
}
|
||||
bytes := e.m[r.Start.Segment]
|
||||
bytes = bytes[r.Start.Byte : r.End.Byte+1]
|
||||
// Get the address of the next contents block
|
||||
block = uint64FromBytes(bytes)
|
||||
if block == 0 {
|
||||
uint64ToBytes(uint64(startBlock), bytes)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Add to the next available TOC space
|
||||
toc.Start = startBlock
|
||||
toc.End = startBlock + 1
|
||||
toc.Thread = 1 // SYSTEM thread
|
||||
return e.addToTOC(&toc, false)
|
||||
}
|
||||
|
||||
// addToTOC adds a ContentsEntry structure in the next available place
|
||||
func (e *EdfFile) addToTOC(c *ContentEntry, extend bool) error {
|
||||
// Traverse the contents table looking for a free spot
|
||||
block := uint64(2)
|
||||
for {
|
||||
// Get the range for this block
|
||||
r := e.GetPageRange(block, block)
|
||||
if r.Start.Segment != r.End.Segment {
|
||||
return fmt.Errorf("Contents block split across segments")
|
||||
}
|
||||
bytes := e.m[r.Start.Segment]
|
||||
bytes = bytes[r.Start.Byte : r.End.Byte+1]
|
||||
// Get the address of the next contents block
|
||||
block = uint64FromBytes(bytes)
|
||||
if block != 0 {
|
||||
// No point in checking this block for free space
|
||||
continue
|
||||
}
|
||||
bytes = bytes[8:]
|
||||
// Look for a blank entry in the table
|
||||
cur := 0
|
||||
for {
|
||||
threadID := uint32FromBytes(bytes)
|
||||
if threadID == 0 {
|
||||
break
|
||||
}
|
||||
cur += 12
|
||||
bytes = bytes[12:]
|
||||
if len(bytes) < 12 {
|
||||
if extend {
|
||||
// Append a new contents block and try again
|
||||
e.addNewContentsBlock()
|
||||
return e.addToTOC(c, false)
|
||||
}
|
||||
return fmt.Errorf("Can't add to contents: no space available")
|
||||
}
|
||||
}
|
||||
// Write the contents information into this block
|
||||
uint32ToBytes(c.Thread, bytes)
|
||||
bytes = bytes[4:]
|
||||
uint32ToBytes(c.Start, bytes)
|
||||
bytes = bytes[4:]
|
||||
uint32ToBytes(c.End, bytes)
|
||||
break
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllocPages allocates a |pagesRequested| chunk of pages on the Thread
|
||||
// with the given identifier. Returns an EdfRange describing the result.
|
||||
func (e *EdfFile) AllocPages(pagesRequested uint32, thread uint32) (EdfRange, error) {
|
||||
|
||||
var ret EdfRange
|
||||
var toc ContentEntry
|
||||
|
||||
// Parameter check
|
||||
if pagesRequested == 0 {
|
||||
return ret, fmt.Errorf("Must request some pages")
|
||||
}
|
||||
if thread == 0 {
|
||||
return ret, fmt.Errorf("Need a valid page identifier")
|
||||
}
|
||||
|
||||
// Find the next available offset
|
||||
startBlock, err := e.getContiguousOffset(pagesRequested)
|
||||
if startBlock == 0 && err == nil {
|
||||
// Increase the size of the file if necessary
|
||||
e.extend(pagesRequested)
|
||||
return e.AllocPages(pagesRequested, thread)
|
||||
} else if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Add to the table of contents
|
||||
toc.Thread = thread
|
||||
toc.Start = startBlock
|
||||
toc.End = startBlock + pagesRequested
|
||||
err = e.addToTOC(&toc, true)
|
||||
|
||||
// Compute the range
|
||||
ret = e.GetPageRange(uint64(startBlock), uint64(startBlock+pagesRequested))
|
||||
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// GetThreadBlocks returns EdfRanges containing blocks assigned to a given thread.
|
||||
func (e *EdfFile) GetThreadBlocks(thread uint32) ([]EdfRange, error) {
|
||||
var ret []EdfRange
|
||||
// Traverse the contents table
|
||||
block := uint64(2)
|
||||
for {
|
||||
// Get the range for this block
|
||||
r := e.GetPageRange(block, block)
|
||||
if r.Start.Segment != r.End.Segment {
|
||||
return nil, fmt.Errorf("Contents block split across segments")
|
||||
}
|
||||
bytes := e.m[r.Start.Segment]
|
||||
bytes = bytes[r.Start.Byte : r.End.Byte+1]
|
||||
// Get the address of the next contents block
|
||||
block = uint64FromBytes(bytes)
|
||||
bytes = bytes[8:]
|
||||
// Look for matching contents entries
|
||||
for {
|
||||
threadID := uint32FromBytes(bytes)
|
||||
if threadID == thread {
|
||||
blockStart := uint32FromBytes(bytes[4:])
|
||||
blockEnd := uint32FromBytes(bytes[8:])
|
||||
r = e.GetPageRange(uint64(blockStart), uint64(blockEnd))
|
||||
ret = append(ret, r)
|
||||
}
|
||||
bytes = bytes[12:]
|
||||
if len(bytes) < 12 {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Time to stop
|
||||
if block == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
70
base/edf/alloc_test.go
Normal file
70
base/edf/alloc_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
package edf
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAllocFixed(t *testing.T) {
|
||||
Convey("Creating a non-existent file should succeed", t, func() {
|
||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Mapping the file should suceed", func() {
|
||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Allocation should suceed", func() {
|
||||
r, err := mapping.AllocPages(1, 2)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(r.Start.Byte, ShouldEqual, 4*os.Getpagesize())
|
||||
So(r.Start.Segment, ShouldEqual, 0)
|
||||
Convey("Unmapping the file should suceed", func() {
|
||||
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Remapping the file should suceed", func() {
|
||||
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
|
||||
Convey("Should get the same allocations back", func() {
|
||||
rr, err := mapping.GetThreadBlocks(2)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(len(rr), ShouldEqual, 1)
|
||||
So(rr[0], ShouldResemble, r)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAllocWithExtraContentsBlock(t *testing.T) {
|
||||
Convey("Creating a non-existent file should succeed", t, func() {
|
||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Mapping the file should suceed", func() {
|
||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Allocation of 10 pages should suceed", func() {
|
||||
allocated := make([]EdfRange, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
r, err := mapping.AllocPages(1, 2)
|
||||
So(err, ShouldEqual, nil)
|
||||
allocated[i] = r
|
||||
}
|
||||
Convey("Unmapping the file should suceed", func() {
|
||||
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Remapping the file should suceed", func() {
|
||||
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
|
||||
Convey("Should get the same allocations back", func() {
|
||||
rr, err := mapping.GetThreadBlocks(2)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(len(rr), ShouldEqual, 10)
|
||||
So(rr, ShouldResemble, allocated)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
40
base/edf/edf.go
Normal file
40
base/edf/edf.go
Normal file
@ -0,0 +1,40 @@
|
||||
package edf
|
||||
|
||||
// map.go: handles mmaping, truncation, header creation, verification,
|
||||
// creation of initial thread contents block (todo)
|
||||
// creation of initial thread metadata block (todo)
|
||||
// thread.go: handles extending thread contents block (todo)
|
||||
// extending thread metadata block (todo), adding threads (todo),
|
||||
// retrieving the segments and offsets relevant to a thread (todo)
|
||||
// resolution of threads by name (todo)
|
||||
// appending data to a thread (todo)
|
||||
// deleting threads (todo)
|
||||
|
||||
const (
|
||||
// EDF_VERSION is the file format version
|
||||
EDF_VERSION = 1
|
||||
// EDF_LENGTH is th number of OS pages in each slice
|
||||
EDF_LENGTH = 32
|
||||
// EDF_SIZE sets the maximum size of the mapping, represented with
|
||||
// EDF_LENGTH segments
|
||||
// Currently set arbitrarily to 128 MiB
|
||||
EDF_SIZE = 128 * (1024 * 1024)
|
||||
)
|
||||
|
||||
const (
|
||||
// EDF_READ_ONLY means the file will only be read, modifications fail
|
||||
EDF_READ_ONLY = iota
|
||||
// EDF_READ_WRITE specifies that the file will be read and written
|
||||
EDF_READ_WRITE
|
||||
// EDF_CREATE means the file will be created and opened with EDF_READ_WRITE
|
||||
EDF_CREATE
|
||||
)
|
||||
|
||||
const (
|
||||
// EDF_UNMAP_NOSYNC means the file won't be
|
||||
// Sync'd to disk before unmapping
|
||||
EDF_UNMAP_NOSYNC = iota
|
||||
// EDF_UNMAP_SYNC synchronises the EDF file to disk
|
||||
// during unmapping
|
||||
EDF_UNMAP_SYNC
|
||||
)
|
386
base/edf/map.go
Normal file
386
base/edf/map.go
Normal file
@ -0,0 +1,386 @@
|
||||
package edf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
mmap "github.com/riobard/go-mmap"
|
||||
"os"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// EdfFile represents a mapped file on disk or
|
||||
// and anonymous mapping for instance storage
|
||||
type EdfFile struct {
|
||||
f *os.File
|
||||
m []mmap.Mmap
|
||||
segmentSize uint64
|
||||
pageSize uint64
|
||||
}
|
||||
|
||||
// GetPageSize returns the pageSize of an EdfFile
|
||||
func (e *EdfFile) GetPageSize() uint64 {
|
||||
return e.pageSize
|
||||
}
|
||||
|
||||
// GetSegmentSize returns the segmentSize of an EdfFile
|
||||
func (e *EdfFile) GetSegmentSize() uint64 {
|
||||
return e.segmentSize
|
||||
}
|
||||
|
||||
// EdfPosition represents a start and finish point
|
||||
// within the mapping
|
||||
type EdfPosition struct {
|
||||
Segment uint64
|
||||
Byte uint64
|
||||
}
|
||||
|
||||
// EdfRange represents a start and an end segment
|
||||
// mapped in an EdfFile and also the byte offsets
|
||||
// within that segment
|
||||
type EdfRange struct {
|
||||
Start EdfPosition
|
||||
End EdfPosition
|
||||
segmentSize uint64
|
||||
}
|
||||
|
||||
// Size returns the size (in bytes) of a given EdfRange
|
||||
func (r *EdfRange) Size() uint64 {
|
||||
ret := uint64(r.End.Segment-r.Start.Segment) * r.segmentSize
|
||||
ret += uint64(r.End.Byte - r.Start.Byte)
|
||||
return ret
|
||||
}
|
||||
|
||||
// edfCallFree is a half-baked finalizer called on garbage
|
||||
// collection to ensure that the mapping gets freed
|
||||
func edfCallFree(e *EdfFile) {
|
||||
e.Unmap(EDF_UNMAP_NOSYNC)
|
||||
}
|
||||
|
||||
// EdfAnonMap maps the EdfFile structure into RAM
|
||||
// IMPORTANT: everything's lost if unmapped
|
||||
func EdfAnonMap() (*EdfFile, error) {
|
||||
|
||||
var err error
|
||||
|
||||
ret := new(EdfFile)
|
||||
|
||||
// Figure out the flags
|
||||
protFlags := mmap.PROT_READ | mmap.PROT_WRITE
|
||||
mapFlags := mmap.MAP_FILE | mmap.MAP_SHARED
|
||||
// Create mapping references
|
||||
ret.m = make([]mmap.Mmap, 0)
|
||||
// Get the page size
|
||||
pageSize := int64(os.Getpagesize())
|
||||
// Segment size is the size of each mapped region
|
||||
ret.pageSize = uint64(pageSize)
|
||||
ret.segmentSize = uint64(EDF_LENGTH) * uint64(os.Getpagesize())
|
||||
|
||||
// Map the memory
|
||||
for i := int64(0); i < EDF_SIZE; i += int64(EDF_LENGTH) * pageSize {
|
||||
thisMapping, err := mmap.AnonMap(int(ret.segmentSize), protFlags, mapFlags)
|
||||
if err != nil {
|
||||
// TODO: cleanup
|
||||
return nil, err
|
||||
}
|
||||
ret.m = append(ret.m, thisMapping)
|
||||
}
|
||||
|
||||
// Generate the header
|
||||
ret.createHeader()
|
||||
err = ret.writeInitialData()
|
||||
|
||||
// Make sure this gets unmapped on garbage collection
|
||||
runtime.SetFinalizer(ret, edfCallFree)
|
||||
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// EdfMap takes an os.File and returns an EdfMappedFile
|
||||
// structure, which represents the mmap'd underlying file
|
||||
//
|
||||
// The `mode` parameter takes the following values
|
||||
// EDF_CREATE: EdfMap will truncate the file to the right length and write the correct header information
|
||||
// EDF_READ_WRITE: EdfMap will verify header information
|
||||
// EDF_READ_ONLY: EdfMap will verify header information
|
||||
// IMPORTANT: EDF_LENGTH (edf.go) controls the size of the address
|
||||
// space mapping. This means that the file can be truncated to the
|
||||
// correct size without remapping. On 32-bit systems, this
|
||||
// is set to 2GiB.
|
||||
func EdfMap(f *os.File, mode int) (*EdfFile, error) {
|
||||
var err error
|
||||
|
||||
// Set up various things
|
||||
ret := new(EdfFile)
|
||||
ret.f = f
|
||||
ret.m = make([]mmap.Mmap, 0)
|
||||
|
||||
// Figure out the flags
|
||||
protFlags := mmap.PROT_READ
|
||||
if mode == EDF_READ_WRITE || mode == EDF_CREATE {
|
||||
protFlags |= mmap.PROT_WRITE
|
||||
}
|
||||
mapFlags := mmap.MAP_FILE | mmap.MAP_SHARED
|
||||
// Get the page size
|
||||
pageSize := int64(os.Getpagesize())
|
||||
// Segment size is the size of each mapped region
|
||||
ret.pageSize = uint64(pageSize)
|
||||
ret.segmentSize = uint64(EDF_LENGTH) * uint64(os.Getpagesize())
|
||||
|
||||
// Map the file
|
||||
for i := int64(0); i < EDF_SIZE; i += int64(EDF_LENGTH) * pageSize {
|
||||
thisMapping, err := mmap.Map(f, i*pageSize, int(int64(EDF_LENGTH)*pageSize), protFlags, mapFlags)
|
||||
if err != nil {
|
||||
// TODO: cleanup
|
||||
return nil, err
|
||||
}
|
||||
ret.m = append(ret.m, thisMapping)
|
||||
}
|
||||
|
||||
// Verify or generate the header
|
||||
if mode == EDF_READ_WRITE || mode == EDF_READ_ONLY {
|
||||
err = ret.VerifyHeader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if mode == EDF_CREATE {
|
||||
err = ret.truncate(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret.createHeader()
|
||||
err = ret.writeInitialData()
|
||||
} else {
|
||||
err = fmt.Errorf("Unrecognised flags")
|
||||
}
|
||||
|
||||
// Make sure this gets unmapped on garbage collection
|
||||
runtime.SetFinalizer(ret, edfCallFree)
|
||||
|
||||
return ret, err
|
||||
|
||||
}
|
||||
|
||||
// Range returns the segment offset and range of
|
||||
// two positions in the file.
|
||||
func (e *EdfFile) Range(byteStart uint64, byteEnd uint64) EdfRange {
|
||||
var ret EdfRange
|
||||
ret.Start.Segment = byteStart / e.segmentSize
|
||||
ret.End.Segment = byteEnd / e.segmentSize
|
||||
ret.Start.Byte = byteStart % e.segmentSize
|
||||
ret.End.Byte = byteEnd % e.segmentSize
|
||||
ret.segmentSize = e.segmentSize
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetPageRange returns the segment offset and range of
|
||||
// two pages in the file.
|
||||
func (e *EdfFile) GetPageRange(pageStart uint64, pageEnd uint64) EdfRange {
|
||||
return e.Range(pageStart*e.pageSize, pageEnd*e.pageSize+e.pageSize-1)
|
||||
}
|
||||
|
||||
// VerifyHeader checks that this version of Golearn can
|
||||
// read the file presented.
|
||||
func (e *EdfFile) VerifyHeader() error {
|
||||
// Check the magic bytes
|
||||
diff := (e.m[0][0] ^ byte('G')) | (e.m[0][1] ^ byte('O'))
|
||||
diff |= (e.m[0][2] ^ byte('L')) | (e.m[0][3] ^ byte('N'))
|
||||
if diff != 0 {
|
||||
return fmt.Errorf("Invalid magic bytes")
|
||||
}
|
||||
// Check the file version
|
||||
version := uint32FromBytes(e.m[0][4:8])
|
||||
if version != EDF_VERSION {
|
||||
return fmt.Errorf("Unsupported version: %d", version)
|
||||
}
|
||||
// Check the page size
|
||||
pageSize := uint32FromBytes(e.m[0][8:12])
|
||||
if pageSize != uint32(os.Getpagesize()) {
|
||||
return fmt.Errorf("Unsupported page size: (file: %d, system: %d", pageSize, os.Getpagesize())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createHeader writes a valid header file into the file.
|
||||
// Unexported since it can cause data loss.
|
||||
func (e *EdfFile) createHeader() {
|
||||
e.m[0][0] = byte('G')
|
||||
e.m[0][1] = byte('O')
|
||||
e.m[0][2] = byte('L')
|
||||
e.m[0][3] = byte('N')
|
||||
uint32ToBytes(EDF_VERSION, e.m[0][4:8])
|
||||
uint32ToBytes(uint32(os.Getpagesize()), e.m[0][8:12])
|
||||
e.Sync()
|
||||
}
|
||||
|
||||
// writeInitialData writes system thread information
|
||||
func (e *EdfFile) writeInitialData() error {
|
||||
var t Thread
|
||||
t.name = "SYSTEM"
|
||||
t.id = 1
|
||||
err := e.WriteThread(&t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.name = "FIXED"
|
||||
t.id = 2
|
||||
err = e.WriteThread(&t)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetThreadCount returns the number of threads in this file.
|
||||
func (e *EdfFile) GetThreadCount() uint32 {
|
||||
// The number of threads is stored in bytes 12-16 in the header
|
||||
return uint32FromBytes(e.m[0][12:])
|
||||
}
|
||||
|
||||
// incrementThreadCount increments the record of the number
|
||||
// of threads in this file
|
||||
func (e *EdfFile) incrementThreadCount() uint32 {
|
||||
cur := e.GetThreadCount()
|
||||
cur++
|
||||
uint32ToBytes(cur, e.m[0][12:])
|
||||
return cur
|
||||
}
|
||||
|
||||
// GetThreads returns the thread identifier -> name map.
|
||||
func (e *EdfFile) GetThreads() (map[uint32]string, error) {
|
||||
ret := make(map[uint32]string)
|
||||
count := e.GetThreadCount()
|
||||
// The starting block
|
||||
block := uint64(1)
|
||||
for {
|
||||
// Decode the block offset
|
||||
r := e.GetPageRange(block, block)
|
||||
if r.Start.Segment != r.End.Segment {
|
||||
return nil, fmt.Errorf("Thread range split across segments")
|
||||
}
|
||||
bytes := e.m[r.Start.Segment]
|
||||
bytes = bytes[r.Start.Byte : r.End.Byte+1]
|
||||
// The first 8 bytes say where to go next
|
||||
block = uint64FromBytes(bytes)
|
||||
bytes = bytes[8:]
|
||||
|
||||
for {
|
||||
length := uint32FromBytes(bytes)
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
t := &Thread{}
|
||||
size := t.Deserialize(bytes)
|
||||
bytes = bytes[size:]
|
||||
ret[t.id] = t.name[0:len(t.name)]
|
||||
}
|
||||
// If next block offset is zero, no more threads to read
|
||||
if block == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Hey? What's wrong with you!
|
||||
if len(ret) != int(count) {
|
||||
return ret, fmt.Errorf("Thread mismatch: %d/%d, indicates possible corruption", len(ret), count)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Sync writes information to physical storage.
|
||||
func (e *EdfFile) Sync() error {
|
||||
for _, m := range e.m {
|
||||
err := m.Sync(mmap.MS_SYNC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// truncate changes the size of the underlying file
|
||||
// The size of the address space doesn't change.
|
||||
func (e *EdfFile) truncate(size int64) error {
|
||||
pageSize := int64(os.Getpagesize())
|
||||
newSize := pageSize * size
|
||||
|
||||
// Synchronise
|
||||
// e.Sync()
|
||||
|
||||
// Double-check that we're not reducing file size
|
||||
fileInfo, err := e.f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fileInfo.Size() > newSize {
|
||||
return fmt.Errorf("Can't reduce file size!")
|
||||
}
|
||||
|
||||
// Truncate the file
|
||||
err = e.f.Truncate(newSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify that the file is larger now than it was
|
||||
fileInfo, err = e.f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fileInfo.Size() != newSize {
|
||||
return fmt.Errorf("Truncation failed: %d, %d", fileInfo.Size(), newSize)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Unmap unlinks the EdfFile from the address space.
|
||||
// EDF_UNMAP_NOSYNC skips calling Sync() on the underlying
|
||||
// file before this happens.
|
||||
// IMPORTANT: attempts to use this mapping after Unmap() is
|
||||
// called will result in crashes.
|
||||
func (e *EdfFile) Unmap(flags int) error {
|
||||
// Sync the file
|
||||
e.Sync()
|
||||
if flags != EDF_UNMAP_NOSYNC {
|
||||
e.Sync()
|
||||
}
|
||||
// Unmap the file
|
||||
for _, m := range e.m {
|
||||
err := m.Unmap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveRange returns a slice of byte slices representing
|
||||
// the underlying memory referenced by EdfRange.
|
||||
//
|
||||
// WARNING: slow.
|
||||
func (e *EdfFile) ResolveRange(r EdfRange) [][]byte {
|
||||
var ret [][]byte
|
||||
segCounter := 0
|
||||
for segment := r.Start.Segment; segment <= r.End.Segment; segment++ {
|
||||
if segment == r.Start.Segment {
|
||||
ret = append(ret, e.m[segment][r.Start.Byte:])
|
||||
} else if segment == r.End.Segment {
|
||||
ret = append(ret, e.m[segment][:r.End.Byte+1])
|
||||
} else {
|
||||
ret = append(ret, e.m[segment])
|
||||
}
|
||||
segCounter++
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// IResolveRange returns a byte slice representing the current EdfRange
|
||||
// and returns a value saying whether there's more. Subsequent calls to IncrementallyResolveRange
|
||||
// should use the value returned by the previous one until no more ranges are available.
|
||||
func (e *EdfFile) IResolveRange(r EdfRange, prev uint64) ([]byte, uint64) {
|
||||
segment := r.Start.Segment + prev
|
||||
if segment > r.End.Segment {
|
||||
return nil, 0
|
||||
}
|
||||
if segment == r.Start.Segment {
|
||||
return e.m[segment][r.Start.Byte:], prev + 1
|
||||
}
|
||||
if segment == r.End.Segment {
|
||||
return e.m[segment][:r.End.Byte+1], 0
|
||||
}
|
||||
return e.m[segment], prev + 1
|
||||
}
|
118
base/edf/map_test.go
Normal file
118
base/edf/map_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package edf
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAnonMap(t *testing.T) {
|
||||
Convey("Anonymous mapping should suceed", t, func() {
|
||||
mapping, err := EdfAnonMap()
|
||||
So(err, ShouldEqual, nil)
|
||||
bytes := mapping.m[0]
|
||||
// Read the magic bytes
|
||||
magic := bytes[0:4]
|
||||
Convey("Magic bytes should be correct", func() {
|
||||
So(magic[0], ShouldEqual, byte('G'))
|
||||
So(magic[1], ShouldEqual, byte('O'))
|
||||
So(magic[2], ShouldEqual, byte('L'))
|
||||
So(magic[3], ShouldEqual, byte('N'))
|
||||
})
|
||||
// Read the file version
|
||||
versionBytes := bytes[4:8]
|
||||
Convey("Version should be correct", func() {
|
||||
version := uint32FromBytes(versionBytes)
|
||||
So(version, ShouldEqual, EDF_VERSION)
|
||||
})
|
||||
// Read the block size
|
||||
blockBytes := bytes[8:12]
|
||||
Convey("Page size should be correct", func() {
|
||||
pageSize := uint32FromBytes(blockBytes)
|
||||
So(pageSize, ShouldEqual, os.Getpagesize())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileCreate(t *testing.T) {
|
||||
Convey("Creating a non-existent file should succeed", t, func() {
|
||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Mapping the file should suceed", func() {
|
||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Unmapping the file should suceed", func() {
|
||||
err = mapping.Unmap(EDF_UNMAP_SYNC)
|
||||
So(err, ShouldEqual, nil)
|
||||
})
|
||||
|
||||
// Read the magic bytes
|
||||
magic := make([]byte, 4)
|
||||
read, err := tempFile.ReadAt(magic, 0)
|
||||
Convey("Magic bytes should be correct", func() {
|
||||
So(err, ShouldEqual, nil)
|
||||
So(read, ShouldEqual, 4)
|
||||
So(magic[0], ShouldEqual, byte('G'))
|
||||
So(magic[1], ShouldEqual, byte('O'))
|
||||
So(magic[2], ShouldEqual, byte('L'))
|
||||
So(magic[3], ShouldEqual, byte('N'))
|
||||
})
|
||||
// Read the file version
|
||||
versionBytes := make([]byte, 4)
|
||||
read, err = tempFile.ReadAt(versionBytes, 4)
|
||||
Convey("Version should be correct", func() {
|
||||
So(err, ShouldEqual, nil)
|
||||
So(read, ShouldEqual, 4)
|
||||
version := uint32FromBytes(versionBytes)
|
||||
So(version, ShouldEqual, EDF_VERSION)
|
||||
})
|
||||
// Read the block size
|
||||
blockBytes := make([]byte, 4)
|
||||
read, err = tempFile.ReadAt(blockBytes, 8)
|
||||
Convey("Page size should be correct", func() {
|
||||
So(err, ShouldEqual, nil)
|
||||
So(read, ShouldEqual, 4)
|
||||
pageSize := uint32FromBytes(blockBytes)
|
||||
So(pageSize, ShouldEqual, os.Getpagesize())
|
||||
})
|
||||
// Check the file size is at least four * page size
|
||||
info, err := tempFile.Stat()
|
||||
Convey("File should be the right size", func() {
|
||||
So(err, ShouldEqual, nil)
|
||||
So(info.Size(), ShouldBeGreaterThanOrEqualTo, 4*os.Getpagesize())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileThreadCounter(t *testing.T) {
|
||||
Convey("Creating a non-existent file should succeed", t, func() {
|
||||
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Mapping the file should suceed", func() {
|
||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("The file should have two threads to start with", func() {
|
||||
count := mapping.GetThreadCount()
|
||||
So(count, ShouldEqual, 2)
|
||||
Convey("They should be SYSTEM and FIXED", func() {
|
||||
threads, err := mapping.GetThreads()
|
||||
So(err, ShouldEqual, nil)
|
||||
So(len(threads), ShouldEqual, 2)
|
||||
So(threads[1], ShouldEqual, "SYSTEM")
|
||||
So(threads[2], ShouldEqual, "FIXED")
|
||||
})
|
||||
})
|
||||
Convey("Incrementing the threadcount should result in three threads", func() {
|
||||
mapping.incrementThreadCount()
|
||||
count := mapping.GetThreadCount()
|
||||
So(count, ShouldEqual, 3)
|
||||
Convey("Thread information should indicate corruption", func() {
|
||||
_, err := mapping.GetThreads()
|
||||
So(err, ShouldNotEqual, nil)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
137
base/edf/thread.go
Normal file
137
base/edf/thread.go
Normal file
@ -0,0 +1,137 @@
|
||||
package edf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Threads are streams of data encapsulated within the file.
|
||||
type Thread struct {
|
||||
name string
|
||||
id uint32
|
||||
}
|
||||
|
||||
// NewThread returns a new thread.
|
||||
func NewThread(e *EdfFile, name string) *Thread {
|
||||
return &Thread{name, e.GetThreadCount() + 1}
|
||||
}
|
||||
|
||||
// GetSpaceNeeded the number of bytes needed to serialize this
|
||||
// Thread.
|
||||
func (t *Thread) GetSpaceNeeded() int {
|
||||
return 8 + len(t.name)
|
||||
}
|
||||
|
||||
// Serialize copies this thread to the output byte slice
|
||||
// Returns the number of bytes used.
|
||||
func (t *Thread) Serialize(out []byte) int {
|
||||
// ret keeps track of written bytes
|
||||
ret := 0
|
||||
// Write the length of the name first
|
||||
nameLength := len(t.name)
|
||||
uint32ToBytes(uint32(nameLength), out)
|
||||
out = out[4:]
|
||||
ret += 4
|
||||
// Then write the string
|
||||
copy(out, t.name)
|
||||
out = out[nameLength:]
|
||||
ret += nameLength
|
||||
// Then the thread number
|
||||
uint32ToBytes(t.id, out)
|
||||
ret += 4
|
||||
return ret
|
||||
}
|
||||
|
||||
// Deserialize copies the input byte slice into a thread.
|
||||
func (t *Thread) Deserialize(out []byte) int {
|
||||
ret := 0
|
||||
// Read the length of the thread's name
|
||||
nameLength := uint32FromBytes(out)
|
||||
ret += 4
|
||||
out = out[4:]
|
||||
// Copy out the string
|
||||
t.name = string(out[:nameLength])
|
||||
ret += int(nameLength)
|
||||
out = out[nameLength:]
|
||||
// Read the identifier
|
||||
t.id = uint32FromBytes(out)
|
||||
ret += 4
|
||||
return ret
|
||||
}
|
||||
|
||||
// FindThread obtains the index of a thread in the EdfFile.
|
||||
func (e *EdfFile) FindThread(targetName string) (uint32, error) {
|
||||
|
||||
var offset uint32
|
||||
var counter uint32
|
||||
// Resolve the initial thread block
|
||||
blockRange := e.GetPageRange(1, 1)
|
||||
if blockRange.Start.Segment != blockRange.End.Segment {
|
||||
return 0, fmt.Errorf("Thread block split across segments!")
|
||||
}
|
||||
bytes := e.m[blockRange.Start.Segment][blockRange.Start.Byte:blockRange.End.Byte]
|
||||
// Skip the first 8 bytes, since we don't support multiple thread blocks yet
|
||||
// TODO: fix that
|
||||
bytes = bytes[8:]
|
||||
counter = 1
|
||||
for {
|
||||
length := uint32FromBytes(bytes)
|
||||
if length == 0 {
|
||||
return 0, fmt.Errorf("No matching threads")
|
||||
}
|
||||
|
||||
name := string(bytes[4 : 4+length])
|
||||
if name == targetName {
|
||||
offset = counter
|
||||
break
|
||||
}
|
||||
|
||||
bytes = bytes[8+length:]
|
||||
counter++
|
||||
}
|
||||
|
||||
return offset, nil
|
||||
}
|
||||
|
||||
// WriteThread inserts a new thread into the EdfFile.
|
||||
func (e *EdfFile) WriteThread(t *Thread) error {
|
||||
|
||||
offset, _ := e.FindThread(t.name)
|
||||
if offset != 0 {
|
||||
return fmt.Errorf("Writing a duplicate thread")
|
||||
}
|
||||
// Resolve the initial Thread block
|
||||
blockRange := e.GetPageRange(1, 1)
|
||||
if blockRange.Start.Segment != blockRange.End.Segment {
|
||||
return fmt.Errorf("Thread block split across segments!")
|
||||
}
|
||||
bytes := e.m[blockRange.Start.Segment][blockRange.Start.Byte:blockRange.End.Byte]
|
||||
// Skip the first 8 bytes, since we don't support multiple thread blocks yet
|
||||
// TODO: fix that
|
||||
bytes = bytes[8:]
|
||||
cur := 0
|
||||
for {
|
||||
length := uint32FromBytes(bytes)
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
cur += 8 + int(length)
|
||||
bytes = bytes[8+length:]
|
||||
}
|
||||
// cur should have now found an empty offset
|
||||
// Check that we have enough room left to insert
|
||||
roomLeft := len(bytes)
|
||||
roomNeeded := t.GetSpaceNeeded()
|
||||
if roomLeft < roomNeeded {
|
||||
return fmt.Errorf("Not enough space available")
|
||||
}
|
||||
// If everything's fine, serialise
|
||||
t.Serialize(bytes)
|
||||
// Increment thread count
|
||||
e.incrementThreadCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetId returns this Thread's identifier.
|
||||
func (t *Thread) GetId() uint32 {
|
||||
return t.id
|
||||
}
|
59
base/edf/thread_test.go
Normal file
59
base/edf/thread_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
package edf
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
"os"
|
||||
)
|
||||
|
||||
func TestThreadDeserialize(T *testing.T) {
|
||||
bytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
|
||||
Convey("Given a byte slice", T, func() {
|
||||
var t Thread
|
||||
size := t.Deserialize(bytes)
|
||||
Convey("Decoded name should be SYSTEM", func() {
|
||||
So(t.name, ShouldEqual, "SYSTEM")
|
||||
})
|
||||
Convey("Size should be the same as the array", func() {
|
||||
So(size, ShouldEqual, len(bytes))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestThreadSerialize(T *testing.T) {
|
||||
var t Thread
|
||||
refBytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
|
||||
t.name = "SYSTEM"
|
||||
t.id = 1
|
||||
toBytes := make([]byte, len(refBytes))
|
||||
Convey("Should serialize correctly", T, func() {
|
||||
t.Serialize(toBytes)
|
||||
So(toBytes, ShouldResemble, refBytes)
|
||||
})
|
||||
}
|
||||
|
||||
func TestThreadFindAndWrite(T *testing.T) {
|
||||
Convey("Creating a non-existent file should succeed", T, func() {
|
||||
tempFile, err := os.OpenFile("hello.db", os.O_RDWR | os.O_TRUNC | os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate")
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Mapping the file should suceed", func() {
|
||||
mapping, err := EdfMap(tempFile, EDF_CREATE)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Writing the thread should succeed", func () {
|
||||
t := NewThread(mapping, "MyNameISWhat")
|
||||
Convey("Thread number should be 3", func () {
|
||||
So(t.id, ShouldEqual, 3)
|
||||
})
|
||||
Convey("Writing the thread should succeed", func() {
|
||||
err := mapping.WriteThread(t)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Should be able to find the thread again later", func() {
|
||||
id, err := mapping.FindThread("MyNameISWhat")
|
||||
So(err, ShouldEqual, nil)
|
||||
So(id, ShouldEqual, 3)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
32
base/edf/util.go
Normal file
32
base/edf/util.go
Normal file
@ -0,0 +1,32 @@
|
||||
package edf
|
||||
|
||||
func uint64ToBytes(in uint64, out []byte) {
|
||||
var i uint64
|
||||
for i = 0; i < 8; i++ {
|
||||
out[7-i] = byte((in & (0xFF << i * 8) >> i * 8))
|
||||
}
|
||||
}
|
||||
|
||||
func uint64FromBytes(in []byte) uint64 {
|
||||
var i uint64
|
||||
out := uint64(0)
|
||||
for i = 0; i < 8; i++ {
|
||||
out |= uint64(in[7-i] << uint64(i*0x8))
|
||||
}
|
||||
return out
|
||||
}
|
||||
func uint32FromBytes(in []byte) uint32 {
|
||||
ret := uint32(0)
|
||||
ret |= uint32(in[0]) << 24
|
||||
ret |= uint32(in[1]) << 16
|
||||
ret |= uint32(in[2]) << 8
|
||||
ret |= uint32(in[3])
|
||||
return ret
|
||||
}
|
||||
|
||||
func uint32ToBytes(in uint32, out []byte) {
|
||||
out[0] = byte(in & (0xFF << 24) >> 24)
|
||||
out[1] = byte(in & (0xFF << 16) >> 16)
|
||||
out[2] = byte(in & (0xFF << 8) >> 8)
|
||||
out[3] = byte(in & 0xFF)
|
||||
}
|
20
base/edf/util_test.go
Normal file
20
base/edf/util_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package edf
|
||||
|
||||
// Utility function tests
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInt32Conversion(t *testing.T) {
|
||||
Convey("Given deadbeef", t, func() {
|
||||
buf := make([]byte, 4)
|
||||
original := uint32(0xDEAD)
|
||||
uint32ToBytes(original, buf)
|
||||
converted := uint32FromBytes(buf)
|
||||
Convey("Decoded value should be the original...", func() {
|
||||
So(converted, ShouldEqual, original)
|
||||
})
|
||||
})
|
||||
}
|
258
base/filtered.go
Normal file
258
base/filtered.go
Normal file
@ -0,0 +1,258 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Maybe included a TransformedAttribute struct
|
||||
// so we can map from ClassAttribute to ClassAttribute
|
||||
|
||||
// LazilyFilteredInstances map a Filter over an underlying
|
||||
// FixedDataGrid and are a memory-efficient way of applying them.
|
||||
type LazilyFilteredInstances struct {
|
||||
filter Filter
|
||||
src FixedDataGrid
|
||||
attrs []FilteredAttribute
|
||||
classAttrs map[Attribute]bool
|
||||
unfilteredMap map[Attribute]bool
|
||||
}
|
||||
|
||||
// NewLazilyFitleredInstances returns a new FixedDataGrid after
|
||||
// applying the given Filter to the Attributes it includes. Unfiltered
|
||||
// Attributes are passed through without modification.
|
||||
func NewLazilyFilteredInstances(src FixedDataGrid, f Filter) *LazilyFilteredInstances {
|
||||
|
||||
// Get the Attributes after filtering
|
||||
attrs := f.GetAttributesAfterFiltering()
|
||||
|
||||
// Build a set of Attributes which have undergone filtering
|
||||
unFilteredMap := make(map[Attribute]bool)
|
||||
for _, a := range src.AllAttributes() {
|
||||
unFilteredMap[a] = true
|
||||
}
|
||||
for _, a := range attrs {
|
||||
unFilteredMap[a.Old] = false
|
||||
}
|
||||
|
||||
// Create the return structure
|
||||
ret := &LazilyFilteredInstances{
|
||||
f,
|
||||
src,
|
||||
attrs,
|
||||
make(map[Attribute]bool),
|
||||
unFilteredMap,
|
||||
}
|
||||
|
||||
// Transfer class Attributes
|
||||
for _, a := range src.AllClassAttributes() {
|
||||
ret.AddClassAttribute(a)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetAttribute returns an AttributeSpecification for a given Attribute
|
||||
func (l *LazilyFilteredInstances) GetAttribute(target Attribute) (AttributeSpec, error) {
|
||||
if l.unfilteredMap[target] {
|
||||
return l.src.GetAttribute(target)
|
||||
}
|
||||
var ret AttributeSpec
|
||||
ret.pond = -1
|
||||
for i, a := range l.attrs {
|
||||
if a.New.Equals(target) {
|
||||
ret.position = i
|
||||
ret.attr = target
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
return ret, fmt.Errorf("Couldn't resolve %s", target)
|
||||
}
|
||||
|
||||
// AllAttributes returns every Attribute defined in the source datagrid,
|
||||
// in addition to the revised Attributes created by the filter.
|
||||
func (l *LazilyFilteredInstances) AllAttributes() []Attribute {
|
||||
ret := make([]Attribute, 0)
|
||||
for _, a := range l.src.AllAttributes() {
|
||||
if l.unfilteredMap[a] {
|
||||
ret = append(ret, a)
|
||||
} else {
|
||||
for _, b := range l.attrs {
|
||||
if a.Equals(b.Old) {
|
||||
ret = append(ret, b.New)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddClassAttribute adds a given Attribute (either before or after filtering)
|
||||
// to the set of defined class Attributes.
|
||||
func (l *LazilyFilteredInstances) AddClassAttribute(cls Attribute) error {
|
||||
if l.unfilteredMap[cls] {
|
||||
l.classAttrs[cls] = true
|
||||
return nil
|
||||
}
|
||||
for _, a := range l.attrs {
|
||||
if a.Old.Equals(cls) || a.New.Equals(cls) {
|
||||
l.classAttrs[a.New] = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("Attribute %s could not be resolved", cls)
|
||||
}
|
||||
|
||||
// RemoveClassAttribute removes a given Attribute (either before or
|
||||
// after filtering) from the set of defined class Attributes.
|
||||
func (l *LazilyFilteredInstances) RemoveClassAttribute(cls Attribute) error {
|
||||
if l.unfilteredMap[cls] {
|
||||
l.classAttrs[cls] = false
|
||||
return nil
|
||||
}
|
||||
for _, a := range l.attrs {
|
||||
if a.Old.Equals(cls) || a.New.Equals(cls) {
|
||||
l.classAttrs[a.New] = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("Attribute %s could not be resolved", cls)
|
||||
}
|
||||
|
||||
// AllClassAttributes returns details of all Attributes currently specified
|
||||
// as being class Attributes.
|
||||
//
|
||||
// If applicable, the Attributes returned are those after modification
|
||||
// by the Filter.
|
||||
func (l *LazilyFilteredInstances) AllClassAttributes() []Attribute {
|
||||
ret := make([]Attribute, 0)
|
||||
for a := range l.classAttrs {
|
||||
if l.classAttrs[a] {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (l *LazilyFilteredInstances) transformNewToOldAttribute(as AttributeSpec) (AttributeSpec, error) {
|
||||
if l.unfilteredMap[as.GetAttribute()] {
|
||||
return as, nil
|
||||
}
|
||||
for _, a := range l.attrs {
|
||||
if a.Old.Equals(as.attr) || a.New.Equals(as.attr) {
|
||||
as, err := l.src.GetAttribute(a.Old)
|
||||
if err != nil {
|
||||
return AttributeSpec{}, fmt.Errorf("Internal error in Attribute resolution: '%s'", err)
|
||||
}
|
||||
return as, nil
|
||||
}
|
||||
}
|
||||
return AttributeSpec{}, fmt.Errorf("No matching Attribute")
|
||||
}
|
||||
|
||||
// Get returns a transformed byte slice stored at a given AttributeSpec and row.
|
||||
func (l *LazilyFilteredInstances) Get(as AttributeSpec, row int) []byte {
|
||||
asOld, err := l.transformNewToOldAttribute(as)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Attribute %s could not be resolved. (Error: %s)", as, err))
|
||||
}
|
||||
byteSeq := l.src.Get(asOld, row)
|
||||
if l.unfilteredMap[as.attr] {
|
||||
return byteSeq
|
||||
}
|
||||
newByteSeq := l.filter.Transform(asOld.attr, as.attr, byteSeq)
|
||||
return newByteSeq
|
||||
}
|
||||
|
||||
// MapOverRows maps an iteration mapFunc over the bytes contained in the source
|
||||
// FixedDataGrid, after modification by the filter.
|
||||
func (l *LazilyFilteredInstances) MapOverRows(asv []AttributeSpec, mapFunc func([][]byte, int) (bool, error)) error {
|
||||
|
||||
// Have to transform each item of asv into an
|
||||
// AttributeSpec in the original
|
||||
oldAsv := make([]AttributeSpec, len(asv))
|
||||
for i, a := range asv {
|
||||
old, err := l.transformNewToOldAttribute(a)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't fetch old Attribute: '%s'", a)
|
||||
}
|
||||
oldAsv[i] = old
|
||||
}
|
||||
|
||||
// Then map over each row in the original
|
||||
newRowBuf := make([][]byte, len(asv))
|
||||
return l.src.MapOverRows(oldAsv, func(oldRow [][]byte, oldRowNo int) (bool, error) {
|
||||
for i, b := range oldRow {
|
||||
newField := l.filter.Transform(oldAsv[i].attr, asv[i].attr, b)
|
||||
newRowBuf[i] = newField
|
||||
}
|
||||
return mapFunc(newRowBuf, oldRowNo)
|
||||
})
|
||||
}
|
||||
|
||||
// RowString returns a string representation of a given row
|
||||
// after filtering.
|
||||
func (l *LazilyFilteredInstances) RowString(row int) string {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
as := ResolveAllAttributes(l) // Retrieve all Attribute data
|
||||
first := true // Decide whether to prefix
|
||||
|
||||
for _, a := range as {
|
||||
prefix := " " // What to print before value
|
||||
if first {
|
||||
first = false // Don't print space on first value
|
||||
prefix = ""
|
||||
}
|
||||
val := l.Get(a, row) // Retrieve filtered value
|
||||
buffer.WriteString(fmt.Sprintf("%s%s", prefix, a.attr.GetStringFromSysVal(val)))
|
||||
}
|
||||
|
||||
return buffer.String() // Return the result
|
||||
}
|
||||
|
||||
// Size returns the number of Attributes and rows of the underlying
|
||||
// FixedDataGrid.
|
||||
func (l *LazilyFilteredInstances) Size() (int, int) {
|
||||
return l.src.Size()
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this FixedDataGrid
|
||||
// after filtering.
|
||||
func (l *LazilyFilteredInstances) String() string {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
// Decide on rows to print
|
||||
_, rows := l.Size()
|
||||
maxRows := 5
|
||||
if rows < maxRows {
|
||||
maxRows = rows
|
||||
}
|
||||
|
||||
// Get all Attribute information
|
||||
as := ResolveAllAttributes(l)
|
||||
|
||||
// Print header
|
||||
buffer.WriteString("Lazily filtered instances using ")
|
||||
buffer.WriteString(fmt.Sprintf("%s\n", l.filter))
|
||||
buffer.WriteString(fmt.Sprintf("Attributes: \n"))
|
||||
|
||||
for _, a := range as {
|
||||
prefix := "\t"
|
||||
if l.classAttrs[a.attr] {
|
||||
prefix = "*\t"
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a.attr))
|
||||
}
|
||||
|
||||
buffer.WriteString("\nData:\n")
|
||||
for i := 0; i < maxRows; i++ {
|
||||
buffer.WriteString("\t")
|
||||
for _, a := range as {
|
||||
val := l.Get(a, i)
|
||||
buffer.WriteString(fmt.Sprintf("%s ", a.attr.GetStringFromSysVal(val)))
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
23
base/filters.go
Normal file
23
base/filters.go
Normal file
@ -0,0 +1,23 @@
|
||||
package base
|
||||
|
||||
// FilteredAttributes represent a mapping from the output
|
||||
// generated by a filter to the original value.
|
||||
type FilteredAttribute struct {
|
||||
Old Attribute
|
||||
New Attribute
|
||||
}
|
||||
|
||||
// Filters transform the byte sequences stored in DataGrid
|
||||
// implementations.
|
||||
type Filter interface {
|
||||
// Adds an Attribute to the filter
|
||||
AddAttribute(Attribute) error
|
||||
// Allows mapping old to new Attributes
|
||||
GetAttributesAfterFiltering() []FilteredAttribute
|
||||
// Gets a string for printing
|
||||
String() string
|
||||
// Accepts an old Attribute, the new one and returns a sequence
|
||||
Transform(Attribute, Attribute, []byte) []byte
|
||||
// Builds the filter
|
||||
Train() error
|
||||
}
|
103
base/float.go
Normal file
103
base/float.go
Normal file
@ -0,0 +1,103 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// FloatAttribute is an implementation which stores floating point
|
||||
// representations of numbers.
|
||||
type FloatAttribute struct {
|
||||
Name string
|
||||
Precision int
|
||||
}
|
||||
|
||||
// NewFloatAttribute returns a new FloatAttribute with a default
|
||||
// precision of 2 decimal places
|
||||
func NewFloatAttribute() *FloatAttribute {
|
||||
return &FloatAttribute{"", 2}
|
||||
}
|
||||
|
||||
// Compatable checks whether this FloatAttribute can be ponded with another
|
||||
// Attribute (checks if they're both FloatAttributes)
|
||||
func (Attr *FloatAttribute) Compatable(other Attribute) bool {
|
||||
_, ok := other.(*FloatAttribute)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Equals tests a FloatAttribute for equality with another Attribute.
|
||||
//
|
||||
// Returns false if the other Attribute has a different name
|
||||
// or if the other Attribute is not a FloatAttribute.
|
||||
func (Attr *FloatAttribute) Equals(other Attribute) bool {
|
||||
// Check whether this FloatAttribute is equal to another
|
||||
_, ok := other.(*FloatAttribute)
|
||||
if !ok {
|
||||
// Not the same type, so can't be equal
|
||||
return false
|
||||
}
|
||||
if Attr.GetName() != other.GetName() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetName returns this FloatAttribute's human-readable name.
|
||||
func (Attr *FloatAttribute) GetName() string {
|
||||
return Attr.Name
|
||||
}
|
||||
|
||||
// SetName sets this FloatAttribute's human-readable name.
|
||||
func (Attr *FloatAttribute) SetName(name string) {
|
||||
Attr.Name = name
|
||||
}
|
||||
|
||||
// GetType returns Float64Type.
|
||||
func (Attr *FloatAttribute) GetType() int {
|
||||
return Float64Type
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this Attribute.
|
||||
// e.g. "FloatAttribute(Sepal Width)"
|
||||
func (Attr *FloatAttribute) String() string {
|
||||
return fmt.Sprintf("FloatAttribute(%s)", Attr.Name)
|
||||
}
|
||||
|
||||
// CheckSysValFromString confirms whether a given rawVal can
|
||||
// be converted into a valid system representation. If it can't,
|
||||
// the returned value is nil.
|
||||
func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) ([]byte, error) {
|
||||
f, err := strconv.ParseFloat(rawVal, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := PackFloatToBytes(f)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// GetSysValFromString parses the given rawVal string to a float64 and returns it.
|
||||
//
|
||||
// float64 happens to be a 1-to-1 mapping to the system representation.
|
||||
// IMPORTANT: This function panic()s if rawVal is not a valid float.
|
||||
// Use CheckSysValFromString to confirm.
|
||||
func (Attr *FloatAttribute) GetSysValFromString(rawVal string) []byte {
|
||||
f, err := Attr.CheckSysValFromString(rawVal)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// GetFloatFromSysVal converts a given system value to a float
|
||||
func (Attr *FloatAttribute) GetFloatFromSysVal(rawVal []byte) float64 {
|
||||
return UnpackBytesToFloat(rawVal)
|
||||
}
|
||||
|
||||
// GetStringFromSysVal converts a given system value to to a string with two decimal
|
||||
// places of precision.
|
||||
func (Attr *FloatAttribute) GetStringFromSysVal(rawVal []byte) string {
|
||||
f := UnpackBytesToFloat(rawVal)
|
||||
formatString := fmt.Sprintf("%%.%df", Attr.Precision)
|
||||
return fmt.Sprintf(formatString, f)
|
||||
}
|
@ -1,561 +0,0 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
// SortDirection specifies sorting direction...
|
||||
type SortDirection int
|
||||
|
||||
const (
|
||||
// Descending says that Instances should be sorted high to low...
|
||||
Descending SortDirection = 1
|
||||
// Ascending states that Instances should be sorted low to high...
|
||||
Ascending SortDirection = 2
|
||||
)
|
||||
|
||||
const highBit int64 = -1 << 63
|
||||
|
||||
// Instances represents a grid of numbers (typed by Attributes)
|
||||
// stored internally in mat.DenseMatrix as float64's.
|
||||
// See docs/instances.md for more information.
|
||||
type Instances struct {
|
||||
storage *mat64.Dense
|
||||
attributes []Attribute
|
||||
Rows int
|
||||
Cols int
|
||||
ClassIndex int
|
||||
}
|
||||
|
||||
func xorFloatOp(item float64) float64 {
|
||||
var ret float64
|
||||
var tmp int64
|
||||
buf := bytes.NewBuffer(nil)
|
||||
binary.Write(buf, binary.LittleEndian, item)
|
||||
binary.Read(buf, binary.LittleEndian, &tmp)
|
||||
tmp ^= -1 << 63
|
||||
binary.Write(buf, binary.LittleEndian, tmp)
|
||||
binary.Read(buf, binary.LittleEndian, &ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
func printFloatByteArr(arr [][]byte) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
var f float64
|
||||
for _, b := range arr {
|
||||
buf.Write(b)
|
||||
binary.Read(buf, binary.LittleEndian, &f)
|
||||
f = xorFloatOp(f)
|
||||
fmt.Println(f)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort does an in-place radix sort of Instances, using SortDirection
|
||||
// direction (Ascending or Descending) with attrs as a slice of Attribute
|
||||
// indices that you want to sort by.
|
||||
//
|
||||
// IMPORTANT: Radix sort is not stable, so ordering outside
|
||||
// the attributes used for sorting is arbitrary.
|
||||
func (inst *Instances) Sort(direction SortDirection, attrs []int) {
|
||||
// Create a buffer
|
||||
buf := bytes.NewBuffer(nil)
|
||||
ds := make([][]byte, inst.Rows)
|
||||
rs := make([]int, inst.Rows)
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
byteBuf := make([]byte, 8*len(attrs))
|
||||
for _, a := range attrs {
|
||||
x := inst.storage.At(i, a)
|
||||
binary.Write(buf, binary.LittleEndian, xorFloatOp(x))
|
||||
}
|
||||
buf.Read(byteBuf)
|
||||
ds[i] = byteBuf
|
||||
rs[i] = i
|
||||
}
|
||||
// Sort viua
|
||||
valueBins := make([][][]byte, 256)
|
||||
rowBins := make([][]int, 256)
|
||||
for i := 0; i < 8*len(attrs); i++ {
|
||||
for j := 0; j < len(ds); j++ {
|
||||
// Address each row value by it's ith byte
|
||||
b := ds[j]
|
||||
valueBins[b[i]] = append(valueBins[b[i]], b)
|
||||
rowBins[b[i]] = append(rowBins[b[i]], rs[j])
|
||||
}
|
||||
j := 0
|
||||
for k := 0; k < 256; k++ {
|
||||
bs := valueBins[k]
|
||||
rc := rowBins[k]
|
||||
copy(ds[j:], bs)
|
||||
copy(rs[j:], rc)
|
||||
j += len(bs)
|
||||
valueBins[k] = bs[:0]
|
||||
rowBins[k] = rc[:0]
|
||||
}
|
||||
}
|
||||
|
||||
for _, b := range ds {
|
||||
var v float64
|
||||
buf.Write(b)
|
||||
binary.Read(buf, binary.LittleEndian, &v)
|
||||
}
|
||||
|
||||
done := make([]bool, inst.Rows)
|
||||
for index := range rs {
|
||||
if done[index] {
|
||||
continue
|
||||
}
|
||||
j := index
|
||||
for {
|
||||
done[j] = true
|
||||
if rs[j] != index {
|
||||
inst.swapRows(j, rs[j])
|
||||
j = rs[j]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if direction == Descending {
|
||||
// Reverse the matrix
|
||||
for i, j := 0, inst.Rows-1; i < j; i, j = i+1, j-1 {
|
||||
inst.swapRows(i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewInstances returns a preallocated Instances structure
|
||||
// with some helful values pre-filled.
|
||||
func NewInstances(attrs []Attribute, rows int) *Instances {
|
||||
rawStorage := make([]float64, rows*len(attrs))
|
||||
return NewInstancesFromRaw(attrs, rows, rawStorage)
|
||||
}
|
||||
|
||||
// CheckNewInstancesFromRaw checks whether a call to NewInstancesFromRaw
|
||||
// is likely to produce an error-free result.
|
||||
func CheckNewInstancesFromRaw(attrs []Attribute, rows int, data []float64) error {
|
||||
size := rows * len(attrs)
|
||||
if size < len(data) {
|
||||
return errors.New("base: data length is larger than the rows * attribute space")
|
||||
} else if size > len(data) {
|
||||
return errors.New("base: data is smaller than the rows * attribute space")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewInstancesFromRaw wraps a slice of float64 numbers in a
|
||||
// mat64.Dense structure, reshaping it with the given number of rows
|
||||
// and representing it with the given attrs (Attribute slice)
|
||||
//
|
||||
// IMPORTANT: if the |attrs| * |rows| value doesn't equal len(data)
|
||||
// then panic()s may occur. Use CheckNewInstancesFromRaw to confirm.
|
||||
func NewInstancesFromRaw(attrs []Attribute, rows int, data []float64) *Instances {
|
||||
rawStorage := mat64.NewDense(rows, len(attrs), data)
|
||||
return NewInstancesFromDense(attrs, rows, rawStorage)
|
||||
}
|
||||
|
||||
// NewInstancesFromDense creates a set of Instances from a mat64.Dense
|
||||
// matrix
|
||||
func NewInstancesFromDense(attrs []Attribute, rows int, mat *mat64.Dense) *Instances {
|
||||
return &Instances{mat, attrs, rows, len(attrs), len(attrs) - 1}
|
||||
}
|
||||
|
||||
// InstancesTrainTestSplit takes a given Instances (src) and a train-test fraction
|
||||
// (prop) and returns an array of two new Instances, one containing approximately
|
||||
// that fraction and the other containing what's left.
|
||||
//
|
||||
// IMPORTANT: this function is only meaningful when prop is between 0.0 and 1.0.
|
||||
// Using any other values may result in odd behaviour.
|
||||
func InstancesTrainTestSplit(src *Instances, prop float64) (*Instances, *Instances) {
|
||||
trainingRows := make([]int, 0)
|
||||
testingRows := make([]int, 0)
|
||||
numAttrs := len(src.attributes)
|
||||
src.Shuffle()
|
||||
for i := 0; i < src.Rows; i++ {
|
||||
trainOrTest := rand.Intn(101)
|
||||
if trainOrTest > int(100*prop) {
|
||||
trainingRows = append(trainingRows, i)
|
||||
} else {
|
||||
testingRows = append(testingRows, i)
|
||||
}
|
||||
}
|
||||
|
||||
rawTrainMatrix := mat64.NewDense(len(trainingRows), numAttrs, make([]float64, len(trainingRows)*numAttrs))
|
||||
rawTestMatrix := mat64.NewDense(len(testingRows), numAttrs, make([]float64, len(testingRows)*numAttrs))
|
||||
|
||||
for i, row := range trainingRows {
|
||||
rowDat := src.storage.RowView(row)
|
||||
rawTrainMatrix.SetRow(i, rowDat)
|
||||
}
|
||||
for i, row := range testingRows {
|
||||
rowDat := src.storage.RowView(row)
|
||||
rawTestMatrix.SetRow(i, rowDat)
|
||||
}
|
||||
|
||||
trainingRet := NewInstancesFromDense(src.attributes, len(trainingRows), rawTrainMatrix)
|
||||
testRet := NewInstancesFromDense(src.attributes, len(testingRows), rawTestMatrix)
|
||||
return trainingRet, testRet
|
||||
}
|
||||
|
||||
// CountAttrValues returns the distribution of values of a given
|
||||
// Attribute.
|
||||
// IMPORTANT: calls panic() if the attribute index of a cannot be
|
||||
// determined. Call GetAttrIndex(a) and check for a -1 return value.
|
||||
func (inst *Instances) CountAttrValues(a Attribute) map[string]int {
|
||||
ret := make(map[string]int)
|
||||
attrIndex := inst.GetAttrIndex(a)
|
||||
if attrIndex == -1 {
|
||||
panic("Invalid attribute")
|
||||
}
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
sysVal := inst.Get(i, attrIndex)
|
||||
stringVal := a.GetStringFromSysVal(sysVal)
|
||||
ret[stringVal]++
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// CountClassValues returns the class distribution of this
|
||||
// Instances set
|
||||
func (inst *Instances) CountClassValues() map[string]int {
|
||||
a := inst.GetAttr(inst.ClassIndex)
|
||||
return inst.CountAttrValues(a)
|
||||
}
|
||||
|
||||
// DecomposeOnAttributeValues divides the instance set depending on the
|
||||
// value of a given Attribute, constructs child instances, and returns
|
||||
// them in a map keyed on the string value of that Attribute.
|
||||
// IMPORTANT: calls panic() if the attribute index of at cannot be determined.
|
||||
// Use GetAttrIndex(at) and check for a non-zero return value.
|
||||
func (inst *Instances) DecomposeOnAttributeValues(at Attribute) map[string]*Instances {
|
||||
// Find the attribute we're decomposing on
|
||||
attrIndex := inst.GetAttrIndex(at)
|
||||
if attrIndex == -1 {
|
||||
panic("Invalid attribute index")
|
||||
}
|
||||
// Construct the new attribute set
|
||||
newAttrs := make([]Attribute, 0)
|
||||
for i := range inst.attributes {
|
||||
a := inst.attributes[i]
|
||||
if a.Equals(at) {
|
||||
continue
|
||||
}
|
||||
newAttrs = append(newAttrs, a)
|
||||
}
|
||||
// Create the return map, several counting maps
|
||||
ret := make(map[string]*Instances)
|
||||
counts := inst.CountAttrValues(at) // So we know what to allocate
|
||||
rows := make(map[string]int)
|
||||
for k := range counts {
|
||||
tmp := NewInstances(newAttrs, counts[k])
|
||||
ret[k] = tmp
|
||||
}
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
newAttrCounter := 0
|
||||
classVar := at.GetStringFromSysVal(inst.Get(i, attrIndex))
|
||||
dest := ret[classVar]
|
||||
destRow := rows[classVar]
|
||||
for j := 0; j < inst.Cols; j++ {
|
||||
a := inst.attributes[j]
|
||||
if a.Equals(at) {
|
||||
continue
|
||||
}
|
||||
dest.Set(destRow, newAttrCounter, inst.Get(i, j))
|
||||
newAttrCounter++
|
||||
}
|
||||
rows[classVar]++
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (inst *Instances) GetClassDistributionAfterSplit(at Attribute) map[string]map[string]int {
|
||||
|
||||
ret := make(map[string]map[string]int)
|
||||
|
||||
// Find the attribute we're decomposing on
|
||||
attrIndex := inst.GetAttrIndex(at)
|
||||
if attrIndex == -1 {
|
||||
panic("Invalid attribute index")
|
||||
}
|
||||
|
||||
// Get the class index
|
||||
classAttr := inst.GetAttr(inst.ClassIndex)
|
||||
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
splitVar := at.GetStringFromSysVal(inst.Get(i, attrIndex))
|
||||
classVar := classAttr.GetStringFromSysVal(inst.Get(i, inst.ClassIndex))
|
||||
if _, ok := ret[splitVar]; !ok {
|
||||
ret[splitVar] = make(map[string]int)
|
||||
i--
|
||||
continue
|
||||
}
|
||||
ret[splitVar][classVar]++
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Get returns the system representation (float64) of the value
|
||||
// stored at the given row and col coordinate.
|
||||
func (inst *Instances) Get(row int, col int) float64 {
|
||||
return inst.storage.At(row, col)
|
||||
}
|
||||
|
||||
// Set sets the system representation (float64) to val at the
|
||||
// given row and column coordinate.
|
||||
func (inst *Instances) Set(row int, col int, val float64) {
|
||||
inst.storage.Set(row, col, val)
|
||||
}
|
||||
|
||||
// GetRowVector returns a row of system representation
|
||||
// values at the given row index.
|
||||
func (inst *Instances) GetRowVector(row int) []float64 {
|
||||
return inst.storage.RowView(row)
|
||||
}
|
||||
|
||||
// GetRowVectorWithoutClass returns a row of system representation
|
||||
// values at the given row index, excluding the class attribute
|
||||
func (inst *Instances) GetRowVectorWithoutClass(row int) []float64 {
|
||||
rawRow := make([]float64, inst.Cols)
|
||||
copy(rawRow, inst.GetRowVector(row))
|
||||
return append(rawRow[0:inst.ClassIndex], rawRow[inst.ClassIndex+1:inst.Cols]...)
|
||||
}
|
||||
|
||||
// GetClass returns the string representation of the given
|
||||
// row's class, as determined by the Attribute at the ClassIndex
|
||||
// position from GetAttr
|
||||
func (inst *Instances) GetClass(row int) string {
|
||||
attr := inst.GetAttr(inst.ClassIndex)
|
||||
val := inst.Get(row, inst.ClassIndex)
|
||||
return attr.GetStringFromSysVal(val)
|
||||
}
|
||||
|
||||
// GetClassDistribution returns a map containing the count of each
|
||||
// class type (indexed by the class' string representation)
|
||||
func (inst *Instances) GetClassDistribution() map[string]int {
|
||||
ret := make(map[string]int)
|
||||
attr := inst.GetAttr(inst.ClassIndex)
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
val := inst.Get(i, inst.ClassIndex)
|
||||
cls := attr.GetStringFromSysVal(val)
|
||||
ret[cls]++
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (inst *Instances) GetClassAttrPtr() *Attribute {
|
||||
attr := inst.GetAttr(inst.ClassIndex)
|
||||
return &attr
|
||||
}
|
||||
|
||||
func (inst *Instances) GetClassAttr() Attribute {
|
||||
return inst.GetAttr(inst.ClassIndex)
|
||||
}
|
||||
|
||||
//
|
||||
// Attribute functions
|
||||
//
|
||||
|
||||
// GetAttributeCount returns the number of attributes represented.
|
||||
func (inst *Instances) GetAttributeCount() int {
|
||||
// Return the number of attributes attached to this Instance set
|
||||
return len(inst.attributes)
|
||||
}
|
||||
|
||||
// SetAttrStr sets the system-representation value of row in column attr
|
||||
// to value val, implicitly converting the string to system-representation
|
||||
// via the appropriate Attribute function.
|
||||
func (inst *Instances) SetAttrStr(row int, attr int, val string) {
|
||||
// Set an attribute on a particular row from a string value
|
||||
a := inst.attributes[attr]
|
||||
sysVal := a.GetSysValFromString(val)
|
||||
inst.storage.Set(row, attr, sysVal)
|
||||
}
|
||||
|
||||
// GetAttrStr returns a human-readable string value stored in column `attr'
|
||||
// and row `row', as determined by the appropriate Attribute function.
|
||||
func (inst *Instances) GetAttrStr(row int, attr int) string {
|
||||
// Get a human-readable value from a particular row
|
||||
a := inst.attributes[attr]
|
||||
usrVal := a.GetStringFromSysVal(inst.Get(row, attr))
|
||||
return usrVal
|
||||
}
|
||||
|
||||
// GetAttr returns information about an attribute at given index
|
||||
// in the attributes slice.
|
||||
func (inst *Instances) GetAttr(attrIndex int) Attribute {
|
||||
// Return a copy of an attribute attached to this Instance set
|
||||
return inst.attributes[attrIndex]
|
||||
}
|
||||
|
||||
// GetAttrIndex returns the offset of a given Attribute `a' to an
|
||||
// index in the attributes slice
|
||||
func (inst *Instances) GetAttrIndex(of Attribute) int {
|
||||
// Finds the offset of an Attribute in this instance set
|
||||
// Returns -1 if no Attribute matches
|
||||
for i, a := range inst.attributes {
|
||||
if a.Equals(of) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// ReplaceAttr overwrites the attribute at `index' with `a'
|
||||
func (inst *Instances) ReplaceAttr(index int, a Attribute) {
|
||||
// Replace an Attribute at index with another
|
||||
// DOESN'T CONVERT ANY EXISTING VALUES
|
||||
inst.attributes[index] = a
|
||||
}
|
||||
|
||||
//
|
||||
// Printing functions
|
||||
//
|
||||
|
||||
// RowStr returns a human-readable representation of a given row.
|
||||
func (inst *Instances) RowStr(row int) string {
|
||||
// Prints a given row
|
||||
var buffer bytes.Buffer
|
||||
for j := 0; j < inst.Cols; j++ {
|
||||
val := inst.storage.At(row, j)
|
||||
a := inst.attributes[j]
|
||||
postfix := " "
|
||||
if j == inst.Cols-1 {
|
||||
postfix = ""
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s", a.GetStringFromSysVal(val), postfix))
|
||||
}
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (inst *Instances) String() string {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
buffer.WriteString("Instances with ")
|
||||
buffer.WriteString(fmt.Sprintf("%d row(s) ", inst.Rows))
|
||||
buffer.WriteString(fmt.Sprintf("%d attribute(s)\n", inst.Cols))
|
||||
|
||||
buffer.WriteString(fmt.Sprintf("Attributes: \n"))
|
||||
for i, a := range inst.attributes {
|
||||
prefix := "\t"
|
||||
if i == inst.ClassIndex {
|
||||
prefix = "*\t"
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a))
|
||||
}
|
||||
|
||||
buffer.WriteString("\nData:\n")
|
||||
maxRows := 30
|
||||
if inst.Rows < maxRows {
|
||||
maxRows = inst.Rows
|
||||
}
|
||||
|
||||
for i := 0; i < maxRows; i++ {
|
||||
buffer.WriteString("\t")
|
||||
for j := 0; j < inst.Cols; j++ {
|
||||
val := inst.storage.At(i, j)
|
||||
a := inst.attributes[j]
|
||||
buffer.WriteString(fmt.Sprintf("%s ", a.GetStringFromSysVal(val)))
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
|
||||
missingRows := inst.Rows - maxRows
|
||||
if missingRows != 0 {
|
||||
buffer.WriteString(fmt.Sprintf("\t...\n%d row(s) undisplayed", missingRows))
|
||||
} else {
|
||||
buffer.WriteString("All rows displayed")
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
// SelectAttributes returns a new instance set containing
|
||||
// the values from this one with only the Attributes specified
|
||||
func (inst *Instances) SelectAttributes(attrs []Attribute) *Instances {
|
||||
ret := NewInstances(attrs, inst.Rows)
|
||||
attrIndices := make([]int, 0)
|
||||
for _, a := range attrs {
|
||||
attrIndex := inst.GetAttrIndex(a)
|
||||
attrIndices = append(attrIndices, attrIndex)
|
||||
}
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
for j, a := range attrIndices {
|
||||
ret.Set(i, j, inst.Get(i, a))
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// GeneratePredictionVector generates a new set of Instances
|
||||
// with the same number of rows, but only this Instance set's
|
||||
// class Attribute.
|
||||
func (inst *Instances) GeneratePredictionVector() *Instances {
|
||||
attrs := make([]Attribute, 1)
|
||||
attrs[0] = inst.GetClassAttr()
|
||||
ret := NewInstances(attrs, inst.Rows)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Shuffle randomizes the row order in place
|
||||
func (inst *Instances) Shuffle() {
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
j := rand.Intn(i + 1)
|
||||
inst.swapRows(i, j)
|
||||
}
|
||||
}
|
||||
|
||||
// SampleWithReplacement returns a new set of Instances of size `size'
|
||||
// containing random rows from this set of Instances.
|
||||
//
|
||||
// IMPORTANT: There's a high chance of seeing duplicate rows
|
||||
// whenever size is close to the row count.
|
||||
func (inst *Instances) SampleWithReplacement(size int) *Instances {
|
||||
ret := NewInstances(inst.attributes, size)
|
||||
for i := 0; i < size; i++ {
|
||||
srcRow := rand.Intn(inst.Rows)
|
||||
for j := 0; j < inst.Cols; j++ {
|
||||
ret.Set(i, j, inst.Get(srcRow, j))
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Equal checks whether a given Instance set is exactly the same
|
||||
// as another: same size and same values (as determined by the Attributes)
|
||||
//
|
||||
// IMPORTANT: does not explicitly check if the Attributes are considered equal.
|
||||
func (inst *Instances) Equal(other *Instances) bool {
|
||||
if inst.Rows != other.Rows {
|
||||
return false
|
||||
}
|
||||
if inst.Cols != other.Cols {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
for j := 0; j < inst.Cols; j++ {
|
||||
if inst.GetAttrStr(i, j) != other.GetAttrStr(i, j) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (inst *Instances) swapRows(r1 int, r2 int) {
|
||||
row1buf := make([]float64, inst.Cols)
|
||||
row2buf := make([]float64, inst.Cols)
|
||||
row1 := inst.storage.RowView(r1)
|
||||
row2 := inst.storage.RowView(r2)
|
||||
copy(row1buf, row1)
|
||||
copy(row2buf, row2)
|
||||
inst.storage.SetRow(r1, row2buf)
|
||||
inst.storage.SetRow(r2, row1buf)
|
||||
}
|
87
base/lazy_sort_test.go
Normal file
87
base/lazy_sort_test.go
Normal file
@ -0,0 +1,87 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLazySortDesc(testEnv *testing.T) {
|
||||
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
as1 := ResolveAllAttributes(inst1)
|
||||
as2 := ResolveAllAttributes(inst2)
|
||||
|
||||
if isSortedDesc(inst1, as1[0]) {
|
||||
testEnv.Error("Can't test descending sort order")
|
||||
}
|
||||
if !isSortedDesc(inst2, as2[0]) {
|
||||
testEnv.Error("Reference data not sorted in descending order!")
|
||||
}
|
||||
|
||||
inst, err := LazySort(inst1, Descending, as1[0:len(as1)-1])
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
}
|
||||
if !isSortedDesc(inst, as1[0]) {
|
||||
testEnv.Error("Instances are not sorted in descending order")
|
||||
testEnv.Error(inst1)
|
||||
}
|
||||
if !inst2.Equal(inst) {
|
||||
testEnv.Error("Instances don't match")
|
||||
testEnv.Error(inst)
|
||||
testEnv.Error(inst2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLazySortAsc(testEnv *testing.T) {
|
||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
as1 := ResolveAllAttributes(inst)
|
||||
if isSortedAsc(inst, as1[0]) {
|
||||
testEnv.Error("Can't test ascending sort on something ascending already")
|
||||
}
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
insts, err := LazySort(inst, Ascending, as1)
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
if !isSortedAsc(insts, as1[0]) {
|
||||
testEnv.Error("Instances are not sorted in ascending order")
|
||||
testEnv.Error(insts)
|
||||
}
|
||||
|
||||
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
as2 := ResolveAllAttributes(inst2)
|
||||
if !isSortedAsc(inst2, as2[0]) {
|
||||
testEnv.Error("This file should be sorted in ascending order")
|
||||
}
|
||||
|
||||
if !inst2.Equal(insts) {
|
||||
testEnv.Error("Instances don't match")
|
||||
testEnv.Error(inst)
|
||||
testEnv.Error(inst2)
|
||||
}
|
||||
|
||||
rowStr := insts.RowString(0)
|
||||
ref := "4.30 3.00 1.10 0.10 Iris-setosa"
|
||||
if rowStr != ref {
|
||||
panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref))
|
||||
}
|
||||
|
||||
}
|
122
base/pond.go
Normal file
122
base/pond.go
Normal file
@ -0,0 +1,122 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Ponds contain a particular number of rows of
|
||||
// a particular number of Attributes, all of a given type.
|
||||
type Pond struct {
|
||||
threadNo uint32
|
||||
parent DataGrid
|
||||
attributes []Attribute
|
||||
size int
|
||||
alloc [][]byte
|
||||
maxRow int
|
||||
}
|
||||
|
||||
func (p *Pond) String() string {
|
||||
if len(p.alloc) > 1 {
|
||||
return fmt.Sprintf("Pond(%d attributes\n thread: %d\n size: %d\n)", len(p.attributes), p.threadNo, p.size)
|
||||
}
|
||||
return fmt.Sprintf("Pond(%d attributes\n thread: %d\n size: %d\n %d \n)", len(p.attributes), p.threadNo, p.size, p.alloc[0][0:60])
|
||||
}
|
||||
|
||||
// PondStorageRef is a reference to a particular set
|
||||
// of allocated rows within a Pond
|
||||
type PondStorageRef struct {
|
||||
Storage []byte
|
||||
Rows int
|
||||
}
|
||||
|
||||
// RowSize returns the size of each row in bytes
|
||||
func (p *Pond) RowSize() int {
|
||||
return len(p.attributes) * p.size
|
||||
}
|
||||
|
||||
// Attributes returns a slice of Attributes in this Pond
|
||||
func (p *Pond) Attributes() []Attribute {
|
||||
return p.attributes
|
||||
}
|
||||
|
||||
// Storage returns a slice of PondStorageRefs which can
|
||||
// be used to access the memory in this pond.
|
||||
func (p *Pond) Storage() []PondStorageRef {
|
||||
ret := make([]PondStorageRef, len(p.alloc))
|
||||
rowSize := p.RowSize()
|
||||
for i, b := range p.alloc {
|
||||
ret[i] = PondStorageRef{b, len(b) / rowSize}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (p *Pond) resolveBlock(col int, row int) (int, int) {
|
||||
|
||||
if len(p.alloc) == 0 {
|
||||
panic("No blocks to resolve")
|
||||
}
|
||||
|
||||
// Find where in the pond the byte is
|
||||
byteOffset := row*p.RowSize() + col*p.size
|
||||
curOffset := 0
|
||||
curBlock := 0
|
||||
blockOffset := 0
|
||||
for {
|
||||
if curBlock >= len(p.alloc) {
|
||||
panic("Don't have enough blocks to fulfill")
|
||||
}
|
||||
|
||||
// Rows are not allowed to span blocks
|
||||
blockAdd := len(p.alloc[curBlock])
|
||||
blockAdd -= blockAdd % p.RowSize()
|
||||
|
||||
// Case 1: we need to skip this allocation
|
||||
if curOffset+blockAdd < byteOffset {
|
||||
curOffset += blockAdd
|
||||
curBlock++
|
||||
} else {
|
||||
blockOffset = byteOffset - curOffset
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return curBlock, blockOffset
|
||||
}
|
||||
|
||||
func (p *Pond) set(col int, row int, val []byte) {
|
||||
|
||||
// Double-check the length
|
||||
if len(val) != p.size {
|
||||
panic(fmt.Sprintf("Tried to call set() with %d bytes, should be %d", len(val), p.size))
|
||||
}
|
||||
|
||||
// Find where in the pond the byte is
|
||||
curBlock, blockOffset := p.resolveBlock(col, row)
|
||||
|
||||
// Copy the value in
|
||||
copied := copy(p.alloc[curBlock][blockOffset:], val)
|
||||
if copied != p.size {
|
||||
panic(fmt.Sprintf("set() terminated by only copying %d bytes into the current block (should be %d). Check EDF allocation", copied, p.size))
|
||||
}
|
||||
|
||||
row++
|
||||
if row > p.maxRow {
|
||||
p.maxRow = row
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pond) get(col int, row int) []byte {
|
||||
curBlock, blockOffset := p.resolveBlock(col, row)
|
||||
return p.alloc[curBlock][blockOffset : blockOffset+p.size]
|
||||
}
|
||||
|
||||
func (p *Pond) appendToRowBuf(row int, buffer *bytes.Buffer) {
|
||||
for i, a := range p.attributes {
|
||||
postfix := " "
|
||||
if i == len(p.attributes)-1 {
|
||||
postfix = ""
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s", a.GetStringFromSysVal(p.get(i, row)), postfix))
|
||||
}
|
||||
}
|
168
base/sort.go
Normal file
168
base/sort.go
Normal file
@ -0,0 +1,168 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
func sortXorOp(b []byte) []byte {
|
||||
ret := make([]byte, len(b))
|
||||
copy(ret, b)
|
||||
ret[0] ^= 0x80
|
||||
return ret
|
||||
}
|
||||
|
||||
type sortSpec struct {
|
||||
r1 int
|
||||
r2 int
|
||||
}
|
||||
|
||||
// Returns sortSpecs for inst in ascending order
|
||||
func createSortSpec(inst FixedDataGrid, attrsArg []AttributeSpec) []sortSpec {
|
||||
|
||||
attrs := make([]AttributeSpec, len(attrsArg))
|
||||
copy(attrs, attrsArg)
|
||||
// Reverse attribute order to be more intuitive
|
||||
for i, j := 0, len(attrs)-1; i < j; i, j = i+1, j-1 {
|
||||
attrs[i], attrs[j] = attrs[j], attrs[i]
|
||||
}
|
||||
|
||||
_, rows := inst.Size()
|
||||
ret := make([]sortSpec, 0)
|
||||
|
||||
// Create a buffer
|
||||
buf := bytes.NewBuffer(nil)
|
||||
ds := make([][]byte, rows)
|
||||
rs := make([]int, rows)
|
||||
rowSize := 0
|
||||
|
||||
inst.MapOverRows(attrs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
if rowSize == 0 {
|
||||
// Allocate a row buffer
|
||||
for _, r := range row {
|
||||
rowSize += len(r)
|
||||
}
|
||||
}
|
||||
|
||||
byteBuf := make([]byte, rowSize)
|
||||
|
||||
for i, r := range row {
|
||||
if i == 0 {
|
||||
binary.Write(buf, binary.LittleEndian, sortXorOp(r))
|
||||
} else {
|
||||
binary.Write(buf, binary.LittleEndian, r)
|
||||
}
|
||||
}
|
||||
|
||||
buf.Read(byteBuf)
|
||||
ds[rowNo] = byteBuf
|
||||
rs[rowNo] = rowNo
|
||||
return true, nil
|
||||
})
|
||||
|
||||
// Sort values
|
||||
valueBins := make([][][]byte, 256)
|
||||
rowBins := make([][]int, 256)
|
||||
for i := 0; i < rowSize; i++ {
|
||||
for j := 0; j < len(ds); j++ {
|
||||
// Address each row value by it's ith byte
|
||||
b := ds[j]
|
||||
valueBins[b[i]] = append(valueBins[b[i]], b)
|
||||
rowBins[b[i]] = append(rowBins[b[i]], rs[j])
|
||||
}
|
||||
j := 0
|
||||
for k := 0; k < 256; k++ {
|
||||
bs := valueBins[k]
|
||||
rc := rowBins[k]
|
||||
copy(ds[j:], bs)
|
||||
copy(rs[j:], rc)
|
||||
j += len(bs)
|
||||
valueBins[k] = bs[:0]
|
||||
rowBins[k] = rc[:0]
|
||||
}
|
||||
}
|
||||
|
||||
done := make([]bool, rows)
|
||||
for index := range rs {
|
||||
if done[index] {
|
||||
continue
|
||||
}
|
||||
j := index
|
||||
for {
|
||||
done[j] = true
|
||||
if rs[j] != index {
|
||||
ret = append(ret, sortSpec{j, rs[j]})
|
||||
j = rs[j]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Sort does a radix sort of DenseInstances, using SortDirection
|
||||
// direction (Ascending or Descending) with attrs as a slice of Attribute
|
||||
// indices that you want to sort by.
|
||||
//
|
||||
// IMPORTANT: Radix sort is not stable, so ordering outside
|
||||
// the attributes used for sorting is arbitrary.
|
||||
func Sort(inst FixedDataGrid, direction SortDirection, attrs []AttributeSpec) (FixedDataGrid, error) {
|
||||
sortInstructions := createSortSpec(inst, attrs)
|
||||
instUpdatable, ok := inst.(*DenseInstances)
|
||||
if ok {
|
||||
for _, i := range sortInstructions {
|
||||
instUpdatable.swapRows(i.r1, i.r2)
|
||||
}
|
||||
if direction == Descending {
|
||||
// Reverse the matrix
|
||||
_, rows := inst.Size()
|
||||
for i, j := 0, rows-1; i < j; i, j = i+1, j-1 {
|
||||
instUpdatable.swapRows(i, j)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic("Sort is not supported for this yet!")
|
||||
}
|
||||
return instUpdatable, nil
|
||||
}
|
||||
|
||||
// LazySort also does a sort, but returns an InstanceView and doesn't actually
|
||||
// reorder the rows, just makes it look like they've been reordered
|
||||
// See also: Sort
|
||||
func LazySort(inst FixedDataGrid, direction SortDirection, attrs []AttributeSpec) (FixedDataGrid, error) {
|
||||
// Run the sort operation
|
||||
sortInstructions := createSortSpec(inst, attrs)
|
||||
|
||||
// Build the row -> row mapping
|
||||
_, rows := inst.Size() // Get the total row count
|
||||
rowArr := make([]int, rows) // Create an array of positions
|
||||
for i := 0; i < len(rowArr); i++ {
|
||||
rowArr[i] = i
|
||||
}
|
||||
for i := range sortInstructions {
|
||||
r1 := rowArr[sortInstructions[i].r1]
|
||||
r2 := rowArr[sortInstructions[i].r2]
|
||||
// Swap
|
||||
rowArr[sortInstructions[i].r1] = r2
|
||||
rowArr[sortInstructions[i].r2] = r1
|
||||
}
|
||||
if direction == Descending {
|
||||
for i, j := 0, rows-1; i < j; i, j = i+1, j-1 {
|
||||
tmp := rowArr[i]
|
||||
rowArr[i] = rowArr[j]
|
||||
rowArr[j] = tmp
|
||||
}
|
||||
}
|
||||
// Create a mapping dictionary
|
||||
rowMap := make(map[int]int)
|
||||
for i, a := range rowArr {
|
||||
if i == a {
|
||||
continue
|
||||
}
|
||||
rowMap[i] = a
|
||||
}
|
||||
// Create the return structure
|
||||
ret := NewInstancesViewFromRows(inst, rowMap)
|
||||
return ret, nil
|
||||
}
|
@ -2,10 +2,11 @@ package base
|
||||
|
||||
import "testing"
|
||||
|
||||
func isSortedAsc(inst *Instances, attrIndex int) bool {
|
||||
func isSortedAsc(inst FixedDataGrid, attr AttributeSpec) bool {
|
||||
valPrev := 0.0
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
cur := inst.Get(i, attrIndex)
|
||||
_, rows := inst.Size()
|
||||
for i := 0; i < rows; i++ {
|
||||
cur := UnpackBytesToFloat(inst.Get(attr, i))
|
||||
if i > 0 {
|
||||
if valPrev > cur {
|
||||
return false
|
||||
@ -16,10 +17,11 @@ func isSortedAsc(inst *Instances, attrIndex int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedDesc(inst *Instances, attrIndex int) bool {
|
||||
func isSortedDesc(inst FixedDataGrid, attr AttributeSpec) bool {
|
||||
valPrev := 0.0
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
cur := inst.Get(i, attrIndex)
|
||||
_, rows := inst.Size()
|
||||
for i := 0; i < rows; i++ {
|
||||
cur := UnpackBytesToFloat(inst.Get(attr, i))
|
||||
if i > 0 {
|
||||
if valPrev < cur {
|
||||
return false
|
||||
@ -42,25 +44,25 @@ func TestSortDesc(testEnv *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if isSortedDesc(inst1, 0) {
|
||||
as1 := ResolveAllAttributes(inst1)
|
||||
as2 := ResolveAllAttributes(inst2)
|
||||
|
||||
if isSortedDesc(inst1, as1[0]) {
|
||||
testEnv.Error("Can't test descending sort order")
|
||||
}
|
||||
if !isSortedDesc(inst2, 0) {
|
||||
if !isSortedDesc(inst2, as2[0]) {
|
||||
testEnv.Error("Reference data not sorted in descending order!")
|
||||
}
|
||||
attrs := make([]int, 4)
|
||||
attrs[0] = 3
|
||||
attrs[1] = 2
|
||||
attrs[2] = 1
|
||||
attrs[3] = 0
|
||||
inst1.Sort(Descending, attrs)
|
||||
if !isSortedDesc(inst1, 0) {
|
||||
|
||||
Sort(inst1, Descending, as1[0:len(as1)-1])
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
}
|
||||
if !isSortedDesc(inst1, as1[0]) {
|
||||
testEnv.Error("Instances are not sorted in descending order")
|
||||
testEnv.Error(inst1)
|
||||
}
|
||||
if !inst2.Equal(inst1) {
|
||||
inst1.storage.Sub(inst1.storage, inst2.storage)
|
||||
testEnv.Error(inst1.storage)
|
||||
testEnv.Error("Instances don't match")
|
||||
testEnv.Error(inst1)
|
||||
testEnv.Error(inst2)
|
||||
@ -69,20 +71,16 @@ func TestSortDesc(testEnv *testing.T) {
|
||||
|
||||
func TestSortAsc(testEnv *testing.T) {
|
||||
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if isSortedAsc(inst, 0) {
|
||||
as1 := ResolveAllAttributes(inst)
|
||||
if isSortedAsc(inst, as1[0]) {
|
||||
testEnv.Error("Can't test ascending sort on something ascending already")
|
||||
}
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
attrs := make([]int, 4)
|
||||
attrs[0] = 3
|
||||
attrs[1] = 2
|
||||
attrs[2] = 1
|
||||
attrs[3] = 0
|
||||
inst.Sort(Ascending, attrs)
|
||||
if !isSortedAsc(inst, 0) {
|
||||
Sort(inst, Ascending, as1[0:1])
|
||||
if !isSortedAsc(inst, as1[0]) {
|
||||
testEnv.Error("Instances are not sorted in ascending order")
|
||||
testEnv.Error(inst)
|
||||
}
|
||||
@ -92,13 +90,12 @@ func TestSortAsc(testEnv *testing.T) {
|
||||
testEnv.Error(err)
|
||||
return
|
||||
}
|
||||
if !isSortedAsc(inst2, 0) {
|
||||
as2 := ResolveAllAttributes(inst2)
|
||||
if !isSortedAsc(inst2, as2[0]) {
|
||||
testEnv.Error("This file should be sorted in ascending order")
|
||||
}
|
||||
|
||||
if !inst2.Equal(inst) {
|
||||
inst.storage.Sub(inst.storage, inst2.storage)
|
||||
testEnv.Error(inst.storage)
|
||||
testEnv.Error("Instances don't match")
|
||||
testEnv.Error(inst)
|
||||
testEnv.Error(inst2)
|
||||
|
25
base/spec.go
Normal file
25
base/spec.go
Normal file
@ -0,0 +1,25 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// AttributeSpec is a pointer to a particular Attribute
|
||||
// within a particular Instance structure and encodes position
|
||||
// and storage information associated with that Attribute.
|
||||
type AttributeSpec struct {
|
||||
pond int
|
||||
position int
|
||||
attr Attribute
|
||||
}
|
||||
|
||||
// GetAttribute returns an AttributeSpec which matches a given
|
||||
// Attribute.
|
||||
func (a *AttributeSpec) GetAttribute() Attribute {
|
||||
return a.attr
|
||||
}
|
||||
|
||||
// String returns a human-readable description of this AttributeSpec.
|
||||
func (a *AttributeSpec) String() string {
|
||||
return fmt.Sprintf("AttributeSpec(Attribute: '%s', Pond: %d/%d)", a.attr, a.pond, a.position)
|
||||
}
|
98
base/util.go
Normal file
98
base/util.go
Normal file
@ -0,0 +1,98 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// PackU64ToBytesInline fills ret with the byte values of
|
||||
// val. Ret must have length at least 8.
|
||||
func PackU64ToBytesInline(val uint64, ret []byte) {
|
||||
ret[7] = byte(val & (0xFF << 56) >> 56)
|
||||
ret[6] = byte(val & (0xFF << 48) >> 48)
|
||||
ret[5] = byte(val & (0xFF << 40) >> 40)
|
||||
ret[4] = byte(val & (0xFF << 32) >> 32)
|
||||
ret[3] = byte(val & (0xFF << 24) >> 24)
|
||||
ret[2] = byte(val & (0xFF << 16) >> 16)
|
||||
ret[1] = byte(val & (0xFF << 8) >> 8)
|
||||
ret[0] = byte(val & (0xFF << 0) >> 0)
|
||||
}
|
||||
|
||||
// PackFloatToBytesInline fills ret with the byte values of
|
||||
// the float64 argument. ret must be at least 8 bytes in size.
|
||||
func PackFloatToBytesInline(val float64, ret []byte) {
|
||||
PackU64ToBytesInline(math.Float64bits(val), ret)
|
||||
}
|
||||
|
||||
// PackU64ToBytes allocates a return value of appropriate length
|
||||
// and fills it with the values of val.
|
||||
func PackU64ToBytes(val uint64) []byte {
|
||||
ret := make([]byte, 8)
|
||||
ret[7] = byte(val & (0xFF << 56) >> 56)
|
||||
ret[6] = byte(val & (0xFF << 48) >> 48)
|
||||
ret[5] = byte(val & (0xFF << 40) >> 40)
|
||||
ret[4] = byte(val & (0xFF << 32) >> 32)
|
||||
ret[3] = byte(val & (0xFF << 24) >> 24)
|
||||
ret[2] = byte(val & (0xFF << 16) >> 16)
|
||||
ret[1] = byte(val & (0xFF << 8) >> 8)
|
||||
ret[0] = byte(val & (0xFF << 0) >> 0)
|
||||
return ret
|
||||
}
|
||||
|
||||
// UnpackBytesToU64 converst a given byte slice into
|
||||
// a uint64 value.
|
||||
func UnpackBytesToU64(val []byte) uint64 {
|
||||
pb := unsafe.Pointer(&val[0])
|
||||
return *(*uint64)(pb)
|
||||
}
|
||||
|
||||
// PackFloatToBytes returns a 8-byte slice containing
|
||||
// the byte values of a float64.
|
||||
func PackFloatToBytes(val float64) []byte {
|
||||
return PackU64ToBytes(math.Float64bits(val))
|
||||
}
|
||||
|
||||
// UnpackBytesToFloat converts a given byte slice into an
|
||||
// equivalent float64.
|
||||
func UnpackBytesToFloat(val []byte) float64 {
|
||||
pb := unsafe.Pointer(&val[0])
|
||||
return *(*float64)(pb)
|
||||
}
|
||||
|
||||
func xorFloatOp(item float64) float64 {
|
||||
var ret float64
|
||||
var tmp int64
|
||||
buf := bytes.NewBuffer(nil)
|
||||
binary.Write(buf, binary.LittleEndian, item)
|
||||
binary.Read(buf, binary.LittleEndian, &tmp)
|
||||
tmp ^= -1 << 63
|
||||
binary.Write(buf, binary.LittleEndian, tmp)
|
||||
binary.Read(buf, binary.LittleEndian, &ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
func printFloatByteArr(arr [][]byte) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
var f float64
|
||||
for _, b := range arr {
|
||||
buf.Write(b)
|
||||
binary.Read(buf, binary.LittleEndian, &f)
|
||||
f = xorFloatOp(f)
|
||||
fmt.Println(f)
|
||||
}
|
||||
}
|
||||
|
||||
func byteSeqEqual(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
148
base/util_attributes.go
Normal file
148
base/util_attributes.go
Normal file
@ -0,0 +1,148 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// This file contains utility functions relating to Attributes and Attribute specifications.
|
||||
|
||||
// NonClassFloatAttributes returns all FloatAttributes which
|
||||
// aren't designated as a class Attribute.
|
||||
func NonClassFloatAttributes(d DataGrid) []Attribute {
|
||||
classAttrs := d.AllClassAttributes()
|
||||
allAttrs := d.AllAttributes()
|
||||
ret := make([]Attribute, 0)
|
||||
for _, a := range allAttrs {
|
||||
matched := false
|
||||
if _, ok := a.(*FloatAttribute); !ok {
|
||||
continue
|
||||
}
|
||||
for _, b := range classAttrs {
|
||||
if a.Equals(b) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// NonClassAttrs returns all Attributes which aren't designated as a
|
||||
// class Attribute.
|
||||
func NonClassAttributes(d DataGrid) []Attribute {
|
||||
classAttrs := d.AllClassAttributes()
|
||||
allAttrs := d.AllAttributes()
|
||||
return AttributeDifferenceReferences(allAttrs, classAttrs)
|
||||
}
|
||||
|
||||
// ResolveAttributes returns AttributeSpecs describing
|
||||
// all of the Attributes.
|
||||
func ResolveAttributes(d DataGrid, attrs []Attribute) []AttributeSpec {
|
||||
ret := make([]AttributeSpec, len(attrs))
|
||||
for i, a := range attrs {
|
||||
spec, err := d.GetAttribute(a)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Error resolving Attribute %s: %s", a, err))
|
||||
}
|
||||
ret[i] = spec
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// ResolveAllAttributes returns every AttributeSpec
|
||||
func ResolveAllAttributes(d DataGrid) []AttributeSpec {
|
||||
return ResolveAttributes(d, d.AllAttributes())
|
||||
}
|
||||
|
||||
func buildAttrSet(a []Attribute) map[Attribute]bool {
|
||||
ret := make(map[Attribute]bool)
|
||||
for _, a := range a {
|
||||
ret[a] = true
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AttributeIntersect returns the intersection of two Attribute slices.
|
||||
//
|
||||
// IMPORTANT: result is ordered in order of the first []Attribute argument.
|
||||
//
|
||||
// IMPORTANT: result contains only Attributes from a1.
|
||||
func AttributeIntersect(a1, a2 []Attribute) []Attribute {
|
||||
ret := make([]Attribute, 0)
|
||||
for _, a := range a1 {
|
||||
matched := false
|
||||
for _, b := range a2 {
|
||||
if a.Equals(b) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AttributeIntersectReferences returns the intersection of two Attribute slices.
|
||||
//
|
||||
// IMPORTANT: result is not guaranteed to be ordered.
|
||||
//
|
||||
// IMPORTANT: done using pointers for speed, use AttributeDifference
|
||||
// if the Attributes originate from different DataGrids.
|
||||
func AttributeIntersectReferences(a1, a2 []Attribute) []Attribute {
|
||||
a1b := buildAttrSet(a1)
|
||||
a2b := buildAttrSet(a2)
|
||||
ret := make([]Attribute, 0)
|
||||
for a := range a1b {
|
||||
if _, ok := a2b[a]; ok {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AttributeDifference returns the difference between two Attribute
|
||||
// slices: i.e. all the values in a1 which do not occur in a2.
|
||||
//
|
||||
// IMPORTANT: result is ordered the same as a1.
|
||||
//
|
||||
// IMPORTANT: result only contains values from a1.
|
||||
func AttributeDifference(a1, a2 []Attribute) []Attribute {
|
||||
ret := make([]Attribute, 0)
|
||||
for _, a := range a1 {
|
||||
matched := false
|
||||
for _, b := range a2 {
|
||||
if a.Equals(b) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AttributeDifferenceReferences returns the difference between two Attribute
|
||||
// slices: i.e. all the values in a1 which do not occur in a2.
|
||||
//
|
||||
// IMPORTANT: result is not guaranteed to be ordered.
|
||||
//
|
||||
// IMPORTANT: done using pointers for speed, use AttributeDifference
|
||||
// if the Attributes originate from different DataGrids.
|
||||
func AttributeDifferenceReferences(a1, a2 []Attribute) []Attribute {
|
||||
a1b := buildAttrSet(a1)
|
||||
a2b := buildAttrSet(a2)
|
||||
ret := make([]Attribute, 0)
|
||||
for a := range a1b {
|
||||
if _, ok := a2b[a]; !ok {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
254
base/util_instances.go
Normal file
254
base/util_instances.go
Normal file
@ -0,0 +1,254 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
// This file contains utility functions relating to efficiently
|
||||
// generating predictions and instantiating DataGrid implementations.
|
||||
|
||||
// GeneratePredictionVector selects the class Attributes from a given
|
||||
// FixedDataGrid and returns something which can hold the predictions.
|
||||
func GeneratePredictionVector(from FixedDataGrid) UpdatableDataGrid {
|
||||
classAttrs := from.AllClassAttributes()
|
||||
_, rowCount := from.Size()
|
||||
ret := NewDenseInstances()
|
||||
for _, a := range classAttrs {
|
||||
ret.AddAttribute(a)
|
||||
ret.AddClassAttribute(a)
|
||||
}
|
||||
ret.Extend(rowCount)
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetClass is a shortcut for returning the string value of the current
|
||||
// class on a given row.
|
||||
//
|
||||
// IMPORTANT: GetClass will panic if the number of ClassAttributes is
|
||||
// set to anything other than one.
|
||||
func GetClass(from DataGrid, row int) string {
|
||||
|
||||
// Get the Attribute
|
||||
classAttrs := from.AllClassAttributes()
|
||||
if len(classAttrs) > 1 {
|
||||
panic("More than one class defined")
|
||||
} else if len(classAttrs) == 0 {
|
||||
panic("No class defined!")
|
||||
}
|
||||
classAttr := classAttrs[0]
|
||||
|
||||
// Fetch and convert the class value
|
||||
classAttrSpec, err := from.GetAttribute(classAttr)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Can't resolve class Attribute %s", err))
|
||||
}
|
||||
|
||||
classVal := from.Get(classAttrSpec, row)
|
||||
if classVal == nil {
|
||||
panic("Class values shouldn't be missing")
|
||||
}
|
||||
|
||||
return classAttr.GetStringFromSysVal(classVal)
|
||||
}
|
||||
|
||||
// SetClass is a shortcut for updating the given class of a row.
|
||||
//
|
||||
// IMPORTANT: SetClass will panic if the number of class Attributes
|
||||
// is anything other than one.
|
||||
func SetClass(at UpdatableDataGrid, row int, class string) {
|
||||
|
||||
// Get the Attribute
|
||||
classAttrs := at.AllClassAttributes()
|
||||
if len(classAttrs) > 1 {
|
||||
panic("More than one class defined")
|
||||
} else if len(classAttrs) == 0 {
|
||||
panic("No class Attributes are defined")
|
||||
}
|
||||
|
||||
classAttr := classAttrs[0]
|
||||
|
||||
// Fetch and convert the class value
|
||||
classAttrSpec, err := at.GetAttribute(classAttr)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Can't resolve class Attribute %s", err))
|
||||
}
|
||||
|
||||
classBytes := classAttr.GetSysValFromString(class)
|
||||
at.Set(classAttrSpec, row, classBytes)
|
||||
}
|
||||
|
||||
// GetClassDistribution returns a map containing the count of each
|
||||
// class type (indexed by the class' string representation).
|
||||
func GetClassDistribution(inst FixedDataGrid) map[string]int {
|
||||
ret := make(map[string]int)
|
||||
_, rows := inst.Size()
|
||||
for i := 0; i < rows; i++ {
|
||||
cls := GetClass(inst, i)
|
||||
ret[cls]++
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetClassDistributionAfterSplit returns the class distribution
|
||||
// after a speculative split on a given Attribute.
|
||||
func GetClassDistributionAfterSplit(inst FixedDataGrid, at Attribute) map[string]map[string]int {
|
||||
|
||||
ret := make(map[string]map[string]int)
|
||||
|
||||
// Find the attribute we're decomposing on
|
||||
attrSpec, err := inst.GetAttribute(at)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Invalid attribute %s (%s)", at, err))
|
||||
}
|
||||
|
||||
_, rows := inst.Size()
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
splitVar := at.GetStringFromSysVal(inst.Get(attrSpec, i))
|
||||
classVar := GetClass(inst, i)
|
||||
if _, ok := ret[splitVar]; !ok {
|
||||
ret[splitVar] = make(map[string]int)
|
||||
i--
|
||||
continue
|
||||
}
|
||||
ret[splitVar][classVar]++
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// DecomposeOnAttributeValues divides the instance set depending on the
|
||||
// value of a given Attribute, constructs child instances, and returns
|
||||
// them in a map keyed on the string value of that Attribute.
|
||||
//
|
||||
// IMPORTANT: calls panic() if the AttributeSpec of at cannot be determined.
|
||||
func DecomposeOnAttributeValues(inst FixedDataGrid, at Attribute) map[string]FixedDataGrid {
|
||||
// Find the Attribute we're decomposing on
|
||||
attrSpec, err := inst.GetAttribute(at)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Invalid Attribute index %s", at))
|
||||
}
|
||||
// Construct the new Attribute set
|
||||
newAttrs := make([]Attribute, 0)
|
||||
for _, a := range inst.AllAttributes() {
|
||||
if a.Equals(at) {
|
||||
continue
|
||||
}
|
||||
newAttrs = append(newAttrs, a)
|
||||
}
|
||||
// Create the return map
|
||||
ret := make(map[string]FixedDataGrid)
|
||||
|
||||
// Create the return row mapping
|
||||
rowMaps := make(map[string][]int)
|
||||
|
||||
// Build full Attribute set
|
||||
fullAttrSpec := ResolveAttributes(inst, newAttrs)
|
||||
fullAttrSpec = append(fullAttrSpec, attrSpec)
|
||||
|
||||
// Decompose
|
||||
inst.MapOverRows(fullAttrSpec, func(row [][]byte, rowNo int) (bool, error) {
|
||||
// Find the output instance set
|
||||
targetBytes := row[len(row)-1]
|
||||
targetAttr := fullAttrSpec[len(fullAttrSpec)-1].attr
|
||||
targetSet := targetAttr.GetStringFromSysVal(targetBytes)
|
||||
if _, ok := rowMaps[targetSet]; !ok {
|
||||
rowMaps[targetSet] = make([]int, 0)
|
||||
}
|
||||
rowMap := rowMaps[targetSet]
|
||||
rowMaps[targetSet] = append(rowMap, rowNo)
|
||||
return true, nil
|
||||
})
|
||||
|
||||
for a := range rowMaps {
|
||||
ret[a] = NewInstancesViewFromVisible(inst, rowMaps[a], newAttrs)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// InstancesTrainTestSplit takes a given Instances (src) and a train-test fraction
|
||||
// (prop) and returns an array of two new Instances, one containing approximately
|
||||
// that fraction and the other containing what's left.
|
||||
//
|
||||
// IMPORTANT: this function is only meaningful when prop is between 0.0 and 1.0.
|
||||
// Using any other values may result in odd behaviour.
|
||||
func InstancesTrainTestSplit(src FixedDataGrid, prop float64) (FixedDataGrid, FixedDataGrid) {
|
||||
trainingRows := make([]int, 0)
|
||||
testingRows := make([]int, 0)
|
||||
src = Shuffle(src)
|
||||
|
||||
// Create the return structure
|
||||
_, rows := src.Size()
|
||||
for i := 0; i < rows; i++ {
|
||||
trainOrTest := rand.Intn(101)
|
||||
if trainOrTest > int(100*prop) {
|
||||
trainingRows = append(trainingRows, i)
|
||||
} else {
|
||||
testingRows = append(testingRows, i)
|
||||
}
|
||||
}
|
||||
|
||||
allAttrs := src.AllAttributes()
|
||||
|
||||
return NewInstancesViewFromVisible(src, trainingRows, allAttrs), NewInstancesViewFromVisible(src, testingRows, allAttrs)
|
||||
|
||||
}
|
||||
|
||||
// LazyShuffle randomizes the row order without re-ordering the rows
|
||||
// via an InstancesView.
|
||||
func LazyShuffle(from FixedDataGrid) FixedDataGrid {
|
||||
_, rows := from.Size()
|
||||
rowMap := make(map[int]int)
|
||||
for i := 0; i < rows; i++ {
|
||||
j := rand.Intn(i + 1)
|
||||
rowMap[i] = j
|
||||
rowMap[j] = i
|
||||
}
|
||||
return NewInstancesViewFromRows(from, rowMap)
|
||||
}
|
||||
|
||||
// Shuffle randomizes the row order either in place (if DenseInstances)
|
||||
// or using LazyShuffle.
|
||||
func Shuffle(from FixedDataGrid) FixedDataGrid {
|
||||
_, rows := from.Size()
|
||||
if inst, ok := from.(*DenseInstances); ok {
|
||||
for i := 0; i < rows; i++ {
|
||||
j := rand.Intn(i + 1)
|
||||
inst.swapRows(i, j)
|
||||
}
|
||||
return inst
|
||||
} else {
|
||||
return LazyShuffle(from)
|
||||
}
|
||||
}
|
||||
|
||||
// SampleWithReplacement returns a new FixedDataGrid containing
|
||||
// an equal number of random rows drawn from the original FixedDataGrid
|
||||
//
|
||||
// IMPORTANT: There's a high chance of seeing duplicate rows
|
||||
// whenever size is close to the row count.
|
||||
func SampleWithReplacement(from FixedDataGrid, size int) FixedDataGrid {
|
||||
rowMap := make(map[int]int)
|
||||
_, rows := from.Size()
|
||||
for i := 0; i < size; i++ {
|
||||
srcRow := rand.Intn(rows)
|
||||
rowMap[i] = srcRow
|
||||
}
|
||||
return NewInstancesViewFromRows(from, rowMap)
|
||||
}
|
||||
|
||||
// CheckCompatable checks whether two DataGrids have the same Attributes
|
||||
// and if they do, it returns them.
|
||||
func CheckCompatable(s1 FixedDataGrid, s2 FixedDataGrid) []Attribute {
|
||||
s1Attrs := s1.AllAttributes()
|
||||
s2Attrs := s2.AllAttributes()
|
||||
interAttrs := AttributeIntersect(s1Attrs, s2Attrs)
|
||||
if len(interAttrs) != len(s1Attrs) {
|
||||
return nil
|
||||
} else if len(interAttrs) != len(s2Attrs) {
|
||||
return nil
|
||||
}
|
||||
return interAttrs
|
||||
}
|
66
base/util_test.go
Normal file
66
base/util_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClassDistributionAfterSplit(t *testing.T) {
|
||||
Convey("Given the PlayTennis dataset", t, func() {
|
||||
inst, err := ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
|
||||
Convey("Splitting on Sunny should give the right result...", func() {
|
||||
result := GetClassDistributionAfterSplit(inst, inst.AllAttributes()[0])
|
||||
So(result["sunny"]["no"], ShouldEqual, 3)
|
||||
So(result["sunny"]["yes"], ShouldEqual, 2)
|
||||
So(result["overcast"]["yes"], ShouldEqual, 4)
|
||||
So(result["rainy"]["yes"], ShouldEqual, 3)
|
||||
So(result["rainy"]["no"], ShouldEqual, 2)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestPackAndUnpack(t *testing.T) {
|
||||
Convey("Given some uint64", t, func() {
|
||||
x := uint64(0xDEADBEEF)
|
||||
Convey("When the integer is packed", func() {
|
||||
packed := PackU64ToBytes(x)
|
||||
Convey("And then unpacked", func() {
|
||||
unpacked := UnpackBytesToU64(packed)
|
||||
Convey("The unpacked version should be the same", func() {
|
||||
So(x, ShouldEqual, unpacked)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Convey("Given another uint64", t, func() {
|
||||
x := uint64(1)
|
||||
Convey("When the integer is packed", func() {
|
||||
packed := PackU64ToBytes(x)
|
||||
Convey("And then unpacked", func() {
|
||||
unpacked := UnpackBytesToU64(packed)
|
||||
Convey("The unpacked version should be the same", func() {
|
||||
So(x, ShouldEqual, unpacked)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPackAndUnpackFloat(t *testing.T) {
|
||||
Convey("Given some float", t, func() {
|
||||
x := 1.2011
|
||||
Convey("When the float gets packed", func() {
|
||||
packed := PackFloatToBytes(x)
|
||||
Convey("And then unpacked", func() {
|
||||
unpacked := UnpackBytesToFloat(packed)
|
||||
Convey("The unpacked version should be the same", func() {
|
||||
So(unpacked, ShouldEqual, x)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
320
base/view.go
Normal file
320
base/view.go
Normal file
@ -0,0 +1,320 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// InstancesViews hide or re-order Attributes and rows from
|
||||
// a given DataGrid to make it appear that they've been deleted.
|
||||
type InstancesView struct {
|
||||
src FixedDataGrid
|
||||
attrs []AttributeSpec
|
||||
rows map[int]int
|
||||
classAttrs map[Attribute]bool
|
||||
maskRows bool
|
||||
}
|
||||
|
||||
func (v *InstancesView) addClassAttrsFromSrc(src FixedDataGrid) {
|
||||
for _, a := range src.AllClassAttributes() {
|
||||
matched := true
|
||||
if v.attrs != nil {
|
||||
matched = false
|
||||
for _, b := range v.attrs {
|
||||
if b.attr.Equals(a) {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
v.classAttrs[a] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (v *InstancesView) resolveRow(origRow int) int {
|
||||
if v.rows != nil {
|
||||
if newRow, ok := v.rows[origRow]; !ok {
|
||||
if v.maskRows {
|
||||
return -1
|
||||
}
|
||||
} else {
|
||||
return newRow
|
||||
}
|
||||
}
|
||||
|
||||
return origRow
|
||||
|
||||
}
|
||||
|
||||
// NewInstancesViewFromRows creates a new InstancesView from a source
|
||||
// FixedDataGrid and row -> row mapping. The key of the rows map is the
|
||||
// row as it exists within this mapping: for example an entry like 5 -> 1
|
||||
// means that row 1 in src will appear at row 5 in the Instancesview.
|
||||
//
|
||||
// Rows are not masked in this implementation, meaning that all rows which
|
||||
// are left unspecified appear as normal.
|
||||
func NewInstancesViewFromRows(src FixedDataGrid, rows map[int]int) *InstancesView {
|
||||
ret := &InstancesView{
|
||||
src,
|
||||
nil,
|
||||
rows,
|
||||
make(map[Attribute]bool),
|
||||
false,
|
||||
}
|
||||
|
||||
ret.addClassAttrsFromSrc(src)
|
||||
return ret
|
||||
}
|
||||
|
||||
// NewInstancesViewFromVisible creates a new InstancesView from a source
|
||||
// FixedDataGrid, a slice of row numbers and a slice of Attributes.
|
||||
//
|
||||
// Only the rows specified will appear in this InstancesView, and they will
|
||||
// appear in the same order they appear within the rows array.
|
||||
//
|
||||
// Only the Attributes specified will appear in this InstancesView. Retrieving
|
||||
// Attribute specifications from this InstancesView will maintain their order.
|
||||
func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribute) *InstancesView {
|
||||
ret := &InstancesView{
|
||||
src,
|
||||
ResolveAttributes(src, attrs),
|
||||
make(map[int]int),
|
||||
make(map[Attribute]bool),
|
||||
true,
|
||||
}
|
||||
|
||||
for i, a := range rows {
|
||||
ret.rows[i] = a
|
||||
}
|
||||
|
||||
ret.addClassAttrsFromSrc(src)
|
||||
return ret
|
||||
}
|
||||
|
||||
// NewInstancesViewFromAttrs creates a new InstancesView from a source
|
||||
// FixedDataGrid and a slice of Attributes.
|
||||
//
|
||||
// Only the Attributes specified will appear in this InstancesView.
|
||||
func NewInstancesViewFromAttrs(src FixedDataGrid, attrs []Attribute) *InstancesView {
|
||||
ret := &InstancesView{
|
||||
src,
|
||||
ResolveAttributes(src, attrs),
|
||||
nil,
|
||||
make(map[Attribute]bool),
|
||||
false,
|
||||
}
|
||||
|
||||
ret.addClassAttrsFromSrc(src)
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetAttribute returns an Attribute specification matching an Attribute
|
||||
// if it has not been filtered.
|
||||
//
|
||||
// The AttributeSpecs returned are the same as those returned by the
|
||||
// source FixedDataGrid.
|
||||
func (v *InstancesView) GetAttribute(a Attribute) (AttributeSpec, error) {
|
||||
if a == nil {
|
||||
return AttributeSpec{}, fmt.Errorf("Attribute can't be nil")
|
||||
}
|
||||
// Pass-through on nil
|
||||
if v.attrs == nil {
|
||||
return v.src.GetAttribute(a)
|
||||
}
|
||||
// Otherwise
|
||||
for _, r := range v.attrs {
|
||||
// If the attribute matches...
|
||||
if r.GetAttribute().Equals(a) {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return AttributeSpec{}, fmt.Errorf("Requested Attribute has been filtered")
|
||||
}
|
||||
|
||||
// AllAttributes returns every Attribute which hasn't been filtered.
|
||||
func (v *InstancesView) AllAttributes() []Attribute {
|
||||
|
||||
if v.attrs == nil {
|
||||
return v.src.AllAttributes()
|
||||
}
|
||||
|
||||
ret := make([]Attribute, len(v.attrs))
|
||||
|
||||
for i, a := range v.attrs {
|
||||
ret[i] = a.GetAttribute()
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddClassAttribute adds the given Attribute to the set of defined
|
||||
// class Attributes, if it hasn't been filtered.
|
||||
func (v *InstancesView) AddClassAttribute(a Attribute) error {
|
||||
// Check that this Attribute is defined
|
||||
matched := false
|
||||
for _, r := range v.AllAttributes() {
|
||||
if r.Equals(a) {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
return fmt.Errorf("Attribute has been filtered")
|
||||
}
|
||||
|
||||
v.classAttrs[a] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveClassAttribute removes the given Attribute from the set of
|
||||
// class Attributes.
|
||||
func (v *InstancesView) RemoveClassAttribute(a Attribute) error {
|
||||
v.classAttrs[a] = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllClassAttributes returns all the Attributes currently defined
|
||||
// as being class Attributes.
|
||||
func (v *InstancesView) AllClassAttributes() []Attribute {
|
||||
ret := make([]Attribute, 0)
|
||||
for a := range v.classAttrs {
|
||||
if v.classAttrs[a] {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Get returns a sequence of bytes stored under a given Attribute
|
||||
// on a given row.
|
||||
//
|
||||
// IMPORTANT: The AttributeSpec is unverified, meaning it's possible
|
||||
// to return values from Attributes filtered by this InstancesView
|
||||
// if the underlying AttributeSpec is known.
|
||||
func (v *InstancesView) Get(as AttributeSpec, row int) []byte {
|
||||
// Change the row if necessary
|
||||
row = v.resolveRow(row)
|
||||
if row == -1 {
|
||||
panic("Out of range")
|
||||
}
|
||||
return v.src.Get(as, row)
|
||||
}
|
||||
|
||||
// MapOverRows, see DenseInstances.MapOverRows.
|
||||
//
|
||||
// IMPORTANT: MapOverRows is not guaranteed to be ordered, but this one
|
||||
// especially so.
|
||||
func (v *InstancesView) MapOverRows(as []AttributeSpec, rowFunc func([][]byte, int) (bool, error)) error {
|
||||
if v.maskRows {
|
||||
rowBuf := make([][]byte, len(as))
|
||||
for r := range v.rows {
|
||||
row := v.rows[r]
|
||||
for i, a := range as {
|
||||
rowBuf[i] = v.src.Get(a, row)
|
||||
}
|
||||
ok, err := rowFunc(rowBuf, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return v.src.MapOverRows(as, rowFunc)
|
||||
}
|
||||
}
|
||||
|
||||
// Size Returns the number of Attributes and rows this InstancesView
|
||||
// contains.
|
||||
func (v *InstancesView) Size() (int, int) {
|
||||
// Get the original size
|
||||
hSize, vSize := v.src.Size()
|
||||
// Adjust to the number of defined Attributes
|
||||
if v.attrs != nil {
|
||||
hSize = len(v.attrs)
|
||||
}
|
||||
// Adjust to the number of defined rows
|
||||
if v.rows != nil {
|
||||
if v.maskRows {
|
||||
vSize = len(v.rows)
|
||||
} else if len(v.rows) > vSize {
|
||||
vSize = len(v.rows)
|
||||
}
|
||||
}
|
||||
return hSize, vSize
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this InstancesView.
|
||||
func (v *InstancesView) String() string {
|
||||
var buffer bytes.Buffer
|
||||
maxRows := 30
|
||||
|
||||
// Get all Attribute information
|
||||
as := ResolveAllAttributes(v)
|
||||
|
||||
// Print header
|
||||
cols, rows := v.Size()
|
||||
buffer.WriteString("InstancesView with ")
|
||||
buffer.WriteString(fmt.Sprintf("%d row(s) ", rows))
|
||||
buffer.WriteString(fmt.Sprintf("%d attribute(s)\n", cols))
|
||||
if v.attrs != nil {
|
||||
buffer.WriteString(fmt.Sprintf("With defined Attribute view\n"))
|
||||
}
|
||||
if v.rows != nil {
|
||||
buffer.WriteString(fmt.Sprintf("With defined Row view\n"))
|
||||
}
|
||||
if v.maskRows {
|
||||
buffer.WriteString("Row masking on.\n")
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("Attributes:\n"))
|
||||
|
||||
for _, a := range as {
|
||||
prefix := "\t"
|
||||
if v.classAttrs[a.attr] {
|
||||
prefix = "*\t"
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a.attr))
|
||||
}
|
||||
|
||||
// Print data
|
||||
if rows < maxRows {
|
||||
maxRows = rows
|
||||
}
|
||||
buffer.WriteString("Data:")
|
||||
for i := 0; i < maxRows; i++ {
|
||||
buffer.WriteString("\t")
|
||||
for _, a := range as {
|
||||
val := v.Get(a, i)
|
||||
buffer.WriteString(fmt.Sprintf("%s ", a.attr.GetStringFromSysVal(val)))
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
|
||||
missingRows := rows - maxRows
|
||||
if missingRows != 0 {
|
||||
buffer.WriteString(fmt.Sprintf("\t...\n%d row(s) undisplayed", missingRows))
|
||||
} else {
|
||||
buffer.WriteString("All rows displayed")
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
// RowString returns a string representation of a given row.
|
||||
func (v *InstancesView) RowString(row int) string {
|
||||
var buffer bytes.Buffer
|
||||
as := ResolveAllAttributes(v)
|
||||
first := true
|
||||
for _, a := range as {
|
||||
val := v.Get(a, row)
|
||||
prefix := " "
|
||||
if first {
|
||||
prefix = ""
|
||||
first = false
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s", prefix, a.attr.GetStringFromSysVal(val)))
|
||||
}
|
||||
return buffer.String()
|
||||
}
|
119
base/view_test.go
Normal file
119
base/view_test.go
Normal file
@ -0,0 +1,119 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInstancesViewRows(t *testing.T) {
|
||||
Convey("Given Iris", t, func() {
|
||||
instOrig, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Given a new row map containing only row 5", func() {
|
||||
rMap := make(map[int]int)
|
||||
rMap[0] = 5
|
||||
instView := NewInstancesViewFromRows(instOrig, rMap)
|
||||
Convey("The internal structure should be right...", func() {
|
||||
So(instView.rows[0], ShouldEqual, 5)
|
||||
})
|
||||
Convey("The reconstructed values should be correct...", func() {
|
||||
str := "5.40 3.90 1.70 0.40 Iris-setosa"
|
||||
row := instView.RowString(0)
|
||||
So(row, ShouldEqual, str)
|
||||
})
|
||||
Convey("And the size should be correct...", func() {
|
||||
width, height := instView.Size()
|
||||
So(width, ShouldEqual, 5)
|
||||
So(height, ShouldEqual, 150)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstancesViewFromVisible(t *testing.T) {
|
||||
Convey("Given Iris", t, func() {
|
||||
instOrig, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Generate something that says every other row should be visible", func() {
|
||||
rowVisiblex1 := make([]int, 0)
|
||||
_, totalRows := instOrig.Size()
|
||||
for i := 0; i < totalRows; i += 2 {
|
||||
rowVisiblex1 = append(rowVisiblex1, i)
|
||||
}
|
||||
instViewx1 := NewInstancesViewFromVisible(instOrig, rowVisiblex1, instOrig.AllAttributes())
|
||||
for i, a := range rowVisiblex1 {
|
||||
rowStr1 := instViewx1.RowString(i)
|
||||
rowStr2 := instOrig.RowString(a)
|
||||
So(rowStr1, ShouldEqual, rowStr2)
|
||||
}
|
||||
Convey("And then generate something that says that every other row than that should be visible", func() {
|
||||
rowVisiblex2 := make([]int, 0)
|
||||
for i := 0; i < totalRows; i += 4 {
|
||||
rowVisiblex2 = append(rowVisiblex1, i)
|
||||
}
|
||||
instViewx2 := NewInstancesViewFromVisible(instOrig, rowVisiblex2, instOrig.AllAttributes())
|
||||
for i, a := range rowVisiblex2 {
|
||||
rowStr1 := instViewx2.RowString(i)
|
||||
rowStr2 := instOrig.RowString(a)
|
||||
So(rowStr1, ShouldEqual, rowStr2)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstancesViewAttrs(t *testing.T) {
|
||||
Convey("Given Iris", t, func() {
|
||||
instOrig, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("Given a new Attribute vector with the last 4...", func() {
|
||||
cMap := instOrig.AllAttributes()[1:]
|
||||
instView := NewInstancesViewFromAttrs(instOrig, cMap)
|
||||
Convey("The size should be correct", func() {
|
||||
h, v := instView.Size()
|
||||
So(h, ShouldEqual, 4)
|
||||
_, vOrig := instOrig.Size()
|
||||
So(v, ShouldEqual, vOrig)
|
||||
})
|
||||
Convey("There should be 4 Attributes...", func() {
|
||||
attrs := instView.AllAttributes()
|
||||
So(len(attrs), ShouldEqual, 4)
|
||||
})
|
||||
Convey("There should be 4 Attributes with the right headers...", func() {
|
||||
attrs := instView.AllAttributes()
|
||||
So(attrs[0].GetName(), ShouldEqual, "Sepal width")
|
||||
So(attrs[1].GetName(), ShouldEqual, "Petal length")
|
||||
So(attrs[2].GetName(), ShouldEqual, "Petal width")
|
||||
So(attrs[3].GetName(), ShouldEqual, "Species")
|
||||
})
|
||||
Convey("There should be a class Attribute...", func() {
|
||||
attrs := instView.AllClassAttributes()
|
||||
So(len(attrs), ShouldEqual, 1)
|
||||
})
|
||||
Convey("The class Attribute should be preserved...", func() {
|
||||
attrs := instView.AllClassAttributes()
|
||||
So(attrs[0].GetName(), ShouldEqual, "Species")
|
||||
})
|
||||
Convey("Attempts to get the filtered Attribute should fail...", func() {
|
||||
_, err := instView.GetAttribute(instOrig.AllAttributes()[0])
|
||||
So(err, ShouldNotEqual, nil)
|
||||
})
|
||||
Convey("The filtered Attribute should not appear in the RowString", func() {
|
||||
str := "3.90 1.70 0.40 Iris-setosa"
|
||||
row := instView.RowString(5)
|
||||
So(row, ShouldEqual, str)
|
||||
})
|
||||
Convey("The filtered Attributes should all be the same type...", func() {
|
||||
attrs := instView.AllAttributes()
|
||||
_, ok1 := attrs[0].(*FloatAttribute)
|
||||
_, ok2 := attrs[1].(*FloatAttribute)
|
||||
_, ok3 := attrs[2].(*FloatAttribute)
|
||||
_, ok4 := attrs[3].(*CategoricalAttribute)
|
||||
So(ok1, ShouldEqual, true)
|
||||
So(ok2, ShouldEqual, true)
|
||||
So(ok3, ShouldEqual, true)
|
||||
So(ok4, ShouldEqual, true)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
@ -1,13 +1,12 @@
|
||||
/*
|
||||
//
|
||||
//
|
||||
// Ensemble contains classifiers which combine other classifiers.
|
||||
//
|
||||
// RandomForest:
|
||||
// Generates ForestSize bagged decision trees (currently ID3-based)
|
||||
// each considering a fixed number of random features.
|
||||
//
|
||||
// Built on meta.Bagging
|
||||
//
|
||||
|
||||
Ensemble contains classifiers which combine other classifiers.
|
||||
|
||||
RandomForest:
|
||||
Generates ForestSize bagged decision trees (currently ID3-based)
|
||||
each considering a fixed number of random features.
|
||||
|
||||
Built on meta.Bagging
|
||||
|
||||
*/
|
||||
|
||||
package ensemble
|
||||
package ensemble
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// RandomForest classifies instances using an ensemble
|
||||
// of bagged random decision trees
|
||||
// of bagged random decision trees.
|
||||
type RandomForest struct {
|
||||
base.BaseClassifier
|
||||
ForestSize int
|
||||
@ -18,7 +18,7 @@ type RandomForest struct {
|
||||
|
||||
// NewRandomForest generates and return a new random forests
|
||||
// forestSize controls the number of trees that get built
|
||||
// features controls the number of features used to build each tree
|
||||
// features controls the number of features used to build each tree.
|
||||
func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||
ret := &RandomForest{
|
||||
base.BaseClassifier{},
|
||||
@ -30,7 +30,7 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||
}
|
||||
|
||||
// Fit builds the RandomForest on the specified instances
|
||||
func (f *RandomForest) Fit(on *base.Instances) {
|
||||
func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
f.Model.RandomFeatures = f.Features
|
||||
for i := 0; i < f.ForestSize; i++ {
|
||||
@ -40,11 +40,12 @@ func (f *RandomForest) Fit(on *base.Instances) {
|
||||
f.Model.Fit(on)
|
||||
}
|
||||
|
||||
// Predict generates predictions from a trained RandomForest
|
||||
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
|
||||
// Predict generates predictions from a trained RandomForest.
|
||||
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
|
||||
return f.Model.Predict(with)
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of this tree.
|
||||
func (f *RandomForest) String() string {
|
||||
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
||||
}
|
||||
|
@ -13,12 +13,16 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
|
||||
filt := filters.NewChiMergeFilter(trainData, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
|
||||
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
|
||||
|
||||
rf := NewRandomForest(10, 3)
|
||||
rf.Fit(trainData)
|
||||
predictions := rf.Predict(testData)
|
||||
|
@ -11,19 +11,22 @@ type ConfusionMatrix map[string]map[string]int
|
||||
|
||||
// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
|
||||
// and generate (`gen') Instances.
|
||||
func GetConfusionMatrix(ref *base.Instances, gen *base.Instances) map[string]map[string]int {
|
||||
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) map[string]map[string]int {
|
||||
|
||||
if ref.Rows != gen.Rows {
|
||||
_, refRows := ref.Size()
|
||||
_, genRows := gen.Size()
|
||||
|
||||
if refRows != genRows {
|
||||
panic("Row counts should match")
|
||||
}
|
||||
|
||||
ret := make(map[string]map[string]int)
|
||||
|
||||
for i := 0; i < ref.Rows; i++ {
|
||||
referenceClass := ref.GetClass(i)
|
||||
predictedClass := gen.GetClass(i)
|
||||
for i := 0; i < int(refRows); i++ {
|
||||
referenceClass := base.GetClass(ref, i)
|
||||
predictedClass := base.GetClass(gen, i)
|
||||
if _, ok := ret[referenceClass]; ok {
|
||||
ret[referenceClass][predictedClass]++
|
||||
ret[referenceClass][predictedClass] += 1
|
||||
} else {
|
||||
ret[referenceClass] = make(map[string]int)
|
||||
ret[referenceClass][predictedClass] = 1
|
||||
|
@ -1,152 +1,151 @@
|
||||
Sepal length,Sepal width,Petal length,Petal width,Species
|
||||
2,3.5,1.4,0.2,Iris-setosa
|
||||
1,3,1.4,0.2,Iris-setosa
|
||||
1,3.2,1.3,0.2,Iris-setosa
|
||||
0,3.1,1.5,0.2,Iris-setosa
|
||||
1,3.6,1.4,0.2,Iris-setosa
|
||||
3,3.9,1.7,0.4,Iris-setosa
|
||||
0,3.4,1.4,0.3,Iris-setosa
|
||||
1,3.4,1.5,0.2,Iris-setosa
|
||||
0,2.9,1.4,0.2,Iris-setosa
|
||||
1,3.1,1.5,0.1,Iris-setosa
|
||||
3,3.7,1.5,0.2,Iris-setosa
|
||||
1,3.4,1.6,0.2,Iris-setosa
|
||||
1,3,1.4,0.1,Iris-setosa
|
||||
0,3,1.1,0.1,Iris-setosa
|
||||
4,4,1.2,0.2,Iris-setosa
|
||||
3,4.4,1.5,0.4,Iris-setosa
|
||||
3,3.9,1.3,0.4,Iris-setosa
|
||||
2,3.5,1.4,0.3,Iris-setosa
|
||||
3,3.8,1.7,0.3,Iris-setosa
|
||||
2,3.8,1.5,0.3,Iris-setosa
|
||||
3,3.4,1.7,0.2,Iris-setosa
|
||||
2,3.7,1.5,0.4,Iris-setosa
|
||||
0,3.6,1,0.2,Iris-setosa
|
||||
2,3.3,1.7,0.5,Iris-setosa
|
||||
1,3.4,1.9,0.2,Iris-setosa
|
||||
1,3,1.6,0.2,Iris-setosa
|
||||
1,3.4,1.6,0.4,Iris-setosa
|
||||
2,3.5,1.5,0.2,Iris-setosa
|
||||
2,3.4,1.4,0.2,Iris-setosa
|
||||
1,3.2,1.6,0.2,Iris-setosa
|
||||
1,3.1,1.6,0.2,Iris-setosa
|
||||
3,3.4,1.5,0.4,Iris-setosa
|
||||
2,4.1,1.5,0.1,Iris-setosa
|
||||
3,4.2,1.4,0.2,Iris-setosa
|
||||
1,3.1,1.5,0.1,Iris-setosa
|
||||
1,3.2,1.2,0.2,Iris-setosa
|
||||
3,3.5,1.3,0.2,Iris-setosa
|
||||
1,3.1,1.5,0.1,Iris-setosa
|
||||
0,3,1.3,0.2,Iris-setosa
|
||||
2,3.4,1.5,0.2,Iris-setosa
|
||||
1,3.5,1.3,0.3,Iris-setosa
|
||||
0,2.3,1.3,0.3,Iris-setosa
|
||||
0,3.2,1.3,0.2,Iris-setosa
|
||||
1,3.5,1.6,0.6,Iris-setosa
|
||||
2,3.8,1.9,0.4,Iris-setosa
|
||||
1,3,1.4,0.3,Iris-setosa
|
||||
2,3.8,1.6,0.2,Iris-setosa
|
||||
0,3.2,1.4,0.2,Iris-setosa
|
||||
2,3.7,1.5,0.2,Iris-setosa
|
||||
1,3.3,1.4,0.2,Iris-setosa
|
||||
7,3.2,4.7,1.4,Iris-versicolor
|
||||
5,3.2,4.5,1.5,Iris-versicolor
|
||||
7,3.1,4.9,1.5,Iris-versicolor
|
||||
3,2.3,4,1.3,Iris-versicolor
|
||||
6,2.8,4.6,1.5,Iris-versicolor
|
||||
3,2.8,4.5,1.3,Iris-versicolor
|
||||
5,3.3,4.7,1.6,Iris-versicolor
|
||||
1,2.4,3.3,1,Iris-versicolor
|
||||
6,2.9,4.6,1.3,Iris-versicolor
|
||||
2,2.7,3.9,1.4,Iris-versicolor
|
||||
1,2,3.5,1,Iris-versicolor
|
||||
4,3,4.2,1.5,Iris-versicolor
|
||||
4,2.2,4,1,Iris-versicolor
|
||||
5,2.9,4.7,1.4,Iris-versicolor
|
||||
3,2.9,3.6,1.3,Iris-versicolor
|
||||
6,3.1,4.4,1.4,Iris-versicolor
|
||||
3,3,4.5,1.5,Iris-versicolor
|
||||
4,2.7,4.1,1,Iris-versicolor
|
||||
5,2.2,4.5,1.5,Iris-versicolor
|
||||
3,2.5,3.9,1.1,Iris-versicolor
|
||||
4,3.2,4.8,1.8,Iris-versicolor
|
||||
5,2.8,4,1.3,Iris-versicolor
|
||||
5,2.5,4.9,1.5,Iris-versicolor
|
||||
5,2.8,4.7,1.2,Iris-versicolor
|
||||
5,2.9,4.3,1.3,Iris-versicolor
|
||||
6,3,4.4,1.4,Iris-versicolor
|
||||
6,2.8,4.8,1.4,Iris-versicolor
|
||||
6,3,5,1.7,Iris-versicolor
|
||||
4,2.9,4.5,1.5,Iris-versicolor
|
||||
3,2.6,3.5,1,Iris-versicolor
|
||||
3,2.4,3.8,1.1,Iris-versicolor
|
||||
3,2.4,3.7,1,Iris-versicolor
|
||||
4,2.7,3.9,1.2,Iris-versicolor
|
||||
4,2.7,5.1,1.6,Iris-versicolor
|
||||
3,3,4.5,1.5,Iris-versicolor
|
||||
4,3.4,4.5,1.6,Iris-versicolor
|
||||
6,3.1,4.7,1.5,Iris-versicolor
|
||||
5,2.3,4.4,1.3,Iris-versicolor
|
||||
3,3,4.1,1.3,Iris-versicolor
|
||||
3,2.5,4,1.3,Iris-versicolor
|
||||
3,2.6,4.4,1.2,Iris-versicolor
|
||||
5,3,4.6,1.4,Iris-versicolor
|
||||
4,2.6,4,1.2,Iris-versicolor
|
||||
1,2.3,3.3,1,Iris-versicolor
|
||||
3,2.7,4.2,1.3,Iris-versicolor
|
||||
3,3,4.2,1.2,Iris-versicolor
|
||||
3,2.9,4.2,1.3,Iris-versicolor
|
||||
5,2.9,4.3,1.3,Iris-versicolor
|
||||
2,2.5,3,1.1,Iris-versicolor
|
||||
3,2.8,4.1,1.3,Iris-versicolor
|
||||
5,3.3,6,2.5,Iris-virginica
|
||||
4,2.7,5.1,1.9,Iris-virginica
|
||||
7,3,5.9,2.1,Iris-virginica
|
||||
5,2.9,5.6,1.8,Iris-virginica
|
||||
6,3,5.8,2.2,Iris-virginica
|
||||
9,3,6.6,2.1,Iris-virginica
|
||||
1,2.5,4.5,1.7,Iris-virginica
|
||||
8,2.9,6.3,1.8,Iris-virginica
|
||||
6,2.5,5.8,1.8,Iris-virginica
|
||||
8,3.6,6.1,2.5,Iris-virginica
|
||||
6,3.2,5.1,2,Iris-virginica
|
||||
5,2.7,5.3,1.9,Iris-virginica
|
||||
6,3,5.5,2.1,Iris-virginica
|
||||
3,2.5,5,2,Iris-virginica
|
||||
4,2.8,5.1,2.4,Iris-virginica
|
||||
5,3.2,5.3,2.3,Iris-virginica
|
||||
6,3,5.5,1.8,Iris-virginica
|
||||
9,3.8,6.7,2.2,Iris-virginica
|
||||
9,2.6,6.9,2.3,Iris-virginica
|
||||
4,2.2,5,1.5,Iris-virginica
|
||||
7,3.2,5.7,2.3,Iris-virginica
|
||||
3,2.8,4.9,2,Iris-virginica
|
||||
9,2.8,6.7,2,Iris-virginica
|
||||
5,2.7,4.9,1.8,Iris-virginica
|
||||
6,3.3,5.7,2.1,Iris-virginica
|
||||
8,3.2,6,1.8,Iris-virginica
|
||||
5,2.8,4.8,1.8,Iris-virginica
|
||||
5,3,4.9,1.8,Iris-virginica
|
||||
5,2.8,5.6,2.1,Iris-virginica
|
||||
8,3,5.8,1.6,Iris-virginica
|
||||
8,2.8,6.1,1.9,Iris-virginica
|
||||
9,3.8,6.4,2,Iris-virginica
|
||||
5,2.8,5.6,2.2,Iris-virginica
|
||||
5,2.8,5.1,1.5,Iris-virginica
|
||||
5,2.6,5.6,1.4,Iris-virginica
|
||||
9,3,6.1,2.3,Iris-virginica
|
||||
5,3.4,5.6,2.4,Iris-virginica
|
||||
5,3.1,5.5,1.8,Iris-virginica
|
||||
4,3,4.8,1.8,Iris-virginica
|
||||
7,3.1,5.4,2.1,Iris-virginica
|
||||
6,3.1,5.6,2.4,Iris-virginica
|
||||
7,3.1,5.1,2.3,Iris-virginica
|
||||
4,2.7,5.1,1.9,Iris-virginica
|
||||
6,3.2,5.9,2.3,Iris-virginica
|
||||
6,3.3,5.7,2.5,Iris-virginica
|
||||
6,3,5.2,2.3,Iris-virginica
|
||||
5,2.5,5,1.9,Iris-virginica
|
||||
6,3,5.2,2,Iris-virginica
|
||||
5,3.4,5.4,2.3,Iris-virginica
|
||||
4,3,5.1,1.8,Iris-virginica
|
||||
|
||||
Sepal length,Sepal width,Petal length, Petal width,Species
|
||||
5.02,3.5,1.4,0.2,Iris-setosa
|
||||
4.66,3,1.4,0.2,Iris-setosa
|
||||
4.66,3.2,1.3,0.2,Iris-setosa
|
||||
4.3,3.1,1.5,0.2,Iris-setosa
|
||||
4.66,3.6,1.4,0.2,Iris-setosa
|
||||
5.38,3.9,1.7,0.4,Iris-setosa
|
||||
4.3,3.4,1.4,0.3,Iris-setosa
|
||||
4.66,3.4,1.5,0.2,Iris-setosa
|
||||
4.3,2.9,1.4,0.2,Iris-setosa
|
||||
4.66,3.1,1.5,0.1,Iris-setosa
|
||||
5.38,3.7,1.5,0.2,Iris-setosa
|
||||
4.66,3.4,1.6,0.2,Iris-setosa
|
||||
4.66,3,1.4,0.1,Iris-setosa
|
||||
4.3,3,1.1,0.1,Iris-setosa
|
||||
5.74,4,1.2,0.2,Iris-setosa
|
||||
5.38,4.4,1.5,0.4,Iris-setosa
|
||||
5.38,3.9,1.3,0.4,Iris-setosa
|
||||
5.02,3.5,1.4,0.3,Iris-setosa
|
||||
5.38,3.8,1.7,0.3,Iris-setosa
|
||||
5.02,3.8,1.5,0.3,Iris-setosa
|
||||
5.38,3.4,1.7,0.2,Iris-setosa
|
||||
5.02,3.7,1.5,0.4,Iris-setosa
|
||||
4.3,3.6,1,0.2,Iris-setosa
|
||||
5.02,3.3,1.7,0.5,Iris-setosa
|
||||
4.66,3.4,1.9,0.2,Iris-setosa
|
||||
4.66,3,1.6,0.2,Iris-setosa
|
||||
4.66,3.4,1.6,0.4,Iris-setosa
|
||||
5.02,3.5,1.5,0.2,Iris-setosa
|
||||
5.02,3.4,1.4,0.2,Iris-setosa
|
||||
4.66,3.2,1.6,0.2,Iris-setosa
|
||||
4.66,3.1,1.6,0.2,Iris-setosa
|
||||
5.38,3.4,1.5,0.4,Iris-setosa
|
||||
5.02,4.1,1.5,0.1,Iris-setosa
|
||||
5.38,4.2,1.4,0.2,Iris-setosa
|
||||
4.66,3.1,1.5,0.1,Iris-setosa
|
||||
4.66,3.2,1.2,0.2,Iris-setosa
|
||||
5.38,3.5,1.3,0.2,Iris-setosa
|
||||
4.66,3.1,1.5,0.1,Iris-setosa
|
||||
4.3,3,1.3,0.2,Iris-setosa
|
||||
5.02,3.4,1.5,0.2,Iris-setosa
|
||||
4.66,3.5,1.3,0.3,Iris-setosa
|
||||
4.3,2.3,1.3,0.3,Iris-setosa
|
||||
4.3,3.2,1.3,0.2,Iris-setosa
|
||||
4.66,3.5,1.6,0.6,Iris-setosa
|
||||
5.02,3.8,1.9,0.4,Iris-setosa
|
||||
4.66,3,1.4,0.3,Iris-setosa
|
||||
5.02,3.8,1.6,0.2,Iris-setosa
|
||||
4.3,3.2,1.4,0.2,Iris-setosa
|
||||
5.02,3.7,1.5,0.2,Iris-setosa
|
||||
4.66,3.3,1.4,0.2,Iris-setosa
|
||||
6.82,3.2,4.7,1.4,Iris-versicolor
|
||||
6.1,3.2,4.5,1.5,Iris-versicolor
|
||||
6.82,3.1,4.9,1.5,Iris-versicolor
|
||||
5.38,2.3,4,1.3,Iris-versicolor
|
||||
6.46,2.8,4.6,1.5,Iris-versicolor
|
||||
5.38,2.8,4.5,1.3,Iris-versicolor
|
||||
6.1,3.3,4.7,1.6,Iris-versicolor
|
||||
4.66,2.4,3.3,1,Iris-versicolor
|
||||
6.46,2.9,4.6,1.3,Iris-versicolor
|
||||
5.02,2.7,3.9,1.4,Iris-versicolor
|
||||
4.66,2,3.5,1,Iris-versicolor
|
||||
5.74,3,4.2,1.5,Iris-versicolor
|
||||
5.74,2.2,4,1,Iris-versicolor
|
||||
6.1,2.9,4.7,1.4,Iris-versicolor
|
||||
5.38,2.9,3.6,1.3,Iris-versicolor
|
||||
6.46,3.1,4.4,1.4,Iris-versicolor
|
||||
5.38,3,4.5,1.5,Iris-versicolor
|
||||
5.74,2.7,4.1,1,Iris-versicolor
|
||||
6.1,2.2,4.5,1.5,Iris-versicolor
|
||||
5.38,2.5,3.9,1.1,Iris-versicolor
|
||||
5.74,3.2,4.8,1.8,Iris-versicolor
|
||||
6.1,2.8,4,1.3,Iris-versicolor
|
||||
6.1,2.5,4.9,1.5,Iris-versicolor
|
||||
6.1,2.8,4.7,1.2,Iris-versicolor
|
||||
6.1,2.9,4.3,1.3,Iris-versicolor
|
||||
6.46,3,4.4,1.4,Iris-versicolor
|
||||
6.46,2.8,4.8,1.4,Iris-versicolor
|
||||
6.46,3,5,1.7,Iris-versicolor
|
||||
5.74,2.9,4.5,1.5,Iris-versicolor
|
||||
5.38,2.6,3.5,1,Iris-versicolor
|
||||
5.38,2.4,3.8,1.1,Iris-versicolor
|
||||
5.38,2.4,3.7,1,Iris-versicolor
|
||||
5.74,2.7,3.9,1.2,Iris-versicolor
|
||||
5.74,2.7,5.1,1.6,Iris-versicolor
|
||||
5.38,3,4.5,1.5,Iris-versicolor
|
||||
5.74,3.4,4.5,1.6,Iris-versicolor
|
||||
6.46,3.1,4.7,1.5,Iris-versicolor
|
||||
6.1,2.3,4.4,1.3,Iris-versicolor
|
||||
5.38,3,4.1,1.3,Iris-versicolor
|
||||
5.38,2.5,4,1.3,Iris-versicolor
|
||||
5.38,2.6,4.4,1.2,Iris-versicolor
|
||||
6.1,3,4.6,1.4,Iris-versicolor
|
||||
5.74,2.6,4,1.2,Iris-versicolor
|
||||
4.66,2.3,3.3,1,Iris-versicolor
|
||||
5.38,2.7,4.2,1.3,Iris-versicolor
|
||||
5.38,3,4.2,1.2,Iris-versicolor
|
||||
5.38,2.9,4.2,1.3,Iris-versicolor
|
||||
6.1,2.9,4.3,1.3,Iris-versicolor
|
||||
5.02,2.5,3,1.1,Iris-versicolor
|
||||
5.38,2.8,4.1,1.3,Iris-versicolor
|
||||
6.1,3.3,6,2.5,Iris-virginica
|
||||
5.74,2.7,5.1,1.9,Iris-virginica
|
||||
6.82,3,5.9,2.1,Iris-virginica
|
||||
6.1,2.9,5.6,1.8,Iris-virginica
|
||||
6.46,3,5.8,2.2,Iris-virginica
|
||||
7.54,3,6.6,2.1,Iris-virginica
|
||||
4.66,2.5,4.5,1.7,Iris-virginica
|
||||
7.18,2.9,6.3,1.8,Iris-virginica
|
||||
6.46,2.5,5.8,1.8,Iris-virginica
|
||||
7.18,3.6,6.1,2.5,Iris-virginica
|
||||
6.46,3.2,5.1,2,Iris-virginica
|
||||
6.1,2.7,5.3,1.9,Iris-virginica
|
||||
6.46,3,5.5,2.1,Iris-virginica
|
||||
5.38,2.5,5,2,Iris-virginica
|
||||
5.74,2.8,5.1,2.4,Iris-virginica
|
||||
6.1,3.2,5.3,2.3,Iris-virginica
|
||||
6.46,3,5.5,1.8,Iris-virginica
|
||||
7.54,3.8,6.7,2.2,Iris-virginica
|
||||
7.54,2.6,6.9,2.3,Iris-virginica
|
||||
5.74,2.2,5,1.5,Iris-virginica
|
||||
6.82,3.2,5.7,2.3,Iris-virginica
|
||||
5.38,2.8,4.9,2,Iris-virginica
|
||||
7.54,2.8,6.7,2,Iris-virginica
|
||||
6.1,2.7,4.9,1.8,Iris-virginica
|
||||
6.46,3.3,5.7,2.1,Iris-virginica
|
||||
7.18,3.2,6,1.8,Iris-virginica
|
||||
6.1,2.8,4.8,1.8,Iris-virginica
|
||||
6.1,3,4.9,1.8,Iris-virginica
|
||||
6.1,2.8,5.6,2.1,Iris-virginica
|
||||
7.18,3,5.8,1.6,Iris-virginica
|
||||
7.18,2.8,6.1,1.9,Iris-virginica
|
||||
7.9,3.8,6.4,2,Iris-virginica
|
||||
6.1,2.8,5.6,2.2,Iris-virginica
|
||||
6.1,2.8,5.1,1.5,Iris-virginica
|
||||
6.1,2.6,5.6,1.4,Iris-virginica
|
||||
7.54,3,6.1,2.3,Iris-virginica
|
||||
6.1,3.4,5.6,2.4,Iris-virginica
|
||||
6.1,3.1,5.5,1.8,Iris-virginica
|
||||
5.74,3,4.8,1.8,Iris-virginica
|
||||
6.82,3.1,5.4,2.1,Iris-virginica
|
||||
6.46,3.1,5.6,2.4,Iris-virginica
|
||||
6.82,3.1,5.1,2.3,Iris-virginica
|
||||
5.74,2.7,5.1,1.9,Iris-virginica
|
||||
6.46,3.2,5.9,2.3,Iris-virginica
|
||||
6.46,3.3,5.7,2.5,Iris-virginica
|
||||
6.46,3,5.2,2.3,Iris-virginica
|
||||
6.1,2.5,5,1.9,Iris-virginica
|
||||
6.46,3,5.2,2,Iris-virginica
|
||||
6.1,3.4,5.4,2.3,Iris-virginica
|
||||
5.74,3,5.1,1.8,Iris-virginica
|
||||
|
|
@ -34,7 +34,7 @@ func main() {
|
||||
|
||||
// If two decimal places isn't enough, you can update the
|
||||
// Precision field on any FloatAttribute
|
||||
if attr, ok := rawData.GetAttr(0).(*base.FloatAttribute); !ok {
|
||||
if attr, ok := rawData.AllAttributes()[0].(*base.FloatAttribute); !ok {
|
||||
panic("Invalid cast")
|
||||
} else {
|
||||
attr.Precision = 4
|
||||
@ -44,8 +44,15 @@ func main() {
|
||||
|
||||
// We can update the set of Instances, although the API
|
||||
// for doing so is not very sophisticated.
|
||||
rawData.SetAttrStr(0, 0, "1.00")
|
||||
rawData.SetAttrStr(0, rawData.ClassIndex, "Iris-unusual")
|
||||
|
||||
// First, have to resolve Attribute Specifications
|
||||
as := base.ResolveAttributes(rawData, rawData.AllAttributes())
|
||||
|
||||
// Attribute Specifications describe where a given column lives
|
||||
rawData.Set(as[0], 0, as[0].GetAttribute().GetSysValFromString("1.00"))
|
||||
|
||||
// A SetClass method exists as a shortcut
|
||||
base.SetClass(rawData, 0, "Iris-unusual")
|
||||
fmt.Println(rawData)
|
||||
|
||||
// There is a way of creating new Instances from scratch.
|
||||
@ -64,6 +71,21 @@ func main() {
|
||||
attrs[1].GetSysValFromString("A")
|
||||
|
||||
// Now let's create the final instances set
|
||||
newInst := base.NewInstancesFromRaw(attrs, 1, newData)
|
||||
newInst := base.NewDenseInstances()
|
||||
|
||||
// Add the attributes
|
||||
newSpecs := make([]base.AttributeSpec, len(attrs))
|
||||
for i, a := range attrs {
|
||||
newSpecs[i] = newInst.AddAttribute(a)
|
||||
}
|
||||
|
||||
// Allocate space
|
||||
newInst.Extend(1)
|
||||
|
||||
// Write the data
|
||||
newInst.Set(newSpecs[0], 0, newSpecs[0].GetAttribute().GetSysValFromString("1.0"))
|
||||
newInst.Set(newSpecs[1], 0, newSpecs[1].GetAttribute().GetSysValFromString("A"))
|
||||
|
||||
fmt.Println(newInst)
|
||||
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rawData.Shuffle()
|
||||
|
||||
//Initialises a new KNN classifier
|
||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
||||
|
||||
|
@ -5,15 +5,15 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
ensemble "github.com/sjwhitworth/golearn/ensemble"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
filters "github.com/sjwhitworth/golearn/filters"
|
||||
ensemble "github.com/sjwhitworth/golearn/ensemble"
|
||||
trees "github.com/sjwhitworth/golearn/trees"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main () {
|
||||
func main() {
|
||||
|
||||
var tree base.Classifier
|
||||
|
||||
@ -27,12 +27,14 @@ func main () {
|
||||
|
||||
// Discretise the iris dataset with Chi-Merge
|
||||
filt := filters.NewChiMergeFilter(iris, 0.99)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(iris)
|
||||
for _, a := range base.NonClassFloatAttributes(iris) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
irisf := base.NewLazilyFilteredInstances(iris, filt)
|
||||
|
||||
// Create a 60-40 training-test split
|
||||
trainData, testData := base.InstancesTrainTestSplit(iris, 0.60)
|
||||
trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)
|
||||
|
||||
//
|
||||
// First up, use ID3
|
||||
|
151
filters/binary.go
Normal file
151
filters/binary.go
Normal file
@ -0,0 +1,151 @@
|
||||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
// BinaryConvertFilters convert a given DataGrid into one which
|
||||
// only contains BinaryAttributes.
|
||||
//
|
||||
// FloatAttributes are discretised into either 0 (if the value is 0)
|
||||
// or 1 (if the value is not 0).
|
||||
//
|
||||
// CategoricalAttributes are discretised into one or more new
|
||||
// BinaryAttributes.
|
||||
type BinaryConvertFilter struct {
|
||||
attrs []base.Attribute
|
||||
converted []base.FilteredAttribute
|
||||
twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes
|
||||
nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute
|
||||
}
|
||||
|
||||
// NewBinaryConvertFilter creates a blank BinaryConvertFilter
|
||||
func NewBinaryConvertFilter() *BinaryConvertFilter {
|
||||
ret := &BinaryConvertFilter{
|
||||
make([]base.Attribute, 0),
|
||||
make([]base.FilteredAttribute, 0),
|
||||
make(map[base.Attribute]bool),
|
||||
make(map[base.Attribute]map[uint64]base.Attribute),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddAttribute adds a new Attribute to this Filter
|
||||
func (b *BinaryConvertFilter) AddAttribute(a base.Attribute) error {
|
||||
b.attrs = append(b.attrs, a)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAttributesAfterFiltering returns the Attributes previously computed via Train()
|
||||
func (b *BinaryConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||
return b.converted
|
||||
}
|
||||
|
||||
// String gets a human-readable string
|
||||
func (b *BinaryConvertFilter) String() string {
|
||||
return fmt.Sprintf("BinaryConvertFilter(%d Attribute(s))", len(b.attrs))
|
||||
}
|
||||
|
||||
// Transform converts the given byte sequence using the old Attribute into the new
|
||||
// byte sequence.
|
||||
//
|
||||
// If the old Attribute has a categorical value of at most two items, then a zero or
|
||||
// non-zero byte sequence is returned.
|
||||
//
|
||||
// If the old Attribute has a categorical value of at most n-items, then a non-zero
|
||||
// or zero byte sequence is returned based on the value of the new Attribute passed in.
|
||||
//
|
||||
// If the old Attribute is a float, it's value's unpacked and we check for non-zeroness
|
||||
//
|
||||
// If the old Attribute is a BinaryAttribute, just return the input
|
||||
func (b *BinaryConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte {
|
||||
ret := make([]byte, 1)
|
||||
// Check for CategoricalAttribute
|
||||
if _, ok := a.(*base.CategoricalAttribute); ok {
|
||||
// Unpack byte value
|
||||
val := base.UnpackBytesToU64(attrBytes)
|
||||
// If it's a two-valued one, check for non-zero
|
||||
if b.twoValuedCategoricalAttributes[a] {
|
||||
if val > 0 {
|
||||
ret[0] = 1
|
||||
} else {
|
||||
ret[0] = 0
|
||||
}
|
||||
} else if an, ok := b.nValuedCategoricalAttributeMap[a]; ok {
|
||||
// If it's an n-valued one, check the new Attribute maps onto
|
||||
// the unpacked value
|
||||
if af, ok := an[val]; ok {
|
||||
if af.Equals(n) {
|
||||
ret[0] = 1
|
||||
} else {
|
||||
ret[0] = 0
|
||||
}
|
||||
} else {
|
||||
panic("Categorical value not defined!")
|
||||
}
|
||||
} else {
|
||||
panic(fmt.Sprintf("Not a recognised Attribute %v", a))
|
||||
}
|
||||
} else if _, ok := a.(*base.BinaryAttribute); ok {
|
||||
// Binary: just return the original value
|
||||
ret = attrBytes
|
||||
} else if _, ok := a.(*base.FloatAttribute); ok {
|
||||
// Float: check for non-zero
|
||||
val := base.UnpackBytesToFloat(attrBytes)
|
||||
if val > 0 {
|
||||
ret[0] = 1
|
||||
} else {
|
||||
ret[0] = 0
|
||||
}
|
||||
} else {
|
||||
panic(fmt.Sprintf("Unrecognised Attribute: %v", a))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Train converts the FloatAttributesinto equivalently named BinaryAttributes,
|
||||
// leaves BinaryAttributes unmodified and processes
|
||||
// CategoricalAttributes as follows.
|
||||
//
|
||||
// If the CategoricalAttribute has two values, one of them is
|
||||
// designated 0 and the other 1, and a single identically-named
|
||||
// binary Attribute is returned.
|
||||
//
|
||||
// If the CategoricalAttribute has more than two (n) values, the Filter
|
||||
// generates n BinaryAttributes and sets each of them if the value's observed.
|
||||
func (b *BinaryConvertFilter) Train() error {
|
||||
for _, a := range b.attrs {
|
||||
if ac, ok := a.(*base.CategoricalAttribute); ok {
|
||||
vals := ac.GetValues()
|
||||
if len(vals) <= 2 {
|
||||
nAttr := base.NewBinaryAttribute(ac.GetName())
|
||||
fAttr := base.FilteredAttribute{ac, nAttr}
|
||||
b.converted = append(b.converted, fAttr)
|
||||
b.twoValuedCategoricalAttributes[a] = true
|
||||
} else {
|
||||
if _, ok := b.nValuedCategoricalAttributeMap[a]; !ok {
|
||||
b.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute)
|
||||
}
|
||||
for i := uint64(0); i < uint64(len(vals)); i++ {
|
||||
v := vals[i]
|
||||
newName := fmt.Sprintf("%s_%s", ac.GetName(), v)
|
||||
newAttr := base.NewBinaryAttribute(newName)
|
||||
fAttr := base.FilteredAttribute{ac, newAttr}
|
||||
b.converted = append(b.converted, fAttr)
|
||||
b.nValuedCategoricalAttributeMap[a][i] = newAttr
|
||||
}
|
||||
}
|
||||
} else if ab, ok := a.(*base.BinaryAttribute); ok {
|
||||
fAttr := base.FilteredAttribute{ab, ab}
|
||||
b.converted = append(b.converted, fAttr)
|
||||
} else if af, ok := a.(*base.FloatAttribute); ok {
|
||||
newAttr := base.NewBinaryAttribute(af.GetName())
|
||||
fAttr := base.FilteredAttribute{af, newAttr}
|
||||
b.converted = append(b.converted, fAttr)
|
||||
} else {
|
||||
return fmt.Errorf("Unsupported Attribute type: %v", a)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
4
filters/binary_test.csv
Normal file
4
filters/binary_test.csv
Normal file
@ -0,0 +1,4 @@
|
||||
floatAttr,shouldBe1Binary,shouldBe3Binary
|
||||
1.0,true,stoicism
|
||||
1.0,false,heroism
|
||||
0.0,false,romanticism
|
|
@ -9,113 +9,114 @@ import (
|
||||
// BinningFilter does equal-width binning for numeric
|
||||
// Attributes (aka "histogram binning")
|
||||
type BinningFilter struct {
|
||||
Attributes []int
|
||||
Instances *base.Instances
|
||||
BinCount int
|
||||
MinVals map[int]float64
|
||||
MaxVals map[int]float64
|
||||
trained bool
|
||||
AbstractDiscretizeFilter
|
||||
bins int
|
||||
minVals map[base.Attribute]float64
|
||||
maxVals map[base.Attribute]float64
|
||||
}
|
||||
|
||||
// NewBinningFilter creates a BinningFilter structure
|
||||
// with some helpful default initialisations.
|
||||
func NewBinningFilter(inst *base.Instances, bins int) BinningFilter {
|
||||
return BinningFilter{
|
||||
make([]int, 0),
|
||||
inst,
|
||||
func NewBinningFilter(d base.FixedDataGrid, bins int) *BinningFilter {
|
||||
return &BinningFilter{
|
||||
AbstractDiscretizeFilter{
|
||||
make(map[base.Attribute]bool),
|
||||
false,
|
||||
d,
|
||||
},
|
||||
bins,
|
||||
make(map[int]float64),
|
||||
make(map[int]float64),
|
||||
false,
|
||||
make(map[base.Attribute]float64),
|
||||
make(map[base.Attribute]float64),
|
||||
}
|
||||
}
|
||||
|
||||
// AddAttribute adds the index of the given attribute `a'
|
||||
// to the BinningFilter for discretisation.
|
||||
func (b *BinningFilter) AddAttribute(a base.Attribute) {
|
||||
attrIndex := b.Instances.GetAttrIndex(a)
|
||||
if attrIndex == -1 {
|
||||
panic("invalid attribute")
|
||||
}
|
||||
b.Attributes = append(b.Attributes, attrIndex)
|
||||
func (b *BinningFilter) String() string {
|
||||
return fmt.Sprintf("BinningFilter(%d Attribute(s), %d bin(s)", b.attrs, b.bins)
|
||||
}
|
||||
|
||||
// AddAllNumericAttributes adds every suitable attribute
|
||||
// to the BinningFilter for discretiation
|
||||
func (b *BinningFilter) AddAllNumericAttributes() {
|
||||
for i := 0; i < b.Instances.Cols; i++ {
|
||||
if i == b.Instances.ClassIndex {
|
||||
continue
|
||||
}
|
||||
attr := b.Instances.GetAttr(i)
|
||||
if attr.GetType() != base.Float64Type {
|
||||
continue
|
||||
}
|
||||
b.Attributes = append(b.Attributes, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Build computes and stores the bin values
|
||||
// Train computes and stores the bin values
|
||||
// for the training instances.
|
||||
func (b *BinningFilter) Build() {
|
||||
for _, attr := range b.Attributes {
|
||||
maxVal := math.Inf(-1)
|
||||
minVal := math.Inf(1)
|
||||
for i := 0; i < b.Instances.Rows; i++ {
|
||||
val := b.Instances.Get(i, attr)
|
||||
if val > maxVal {
|
||||
maxVal = val
|
||||
func (b *BinningFilter) Train() error {
|
||||
|
||||
as := b.getAttributeSpecs()
|
||||
// Set up the AttributeSpecs, and values
|
||||
for attr := range b.attrs {
|
||||
if !b.attrs[attr] {
|
||||
continue
|
||||
}
|
||||
b.minVals[attr] = float64(math.Inf(1))
|
||||
b.maxVals[attr] = float64(math.Inf(-1))
|
||||
}
|
||||
|
||||
err := b.train.MapOverRows(as, func(row [][]byte, rowNo int) (bool, error) {
|
||||
for i, a := range row {
|
||||
attr := as[i].GetAttribute()
|
||||
attrf := attr.(*base.FloatAttribute)
|
||||
val := float64(attrf.GetFloatFromSysVal(a))
|
||||
if val > b.maxVals[attr] {
|
||||
b.maxVals[attr] = val
|
||||
}
|
||||
if val < minVal {
|
||||
minVal = val
|
||||
if val < b.minVals[attr] {
|
||||
b.minVals[attr] = val
|
||||
}
|
||||
}
|
||||
b.MaxVals[attr] = maxVal
|
||||
b.MinVals[attr] = minVal
|
||||
b.trained = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Training error: %s", err)
|
||||
}
|
||||
b.trained = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run applies a trained BinningFilter to a set of Instances,
|
||||
// discretising any numeric attributes added.
|
||||
//
|
||||
// IMPORTANT: Run discretises in-place, so make sure to take
|
||||
// a copy if the original instances are still needed
|
||||
//
|
||||
// IMPORTANT: This function panic()s if the filter has not been
|
||||
// trained. Call Build() before running this function
|
||||
//
|
||||
// IMPORTANT: Call Build() after adding any additional attributes.
|
||||
// Otherwise, the training structure will be out of date from
|
||||
// the values expected and could cause a panic.
|
||||
func (b *BinningFilter) Run(on *base.Instances) {
|
||||
if !b.trained {
|
||||
panic("Call Build() beforehand")
|
||||
// Transform takes an Attribute and byte sequence and returns
|
||||
// the transformed byte sequence.
|
||||
func (b *BinningFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
|
||||
|
||||
if !b.attrs[a] {
|
||||
return field
|
||||
}
|
||||
for attr := range b.Attributes {
|
||||
minVal := b.MinVals[attr]
|
||||
maxVal := b.MaxVals[attr]
|
||||
disc := 0
|
||||
// Casts to float32 to replicate a floating point precision error
|
||||
delta := float32(maxVal - minVal)
|
||||
delta /= float32(b.BinCount)
|
||||
for i := 0; i < on.Rows; i++ {
|
||||
val := on.Get(i, attr)
|
||||
if val <= minVal {
|
||||
disc = 0
|
||||
} else {
|
||||
disc = int(math.Floor(float64(float32(val-minVal) / delta)))
|
||||
if disc >= b.BinCount {
|
||||
disc = b.BinCount - 1
|
||||
}
|
||||
}
|
||||
on.Set(i, attr, float64(disc))
|
||||
}
|
||||
newAttribute := new(base.CategoricalAttribute)
|
||||
newAttribute.SetName(on.GetAttr(attr).GetName())
|
||||
for i := 0; i < b.BinCount; i++ {
|
||||
newAttribute.GetSysValFromString(fmt.Sprintf("%d", i))
|
||||
}
|
||||
on.ReplaceAttr(attr, newAttribute)
|
||||
af, ok := a.(*base.FloatAttribute)
|
||||
if !ok {
|
||||
panic("Attribute is the wrong type")
|
||||
}
|
||||
minVal := b.minVals[a]
|
||||
maxVal := b.maxVals[a]
|
||||
disc := 0
|
||||
// Casts to float64 to replicate a floating point precision error
|
||||
delta := float64(maxVal-minVal) / float64(b.bins)
|
||||
val := float64(af.GetFloatFromSysVal(field))
|
||||
if val <= minVal {
|
||||
disc = 0
|
||||
} else {
|
||||
disc = int(math.Floor(float64(float64(val-minVal)/delta + 0.0001)))
|
||||
}
|
||||
return base.PackU64ToBytes(uint64(disc))
|
||||
}
|
||||
|
||||
// GetAttributesAfterFiltering gets a list of before/after
|
||||
// Attributes as base.FilteredAttributes
|
||||
func (b *BinningFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||
oldAttrs := b.train.AllAttributes()
|
||||
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
||||
for i, a := range oldAttrs {
|
||||
if b.attrs[a] {
|
||||
retAttr := new(base.CategoricalAttribute)
|
||||
minVal := b.minVals[a]
|
||||
maxVal := b.maxVals[a]
|
||||
delta := float64(maxVal-minVal) / float64(b.bins)
|
||||
retAttr.SetName(a.GetName())
|
||||
for i := 0; i <= b.bins; i++ {
|
||||
floatVal := float64(i)*delta + minVal
|
||||
fmtStr := fmt.Sprintf("%%.%df", a.(*base.FloatAttribute).Precision)
|
||||
binVal := fmt.Sprintf(fmtStr, floatVal)
|
||||
retAttr.GetSysValFromString(binVal)
|
||||
}
|
||||
ret[i] = base.FilteredAttribute{a, retAttr}
|
||||
} else {
|
||||
ret[i] = base.FilteredAttribute{a, a}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -2,27 +2,55 @@ package filters
|
||||
|
||||
import (
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBinning(testEnv *testing.T) {
|
||||
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
|
||||
inst3, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
filt := NewBinningFilter(inst1, 10)
|
||||
filt.AddAttribute(inst1.GetAttr(0))
|
||||
filt.Build()
|
||||
filt.Run(inst1)
|
||||
for i := 0; i < inst1.Rows; i++ {
|
||||
val1 := inst1.Get(i, 0)
|
||||
val2 := inst2.Get(i, 0)
|
||||
val3 := inst3.Get(i, 0)
|
||||
if math.Abs(val1-val2) >= 1 {
|
||||
testEnv.Error(val1, val2, val3, i)
|
||||
func TestBinning(t *testing.T) {
|
||||
Convey("Given some data and a reference", t, func() {
|
||||
// Read the data
|
||||
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
//
|
||||
// Construct the binning filter
|
||||
binAttr := inst1.AllAttributes()[0]
|
||||
filt := NewBinningFilter(inst1, 10)
|
||||
filt.AddAttribute(binAttr)
|
||||
filt.Train()
|
||||
inst1f := base.NewLazilyFilteredInstances(inst1, filt)
|
||||
|
||||
// Retrieve the categorical version of the original Attribute
|
||||
var cAttr base.Attribute
|
||||
for _, a := range inst1f.AllAttributes() {
|
||||
if a.GetName() == binAttr.GetName() {
|
||||
cAttr = a
|
||||
}
|
||||
}
|
||||
|
||||
cAttrSpec, err := inst1f.GetAttribute(cAttr)
|
||||
So(err, ShouldEqual, nil)
|
||||
binAttrSpec, err := inst2.GetAttribute(binAttr)
|
||||
So(err, ShouldEqual, nil)
|
||||
|
||||
//
|
||||
// Create the LazilyFilteredInstances
|
||||
// and check the values
|
||||
Convey("Discretized version should match reference", func() {
|
||||
_, rows := inst1.Size()
|
||||
for i := 0; i < rows; i++ {
|
||||
val1 := inst1f.Get(cAttrSpec, i)
|
||||
val2 := inst2.Get(binAttrSpec, i)
|
||||
val1s := cAttr.GetStringFromSysVal(val1)
|
||||
val2s := binAttr.GetStringFromSysVal(val2)
|
||||
So(val1s, ShouldEqual, val2s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -12,301 +12,30 @@ import (
|
||||
// See Bramer, "Principles of Data Mining", 2nd Edition
|
||||
// pp 105--115
|
||||
type ChiMergeFilter struct {
|
||||
Attributes []int
|
||||
Instances *base.Instances
|
||||
Tables map[int][]*FrequencyTableEntry
|
||||
AbstractDiscretizeFilter
|
||||
tables map[base.Attribute][]*FrequencyTableEntry
|
||||
Significance float64
|
||||
MinRows int
|
||||
MaxRows int
|
||||
_Trained bool
|
||||
}
|
||||
|
||||
// NewChiMergeFilter creates a ChiMergeFilter with some helpful initialisations.
|
||||
func NewChiMergeFilter(inst *base.Instances, significance float64) ChiMergeFilter {
|
||||
return ChiMergeFilter{
|
||||
make([]int, 0),
|
||||
inst,
|
||||
make(map[int][]*FrequencyTableEntry),
|
||||
// NewChiMergeFilter creates a ChiMergeFilter with some helpful intialisations.
|
||||
func NewChiMergeFilter(d base.FixedDataGrid, significance float64) *ChiMergeFilter {
|
||||
_, rows := d.Size()
|
||||
return &ChiMergeFilter{
|
||||
AbstractDiscretizeFilter{
|
||||
make(map[base.Attribute]bool),
|
||||
false,
|
||||
d,
|
||||
},
|
||||
make(map[base.Attribute][]*FrequencyTableEntry),
|
||||
significance,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
}
|
||||
}
|
||||
|
||||
// Build trains a ChiMergeFilter on the ChiMergeFilter.Instances given
|
||||
func (c *ChiMergeFilter) Build() {
|
||||
for _, attr := range c.Attributes {
|
||||
tab := chiMerge(c.Instances, attr, c.Significance, c.MinRows, c.MaxRows)
|
||||
c.Tables[attr] = tab
|
||||
c._Trained = true
|
||||
}
|
||||
}
|
||||
|
||||
// AddAllNumericAttributes adds every suitable attribute
|
||||
// to the ChiMergeFilter for discretisation
|
||||
func (c *ChiMergeFilter) AddAllNumericAttributes() {
|
||||
for i := 0; i < c.Instances.Cols; i++ {
|
||||
if i == c.Instances.ClassIndex {
|
||||
continue
|
||||
}
|
||||
attr := c.Instances.GetAttr(i)
|
||||
if attr.GetType() != base.Float64Type {
|
||||
continue
|
||||
}
|
||||
c.Attributes = append(c.Attributes, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Run discretises the set of Instances `on'
|
||||
//
|
||||
// IMPORTANT: ChiMergeFilter discretises in place.
|
||||
func (c *ChiMergeFilter) Run(on *base.Instances) {
|
||||
if !c._Trained {
|
||||
panic("Call Build() beforehand")
|
||||
}
|
||||
for attr := range c.Tables {
|
||||
table := c.Tables[attr]
|
||||
for i := 0; i < on.Rows; i++ {
|
||||
val := on.Get(i, attr)
|
||||
dis := 0
|
||||
for j, k := range table {
|
||||
if k.Value < val {
|
||||
dis = j
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
on.Set(i, attr, float64(dis))
|
||||
}
|
||||
newAttribute := new(base.CategoricalAttribute)
|
||||
newAttribute.SetName(on.GetAttr(attr).GetName())
|
||||
for _, k := range table {
|
||||
newAttribute.GetSysValFromString(fmt.Sprintf("%f", k.Value))
|
||||
}
|
||||
on.ReplaceAttr(attr, newAttribute)
|
||||
}
|
||||
}
|
||||
|
||||
// AddAttribute add a given numeric Attribute `attr' to the
|
||||
// filter.
|
||||
//
|
||||
// IMPORTANT: This function panic()s if it can't locate the
|
||||
// attribute in the Instances set.
|
||||
func (c *ChiMergeFilter) AddAttribute(attr base.Attribute) {
|
||||
if attr.GetType() != base.Float64Type {
|
||||
panic("ChiMerge only works on Float64Attributes")
|
||||
}
|
||||
attrIndex := c.Instances.GetAttrIndex(attr)
|
||||
if attrIndex == -1 {
|
||||
panic("Invalid attribute!")
|
||||
}
|
||||
c.Attributes = append(c.Attributes, attrIndex)
|
||||
}
|
||||
|
||||
type FrequencyTableEntry struct {
|
||||
Value float64
|
||||
Frequency map[string]int
|
||||
}
|
||||
|
||||
func (t *FrequencyTableEntry) String() string {
|
||||
return fmt.Sprintf("%.2f %v", t.Value, t.Frequency)
|
||||
}
|
||||
|
||||
func ChiMBuildFrequencyTable(attr int, inst *base.Instances) []*FrequencyTableEntry {
|
||||
ret := make([]*FrequencyTableEntry, 0)
|
||||
var attribute *base.FloatAttribute
|
||||
attribute, ok := inst.GetAttr(attr).(*base.FloatAttribute)
|
||||
if !ok {
|
||||
panic("only use Chi-M on numeric stuff")
|
||||
}
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
value := inst.Get(i, attr)
|
||||
valueConv := attribute.GetUsrVal(value)
|
||||
class := inst.GetClass(i)
|
||||
// Search the frequency table for the value
|
||||
found := false
|
||||
for _, entry := range ret {
|
||||
if entry.Value == valueConv {
|
||||
found = true
|
||||
entry.Frequency[class]++
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
newEntry := &FrequencyTableEntry{
|
||||
valueConv,
|
||||
make(map[string]int),
|
||||
}
|
||||
newEntry.Frequency[class] = 1
|
||||
ret = append(ret, newEntry)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func chiSquaredPdf(k float64, x float64) float64 {
|
||||
if x < 0 {
|
||||
return 0
|
||||
}
|
||||
top := math.Pow(x, (k/2)-1) * math.Exp(-x/2)
|
||||
bottom := math.Pow(2, k/2) * math.Gamma(k/2)
|
||||
return top / bottom
|
||||
}
|
||||
|
||||
func chiSquaredPercentile(k int, x float64) float64 {
|
||||
// Implements Yahya et al.'s "A Numerical Procedure
|
||||
// for Computing Chi-Square Percentage Points"
|
||||
// InterStat Journal 01/2007; April 25:page:1-8.
|
||||
steps := 32
|
||||
intervals := 4 * steps
|
||||
w := x / (4.0 * float64(steps))
|
||||
values := make([]float64, intervals+1)
|
||||
for i := 0; i < intervals+1; i++ {
|
||||
c := w * float64(i)
|
||||
v := chiSquaredPdf(float64(k), c)
|
||||
values[i] = v
|
||||
}
|
||||
|
||||
ret1 := values[0] + values[len(values)-1]
|
||||
ret2 := 0.0
|
||||
ret3 := 0.0
|
||||
ret4 := 0.0
|
||||
|
||||
for i := 2; i < intervals-1; i += 4 {
|
||||
ret2 += values[i]
|
||||
}
|
||||
|
||||
for i := 4; i < intervals-3; i += 4 {
|
||||
ret3 += values[i]
|
||||
}
|
||||
|
||||
for i := 1; i < intervals; i += 2 {
|
||||
ret4 += values[i]
|
||||
}
|
||||
|
||||
return (2.0 * w / 45) * (7*ret1 + 12*ret2 + 14*ret3 + 32*ret4)
|
||||
}
|
||||
|
||||
func chiCountClasses(entries []*FrequencyTableEntry) map[string]int {
|
||||
classCounter := make(map[string]int)
|
||||
for _, e := range entries {
|
||||
for k := range e.Frequency {
|
||||
classCounter[k] += e.Frequency[k]
|
||||
}
|
||||
}
|
||||
return classCounter
|
||||
}
|
||||
|
||||
func chiComputeStatistic(entry1 *FrequencyTableEntry, entry2 *FrequencyTableEntry) float64 {
|
||||
|
||||
// Sum the number of things observed per class
|
||||
classCounter := make(map[string]int)
|
||||
for k := range entry1.Frequency {
|
||||
classCounter[k] += entry1.Frequency[k]
|
||||
}
|
||||
for k := range entry2.Frequency {
|
||||
classCounter[k] += entry2.Frequency[k]
|
||||
}
|
||||
|
||||
// Sum the number of things observed per value
|
||||
entryObservations1 := 0
|
||||
entryObservations2 := 0
|
||||
for k := range entry1.Frequency {
|
||||
entryObservations1 += entry1.Frequency[k]
|
||||
}
|
||||
for k := range entry2.Frequency {
|
||||
entryObservations2 += entry2.Frequency[k]
|
||||
}
|
||||
|
||||
totalObservations := entryObservations1 + entryObservations2
|
||||
// Compute the expected values per class
|
||||
expectedClassValues1 := make(map[string]float64)
|
||||
expectedClassValues2 := make(map[string]float64)
|
||||
for k := range classCounter {
|
||||
expectedClassValues1[k] = float64(classCounter[k])
|
||||
expectedClassValues1[k] *= float64(entryObservations1)
|
||||
expectedClassValues1[k] /= float64(totalObservations)
|
||||
}
|
||||
for k := range classCounter {
|
||||
expectedClassValues2[k] = float64(classCounter[k])
|
||||
expectedClassValues2[k] *= float64(entryObservations2)
|
||||
expectedClassValues2[k] /= float64(totalObservations)
|
||||
}
|
||||
|
||||
// Compute chi-squared value
|
||||
chiSum := 0.0
|
||||
for k := range expectedClassValues1 {
|
||||
numerator := float64(entry1.Frequency[k])
|
||||
numerator -= expectedClassValues1[k]
|
||||
numerator = math.Pow(numerator, 2)
|
||||
denominator := float64(expectedClassValues1[k])
|
||||
if denominator < 0.5 {
|
||||
denominator = 0.5
|
||||
}
|
||||
chiSum += numerator / denominator
|
||||
}
|
||||
for k := range expectedClassValues2 {
|
||||
numerator := float64(entry2.Frequency[k])
|
||||
numerator -= expectedClassValues2[k]
|
||||
numerator = math.Pow(numerator, 2)
|
||||
denominator := float64(expectedClassValues2[k])
|
||||
if denominator < 0.5 {
|
||||
denominator = 0.5
|
||||
}
|
||||
chiSum += numerator / denominator
|
||||
}
|
||||
|
||||
return chiSum
|
||||
}
|
||||
|
||||
func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*FrequencyTableEntry {
|
||||
mergeEntry1 := freq[minIndex]
|
||||
mergeEntry2 := freq[minIndex+1]
|
||||
classCounter := make(map[string]int)
|
||||
for k := range mergeEntry1.Frequency {
|
||||
classCounter[k] += mergeEntry1.Frequency[k]
|
||||
}
|
||||
for k := range mergeEntry2.Frequency {
|
||||
classCounter[k] += mergeEntry2.Frequency[k]
|
||||
}
|
||||
newVal := freq[minIndex].Value
|
||||
newEntry := &FrequencyTableEntry{
|
||||
newVal,
|
||||
classCounter,
|
||||
}
|
||||
lowerSlice := freq
|
||||
upperSlice := freq
|
||||
if minIndex > 0 {
|
||||
lowerSlice = freq[0:minIndex]
|
||||
upperSlice = freq[minIndex+1:]
|
||||
} else {
|
||||
lowerSlice = make([]*FrequencyTableEntry, 0)
|
||||
upperSlice = freq[1:]
|
||||
}
|
||||
upperSlice[0] = newEntry
|
||||
freq = append(lowerSlice, upperSlice...)
|
||||
return freq
|
||||
}
|
||||
|
||||
func chiMergePrintTable(freq []*FrequencyTableEntry) {
|
||||
classes := chiCountClasses(freq)
|
||||
fmt.Printf("Attribute value\t")
|
||||
for k := range classes {
|
||||
fmt.Printf("\t%s", k)
|
||||
}
|
||||
fmt.Printf("\tTotal\n")
|
||||
for _, f := range freq {
|
||||
fmt.Printf("%.2f\t", f.Value)
|
||||
total := 0
|
||||
for k := range classes {
|
||||
fmt.Printf("\t%d", f.Frequency[k])
|
||||
total += f.Frequency[k]
|
||||
}
|
||||
fmt.Printf("\t%d\n", total)
|
||||
2,
|
||||
rows,
|
||||
}
|
||||
}
|
||||
|
||||
// Train computes and stores the
|
||||
// Produces a value mapping table
|
||||
// inst: The base.Instances which need discretising
|
||||
// sig: The significance level (e.g. 0.95)
|
||||
@ -316,7 +45,7 @@ func chiMergePrintTable(freq []*FrequencyTableEntry) {
|
||||
// adjacent rows will be merged
|
||||
// precision: internal number of decimal places to round E value to
|
||||
// (useful for verification)
|
||||
func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
|
||||
func chiMerge(inst base.FixedDataGrid, attr base.Attribute, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
|
||||
|
||||
// Parameter sanity checking
|
||||
if !(2 <= minrows) {
|
||||
@ -329,12 +58,17 @@ func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows
|
||||
sig = 10
|
||||
}
|
||||
|
||||
// Check that the attribute is numeric
|
||||
_, ok := attr.(*base.FloatAttribute)
|
||||
if !ok {
|
||||
panic("only use Chi-M on numeric stuff")
|
||||
}
|
||||
|
||||
// Build a frequency table
|
||||
freq := ChiMBuildFrequencyTable(attr, inst)
|
||||
// Count the number of classes
|
||||
classes := chiCountClasses(freq)
|
||||
for {
|
||||
// chiMergePrintTable(freq) DEBUG
|
||||
if len(freq) <= minrows {
|
||||
break
|
||||
}
|
||||
@ -378,3 +112,77 @@ func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows
|
||||
}
|
||||
return freq
|
||||
}
|
||||
|
||||
func (c *ChiMergeFilter) Train() error {
|
||||
as := c.getAttributeSpecs()
|
||||
|
||||
for _, a := range as {
|
||||
|
||||
attr := a.GetAttribute()
|
||||
|
||||
// Skip if not set
|
||||
if !c.attrs[attr] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build sort order
|
||||
sortOrder := []base.AttributeSpec{a}
|
||||
|
||||
// Sort
|
||||
sorted, err := base.LazySort(c.train, base.Ascending, sortOrder)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Perform ChiMerge
|
||||
freq := chiMerge(sorted, attr, c.Significance, c.MinRows, c.MaxRows)
|
||||
c.tables[attr] = freq
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAttributesAfterFiltering gets a list of before/after
|
||||
// Attributes as base.FilteredAttributes
|
||||
func (c *ChiMergeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||
oldAttrs := c.train.AllAttributes()
|
||||
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
||||
for i, a := range oldAttrs {
|
||||
if c.attrs[a] {
|
||||
retAttr := new(base.CategoricalAttribute)
|
||||
retAttr.SetName(a.GetName())
|
||||
for _, k := range c.tables[a] {
|
||||
retAttr.GetSysValFromString(fmt.Sprintf("%f", k.Value))
|
||||
}
|
||||
ret[i] = base.FilteredAttribute{a, retAttr}
|
||||
} else {
|
||||
ret[i] = base.FilteredAttribute{a, a}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Transform returns the byte sequence after discretisation
|
||||
func (c *ChiMergeFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
|
||||
// Do we use this Attribute?
|
||||
if !c.attrs[a] {
|
||||
return field
|
||||
}
|
||||
// Find the Attribute value in the table
|
||||
table := c.tables[a]
|
||||
dis := 0
|
||||
val := base.UnpackBytesToFloat(field)
|
||||
for j, k := range table {
|
||||
if k.Value < val {
|
||||
dis = j
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return base.PackU64ToBytes(uint64(dis))
|
||||
}
|
||||
|
||||
func (c *ChiMergeFilter) String() string {
|
||||
return fmt.Sprintf("ChiMergeFilter(%d Attributes, %.2f Significance)", len(c.tables), c.Significance)
|
||||
}
|
||||
|
14
filters/chimerge_freq.go
Normal file
14
filters/chimerge_freq.go
Normal file
@ -0,0 +1,14 @@
|
||||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type FrequencyTableEntry struct {
|
||||
Value float64
|
||||
Frequency map[string]int
|
||||
}
|
||||
|
||||
func (t *FrequencyTableEntry) String() string {
|
||||
return fmt.Sprintf("%.2f %s", t.Value, t.Frequency)
|
||||
}
|
205
filters/chimerge_funcs.go
Normal file
205
filters/chimerge_funcs.go
Normal file
@ -0,0 +1,205 @@
|
||||
package filters
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
func ChiMBuildFrequencyTable(attr base.Attribute, inst base.FixedDataGrid) []*FrequencyTableEntry {
|
||||
ret := make([]*FrequencyTableEntry, 0)
|
||||
attribute := attr.(*base.FloatAttribute)
|
||||
|
||||
attrSpec, err := inst.GetAttribute(attr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
attrSpecs := []base.AttributeSpec{attrSpec}
|
||||
|
||||
err = inst.MapOverRows(attrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
value := row[0]
|
||||
valueConv := attribute.GetFloatFromSysVal(value)
|
||||
class := base.GetClass(inst, rowNo)
|
||||
// Search the frequency table for the value
|
||||
found := false
|
||||
for _, entry := range ret {
|
||||
if entry.Value == valueConv {
|
||||
found = true
|
||||
entry.Frequency[class] += 1
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
newEntry := &FrequencyTableEntry{
|
||||
valueConv,
|
||||
make(map[string]int),
|
||||
}
|
||||
newEntry.Frequency[class] = 1
|
||||
ret = append(ret, newEntry)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func chiSquaredPdf(k float64, x float64) float64 {
|
||||
if x < 0 {
|
||||
return 0
|
||||
}
|
||||
top := math.Pow(x, (k/2)-1) * math.Exp(-x/2)
|
||||
bottom := math.Pow(2, k/2) * math.Gamma(k/2)
|
||||
return top / bottom
|
||||
}
|
||||
|
||||
func chiSquaredPercentile(k int, x float64) float64 {
|
||||
// Implements Yahya et al.'s "A Numerical Procedure
|
||||
// for Computing Chi-Square Percentage Points"
|
||||
// InterStat Journal 01/2007; April 25:page:1-8.
|
||||
steps := 32
|
||||
intervals := 4 * steps
|
||||
w := x / (4.0 * float64(steps))
|
||||
values := make([]float64, intervals+1)
|
||||
for i := 0; i < intervals+1; i++ {
|
||||
c := w * float64(i)
|
||||
v := chiSquaredPdf(float64(k), c)
|
||||
values[i] = v
|
||||
}
|
||||
|
||||
ret1 := values[0] + values[len(values)-1]
|
||||
ret2 := 0.0
|
||||
ret3 := 0.0
|
||||
ret4 := 0.0
|
||||
|
||||
for i := 2; i < intervals-1; i += 4 {
|
||||
ret2 += values[i]
|
||||
}
|
||||
|
||||
for i := 4; i < intervals-3; i += 4 {
|
||||
ret3 += values[i]
|
||||
}
|
||||
|
||||
for i := 1; i < intervals; i += 2 {
|
||||
ret4 += values[i]
|
||||
}
|
||||
|
||||
return (2.0 * w / 45) * (7*ret1 + 12*ret2 + 14*ret3 + 32*ret4)
|
||||
}
|
||||
|
||||
func chiCountClasses(entries []*FrequencyTableEntry) map[string]int {
|
||||
classCounter := make(map[string]int)
|
||||
for _, e := range entries {
|
||||
for k := range e.Frequency {
|
||||
classCounter[k] += e.Frequency[k]
|
||||
}
|
||||
}
|
||||
return classCounter
|
||||
}
|
||||
|
||||
func chiComputeStatistic(entry1 *FrequencyTableEntry, entry2 *FrequencyTableEntry) float64 {
|
||||
|
||||
// Sum the number of things observed per class
|
||||
classCounter := make(map[string]int)
|
||||
for k := range entry1.Frequency {
|
||||
classCounter[k] += entry1.Frequency[k]
|
||||
}
|
||||
for k := range entry2.Frequency {
|
||||
classCounter[k] += entry2.Frequency[k]
|
||||
}
|
||||
|
||||
// Sum the number of things observed per value
|
||||
entryObservations1 := 0
|
||||
entryObservations2 := 0
|
||||
for k := range entry1.Frequency {
|
||||
entryObservations1 += entry1.Frequency[k]
|
||||
}
|
||||
for k := range entry2.Frequency {
|
||||
entryObservations2 += entry2.Frequency[k]
|
||||
}
|
||||
|
||||
totalObservations := entryObservations1 + entryObservations2
|
||||
// Compute the expected values per class
|
||||
expectedClassValues1 := make(map[string]float64)
|
||||
expectedClassValues2 := make(map[string]float64)
|
||||
for k := range classCounter {
|
||||
expectedClassValues1[k] = float64(classCounter[k])
|
||||
expectedClassValues1[k] *= float64(entryObservations1)
|
||||
expectedClassValues1[k] /= float64(totalObservations)
|
||||
}
|
||||
for k := range classCounter {
|
||||
expectedClassValues2[k] = float64(classCounter[k])
|
||||
expectedClassValues2[k] *= float64(entryObservations2)
|
||||
expectedClassValues2[k] /= float64(totalObservations)
|
||||
}
|
||||
|
||||
// Compute chi-squared value
|
||||
chiSum := 0.0
|
||||
for k := range expectedClassValues1 {
|
||||
numerator := float64(entry1.Frequency[k])
|
||||
numerator -= expectedClassValues1[k]
|
||||
numerator = math.Pow(numerator, 2)
|
||||
denominator := float64(expectedClassValues1[k])
|
||||
if denominator < 0.5 {
|
||||
denominator = 0.5
|
||||
}
|
||||
chiSum += numerator / denominator
|
||||
}
|
||||
for k := range expectedClassValues2 {
|
||||
numerator := float64(entry2.Frequency[k])
|
||||
numerator -= expectedClassValues2[k]
|
||||
numerator = math.Pow(numerator, 2)
|
||||
denominator := float64(expectedClassValues2[k])
|
||||
if denominator < 0.5 {
|
||||
denominator = 0.5
|
||||
}
|
||||
chiSum += numerator / denominator
|
||||
}
|
||||
|
||||
return chiSum
|
||||
}
|
||||
|
||||
func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*FrequencyTableEntry {
|
||||
mergeEntry1 := freq[minIndex]
|
||||
mergeEntry2 := freq[minIndex+1]
|
||||
classCounter := make(map[string]int)
|
||||
for k := range mergeEntry1.Frequency {
|
||||
classCounter[k] += mergeEntry1.Frequency[k]
|
||||
}
|
||||
for k := range mergeEntry2.Frequency {
|
||||
classCounter[k] += mergeEntry2.Frequency[k]
|
||||
}
|
||||
newVal := freq[minIndex].Value
|
||||
newEntry := &FrequencyTableEntry{
|
||||
newVal,
|
||||
classCounter,
|
||||
}
|
||||
lowerSlice := freq
|
||||
upperSlice := freq
|
||||
if minIndex > 0 {
|
||||
lowerSlice = freq[0:minIndex]
|
||||
upperSlice = freq[minIndex+1:]
|
||||
} else {
|
||||
lowerSlice = make([]*FrequencyTableEntry, 0)
|
||||
upperSlice = freq[1:]
|
||||
}
|
||||
upperSlice[0] = newEntry
|
||||
freq = append(lowerSlice, upperSlice...)
|
||||
return freq
|
||||
}
|
||||
|
||||
func chiMergePrintTable(freq []*FrequencyTableEntry) {
|
||||
classes := chiCountClasses(freq)
|
||||
fmt.Printf("Attribute value\t")
|
||||
for k := range classes {
|
||||
fmt.Printf("\t%s", k)
|
||||
}
|
||||
fmt.Printf("\tTotal\n")
|
||||
for _, f := range freq {
|
||||
fmt.Printf("%.2f\t", f.Value)
|
||||
total := 0
|
||||
for k := range classes {
|
||||
fmt.Printf("\t%d", f.Frequency[k])
|
||||
total += f.Frequency[k]
|
||||
}
|
||||
fmt.Printf("\t%d\n", total)
|
||||
}
|
||||
}
|
@ -14,7 +14,7 @@ func TestChiMFreqTable(testEnv *testing.T) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
freq := ChiMBuildFrequencyTable(0, inst)
|
||||
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||
|
||||
if freq[0].Frequency["c1"] != 1 {
|
||||
testEnv.Error("Wrong frequency")
|
||||
@ -32,7 +32,7 @@ func TestChiClassCounter(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
freq := ChiMBuildFrequencyTable(0, inst)
|
||||
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||
classes := chiCountClasses(freq)
|
||||
if classes["c1"] != 27 {
|
||||
testEnv.Error(classes)
|
||||
@ -50,7 +50,7 @@ func TestStatisticValues(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
freq := ChiMBuildFrequencyTable(0, inst)
|
||||
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
|
||||
chiVal := chiComputeStatistic(freq[5], freq[6])
|
||||
if math.Abs(chiVal-1.89) > 0.01 {
|
||||
testEnv.Error(chiVal)
|
||||
@ -78,12 +78,15 @@ func TestChiSquareDistValues(testEnv *testing.T) {
|
||||
}
|
||||
|
||||
func TestChiMerge1(testEnv *testing.T) {
|
||||
// See Bramer, Principles of Machine Learning
|
||||
|
||||
// Read the data
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
freq := chiMerge(inst, 0, 0.90, 0, inst.Rows)
|
||||
_, rows := inst.Size()
|
||||
|
||||
freq := chiMerge(inst, inst.AllAttributes()[0], 0.90, 0, rows)
|
||||
if len(freq) != 3 {
|
||||
testEnv.Error("Wrong length")
|
||||
}
|
||||
@ -106,10 +109,18 @@ func TestChiMerge2(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
attrs := make([]int, 1)
|
||||
attrs[0] = 0
|
||||
inst.Sort(base.Ascending, attrs)
|
||||
freq := chiMerge(inst, 0, 0.90, 0, inst.Rows)
|
||||
|
||||
// Sort the instances
|
||||
allAttrs := inst.AllAttributes()
|
||||
sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
|
||||
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Perform Chi-Merge
|
||||
_, rows := inst.Size()
|
||||
freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows)
|
||||
if len(freq) != 5 {
|
||||
testEnv.Errorf("Wrong length (%d)", len(freq))
|
||||
testEnv.Error(freq)
|
||||
@ -131,6 +142,7 @@ func TestChiMerge2(testEnv *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestChiMerge3(testEnv *testing.T) {
|
||||
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
||||
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
||||
@ -138,12 +150,52 @@ func TestChiMerge3(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
attrs := make([]int, 1)
|
||||
attrs[0] = 0
|
||||
inst.Sort(base.Ascending, attrs)
|
||||
|
||||
insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes()))
|
||||
if err != nil {
|
||||
testEnv.Error(err)
|
||||
}
|
||||
filt := NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAttribute(inst.GetAttr(0))
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
filt.AddAttribute(inst.AllAttributes()[0])
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(insts, filt)
|
||||
fmt.Println(instf)
|
||||
fmt.Println(instf.String())
|
||||
rowStr := instf.RowString(0)
|
||||
ref := "4.300000 3.00 1.10 0.10 Iris-setosa"
|
||||
if rowStr != ref {
|
||||
panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref))
|
||||
}
|
||||
clsAttrs := instf.AllClassAttributes()
|
||||
if len(clsAttrs) != 1 {
|
||||
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
|
||||
}
|
||||
if clsAttrs[0].GetName() != "Species" {
|
||||
panic("Class Attribute wrong!")
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func TestChiMerge4(testEnv *testing.T) {
|
||||
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
|
||||
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
filt := NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAttribute(inst.AllAttributes()[0])
|
||||
filt.AddAttribute(inst.AllAttributes()[1])
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
fmt.Println(instf)
|
||||
fmt.Println(instf.String())
|
||||
clsAttrs := instf.AllClassAttributes()
|
||||
if len(clsAttrs) != 1 {
|
||||
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
|
||||
}
|
||||
if clsAttrs[0].GetName() != "Species" {
|
||||
panic("Class Attribute wrong!")
|
||||
}
|
||||
}
|
||||
|
62
filters/disc.go
Normal file
62
filters/disc.go
Normal file
@ -0,0 +1,62 @@
|
||||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
type AbstractDiscretizeFilter struct {
|
||||
attrs map[base.Attribute]bool
|
||||
trained bool
|
||||
train base.FixedDataGrid
|
||||
}
|
||||
|
||||
// AddAttribute adds the AttributeSpec of the given attribute `a'
|
||||
// to the AbstractFloatFilter for discretisation.
|
||||
func (d *AbstractDiscretizeFilter) AddAttribute(a base.Attribute) error {
|
||||
if _, ok := a.(*base.FloatAttribute); !ok {
|
||||
return fmt.Errorf("%s is not a FloatAttribute", a)
|
||||
}
|
||||
_, err := d.train.GetAttribute(a)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid attribute")
|
||||
}
|
||||
d.attrs[a] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAttributesAfterFiltering gets a list of before/after
|
||||
// Attributes as base.FilteredAttributes
|
||||
func (d *AbstractDiscretizeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
||||
oldAttrs := d.train.AllAttributes()
|
||||
ret := make([]base.FilteredAttribute, len(oldAttrs))
|
||||
for i, a := range oldAttrs {
|
||||
if d.attrs[a] {
|
||||
retAttr := new(base.CategoricalAttribute)
|
||||
retAttr.SetName(a.GetName())
|
||||
ret[i] = base.FilteredAttribute{a, retAttr}
|
||||
} else {
|
||||
ret[i] = base.FilteredAttribute{a, a}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (d *AbstractDiscretizeFilter) getAttributeSpecs() []base.AttributeSpec {
|
||||
as := make([]base.AttributeSpec, 0)
|
||||
// Set up the AttributeSpecs, and values
|
||||
for attr := range d.attrs {
|
||||
// If for some reason we've un-added it...
|
||||
if !d.attrs[attr] {
|
||||
continue
|
||||
}
|
||||
// Get the AttributeSpec for the training set
|
||||
a, err := d.train.GetAttribute(attr)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Attribute resolution error: %s", err))
|
||||
}
|
||||
// Append to return set
|
||||
as = append(as, a)
|
||||
}
|
||||
return as
|
||||
}
|
126
knn/knn.go
126
knn/knn.go
@ -14,7 +14,7 @@ import (
|
||||
// The accepted distance functions at this time are 'euclidean' and 'manhattan'.
|
||||
type KNNClassifier struct {
|
||||
base.BaseEstimator
|
||||
TrainingData *base.Instances
|
||||
TrainingData base.FixedDataGrid
|
||||
DistanceFunc string
|
||||
NearestNeighbours int
|
||||
}
|
||||
@ -28,20 +28,12 @@ func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
|
||||
}
|
||||
|
||||
// Fit stores the training data for later
|
||||
func (KNN *KNNClassifier) Fit(trainingData *base.Instances) {
|
||||
func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) {
|
||||
KNN.TrainingData = trainingData
|
||||
}
|
||||
|
||||
// PredictOne returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
||||
// See http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
|
||||
func (KNN *KNNClassifier) PredictOne(vector []float64) string {
|
||||
|
||||
rows := KNN.TrainingData.Rows
|
||||
rownumbers := make(map[int]float64)
|
||||
labels := make([]string, 0)
|
||||
maxmap := make(map[string]int)
|
||||
|
||||
convertedVector := util.FloatsToMatrix(vector)
|
||||
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
||||
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
|
||||
// Check what distance function we are using
|
||||
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
|
||||
@ -52,40 +44,96 @@ func (KNN *KNNClassifier) PredictOne(vector []float64) string {
|
||||
distanceFunc = pairwiseMetrics.NewManhattan()
|
||||
default:
|
||||
panic("unsupported distance function")
|
||||
|
||||
}
|
||||
// Check compatability
|
||||
allAttrs := base.CheckCompatable(what, KNN.TrainingData)
|
||||
if allAttrs == nil {
|
||||
// Don't have the same Attributes
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
distance := distanceFunc.Distance(rowMat, convertedVector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
|
||||
sorted := util.SortIntMap(rownumbers)
|
||||
values := sorted[:KNN.NearestNeighbours]
|
||||
|
||||
for _, elem := range values {
|
||||
label := KNN.TrainingData.GetClass(elem)
|
||||
labels = append(labels, label)
|
||||
|
||||
if _, ok := maxmap[label]; ok {
|
||||
maxmap[label]++
|
||||
} else {
|
||||
maxmap[label] = 1
|
||||
// Remove the Attributes which aren't numeric
|
||||
allNumericAttrs := make([]base.Attribute, 0)
|
||||
for _, a := range allAttrs {
|
||||
if fAttr, ok := a.(*base.FloatAttribute); ok {
|
||||
allNumericAttrs = append(allNumericAttrs, fAttr)
|
||||
}
|
||||
}
|
||||
|
||||
sortedlabels := util.SortStringMap(maxmap)
|
||||
label := sortedlabels[0]
|
||||
// Generate return vector
|
||||
ret := base.GeneratePredictionVector(what)
|
||||
|
||||
return label
|
||||
}
|
||||
// Resolve Attribute specifications for both
|
||||
whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
|
||||
trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)
|
||||
|
||||
// Reserve storage for most the most similar items
|
||||
distances := make(map[int]float64)
|
||||
|
||||
// Reserve storage for voting map
|
||||
maxmap := make(map[string]int)
|
||||
|
||||
// Reserve storage for row computations
|
||||
trainRowBuf := make([]float64, len(allNumericAttrs))
|
||||
predRowBuf := make([]float64, len(allNumericAttrs))
|
||||
|
||||
// Iterate over all outer rows
|
||||
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
|
||||
// Read the float values out
|
||||
for i, _ := range allNumericAttrs {
|
||||
predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
|
||||
}
|
||||
|
||||
predMat := util.FloatsToMatrix(predRowBuf)
|
||||
|
||||
// Find the closest match in the training data
|
||||
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
||||
|
||||
// Read the float values out
|
||||
for i, _ := range allNumericAttrs {
|
||||
trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
|
||||
}
|
||||
|
||||
// Compute the distance
|
||||
trainMat := util.FloatsToMatrix(trainRowBuf)
|
||||
distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
|
||||
return true, nil
|
||||
})
|
||||
|
||||
sorted := util.SortIntMap(distances)
|
||||
values := sorted[:KNN.NearestNeighbours]
|
||||
|
||||
// Reset maxMap
|
||||
for a := range maxmap {
|
||||
maxmap[a] = 0
|
||||
}
|
||||
|
||||
// Refresh maxMap
|
||||
for _, elem := range values {
|
||||
label := base.GetClass(KNN.TrainingData, elem)
|
||||
if _, ok := maxmap[label]; ok {
|
||||
maxmap[label]++
|
||||
} else {
|
||||
maxmap[label] = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the maxMap
|
||||
var maxClass string
|
||||
maxVal := -1
|
||||
for a := range maxmap {
|
||||
if maxmap[a] > maxVal {
|
||||
maxVal = maxmap[a]
|
||||
maxClass = a
|
||||
}
|
||||
}
|
||||
|
||||
base.SetClass(ret, predRowNo, maxClass)
|
||||
return true, nil
|
||||
|
||||
})
|
||||
|
||||
func (KNN *KNNClassifier) Predict(what *base.Instances) *base.Instances {
|
||||
ret := what.GeneratePredictionVector()
|
||||
for i := 0; i < what.Rows; i++ {
|
||||
ret.SetAttrStr(i, 0, KNN.PredictOne(what.GetRowVectorWithoutClass(i)))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
|
@ -24,16 +24,17 @@ func TestKnnClassifier(t *testing.T) {
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls.Fit(trainingData)
|
||||
predictions := cls.Predict(testingData)
|
||||
So(predictions, ShouldNotEqual, nil)
|
||||
|
||||
Convey("When predicting the label for our first vector", func() {
|
||||
result := predictions.GetClass(0)
|
||||
result := base.GetClass(predictions, 0)
|
||||
Convey("The result should be 'blue", func() {
|
||||
So(result, ShouldEqual, "blue")
|
||||
})
|
||||
})
|
||||
|
||||
Convey("When predicting the label for our first vector", func() {
|
||||
result2 := predictions.GetClass(1)
|
||||
result2 := base.GetClass(predictions, 1)
|
||||
Convey("The result should be 'red", func() {
|
||||
So(result2, ShouldEqual, "red")
|
||||
})
|
||||
|
@ -8,23 +8,26 @@ import (
|
||||
|
||||
func TestLogisticRegression(t *testing.T) {
|
||||
Convey("Given labels, a classifier and data", t, func() {
|
||||
// Load data
|
||||
X, err := base.ParseCSVToInstances("train.csv", false)
|
||||
So(err, ShouldEqual, nil)
|
||||
Y, err := base.ParseCSVToInstances("test.csv", false)
|
||||
So(err, ShouldEqual, nil)
|
||||
|
||||
// Setup the problem
|
||||
lr := NewLogisticRegression("l2", 1.0, 1e-6)
|
||||
lr.Fit(X)
|
||||
|
||||
Convey("When predicting the label of first vector", func() {
|
||||
Z := lr.Predict(Y)
|
||||
Convey("The result should be 1", func() {
|
||||
So(Z.Get(0, 0), ShouldEqual, 1.0)
|
||||
So(Z.RowString(0), ShouldEqual, "1.00")
|
||||
})
|
||||
})
|
||||
Convey("When predicting the label of second vector", func() {
|
||||
Z := lr.Predict(Y)
|
||||
Convey("The result should be -1", func() {
|
||||
So(Z.Get(1, 0), ShouldEqual, -1.0)
|
||||
So(Z.RowString(1), ShouldEqual, "-1.00")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
|
||||
"fmt"
|
||||
_ "github.com/gonum/blas"
|
||||
"github.com/gonum/blas/cblas"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
@ -19,6 +20,8 @@ type LinearRegression struct {
|
||||
fitted bool
|
||||
disturbance float64
|
||||
regressionCoefficients []float64
|
||||
attrs []base.Attribute
|
||||
cls base.Attribute
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -29,31 +32,59 @@ func NewLinearRegression() *LinearRegression {
|
||||
return &LinearRegression{fitted: false}
|
||||
}
|
||||
|
||||
func (lr *LinearRegression) Fit(inst *base.Instances) error {
|
||||
if inst.Rows < inst.GetAttributeCount() {
|
||||
return NotEnoughDataError
|
||||
func (lr *LinearRegression) Fit(inst base.FixedDataGrid) error {
|
||||
|
||||
// Retrieve row size
|
||||
_, rows := inst.Size()
|
||||
|
||||
// Validate class Attribute count
|
||||
classAttrs := inst.AllClassAttributes()
|
||||
if len(classAttrs) != 1 {
|
||||
return fmt.Errorf("Only 1 class variable is permitted")
|
||||
}
|
||||
classAttrSpecs := base.ResolveAttributes(inst, classAttrs)
|
||||
|
||||
// Split into two matrices, observed results (dependent variable y)
|
||||
// and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
|
||||
observed := mat64.NewDense(inst.Rows, 1, nil)
|
||||
explVariables := mat64.NewDense(inst.Rows, inst.GetAttributeCount(), nil)
|
||||
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
observed.Set(i, 0, inst.Get(i, inst.ClassIndex)) // Set observed data
|
||||
|
||||
for j := 0; j < inst.GetAttributeCount(); j++ {
|
||||
if j == 0 {
|
||||
// Set intercepts to 1.0
|
||||
// Could / should be done better: http://www.theanalysisfactor.com/interpret-the-intercept/
|
||||
explVariables.Set(i, 0, 1.0)
|
||||
} else {
|
||||
explVariables.Set(i, j, inst.Get(i, j-1))
|
||||
}
|
||||
// Retrieve relevant Attributes
|
||||
allAttrs := base.NonClassAttributes(inst)
|
||||
attrs := make([]base.Attribute, 0)
|
||||
for _, a := range allAttrs {
|
||||
if _, ok := a.(*base.FloatAttribute); ok {
|
||||
attrs = append(attrs, a)
|
||||
}
|
||||
}
|
||||
|
||||
n := inst.GetAttributeCount()
|
||||
cols := len(attrs) + 1
|
||||
|
||||
if rows < cols {
|
||||
return NotEnoughDataError
|
||||
}
|
||||
|
||||
// Retrieve relevant Attribute specifications
|
||||
attrSpecs := base.ResolveAttributes(inst, attrs)
|
||||
|
||||
// Split into two matrices, observed results (dependent variable y)
|
||||
// and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
|
||||
observed := mat64.NewDense(rows, 1, nil)
|
||||
explVariables := mat64.NewDense(rows, cols, nil)
|
||||
|
||||
// Build the observed matrix
|
||||
inst.MapOverRows(classAttrSpecs, func(row [][]byte, i int) (bool, error) {
|
||||
val := base.UnpackBytesToFloat(row[0])
|
||||
observed.Set(i, 0, val)
|
||||
return true, nil
|
||||
})
|
||||
|
||||
// Build the explainatory variables
|
||||
inst.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
|
||||
// Set intercepts to 1.0
|
||||
explVariables.Set(i, 0, 1.0)
|
||||
for j, r := range row {
|
||||
explVariables.Set(i, j+1, base.UnpackBytesToFloat(r))
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
n := cols
|
||||
qr := mat64.QR(explVariables)
|
||||
q := qr.Q()
|
||||
reg := qr.R()
|
||||
@ -74,25 +105,32 @@ func (lr *LinearRegression) Fit(inst *base.Instances) error {
|
||||
lr.disturbance = regressionCoefficients[0]
|
||||
lr.regressionCoefficients = regressionCoefficients[1:]
|
||||
lr.fitted = true
|
||||
|
||||
lr.attrs = attrs
|
||||
lr.cls = classAttrs[0]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lr *LinearRegression) Predict(X *base.Instances) (*base.Instances, error) {
|
||||
func (lr *LinearRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
if !lr.fitted {
|
||||
return nil, NoTrainingDataError
|
||||
}
|
||||
|
||||
ret := X.GeneratePredictionVector()
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
var prediction float64 = lr.disturbance
|
||||
for j := 0; j < X.Cols; j++ {
|
||||
if j != X.ClassIndex {
|
||||
prediction += X.Get(i, j) * lr.regressionCoefficients[j]
|
||||
}
|
||||
}
|
||||
ret.Set(i, 0, prediction)
|
||||
ret := base.GeneratePredictionVector(X)
|
||||
attrSpecs := base.ResolveAttributes(X, lr.attrs)
|
||||
clsSpec, err := ret.GetAttribute(lr.cls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
X.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
|
||||
var prediction float64 = lr.disturbance
|
||||
for j, r := range row {
|
||||
prediction += base.UnpackBytesToFloat(r) * lr.regressionCoefficients[j]
|
||||
}
|
||||
|
||||
ret.Set(clsSpec, i, base.PackFloatToBytes(prediction))
|
||||
return true, nil
|
||||
})
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
@ -54,8 +54,10 @@ func TestLinearRegression(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < predictions.Rows; i++ {
|
||||
fmt.Printf("Expected: %f || Predicted: %f\n", testData.Get(i, testData.ClassIndex), predictions.Get(i, predictions.ClassIndex))
|
||||
_, rows := predictions.Size()
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -27,51 +27,85 @@ func NewLogisticRegression(penalty string, C float64, eps float64) *LogisticRegr
|
||||
return &lr
|
||||
}
|
||||
|
||||
func convertInstancesToProblemVec(X *base.Instances) [][]float64 {
|
||||
problemVec := make([][]float64, X.Rows)
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
problemVecCounter := 0
|
||||
problemVec[i] = make([]float64, X.Cols-1)
|
||||
for j := 0; j < X.Cols; j++ {
|
||||
if j == X.ClassIndex {
|
||||
continue
|
||||
}
|
||||
problemVec[i][problemVecCounter] = X.Get(i, j)
|
||||
problemVecCounter++
|
||||
func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
||||
// Allocate problem array
|
||||
_, rows := X.Size()
|
||||
problemVec := make([][]float64, rows)
|
||||
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
|
||||
// Convert each row
|
||||
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
// Allocate a new row
|
||||
probRow := make([]float64, len(numericAttrSpecs))
|
||||
// Read out the row
|
||||
for i, _ := range numericAttrSpecs {
|
||||
probRow[i] = base.UnpackBytesToFloat(row[i])
|
||||
}
|
||||
}
|
||||
base.Logger.Println(problemVec, X)
|
||||
// Add the row
|
||||
problemVec[rowNo] = probRow
|
||||
return true, nil
|
||||
})
|
||||
return problemVec
|
||||
}
|
||||
|
||||
func convertInstancesToLabelVec(X *base.Instances) []float64 {
|
||||
labelVec := make([]float64, X.Rows)
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
labelVec[i] = X.Get(i, X.ClassIndex)
|
||||
func convertInstancesToLabelVec(X base.FixedDataGrid) []float64 {
|
||||
// Get the class Attributes
|
||||
classAttrs := X.AllClassAttributes()
|
||||
// Only support 1 class Attribute
|
||||
if len(classAttrs) != 1 {
|
||||
panic(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
|
||||
}
|
||||
// ClassAttribute must be numeric
|
||||
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
|
||||
panic(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
|
||||
}
|
||||
// Allocate return structure
|
||||
_, rows := X.Size()
|
||||
labelVec := make([]float64, rows)
|
||||
// Resolve class Attribute specification
|
||||
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
||||
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
|
||||
return true, nil
|
||||
})
|
||||
return labelVec
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) Fit(X *base.Instances) {
|
||||
func (lr *LogisticRegression) Fit(X base.FixedDataGrid) {
|
||||
problemVec := convertInstancesToProblemVec(X)
|
||||
labelVec := convertInstancesToLabelVec(X)
|
||||
prob := NewProblem(problemVec, labelVec, 0)
|
||||
lr.model = Train(prob, lr.param)
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances {
|
||||
ret := X.GeneratePredictionVector()
|
||||
row := make([]float64, X.Cols-1)
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
rowCounter := 0
|
||||
for j := 0; j < X.Cols; j++ {
|
||||
if j != X.ClassIndex {
|
||||
row[rowCounter] = X.Get(i, j)
|
||||
rowCounter++
|
||||
}
|
||||
}
|
||||
base.Logger.Println(Predict(lr.model, row), row)
|
||||
ret.Set(i, 0, Predict(lr.model, row))
|
||||
func (lr *LogisticRegression) Predict(X base.FixedDataGrid) base.FixedDataGrid {
|
||||
|
||||
// Only support 1 class Attribute
|
||||
classAttrs := X.AllClassAttributes()
|
||||
if len(classAttrs) != 1 {
|
||||
panic(fmt.Sprintf("%d Wrong number of classes", len(classAttrs)))
|
||||
}
|
||||
// Generate return structure
|
||||
ret := base.GeneratePredictionVector(X)
|
||||
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
|
||||
// Allocate row storage
|
||||
row := make([]float64, len(numericAttrSpecs))
|
||||
X.MapOverRows(numericAttrSpecs, func(rowBytes [][]byte, rowNo int) (bool, error) {
|
||||
for i, r := range rowBytes {
|
||||
row[i] = base.UnpackBytesToFloat(r)
|
||||
}
|
||||
val := Predict(lr.model, row)
|
||||
vals := base.PackFloatToBytes(val)
|
||||
ret.Set(classAttrSpecs[0], rowNo, vals)
|
||||
return true, nil
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
@ -21,23 +21,18 @@ type BaggedModel struct {
|
||||
|
||||
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
|
||||
// the provided base.Instances.
|
||||
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
|
||||
func (b *BaggedModel) generateTrainingAttrs(model int, from base.FixedDataGrid) []base.Attribute {
|
||||
ret := make([]base.Attribute, 0)
|
||||
attrs := base.NonClassAttributes(from)
|
||||
if b.RandomFeatures == 0 {
|
||||
for j := 0; j < from.Cols; j++ {
|
||||
attr := from.GetAttr(j)
|
||||
ret = append(ret, attr)
|
||||
}
|
||||
ret = attrs
|
||||
} else {
|
||||
for {
|
||||
if len(ret) >= b.RandomFeatures {
|
||||
break
|
||||
}
|
||||
attrIndex := rand.Intn(from.Cols)
|
||||
if attrIndex == from.ClassIndex {
|
||||
continue
|
||||
}
|
||||
attr := from.GetAttr(attrIndex)
|
||||
attrIndex := rand.Intn(len(attrs))
|
||||
attr := attrs[attrIndex]
|
||||
matched := false
|
||||
for _, a := range ret {
|
||||
if a.Equals(attr) {
|
||||
@ -50,7 +45,9 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
|
||||
}
|
||||
}
|
||||
}
|
||||
ret = append(ret, from.GetClassAttr())
|
||||
for _, a := range from.AllClassAttributes() {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
b.lock.Lock()
|
||||
b.selectedAttributes[model] = ret
|
||||
b.lock.Unlock()
|
||||
@ -60,18 +57,19 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
|
||||
// generatePredictionInstances returns a modified version of the
|
||||
// requested base.Instances with only the base.Attributes selected
|
||||
// for training the model.
|
||||
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
|
||||
func (b *BaggedModel) generatePredictionInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
||||
selected := b.selectedAttributes[model]
|
||||
return from.SelectAttributes(selected)
|
||||
return base.NewInstancesViewFromAttrs(from, selected)
|
||||
}
|
||||
|
||||
// generateTrainingInstances generates RandomFeatures number of
|
||||
// attributes and returns a modified version of base.Instances
|
||||
// for training the model
|
||||
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
|
||||
insts := from.SampleWithReplacement(from.Rows)
|
||||
func (b *BaggedModel) generateTrainingInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
||||
_, rows := from.Size()
|
||||
insts := base.SampleWithReplacement(from, rows)
|
||||
selected := b.generateTrainingAttrs(model, from)
|
||||
return insts.SelectAttributes(selected)
|
||||
return base.NewInstancesViewFromAttrs(insts, selected)
|
||||
}
|
||||
|
||||
// AddModel adds a base.Classifier to the current model
|
||||
@ -81,12 +79,12 @@ func (b *BaggedModel) AddModel(m base.Classifier) {
|
||||
|
||||
// Fit generates and trains each model on a randomised subset of
|
||||
// Instances.
|
||||
func (b *BaggedModel) Fit(from *base.Instances) {
|
||||
func (b *BaggedModel) Fit(from base.FixedDataGrid) {
|
||||
var wait sync.WaitGroup
|
||||
b.selectedAttributes = make(map[int][]base.Attribute)
|
||||
for i, m := range b.Models {
|
||||
wait.Add(1)
|
||||
go func(c base.Classifier, f *base.Instances, model int) {
|
||||
go func(c base.Classifier, f base.FixedDataGrid, model int) {
|
||||
l := b.generateTrainingInstances(model, f)
|
||||
c.Fit(l)
|
||||
wait.Done()
|
||||
@ -100,10 +98,10 @@ func (b *BaggedModel) Fit(from *base.Instances) {
|
||||
//
|
||||
// IMPORTANT: in the event of a tie, the first class which
|
||||
// achieved the tie value is output.
|
||||
func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
||||
n := runtime.NumCPU()
|
||||
// Channel to receive the results as they come in
|
||||
votes := make(chan *base.Instances, n)
|
||||
votes := make(chan base.DataGrid, n)
|
||||
// Count the votes for each class
|
||||
voting := make(map[int](map[string]int))
|
||||
|
||||
@ -111,21 +109,20 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
var votingwait sync.WaitGroup
|
||||
votingwait.Add(1)
|
||||
go func() {
|
||||
for {
|
||||
for { // Need to resolve the voting problem
|
||||
incoming, ok := <-votes
|
||||
if ok {
|
||||
// Step through each prediction
|
||||
for j := 0; j < incoming.Rows; j++ {
|
||||
cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
|
||||
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
|
||||
// Check if we've seen this class before...
|
||||
if _, ok := voting[j]; !ok {
|
||||
if _, ok := voting[predRow]; !ok {
|
||||
// If we haven't, create an entry
|
||||
voting[j] = make(map[string]int)
|
||||
voting[predRow] = make(map[string]int)
|
||||
// Continue on the current row
|
||||
j--
|
||||
continue
|
||||
}
|
||||
voting[j][incoming.GetClass(j)]++
|
||||
}
|
||||
voting[predRow][base.GetClass(incoming, predRow)]++
|
||||
return true, nil
|
||||
})
|
||||
} else {
|
||||
votingwait.Done()
|
||||
break
|
||||
@ -162,7 +159,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
votingwait.Wait() // All the votes are in
|
||||
|
||||
// Generate the overall consensus
|
||||
ret := from.GeneratePredictionVector()
|
||||
ret := base.GeneratePredictionVector(from)
|
||||
for i := range voting {
|
||||
maxClass := ""
|
||||
maxCount := 0
|
||||
@ -174,7 +171,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
maxCount = votes
|
||||
}
|
||||
}
|
||||
ret.SetAttrStr(i, 0, maxClass)
|
||||
base.SetClass(ret, i, maxClass)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -19,16 +19,18 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
testEnv.ResetTimer()
|
||||
for i := 0; i < 20; i++ {
|
||||
rf.Fit(inst)
|
||||
rf.Fit(instf)
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,17 +42,19 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
rf.Fit(inst)
|
||||
rf.Fit(instf)
|
||||
testEnv.ResetTimer()
|
||||
for i := 0; i < 20; i++ {
|
||||
rf.Predict(inst)
|
||||
rf.Predict(instf)
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,19 +67,21 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataf := base.NewLazilyFilteredInstances(testData, filt)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
rf.Fit(trainData)
|
||||
rf.Fit(trainDataf)
|
||||
fmt.Println(rf)
|
||||
predictions := rf.Predict(testData)
|
||||
predictions := rf.Predict(testDataf)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
|
13
meta/meta.go
Normal file
13
meta/meta.go
Normal file
@ -0,0 +1,13 @@
|
||||
/*
|
||||
|
||||
Meta contains base.Classifier implementations which
|
||||
combine the outputs of others defined elsewhere.
|
||||
|
||||
Bagging:
|
||||
Bootstraps samples of the original training set
|
||||
with a number of selected attributes, and uses
|
||||
that to train an ensemble of models. Predictions
|
||||
are generated via majority voting.
|
||||
*/
|
||||
|
||||
package meta
|
@ -1,8 +1,9 @@
|
||||
package naive
|
||||
|
||||
import (
|
||||
"math"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math"
|
||||
)
|
||||
|
||||
// A Bernoulli Naive Bayes Classifier. Naive Bayes classifiers assumes
|
||||
@ -37,91 +38,108 @@ import (
|
||||
// Information Retrieval. Cambridge University Press, pp. 234-265.
|
||||
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
||||
type BernoulliNBClassifier struct {
|
||||
base.BaseEstimator
|
||||
// Conditional probability for each term. This vector should be
|
||||
// accessed in the following way: p(f|c) = condProb[c][f].
|
||||
// Logarithm is used in order to avoid underflow.
|
||||
condProb map[string][]float64
|
||||
// Number of instances in each class. This is necessary in order to
|
||||
// calculate the laplace smooth value during the Predict step.
|
||||
classInstances map[string]int
|
||||
// Number of instances used in training.
|
||||
trainingInstances int
|
||||
// Number of features in the training set
|
||||
features int
|
||||
base.BaseEstimator
|
||||
// Conditional probability for each term. This vector should be
|
||||
// accessed in the following way: p(f|c) = condProb[c][f].
|
||||
// Logarithm is used in order to avoid underflow.
|
||||
condProb map[string][]float64
|
||||
// Number of instances in each class. This is necessary in order to
|
||||
// calculate the laplace smooth value during the Predict step.
|
||||
classInstances map[string]int
|
||||
// Number of instances used in training.
|
||||
trainingInstances int
|
||||
// Number of features used in training
|
||||
features int
|
||||
// Attributes used to Train
|
||||
attrs []base.Attribute
|
||||
}
|
||||
|
||||
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
|
||||
// is the number of possible labels in the classification task.
|
||||
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
|
||||
nb := BernoulliNBClassifier{}
|
||||
nb.condProb = make(map[string][]float64)
|
||||
nb.features = 0
|
||||
nb.trainingInstances = 0
|
||||
return &nb
|
||||
nb := BernoulliNBClassifier{}
|
||||
nb.condProb = make(map[string][]float64)
|
||||
nb.features = 0
|
||||
nb.trainingInstances = 0
|
||||
return &nb
|
||||
}
|
||||
|
||||
// Fill data matrix with Bernoulli Naive Bayes model. All values
|
||||
// necessary for calculating prior probability and p(f_i)
|
||||
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
|
||||
|
||||
// Number of features and instances in this training set
|
||||
nb.trainingInstances = X.Rows
|
||||
nb.features = 0
|
||||
if X.Rows > 0 {
|
||||
nb.features = len(X.GetRowVectorWithoutClass(0))
|
||||
}
|
||||
// Check that all Attributes are binary
|
||||
classAttrs := X.AllClassAttributes()
|
||||
allAttrs := X.AllAttributes()
|
||||
featAttrs := base.AttributeDifference(allAttrs, classAttrs)
|
||||
for i := range featAttrs {
|
||||
if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
|
||||
panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
|
||||
}
|
||||
}
|
||||
featAttrSpecs := base.ResolveAttributes(X, featAttrs)
|
||||
|
||||
// Number of instances in class
|
||||
nb.classInstances = make(map[string]int)
|
||||
// Check that only one classAttribute is defined
|
||||
if len(classAttrs) != 1 {
|
||||
panic("Only one class Attribute can be used")
|
||||
}
|
||||
|
||||
// Number of documents with given term (by class)
|
||||
docsContainingTerm := make(map[string][]int)
|
||||
// Number of features and instances in this training set
|
||||
_, nb.trainingInstances = X.Size()
|
||||
nb.attrs = featAttrs
|
||||
nb.features = len(featAttrs)
|
||||
|
||||
// This algorithm could be vectorized after binarizing the data
|
||||
// matrix. Since mat64 doesn't have this function, a iterative
|
||||
// version is used.
|
||||
for r := 0; r < X.Rows; r++ {
|
||||
class := X.GetClass(r)
|
||||
docVector := X.GetRowVectorWithoutClass(r)
|
||||
// Number of instances in class
|
||||
nb.classInstances = make(map[string]int)
|
||||
|
||||
// increment number of instances in class
|
||||
t, ok := nb.classInstances[class]
|
||||
if !ok { t = 0 }
|
||||
nb.classInstances[class] = t + 1
|
||||
// Number of documents with given term (by class)
|
||||
docsContainingTerm := make(map[string][]int)
|
||||
|
||||
// This algorithm could be vectorized after binarizing the data
|
||||
// matrix. Since mat64 doesn't have this function, a iterative
|
||||
// version is used.
|
||||
X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
|
||||
class := base.GetClass(X, r)
|
||||
|
||||
for feat := 0; feat < len(docVector); feat++ {
|
||||
v := docVector[feat]
|
||||
// In Bernoulli Naive Bayes the presence and absence of
|
||||
// features are considered. All non-zero values are
|
||||
// treated as presence.
|
||||
if v > 0 {
|
||||
// Update number of times this feature appeared within
|
||||
// given label.
|
||||
t, ok := docsContainingTerm[class]
|
||||
if !ok {
|
||||
t = make([]int, nb.features)
|
||||
docsContainingTerm[class] = t
|
||||
}
|
||||
t[feat] += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
// increment number of instances in class
|
||||
t, ok := nb.classInstances[class]
|
||||
if !ok {
|
||||
t = 0
|
||||
}
|
||||
nb.classInstances[class] = t + 1
|
||||
|
||||
// Pre-calculate conditional probabilities for each class
|
||||
for c, _ := range nb.classInstances {
|
||||
nb.condProb[c] = make([]float64, nb.features)
|
||||
for feat := 0; feat < nb.features; feat++ {
|
||||
classTerms, _ := docsContainingTerm[c]
|
||||
numDocs := classTerms[feat]
|
||||
docsInClass, _ := nb.classInstances[c]
|
||||
for feat := 0; feat < len(docVector); feat++ {
|
||||
v := docVector[feat]
|
||||
// In Bernoulli Naive Bayes the presence and absence of
|
||||
// features are considered. All non-zero values are
|
||||
// treated as presence.
|
||||
if v[0] > 0 {
|
||||
// Update number of times this feature appeared within
|
||||
// given label.
|
||||
t, ok := docsContainingTerm[class]
|
||||
if !ok {
|
||||
t = make([]int, nb.features)
|
||||
docsContainingTerm[class] = t
|
||||
}
|
||||
t[feat] += 1
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
classCondProb, _ := nb.condProb[c]
|
||||
// Calculate conditional probability with laplace smoothing
|
||||
classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1)
|
||||
}
|
||||
}
|
||||
// Pre-calculate conditional probabilities for each class
|
||||
for c, _ := range nb.classInstances {
|
||||
nb.condProb[c] = make([]float64, nb.features)
|
||||
for feat := 0; feat < nb.features; feat++ {
|
||||
classTerms, _ := docsContainingTerm[c]
|
||||
numDocs := classTerms[feat]
|
||||
docsInClass, _ := nb.classInstances[c]
|
||||
|
||||
classCondProb, _ := nb.condProb[c]
|
||||
// Calculate conditional probability with laplace smoothing
|
||||
classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use trained model to predict test vector's class. The following
|
||||
@ -133,54 +151,61 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
|
||||
//
|
||||
// IMPORTANT: PredictOne panics if Fit was not called or if the
|
||||
// document vector and train matrix have a different number of columns.
|
||||
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
|
||||
if nb.features == 0 {
|
||||
panic("Fit should be called before predicting")
|
||||
}
|
||||
func (nb *BernoulliNBClassifier) PredictOne(vector [][]byte) string {
|
||||
if nb.features == 0 {
|
||||
panic("Fit should be called before predicting")
|
||||
}
|
||||
|
||||
if len(vector) != nb.features {
|
||||
panic("Different dimensions in Train and Test sets")
|
||||
}
|
||||
if len(vector) != nb.features {
|
||||
panic("Different dimensions in Train and Test sets")
|
||||
}
|
||||
|
||||
// Currently only the predicted class is returned.
|
||||
bestScore := -math.MaxFloat64
|
||||
bestClass := ""
|
||||
// Currently only the predicted class is returned.
|
||||
bestScore := -math.MaxFloat64
|
||||
bestClass := ""
|
||||
|
||||
for class, classCount := range nb.classInstances {
|
||||
// Init classScore with log(prior)
|
||||
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances))
|
||||
for f := 0; f < nb.features; f++ {
|
||||
if vector[f] > 0 {
|
||||
// Test document has feature c
|
||||
classScore += math.Log(nb.condProb[class][f])
|
||||
} else {
|
||||
if nb.condProb[class][f] == 1.0 {
|
||||
// special case when prob = 1.0, consider laplace
|
||||
// smooth
|
||||
classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1))
|
||||
} else {
|
||||
classScore += math.Log(1.0 - nb.condProb[class][f])
|
||||
}
|
||||
}
|
||||
}
|
||||
for class, classCount := range nb.classInstances {
|
||||
// Init classScore with log(prior)
|
||||
classScore := math.Log((float64(classCount)) / float64(nb.trainingInstances))
|
||||
for f := 0; f < nb.features; f++ {
|
||||
if vector[f][0] > 0 {
|
||||
// Test document has feature c
|
||||
classScore += math.Log(nb.condProb[class][f])
|
||||
} else {
|
||||
if nb.condProb[class][f] == 1.0 {
|
||||
// special case when prob = 1.0, consider laplace
|
||||
// smooth
|
||||
classScore += math.Log(1.0 / float64(nb.classInstances[class]+1))
|
||||
} else {
|
||||
classScore += math.Log(1.0 - nb.condProb[class][f])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if classScore > bestScore {
|
||||
bestScore = classScore
|
||||
bestClass = class
|
||||
}
|
||||
}
|
||||
if classScore > bestScore {
|
||||
bestScore = classScore
|
||||
bestClass = class
|
||||
}
|
||||
}
|
||||
|
||||
return bestClass
|
||||
return bestClass
|
||||
}
|
||||
|
||||
// Predict is just a wrapper for the PredictOne function.
|
||||
//
|
||||
// IMPORTANT: Predict panics if Fit was not called or if the
|
||||
// document vector and train matrix have a different number of columns.
|
||||
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
|
||||
ret := what.GeneratePredictionVector()
|
||||
for i := 0; i < what.Rows; i++ {
|
||||
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
|
||||
}
|
||||
return ret
|
||||
func (nb *BernoulliNBClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
// Generate return vector
|
||||
ret := base.GeneratePredictionVector(what)
|
||||
|
||||
// Get the features
|
||||
featAttrSpecs := base.ResolveAttributes(what, nb.attrs)
|
||||
|
||||
what.MapOverRows(featAttrSpecs, func(row [][]byte, i int) (bool, error) {
|
||||
base.SetClass(ret, i, nb.PredictOne(row))
|
||||
return true, nil
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
@ -1,96 +1,109 @@
|
||||
package naive
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"testing"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNoFit(t *testing.T) {
|
||||
Convey("Given an empty BernoulliNaiveBayes", t, func() {
|
||||
nb := NewBernoulliNBClassifier()
|
||||
Convey("Given an empty BernoulliNaiveBayes", t, func() {
|
||||
nb := NewBernoulliNBClassifier()
|
||||
|
||||
Convey("PredictOne should panic if Fit was not called", func() {
|
||||
testDoc := []float64{0.0, 1.0}
|
||||
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
|
||||
})
|
||||
})
|
||||
Convey("PredictOne should panic if Fit was not called", func() {
|
||||
testDoc := [][]byte{[]byte{0}, []byte{1}}
|
||||
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func convertToBinary(src base.FixedDataGrid) base.FixedDataGrid {
|
||||
// Convert to binary
|
||||
b := filters.NewBinaryConvertFilter()
|
||||
attrs := base.NonClassAttributes(src)
|
||||
for _, a := range attrs {
|
||||
b.AddAttribute(a)
|
||||
}
|
||||
b.Train()
|
||||
ret := base.NewLazilyFilteredInstances(src, b)
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestSimple(t *testing.T) {
|
||||
Convey("Given a simple training data", t, func() {
|
||||
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
|
||||
if err1 != nil {
|
||||
t.Error(err1)
|
||||
}
|
||||
Convey("Given a simple training data", t, func() {
|
||||
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
|
||||
if err1 != nil {
|
||||
t.Error(err1)
|
||||
}
|
||||
|
||||
nb := NewBernoulliNBClassifier()
|
||||
nb.Fit(trainingData)
|
||||
nb := NewBernoulliNBClassifier()
|
||||
nb.Fit(convertToBinary(trainingData))
|
||||
|
||||
Convey("Check if Fit is working as expected", func() {
|
||||
Convey("All data needed for prior should be correctly calculated", func() {
|
||||
So(nb.classInstances["blue"], ShouldEqual, 2)
|
||||
So(nb.classInstances["red"], ShouldEqual, 2)
|
||||
So(nb.trainingInstances, ShouldEqual, 4)
|
||||
})
|
||||
Convey("Check if Fit is working as expected", func() {
|
||||
Convey("All data needed for prior should be correctly calculated", func() {
|
||||
So(nb.classInstances["blue"], ShouldEqual, 2)
|
||||
So(nb.classInstances["red"], ShouldEqual, 2)
|
||||
So(nb.trainingInstances, ShouldEqual, 4)
|
||||
})
|
||||
|
||||
Convey("'red' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.condProb["red"][0]
|
||||
logCondProbTok1 := nb.condProb["red"][1]
|
||||
logCondProbTok2 := nb.condProb["red"][2]
|
||||
Convey("'red' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.condProb["red"][0]
|
||||
logCondProbTok1 := nb.condProb["red"][1]
|
||||
logCondProbTok2 := nb.condProb["red"][2]
|
||||
|
||||
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
|
||||
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
|
||||
})
|
||||
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
|
||||
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
|
||||
})
|
||||
|
||||
Convey("'blue' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.condProb["blue"][0]
|
||||
logCondProbTok1 := nb.condProb["blue"][1]
|
||||
logCondProbTok2 := nb.condProb["blue"][2]
|
||||
Convey("'blue' conditional probabilities should be correct", func() {
|
||||
logCondProbTok0 := nb.condProb["blue"][0]
|
||||
logCondProbTok1 := nb.condProb["blue"][1]
|
||||
logCondProbTok2 := nb.condProb["blue"][2]
|
||||
|
||||
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
|
||||
})
|
||||
})
|
||||
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
|
||||
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
|
||||
})
|
||||
})
|
||||
|
||||
Convey("PredictOne should work as expected", func() {
|
||||
Convey("Using a document with different number of cols should panic", func() {
|
||||
testDoc := []float64{0.0, 2.0}
|
||||
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
|
||||
})
|
||||
Convey("PredictOne should work as expected", func() {
|
||||
Convey("Using a document with different number of cols should panic", func() {
|
||||
testDoc := [][]byte{[]byte{0}, []byte{2}}
|
||||
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
|
||||
})
|
||||
|
||||
Convey("Token 1 should be a good predictor of the blue class", func() {
|
||||
testDoc := []float64{0.0, 123.0, 0.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
|
||||
Convey("Token 1 should be a good predictor of the blue class", func() {
|
||||
testDoc := [][]byte{[]byte{0}, []byte{1}, []byte{0}}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
|
||||
|
||||
testDoc = []float64{120.0, 123.0, 0.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
|
||||
})
|
||||
testDoc = [][]byte{[]byte{1}, []byte{1}, []byte{0}}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
|
||||
})
|
||||
|
||||
Convey("Token 2 should be a good predictor of the red class", func() {
|
||||
testDoc := []float64{0.0, 0.0, 120.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "red")
|
||||
Convey("Token 2 should be a good predictor of the red class", func() {
|
||||
testDoc := [][]byte{[]byte{0}, []byte{0}, []byte{1}}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "red")
|
||||
testDoc = [][]byte{[]byte{1}, []byte{0}, []byte{1}}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
|
||||
testDoc = []float64{10.0, 0.0, 120.0}
|
||||
So(nb.PredictOne(testDoc), ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
Convey("Predict should work as expected", func() {
|
||||
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
Convey("Predict should work as expected", func() {
|
||||
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
predictions := nb.Predict(testData)
|
||||
predictions := nb.Predict(convertToBinary(testData))
|
||||
|
||||
Convey("All simple predicitions should be correct", func() {
|
||||
So(predictions.GetClass(0), ShouldEqual, "blue")
|
||||
So(predictions.GetClass(1), ShouldEqual, "red")
|
||||
So(predictions.GetClass(2), ShouldEqual, "blue")
|
||||
So(predictions.GetClass(3), ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
})
|
||||
Convey("All simple predicitions should be correct", func() {
|
||||
So(base.GetClass(predictions, 0), ShouldEqual, "blue")
|
||||
So(base.GetClass(predictions, 1), ShouldEqual, "red")
|
||||
So(base.GetClass(predictions, 2), ShouldEqual, "blue")
|
||||
So(base.GetClass(predictions, 3), ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -17,35 +17,40 @@ type InformationGainRuleGenerator struct {
|
||||
//
|
||||
// IMPORTANT: passing a base.Instances with no Attributes other than the class
|
||||
// variable will panic()
|
||||
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
allAttributes := make([]int, 0)
|
||||
for i := 0; i < f.Cols; i++ {
|
||||
if i != f.ClassIndex {
|
||||
allAttributes = append(allAttributes, i)
|
||||
}
|
||||
}
|
||||
return r.GetSplitAttributeFromSelection(allAttributes, f)
|
||||
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
|
||||
|
||||
attrs := f.AllAttributes()
|
||||
classAttrs := f.AllClassAttributes()
|
||||
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
|
||||
|
||||
return r.GetSplitAttributeFromSelection(candidates, f)
|
||||
}
|
||||
|
||||
// GetSplitAttributeFromSelection returns the class Attribute which maximises
|
||||
// the information gain amongst consideredAttributes
|
||||
//
|
||||
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
|
||||
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
|
||||
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute {
|
||||
|
||||
var selectedAttribute base.Attribute
|
||||
|
||||
// Parameter check
|
||||
if len(consideredAttributes) == 0 {
|
||||
panic("More Attributes should be considered")
|
||||
}
|
||||
|
||||
// Next step is to compute the information gain at this node
|
||||
// for each randomly chosen attribute, and pick the one
|
||||
// which maximises it
|
||||
maxGain := math.Inf(-1)
|
||||
selectedAttribute := -1
|
||||
|
||||
// Compute the base entropy
|
||||
classDist := f.GetClassDistribution()
|
||||
classDist := base.GetClassDistribution(f)
|
||||
baseEntropy := getBaseEntropy(classDist)
|
||||
|
||||
// Compute the information gain for each attribute
|
||||
for _, s := range consideredAttributes {
|
||||
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
|
||||
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
|
||||
localEntropy := getSplitEntropy(proposedClassDist)
|
||||
informationGain := baseEntropy - localEntropy
|
||||
if informationGain > maxGain {
|
||||
@ -55,7 +60,7 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
|
||||
}
|
||||
|
||||
// Pick the one which maximises IG
|
||||
return f.GetAttr(selectedAttribute)
|
||||
return selectedAttribute
|
||||
}
|
||||
|
||||
//
|
||||
|
69
trees/id3.go
69
trees/id3.go
@ -21,7 +21,7 @@ const (
|
||||
// RuleGenerator implementations analyse instances and determine
|
||||
// the best value to split on
|
||||
type RuleGenerator interface {
|
||||
GenerateSplitAttribute(*base.Instances) base.Attribute
|
||||
GenerateSplitAttribute(base.FixedDataGrid) base.Attribute
|
||||
}
|
||||
|
||||
// DecisionTreeNode represents a given portion of a decision tree
|
||||
@ -31,14 +31,19 @@ type DecisionTreeNode struct {
|
||||
SplitAttr base.Attribute
|
||||
ClassDist map[string]int
|
||||
Class string
|
||||
ClassAttr *base.Attribute
|
||||
ClassAttr base.Attribute
|
||||
}
|
||||
|
||||
func getClassAttr(from base.FixedDataGrid) base.Attribute {
|
||||
allClassAttrs := from.AllClassAttributes()
|
||||
return allClassAttrs[0]
|
||||
}
|
||||
|
||||
// InferID3Tree builds a decision tree using a RuleGenerator
|
||||
// from a set of Instances (implements the ID3 algorithm)
|
||||
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode {
|
||||
// Count the number of classes at this node
|
||||
classes := from.CountClassValues()
|
||||
classes := base.GetClassDistribution(from)
|
||||
// If there's only one class, return a DecisionTreeLeaf with
|
||||
// the only class available
|
||||
if len(classes) == 1 {
|
||||
@ -52,7 +57,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
getClassAttr(from),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
@ -69,28 +74,29 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
|
||||
// If there are no more Attributes left to split on,
|
||||
// return a DecisionTreeLeaf with the majority class
|
||||
if from.GetAttributeCount() == 2 {
|
||||
cols, _ := from.Size()
|
||||
if cols == 2 {
|
||||
ret := &DecisionTreeNode{
|
||||
LeafNode,
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
getClassAttr(from),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Generate a return structure
|
||||
ret := &DecisionTreeNode{
|
||||
RuleNode,
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
getClassAttr(from),
|
||||
}
|
||||
|
||||
// Generate a return structure
|
||||
// Generate the splitting attribute
|
||||
splitOnAttribute := with.GenerateSplitAttribute(from)
|
||||
if splitOnAttribute == nil {
|
||||
@ -98,7 +104,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
return ret
|
||||
}
|
||||
// Split the attributes based on this attribute's value
|
||||
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
|
||||
splitInstances := base.DecomposeOnAttributeValues(from, splitOnAttribute)
|
||||
// Create new children from these attributes
|
||||
ret.Children = make(map[string]*DecisionTreeNode)
|
||||
for k := range splitInstances {
|
||||
@ -146,13 +152,13 @@ func (d *DecisionTreeNode) String() string {
|
||||
}
|
||||
|
||||
// computeAccuracy is a helper method for Prune()
|
||||
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
|
||||
func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 {
|
||||
cf := eval.GetConfusionMatrix(from, predictions)
|
||||
return eval.GetAccuracy(cf)
|
||||
}
|
||||
|
||||
// Prune eliminates branches which hurt accuracy
|
||||
func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
|
||||
// If you're a leaf, you're already pruned
|
||||
if d.Children == nil {
|
||||
return
|
||||
@ -162,11 +168,15 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
}
|
||||
|
||||
// Recursively prune children of this node
|
||||
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
|
||||
sub := base.DecomposeOnAttributeValues(using, d.SplitAttr)
|
||||
for k := range d.Children {
|
||||
if sub[k] == nil {
|
||||
continue
|
||||
}
|
||||
subH, subV := sub[k].Size()
|
||||
if subH == 0 || subV == 0 {
|
||||
continue
|
||||
}
|
||||
d.Children[k].Prune(sub[k])
|
||||
}
|
||||
|
||||
@ -185,24 +195,30 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
}
|
||||
|
||||
// Predict outputs a base.Instances containing predictions from this tree
|
||||
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
||||
outputAttrs := make([]base.Attribute, 1)
|
||||
outputAttrs[0] = what.GetClassAttr()
|
||||
predictions := base.NewInstances(outputAttrs, what.Rows)
|
||||
for i := 0; i < what.Rows; i++ {
|
||||
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
predictions := base.GeneratePredictionVector(what)
|
||||
classAttr := getClassAttr(predictions)
|
||||
classAttrSpec, err := predictions.GetAttribute(classAttr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
|
||||
predAttrSpecs := base.ResolveAttributes(what, predAttrs)
|
||||
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
cur := d
|
||||
for {
|
||||
if cur.Children == nil {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
|
||||
break
|
||||
} else {
|
||||
at := cur.SplitAttr
|
||||
j := what.GetAttrIndex(at)
|
||||
if j == -1 {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
ats, err := what.GetAttribute(at)
|
||||
if err != nil {
|
||||
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
|
||||
break
|
||||
}
|
||||
classVar := at.GetStringFromSysVal(what.Get(i, j))
|
||||
|
||||
classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
|
||||
if next, ok := cur.Children[classVar]; ok {
|
||||
cur = next
|
||||
} else {
|
||||
@ -217,7 +233,8 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
return predictions
|
||||
}
|
||||
|
||||
@ -245,7 +262,7 @@ func NewID3DecisionTree(prune float64) *ID3DecisionTree {
|
||||
}
|
||||
|
||||
// Fit builds the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Fit(on *base.Instances) {
|
||||
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
if t.PruneSplit > 0.001 {
|
||||
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
|
||||
@ -257,7 +274,7 @@ func (t *ID3DecisionTree) Fit(on *base.Instances) {
|
||||
}
|
||||
|
||||
// Predict outputs predictions from the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances {
|
||||
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
return t.Root.Predict(what)
|
||||
}
|
||||
|
||||
|
@ -14,32 +14,32 @@ type RandomTreeRuleGenerator struct {
|
||||
|
||||
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
|
||||
// which maximises Information Gain
|
||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
|
||||
|
||||
// First step is to generate the random attributes that we'll consider
|
||||
maximumAttribute := f.GetAttributeCount()
|
||||
consideredAttributes := make([]int, r.Attributes)
|
||||
allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
|
||||
maximumAttribute := len(allAttributes)
|
||||
consideredAttributes := make([]base.Attribute, 0)
|
||||
|
||||
attrCounter := 0
|
||||
for {
|
||||
if len(consideredAttributes) >= r.Attributes {
|
||||
break
|
||||
}
|
||||
selectedAttribute := rand.Intn(maximumAttribute)
|
||||
base.Logger.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
|
||||
if selectedAttribute != f.ClassIndex {
|
||||
matched := false
|
||||
for _, a := range consideredAttributes {
|
||||
if a == selectedAttribute {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
selectedAttrIndex := rand.Intn(maximumAttribute)
|
||||
selectedAttribute := allAttributes[selectedAttrIndex]
|
||||
matched := false
|
||||
for _, a := range consideredAttributes {
|
||||
if a.Equals(selectedAttribute) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
consideredAttributes = append(consideredAttributes, selectedAttribute)
|
||||
attrCounter++
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
consideredAttributes = append(consideredAttributes, selectedAttribute)
|
||||
attrCounter++
|
||||
}
|
||||
|
||||
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
|
||||
@ -67,12 +67,12 @@ func NewRandomTree(attrs int) *RandomTree {
|
||||
}
|
||||
|
||||
// Fit builds a RandomTree suitable for prediction
|
||||
func (rt *RandomTree) Fit(from *base.Instances) {
|
||||
func (rt *RandomTree) Fit(from base.FixedDataGrid) {
|
||||
rt.Root = InferID3Tree(from, rt.Rule)
|
||||
}
|
||||
|
||||
// Predict returns a set of Instances containing predictions
|
||||
func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
|
||||
func (rt *RandomTree) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
||||
return rt.Root.Predict(from)
|
||||
}
|
||||
|
||||
@ -83,6 +83,6 @@ func (rt *RandomTree) String() string {
|
||||
|
||||
// Prune removes nodes from the tree which are detrimental
|
||||
// to determining the accuracy of the test set (with)
|
||||
func (rt *RandomTree) Prune(with *base.Instances) {
|
||||
func (rt *RandomTree) Prune(with base.FixedDataGrid) {
|
||||
rt.Root.Prune(with)
|
||||
}
|
||||
|
@ -14,15 +14,17 @@ func TestRandomTree(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferID3Tree(inst, r)
|
||||
fmt.Println(instf)
|
||||
root := InferID3Tree(instf, r)
|
||||
fmt.Println(root)
|
||||
}
|
||||
|
||||
@ -33,18 +35,20 @@ func TestRandomTreeClassification(testEnv *testing.T) {
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(trainData)
|
||||
filt.Run(testData)
|
||||
fmt.Println(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
||||
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferID3Tree(trainData, r)
|
||||
root := InferID3Tree(trainDataF, r)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
predictions := root.Predict(testDataF)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
@ -58,17 +62,19 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(testData)
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
||||
|
||||
root := NewRandomTree(2)
|
||||
root.Fit(trainData)
|
||||
root.Fit(trainDataF)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
predictions := root.Predict(testDataF)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
@ -82,19 +88,21 @@ func TestPruning(testEnv *testing.T) {
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(testData)
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
||||
|
||||
root := NewRandomTree(2)
|
||||
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
|
||||
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
|
||||
root.Fit(fittrainData)
|
||||
root.Prune(fittestData)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
predictions := root.Predict(testDataF)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
@ -142,6 +150,7 @@ func TestID3Inference(testEnv *testing.T) {
|
||||
testEnv.Error(sunnyChild)
|
||||
}
|
||||
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||
fmt.Println(rainyChild.SplitAttr)
|
||||
testEnv.Error(rainyChild)
|
||||
}
|
||||
if overcastChild.SplitAttr != nil {
|
||||
@ -156,7 +165,6 @@ func TestID3Inference(testEnv *testing.T) {
|
||||
if sunnyLeafNormal.Class != "yes" {
|
||||
testEnv.Error(sunnyLeafNormal)
|
||||
}
|
||||
|
||||
windyLeafFalse := rainyChild.Children["false"]
|
||||
windyLeafTrue := rainyChild.Children["true"]
|
||||
if windyLeafFalse.Class != "yes" {
|
||||
@ -176,12 +184,18 @@ func TestID3Classification(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
filt := filters.NewBinningFilter(inst, 10)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.70)
|
||||
filt := filters.NewBinningFilter(inst, 10)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
fmt.Println(filt)
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
fmt.Println("INSTFA", instf.AllAttributes())
|
||||
fmt.Println("INSTF", instf)
|
||||
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
|
||||
|
||||
// Build the decision tree
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
root := InferID3Tree(trainData, rule)
|
||||
@ -199,6 +213,7 @@ func TestID3(testEnv *testing.T) {
|
||||
|
||||
// Import the "PlayTennis" dataset
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
fmt.Println(inst)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user