1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

Merge pull request #62 from Sentimentron/instances-v2

Instances v2
This commit is contained in:
Stephen Whitworth 2014-08-10 09:12:36 +01:00
commit 76ef9ede34
66 changed files with 5473 additions and 1926 deletions

View File

@ -1,16 +1,14 @@
package base
import "fmt"
import "strconv"
const (
// CategoricalType is for Attributes which represent values distinctly.
CategoricalType = iota
// Float64Type should be replaced with a FractionalNumeric type [DEPRECATED].
Float64Type
BinaryType
)
// Attribute Attributes disambiguate columns of the feature matrix and declare their types.
// Attributes disambiguate columns of the feature matrix and declare their types.
type Attribute interface {
// Returns the general characterstics of this Attribute .
// to avoid the overhead of casting
@ -25,12 +23,12 @@ type Attribute interface {
// representation. For example, a CategoricalAttribute with values
// ["iris-setosa", "iris-virginica"] would return the float64
// representation of 0 when given "iris-setosa".
GetSysValFromString(string) float64
GetSysValFromString(string) []byte
// Converts a given value from a system representation into a human
// representation. For example, a CategoricalAttribute with values
// ["iris-setosa", "iris-viriginica"] might return "iris-setosa"
// when given 0.0 as the argument.
GetStringFromSysVal(float64) string
GetStringFromSysVal([]byte) string
// Tests for equality with another Attribute. Other Attributes are
// considered equal if:
// * They have the same type (i.e. FloatAttribute <> CategoricalAttribute)
@ -38,230 +36,8 @@ type Attribute interface {
// * If applicable, they have the same categorical values (though not
// necessarily in the same order).
Equals(Attribute) bool
}
// FloatAttribute is an implementation which stores floating point
// representations of numbers.
type FloatAttribute struct {
Name string
Precision int
}
// NewFloatAttribute returns a new FloatAttribute with a default
// precision of 2 decimal places
func NewFloatAttribute() *FloatAttribute {
return &FloatAttribute{"", 2}
}
// Equals tests a FloatAttribute for equality with another Attribute.
//
// Returns false if the other Attribute has a different name
// or if the other Attribute is not a FloatAttribute.
func (Attr *FloatAttribute) Equals(other Attribute) bool {
// Check whether this FloatAttribute is equal to another
_, ok := other.(*FloatAttribute)
if !ok {
// Not the same type, so can't be equal
return false
}
if Attr.GetName() != other.GetName() {
return false
}
return true
}
// GetName returns this FloatAttribute's human-readable name.
func (Attr *FloatAttribute) GetName() string {
return Attr.Name
}
// SetName sets this FloatAttribute's human-readable name.
func (Attr *FloatAttribute) SetName(name string) {
Attr.Name = name
}
// GetType returns Float64Type.
func (Attr *FloatAttribute) GetType() int {
return Float64Type
}
// String returns a human-readable summary of this Attribute.
// e.g. "FloatAttribute(Sepal Width)"
func (Attr *FloatAttribute) String() string {
return fmt.Sprintf("FloatAttribute(%s)", Attr.Name)
}
// CheckSysValFromString confirms whether a given rawVal can
// be converted into a valid system representation.
func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) (float64, error) {
f, err := strconv.ParseFloat(rawVal, 64)
if err != nil {
return 0.0, err
}
return f, nil
}
// GetSysValFromString parses the given rawVal string to a float64 and returns it.
//
// float64 happens to be a 1-to-1 mapping to the system representation.
// IMPORTANT: This function panic()s if rawVal is not a valid float.
// Use CheckSysValFromString to confirm.
func (Attr *FloatAttribute) GetSysValFromString(rawVal string) float64 {
f, err := strconv.ParseFloat(rawVal, 64)
if err != nil {
panic(err)
}
return f
}
// GetStringFromSysVal converts a given system value to to a string with two decimal
// places of precision [TODO: revise this and allow more precision].
func (Attr *FloatAttribute) GetStringFromSysVal(rawVal float64) string {
formatString := fmt.Sprintf("%%.%df", Attr.Precision)
return fmt.Sprintf(formatString, rawVal)
}
// GetSysVal returns the system representation of userVal.
//
// Because FloatAttribute represents float64 types, this
// just returns its argument.
func (Attr *FloatAttribute) GetSysVal(userVal float64) float64 {
return userVal
}
// GetUsrVal returns the user representation of sysVal.
//
// Because FloatAttribute represents float64 types, this
// just returns its argument.
func (Attr *FloatAttribute) GetUsrVal(sysVal float64) float64 {
return sysVal
}
// CategoricalAttribute is an Attribute implementation
// which stores discrete string values
// - useful for representing classes.
type CategoricalAttribute struct {
Name string
values []string
}
func NewCategoricalAttribute() *CategoricalAttribute {
return &CategoricalAttribute{
"",
make([]string, 0),
}
}
// GetName returns the human-readable name assigned to this attribute.
func (Attr *CategoricalAttribute) GetName() string {
return Attr.Name
}
// SetName sets the human-readable name on this attribute.
func (Attr *CategoricalAttribute) SetName(name string) {
Attr.Name = name
}
// GetType returns CategoricalType to avoid casting overhead.
func (Attr *CategoricalAttribute) GetType() int {
return CategoricalType
}
// GetSysVal returns the system representation of userVal as an index into the Values slice
// If the userVal can't be found, it returns -1.
func (Attr *CategoricalAttribute) GetSysVal(userVal string) float64 {
for idx, val := range Attr.values {
if val == userVal {
return float64(idx)
}
}
return -1
}
// GetUsrVal returns a human-readable representation of the given sysVal.
//
// IMPORTANT: this function doesn't check the boundaries of the array.
func (Attr *CategoricalAttribute) GetUsrVal(sysVal float64) string {
idx := int(sysVal)
return Attr.values[idx]
}
// GetSysValFromString returns the system representation of rawVal
// as an index into the Values slice. If rawVal is not inside
// the Values slice, it is appended.
//
// IMPORTANT: If no system representation yet exists, this functions adds it.
// If you need to determine whether rawVal exists: use GetSysVal and check
// for a -1 return value.
//
// Example: if the CategoricalAttribute contains the values ["iris-setosa",
// "iris-virginica"] and "iris-versicolor" is provided as the argument,
// the Values slide becomes ["iris-setosa", "iris-virginica", "iris-versicolor"]
// and 2.00 is returned as the system representation.
func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) float64 {
// Match in raw values
catIndex := -1
for i, s := range Attr.values {
if s == rawVal {
catIndex = i
break
}
}
if catIndex == -1 {
Attr.values = append(Attr.values, rawVal)
catIndex = len(Attr.values) - 1
}
return float64(catIndex)
}
// String returns a human-readable summary of this Attribute.
//
// Returns a string containing the list of human-readable values this
// CategoricalAttribute can take.
func (Attr *CategoricalAttribute) String() string {
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
}
// GetStringFromSysVal returns a human-readable value from the given system-representation
// value val.
//
// IMPORTANT: This function calls panic() if the value is greater than
// the length of the array.
// TODO: Return a user-configurable default instead.
func (Attr *CategoricalAttribute) GetStringFromSysVal(val float64) string {
convVal := int(val)
if convVal >= len(Attr.values) {
panic(fmt.Sprintf("Out of range: %d in %d", convVal, len(Attr.values)))
}
return Attr.values[convVal]
}
// Equals checks equality against another Attribute.
//
// Two CategoricalAttributes are considered equal if they contain
// the same values and have the same name. Otherwise, this function
// returns false.
func (Attr *CategoricalAttribute) Equals(other Attribute) bool {
attribute, ok := other.(*CategoricalAttribute)
if !ok {
// Not the same type, so can't be equal
return false
}
if Attr.GetName() != attribute.GetName() {
return false
}
// Check that this CategoricalAttribute has the same
// values as the other, in the same order
if len(attribute.values) != len(Attr.values) {
return false
}
for i, a := range Attr.values {
if a != attribute.values[i] {
return false
}
}
return true
// Tests whether two Attributes can be represented in the same pond
// i.e. they're the same size, and their byte order makes them meaningful
// when considered together
Compatable(Attribute) bool
}

69
base/attributes_test.go Normal file
View File

@ -0,0 +1,69 @@
package base
import (
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestFloatAttributeSysVal(t *testing.T) {
Convey("Given some float", t, func() {
x := "1.21"
attr := NewFloatAttribute()
Convey("When the float gets packed", func() {
packed := attr.GetSysValFromString(x)
Convey("And then unpacked", func() {
unpacked := attr.GetStringFromSysVal(packed)
Convey("The unpacked version should be the same", func() {
So(unpacked, ShouldEqual, x)
})
})
})
})
}
func TestCategoricalAttributeVal(t *testing.T) {
attr := NewCategoricalAttribute()
Convey("Given some string", t, func() {
x := "hello world!"
Convey("When the string gets converted", func() {
packed := attr.GetSysValFromString(x)
Convey("And then unconverted", func() {
unpacked := attr.GetStringFromSysVal(packed)
Convey("The unpacked version should be the same", func() {
So(unpacked, ShouldEqual, x)
})
})
})
})
Convey("Given some second string", t, func() {
x := "hello world 1!"
Convey("When the string gets converted", func() {
packed := attr.GetSysValFromString(x)
So(packed[0], ShouldEqual, 0x1)
Convey("And then unconverted", func() {
unpacked := attr.GetStringFromSysVal(packed)
Convey("The unpacked version should be the same", func() {
So(unpacked, ShouldEqual, x)
})
})
})
})
}
func TestBinaryAttribute(t *testing.T) {
attr := new(BinaryAttribute)
Convey("Given some binary Attribute", t, func() {
Convey("SetName, GetName should be equal", func() {
attr.SetName("Hello")
So(attr.GetName(), ShouldEqual, "Hello")
})
Convey("Non-zero values should equal 1", func() {
sysVal := attr.GetSysValFromString("1")
So(sysVal[0], ShouldEqual, 1)
})
Convey("Zero values should equal 0", func() {
sysVal := attr.GetSysValFromString("0")
So(sysVal[0], ShouldEqual, 0)
})
})
}

78
base/binary.go Normal file
View File

@ -0,0 +1,78 @@
package base
import (
"fmt"
"strconv"
)
// BinaryAttributes can only represent 1 or 0.
type BinaryAttribute struct {
Name string
}
// NewBinaryAttribute creates a BinaryAttribute with the given name
func NewBinaryAttribute(name string) *BinaryAttribute {
return &BinaryAttribute{
name,
}
}
// GetName returns the name of this Attribute.
func (b *BinaryAttribute) GetName() string {
return b.Name
}
// SetName sets the name of this Attribute.
func (b *BinaryAttribute) SetName(name string) {
b.Name = name
}
// GetType returns BinaryType.
func (b *BinaryAttribute) GetType() int {
return BinaryType
}
// GetSysValFromString returns either 1 or 0 in a single byte.
func (b *BinaryAttribute) GetSysValFromString(userVal string) []byte {
f, err := strconv.ParseFloat(userVal, 64)
if err != nil {
panic(err)
}
ret := make([]byte, 1)
if f > 0 {
ret[0] = 1
}
return ret
}
// GetStringFromSysVal returns either 1 or 0.
func (b *BinaryAttribute) GetStringFromSysVal(val []byte) string {
if val[0] > 0 {
return "1"
}
return "0"
}
// Equals checks for equality with another BinaryAttribute.
func (b *BinaryAttribute) Equals(other Attribute) bool {
if a, ok := other.(*BinaryAttribute); !ok {
return false
} else {
return a.Name == b.Name
}
}
// Compatable checks whether this Attribute can be represented
// in the same pond as another.
func (b *BinaryAttribute) Compatable(other Attribute) bool {
if _, ok := other.(*BinaryAttribute); !ok {
return false
} else {
return true
}
}
// String returns a human-redable representation.
func (b *BinaryAttribute) String() string {
return fmt.Sprintf("BinaryAttribute(%s)", b.Name)
}

165
base/categorical.go Normal file
View File

@ -0,0 +1,165 @@
package base
import (
"fmt"
)
// CategoricalAttribute is an Attribute implementation
// which stores discrete string values
// - useful for representing classes.
type CategoricalAttribute struct {
Name string
values []string
}
// NewCategoricalAttribute creates a blank CategoricalAttribute.
func NewCategoricalAttribute() *CategoricalAttribute {
return &CategoricalAttribute{
"",
make([]string, 0),
}
}
// GetValues returns all the values currently defined
func (Attr *CategoricalAttribute) GetValues() []string {
return Attr.values
}
// GetName returns the human-readable name assigned to this attribute.
func (Attr *CategoricalAttribute) GetName() string {
return Attr.Name
}
// SetName sets the human-readable name on this attribute.
func (Attr *CategoricalAttribute) SetName(name string) {
Attr.Name = name
}
// GetType returns CategoricalType to avoid casting overhead.
func (Attr *CategoricalAttribute) GetType() int {
return CategoricalType
}
// GetSysVal returns the system representation of userVal as an index into the Values slice
// If the userVal can't be found, it returns nothing.
func (Attr *CategoricalAttribute) GetSysVal(userVal string) []byte {
for idx, val := range Attr.values {
if val == userVal {
return PackU64ToBytes(uint64(idx))
}
}
return nil
}
// GetUsrVal returns a human-readable representation of the given sysVal.
//
// IMPORTANT: this function doesn't check the boundaries of the array.
func (Attr *CategoricalAttribute) GetUsrVal(sysVal []byte) string {
idx := UnpackBytesToU64(sysVal)
return Attr.values[idx]
}
// GetSysValFromString returns the system representation of rawVal
// as an index into the Values slice. If rawVal is not inside
// the Values slice, it is appended.
//
// IMPORTANT: If no system representation yet exists, this functions adds it.
// If you need to determine whether rawVal exists: use GetSysVal and check
// for a zero-length return value.
//
// Example: if the CategoricalAttribute contains the values ["iris-setosa",
// "iris-virginica"] and "iris-versicolor" is provided as the argument,
// the Values slide becomes ["iris-setosa", "iris-virginica", "iris-versicolor"]
// and 2.00 is returned as the system representation.
func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) []byte {
// Match in raw values
catIndex := -1
for i, s := range Attr.values {
if s == rawVal {
catIndex = i
break
}
}
if catIndex == -1 {
Attr.values = append(Attr.values, rawVal)
catIndex = len(Attr.values) - 1
}
ret := PackU64ToBytes(uint64(catIndex))
return ret
}
// String returns a human-readable summary of this Attribute.
//
// Returns a string containing the list of human-readable values this
// CategoricalAttribute can take.
func (Attr *CategoricalAttribute) String() string {
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
}
// GetStringFromSysVal returns a human-readable value from the given system-representation
// value val.
//
// IMPORTANT: This function calls panic() if the value is greater than
// the length of the array.
// TODO: Return a user-configurable default instead.
func (Attr *CategoricalAttribute) GetStringFromSysVal(rawVal []byte) string {
convVal := int(UnpackBytesToU64(rawVal))
if convVal >= len(Attr.values) {
panic(fmt.Sprintf("Out of range: %d in %d (%s)", convVal, len(Attr.values), Attr))
}
return Attr.values[convVal]
}
// Equals checks equality against another Attribute.
//
// Two CategoricalAttributes are considered equal if they contain
// the same values and have the same name. Otherwise, this function
// returns false.
func (Attr *CategoricalAttribute) Equals(other Attribute) bool {
attribute, ok := other.(*CategoricalAttribute)
if !ok {
// Not the same type, so can't be equal
return false
}
if Attr.GetName() != attribute.GetName() {
return false
}
// Check that this CategoricalAttribute has the same
// values as the other, in the same order
if len(attribute.values) != len(Attr.values) {
return false
}
for i, a := range Attr.values {
if a != attribute.values[i] {
return false
}
}
return true
}
// Compatable checks that this CategoricalAttribute has the same
// values as another, in the same order.
func (Attr *CategoricalAttribute) Compatable(other Attribute) bool {
attribute, ok := other.(*CategoricalAttribute)
if !ok {
return false
}
// Check that this CategoricalAttribute has the same
// values as the other, in the same order
if len(attribute.values) != len(Attr.values) {
return false
}
for i, a := range Attr.values {
if a != attribute.values[i] {
return false
}
}
return true
}

View File

@ -10,17 +10,17 @@ type Classifier interface {
// and constructs a new set of Instances of equivalent
// length with only the class Attribute and fills it in
// with predictions.
Predict(*Instances) *Instances
Predict(FixedDataGrid) FixedDataGrid
// Takes a set of instances and updates the Classifier's
// internal structures to enable prediction
Fit(*Instances)
Fit(FixedDataGrid)
// Why not make every classifier return a nice-looking string?
String() string
}
// BaseClassifier stores options common to every classifier.
type BaseClassifier struct {
TrainingData *Instances
TrainingData *DataGrid
}
type BaseRegressor struct {

View File

@ -77,28 +77,32 @@ func ParseCSVSniffAttributeNames(filepath string, hasHeaders bool) []string {
// The type of a given attribute is determined by looking at the first data row
// of the CSV.
func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
var attrs []Attribute
// Open file
file, err := os.Open(filepath)
if err != nil {
panic(err)
}
defer file.Close()
// Create the CSV reader
reader := csv.NewReader(file)
attrs := make([]Attribute, 0)
if hasHeaders {
// Skip the headers
_, err := reader.Read()
if err != nil {
panic(err)
}
}
// Read the first line of the file
columns, err := reader.Read()
if err != nil {
panic(err)
}
for _, entry := range columns {
// Match the Attribute type with regular expressions
entry = strings.Trim(entry, " ")
matched, err := regexp.MatchString("^[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?$", entry)
//fmt.Println(entry, matched)
if err != nil {
panic(err)
}
@ -112,30 +116,8 @@ func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
return attrs
}
// ParseCSVToInstances reads the CSV file given by filepath and returns
// the read Instances.
func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *Instances, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(error); !ok {
err = fmt.Errorf("golearn: ParseCSVToInstances: %v", r)
}
}
}()
// Read the number of rows in the file
rowCount := ParseCSVGetRows(filepath)
if hasHeaders {
rowCount--
}
// Read the row headers
attrs := ParseCSVGetAttributes(filepath, hasHeaders)
// Allocate the Instances to return
instances = NewInstances(attrs, rowCount)
// ParseCSVBuildInstances updates an [[#UpdatableDataGrid]] from a filepath in place
func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGrid) {
// Read the input
file, err := os.Open(filepath)
@ -146,6 +128,9 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *Instances
reader := csv.NewReader(file)
rowCounter := 0
specs := ResolveAttributes(u, u.AllAttributes())
for {
record, err := reader.Read()
if err == io.EOF {
@ -159,13 +144,68 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *Instances
continue
}
}
for i := range attrs {
instances.SetAttrStr(rowCounter, i, strings.Trim(record[i], " "))
for i, v := range record {
u.Set(specs[i], rowCounter, specs[i].attr.GetSysValFromString(v))
}
rowCounter++
}
return
}
// ParseCSVToInstances reads the CSV file given by filepath and returns
// the read Instances.
func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInstances, err error) {
// Read the number of rows in the file
rowCount := ParseCSVGetRows(filepath)
if hasHeaders {
rowCount--
}
// Read the row headers
attrs := ParseCSVGetAttributes(filepath, hasHeaders)
specs := make([]AttributeSpec, len(attrs))
// Allocate the Instances to return
instances = NewDenseInstances()
for i, a := range attrs {
spec := instances.AddAttribute(a)
specs[i] = spec
}
instances.Extend(rowCount)
// Read the input
file, err := os.Open(filepath)
if err != nil {
panic(err)
}
defer file.Close()
reader := csv.NewReader(file)
rowCounter := 0
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
panic(err)
}
if rowCounter == 0 {
if hasHeaders {
hasHeaders = false
continue
}
}
for i, v := range record {
v = strings.Trim(v, " ")
instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v))
}
rowCounter++
}
instances.AddClassAttribute(attrs[len(attrs)-1])
return instances, nil
}
//ParseCSV parses a CSV file and returns the number of columns and rows, the headers, the labels associated with

View File

@ -1,6 +1,8 @@
package base
import "testing"
import (
"testing"
)
func TestParseCSVGetRows(testEnv *testing.T) {
lineCount := ParseCSVGetRows("../examples/datasets/iris.csv")
@ -76,9 +78,9 @@ func TestReadInstances(testEnv *testing.T) {
testEnv.Error(err)
return
}
row1 := inst.RowStr(0)
row2 := inst.RowStr(50)
row3 := inst.RowStr(100)
row1 := inst.RowString(0)
row2 := inst.RowString(50)
row3 := inst.RowString(100)
if row1 != "5.10 3.50 1.40 0.20 Iris-setosa" {
testEnv.Error(row1)
@ -97,10 +99,11 @@ func TestReadAwkwardInsatnces(testEnv *testing.T) {
testEnv.Error(err)
return
}
if inst.GetAttr(0).GetType() != Float64Type {
attrs := inst.AllAttributes()
if attrs[0].GetType() != Float64Type {
testEnv.Error("Should be float!")
}
if inst.GetAttr(1).GetType() != CategoricalType {
if attrs[1].GetType() != CategoricalType {
testEnv.Error("Should be discrete!")
}
}

51
base/data.go Normal file
View File

@ -0,0 +1,51 @@
package base
// SortDirection specifies sorting direction...
type SortDirection int
const (
// Descending says that Instances should be sorted high to low...
Descending SortDirection = 1
// Ascending states that Instances should be sorted low to high...
Ascending SortDirection = 2
)
// DataGrid implementations represent data addressable by rows and columns.
type DataGrid interface {
// Retrieves a given Attribute's specification
GetAttribute(Attribute) (AttributeSpec, error)
// Retrieves details of every Attribute
AllAttributes() []Attribute
// Marks an Attribute as a class Attribute
AddClassAttribute(Attribute) error
// Unmarks an Attribute as a class Attribute
RemoveClassAttribute(Attribute) error
// Returns details of all class Attributes
AllClassAttributes() []Attribute
// Gets the bytes at a given position or nil
Get(AttributeSpec, int) []byte
// Convenience function for iteration.
MapOverRows([]AttributeSpec, func([][]byte, int) (bool, error)) error
}
// FixedDataGrid implementations have a size known in advance and implement
// all of the functionality offered by DataGrid implementations.
type FixedDataGrid interface {
DataGrid
// Returns a string representation of a given row
RowString(int) string
// Returns the number of Attributes and rows currently allocated
Size() (int, int)
}
// UpdatableDataGrid implementations can be changed in addition to implementing
// all of the functionality offered by FixedDataGrid implementations.
type UpdatableDataGrid interface {
FixedDataGrid
// Sets a given Attribute and row to a byte sequence.
Set(AttributeSpec, int, []byte)
// Adds an Attribute to the grid.
AddAttribute(Attribute) AttributeSpec
// Allocates additional room to hold a number of rows
Extend(int) error
}

View File

@ -1,33 +0,0 @@
package base
import "testing"
func TestDecomp(testEnv *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
if err != nil {
testEnv.Error(err)
return
}
decomp := inst.DecomposeOnAttributeValues(inst.GetAttr(0))
row0 := decomp["0.00"].RowStr(0)
row1 := decomp["1.00"].RowStr(0)
/* row2 := decomp["2.00"].RowStr(0)
row3 := decomp["3.00"].RowStr(0)
row4 := decomp["4.00"].RowStr(0)
row5 := decomp["5.00"].RowStr(0)
row6 := decomp["6.00"].RowStr(0)
row7 := decomp["7.00"].RowStr(0)*/
row8 := decomp["8.00"].RowStr(0)
// row9 := decomp["9.00"].RowStr(0)
if row0 != "3.10 1.50 0.20 Iris-setosa" {
testEnv.Error(row0)
}
if row1 != "3.00 1.40 0.20 Iris-setosa" {
testEnv.Error(row1)
}
if row8 != "2.90 6.30 1.80 Iris-virginica" {
testEnv.Error(row8)
}
}

476
base/dense.go Normal file
View File

@ -0,0 +1,476 @@
package base
import (
"bytes"
"fmt"
"github.com/sjwhitworth/golearn/base/edf"
"math"
"sync"
)
// DenseInstances stores each Attribute value explicitly
// in a large grid.
type DenseInstances struct {
storage *edf.EdfFile
pondMap map[string]int
ponds []*Pond
lock sync.Mutex
fixed bool
classAttrs map[AttributeSpec]bool
maxRow int
attributes []Attribute
}
// NewDenseInstances generates a new DenseInstances set
// with an anonymous EDF mapping and default settings.
func NewDenseInstances() *DenseInstances {
storage, err := edf.EdfAnonMap()
if err != nil {
panic(err)
}
return &DenseInstances{
storage,
make(map[string]int),
make([]*Pond, 0),
sync.Mutex{},
false,
make(map[AttributeSpec]bool),
0,
make([]Attribute, 0),
}
}
//
// Pond functions
//
// createPond adds a new Pond to this set of Instances
// IMPORTANT: do not call unless you've acquired the lock
func (inst *DenseInstances) createPond(name string, size int) {
if inst.fixed {
panic("Can't add additional Attributes")
}
// Resolve or create thread
threads, err := inst.storage.GetThreads()
if err != nil {
panic(err)
}
ok := false
for i := range threads {
if threads[i] == name {
ok = true
break
}
}
if ok {
panic("Can't create pond: pond thread already exists")
}
// Write the pool's thread into the file
thread := edf.NewThread(inst.storage, name)
err = inst.storage.WriteThread(thread)
if err != nil {
panic(fmt.Sprintf("Can't write thread: %s", err))
}
// Create the pond information
pond := new(Pond)
pond.threadNo = thread.GetId()
pond.parent = inst
pond.attributes = make([]Attribute, 0)
pond.size = size
pond.alloc = make([][]byte, 0)
// Store within instances
inst.pondMap[name] = len(inst.ponds)
inst.ponds = append(inst.ponds, pond)
}
// CreatePond adds a new Pond to this set of instances
// with a given name. If the size is 0, a bit-pond is added
// if the size of not 0, then the size of each pond attribute
// is set to that number of bytes.
func (inst *DenseInstances) CreatePond(name string, size int) (err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(error); !ok {
err = fmt.Errorf("CreatePond: %v (not created)", r)
}
}
}()
inst.lock.Lock()
defer inst.lock.Unlock()
inst.createPond(name, size)
return nil
}
// GetPond returns a reference to a Pond of a given name /
func (inst *DenseInstances) GetPond(name string) (*Pond, error) {
inst.lock.Lock()
defer inst.lock.Unlock()
// Check if the pond exists
if id, ok := inst.pondMap[name]; !ok {
return nil, fmt.Errorf("Pond '%s' doesn't exist", name)
} else {
// Return the pond
return inst.ponds[id], nil
}
}
//
// Attribute creation and handling
//
// AddAttribute adds an Attribute to this set of DenseInstances
// Creates a default Pond for it if a suitable one doesn't exist.
// Returns an AttributeSpec for subsequent Set() calls.
//
// IMPORTANT: will panic if storage has been allocated via Extend.
func (inst *DenseInstances) AddAttribute(a Attribute) AttributeSpec {
inst.lock.Lock()
defer inst.lock.Unlock()
if inst.fixed {
panic("Can't add additional Attributes")
}
// Generate a default Pond name
pond := "FLOAT"
if _, ok := a.(*CategoricalAttribute); ok {
pond = "CAT"
} else if _, ok := a.(*FloatAttribute); ok {
pond = "FLOAT"
} else {
panic("Unrecognised Attribute type")
}
// Create the pond if it doesn't exist
if _, ok := inst.pondMap[pond]; !ok {
inst.createPond(pond, 8)
}
id := inst.pondMap[pond]
p := inst.ponds[id]
p.attributes = append(p.attributes, a)
inst.attributes = append(inst.attributes, a)
return AttributeSpec{id, len(p.attributes) - 1, a}
}
// AddAttributeToPond adds an Attribute to a given pond
func (inst *DenseInstances) AddAttributeToPond(newAttribute Attribute, pond string) (AttributeSpec, error) {
inst.lock.Lock()
defer inst.lock.Unlock()
// Check if the pond exists
if _, ok := inst.pondMap[pond]; !ok {
return AttributeSpec{-1, 0, nil}, fmt.Errorf("Pond '%s' doesn't exist. Call CreatePond() first", pond)
}
id := inst.pondMap[pond]
p := inst.ponds[id]
for i, a := range p.attributes {
if !a.Compatable(newAttribute) {
return AttributeSpec{-1, 0, nil}, fmt.Errorf("Attribute %s is not compatable with %s in pond '%s' (position %d)", newAttribute, a, pond, i)
}
}
p.attributes = append(p.attributes, newAttribute)
inst.attributes = append(inst.attributes, newAttribute)
return AttributeSpec{id, len(p.attributes) - 1, newAttribute}, nil
}
// GetAttribute returns an Attribute equal to the argument.
//
// TODO: Write a function to pre-compute this once we've allocated
// TODO: Write a utility function which retrieves all AttributeSpecs for
// a given instance set.
func (inst *DenseInstances) GetAttribute(get Attribute) (AttributeSpec, error) {
inst.lock.Lock()
defer inst.lock.Unlock()
for i, p := range inst.ponds {
for j, a := range p.attributes {
if a.Equals(get) {
return AttributeSpec{i, j, a}, nil
}
}
}
return AttributeSpec{-1, 0, nil}, fmt.Errorf("Couldn't resolve %s", get)
}
// AllAttributes returns a slice of all Attributes.
func (inst *DenseInstances) AllAttributes() []Attribute {
inst.lock.Lock()
defer inst.lock.Unlock()
ret := make([]Attribute, 0)
for _, p := range inst.ponds {
for _, a := range p.attributes {
ret = append(ret, a)
}
}
return ret
}
// AddClassAttribute sets an Attribute to be a class Attribute.
func (inst *DenseInstances) AddClassAttribute(a Attribute) error {
as, err := inst.GetAttribute(a)
if err != nil {
return err
}
inst.lock.Lock()
defer inst.lock.Unlock()
inst.classAttrs[as] = true
return nil
}
// RemoveClassAttribute removes an Attribute from the set of class Attributes.
func (inst *DenseInstances) RemoveClassAttribute(a Attribute) error {
inst.lock.Lock()
defer inst.lock.Unlock()
as, err := inst.GetAttribute(a)
if err != nil {
return err
}
inst.lock.Lock()
defer inst.lock.Unlock()
inst.classAttrs[as] = false
return nil
}
// AllClassAttributes returns a slice of Attributes which have
// been designated class Attributes.
func (inst *DenseInstances) AllClassAttributes() []Attribute {
var ret []Attribute
inst.lock.Lock()
defer inst.lock.Unlock()
for a := range inst.classAttrs {
if inst.classAttrs[a] {
ret = append(ret, a.attr)
}
}
return ret
}
//
// Allocation functions
//
// Extend extends this set of Instances to store rows additional rows.
// It's recommended to set rows to something quite large.
//
// IMPORTANT: panics if the allocation fails
func (inst *DenseInstances) Extend(rows int) error {
inst.lock.Lock()
defer inst.lock.Unlock()
// Get the size of each page
pageSize := inst.storage.GetPageSize()
for pondName := range inst.ponds {
p := inst.ponds[pondName]
// Compute pond row storage requirements
rowSize := p.RowSize()
// How many rows can we store per page?
rowsPerPage := float64(pageSize) / float64(rowSize)
// How many pages?
pagesNeeded := uint32(math.Ceil(float64(rows) / rowsPerPage))
// Allocate those pages
r, err := inst.storage.AllocPages(pagesNeeded, p.threadNo)
if err != nil {
panic(fmt.Sprintf("Allocation error: %s (rowSize %d, pageSize %d, rowsPerPage %.2f, tried to allocate %d page(s) and extend by %d row(s))", err, rowSize, pageSize, rowsPerPage, pagesNeeded, rows))
}
// Resolve and assign those pages
byteBlock := inst.storage.ResolveRange(r)
for _, block := range byteBlock {
p.alloc = append(p.alloc, block)
}
}
inst.fixed = true
inst.maxRow += rows
return nil
}
// Set sets a particular Attribute (given as an AttributeSpec) on a particular
// row to a particular value.
//
// AttributeSpecs can be obtained using GetAttribute() or AddAttribute().
//
// IMPORTANT: Will panic() if the AttributeSpec isn't valid
//
// IMPORTANT: Will panic() if the row is too large
//
// IMPORTANT: Will panic() if the val is not the right length
func (inst *DenseInstances) Set(a AttributeSpec, row int, val []byte) {
inst.ponds[a.pond].set(a.position, row, val)
}
// Get gets a particular Attribute (given as an AttributeSpec) on a particular
// row.
// AttributeSpecs can be obtained using GetAttribute() or AddAttribute()
func (inst *DenseInstances) Get(a AttributeSpec, row int) []byte {
return inst.ponds[a.pond].get(a.position, row)
}
// RowString returns a string representation of a given row.
func (inst *DenseInstances) RowString(row int) string {
var buffer bytes.Buffer
first := true
for name := range inst.ponds {
if first {
first = false
} else {
buffer.WriteString(" ")
}
p := inst.ponds[name]
p.appendToRowBuf(row, &buffer)
}
return buffer.String()
}
//
// Row handling functions
//
func (inst *DenseInstances) allocateRowVector(asv []AttributeSpec) [][]byte {
ret := make([][]byte, len(asv))
for i, as := range asv {
p := inst.ponds[as.pond]
ret[i] = make([]byte, p.size)
}
return ret
}
// MapOverRows passes each row map into a function.
// First argument is a list of AttributeSpec in the order
// they're needed in for the function. The second is the function
// to call on each row.
func (inst *DenseInstances) MapOverRows(asv []AttributeSpec, mapFunc func([][]byte, int) (bool, error)) error {
rowBuf := make([][]byte, len(asv))
for i := 0; i < inst.maxRow; i++ {
for j, as := range asv {
p := inst.ponds[as.pond]
rowBuf[j] = p.get(as.position, i)
}
ok, err := mapFunc(rowBuf, i)
if err != nil {
return err
}
if !ok {
break
}
}
return nil
}
// Size returns the number of Attributes as the first return value
// and the maximum allocated row as the second value.
func (inst *DenseInstances) Size() (int, int) {
return len(inst.AllAttributes()), inst.maxRow
}
// swapRows swaps over rows i and j
func (inst *DenseInstances) swapRows(i, j int) {
as := ResolveAllAttributes(inst)
for _, a := range as {
v1 := inst.Get(a, i)
v2 := inst.Get(a, j)
v3 := make([]byte, len(v2))
copy(v3, v2)
inst.Set(a, j, v1)
inst.Set(a, i, v3)
}
}
// Equal checks whether a given Instance set is exactly the same
// as another: same size and same values (as determined by the Attributes)
//
// IMPORTANT: does not explicitly check if the Attributes are considered equal.
func (inst *DenseInstances) Equal(other DataGrid) bool {
_, rows := inst.Size()
for _, a := range inst.AllAttributes() {
as1, err := inst.GetAttribute(a)
if err != nil {
panic(err) // That indicates some kind of error
}
as2, err := inst.GetAttribute(a)
if err != nil {
return false // Obviously has different Attributes
}
for i := 0; i < rows; i++ {
b1 := inst.Get(as1, i)
b2 := inst.Get(as2, i)
if !byteSeqEqual(b1, b2) {
return false
}
}
}
return true
}
// String returns a human-readable summary of this dataset.
func (inst *DenseInstances) String() string {
var buffer bytes.Buffer
// Get all Attribute information
as := ResolveAllAttributes(inst)
// Print header
cols, rows := inst.Size()
buffer.WriteString("Instances with ")
buffer.WriteString(fmt.Sprintf("%d row(s) ", rows))
buffer.WriteString(fmt.Sprintf("%d attribute(s)\n", cols))
buffer.WriteString(fmt.Sprintf("Attributes: \n"))
for _, a := range as {
prefix := "\t"
if inst.classAttrs[a] {
prefix = "*\t"
}
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a.attr))
}
buffer.WriteString("\nData:\n")
maxRows := 30
if rows < maxRows {
maxRows = rows
}
for i := 0; i < maxRows; i++ {
buffer.WriteString("\t")
for _, a := range as {
val := inst.Get(a, i)
buffer.WriteString(fmt.Sprintf("%s ", a.attr.GetStringFromSysVal(val)))
}
buffer.WriteString("\n")
}
missingRows := rows - maxRows
if missingRows != 0 {
buffer.WriteString(fmt.Sprintf("\t...\n%d row(s) undisplayed", missingRows))
} else {
buffer.WriteString("All rows displayed")
}
return buffer.String()
}

261
base/edf/alloc.go Normal file
View File

@ -0,0 +1,261 @@
package edf
import (
"fmt"
)
// ContentEntry structs are stored in ContentEntry blocks
// which always at block 2.
type ContentEntry struct {
// Which thread this entry is assigned to
Thread uint32
// Which page this block starts at
Start uint32
// The page up to and including which the block ends
End uint32
}
func (e *EdfFile) extend(additionalPages uint32) error {
fileInfo, err := e.f.Stat()
if err != nil {
panic(err)
}
newSize := uint64(fileInfo.Size())/e.pageSize + uint64(additionalPages)
return e.truncate(int64(newSize))
}
func (e *EdfFile) getFreeMapSize() uint64 {
if e.f != nil {
fileInfo, err := e.f.Stat()
if err != nil {
panic(err)
}
return uint64(fileInfo.Size()) / e.pageSize
}
return uint64(EDF_SIZE) / e.pageSize
}
// FixedAlloc allocates a |bytesRequested| chunk of pages
// on the FIXED thread.
func (e *EdfFile) FixedAlloc(bytesRequested uint32) (EdfRange, error) {
pageSize := uint32(e.pageSize)
return e.AllocPages((pageSize*bytesRequested+pageSize/2)/pageSize, 2)
}
func (e *EdfFile) getContiguousOffset(pagesRequested uint32) (uint32, error) {
// Create the free bitmap
bitmap := make([]bool, e.getFreeMapSize())
for i := 0; i < 4; i++ {
bitmap[i] = true
}
// Traverse the contents table and build a free bitmap
block := uint64(2)
for {
// Get the range for this block
r := e.GetPageRange(block, block)
if r.Start.Segment != r.End.Segment {
return 0, fmt.Errorf("Contents block split across segments")
}
bytes := e.m[r.Start.Segment]
bytes = bytes[r.Start.Byte : r.End.Byte+1]
// Get the address of the next contents block
block = uint64FromBytes(bytes)
if block != 0 {
// No point in checking this block for free space
continue
}
bytes = bytes[8:]
// Look for a blank entry in the table
for i := 0; i < len(bytes); i += 12 {
threadID := uint32FromBytes(bytes[i:])
if threadID == 0 {
continue
}
start := uint32FromBytes(bytes[i+4:])
end := uint32FromBytes(bytes[i+8:])
for j := start; j <= end; j++ {
if int(j) >= len(bitmap) {
break
}
bitmap[j] = true
}
}
break
}
// Look through the freemap and find a good spot
for i := 0; i < len(bitmap); i++ {
if bitmap[i] {
continue
}
for j := i; j < len(bitmap); j++ {
if !bitmap[j] {
diff := j - 1 - i
if diff > int(pagesRequested) {
return uint32(i), nil
}
}
}
}
return 0, nil
}
// addNewContentsBlock adds a new contents block in the next available space
func (e *EdfFile) addNewContentsBlock() error {
var toc ContentEntry
// Find the next available offset
startBlock, err := e.getContiguousOffset(1)
if startBlock == 0 && err == nil {
// Increase the size of the file if necessary
e.extend(uint32(e.pageSize))
} else if err != nil {
return err
}
// Traverse the contents blocks looking for one with a blank NEXT pointer
block := uint64(2)
for {
// Get the range for this block
r := e.GetPageRange(block, block)
if r.Start.Segment != r.End.Segment {
return fmt.Errorf("Contents block split across segments")
}
bytes := e.m[r.Start.Segment]
bytes = bytes[r.Start.Byte : r.End.Byte+1]
// Get the address of the next contents block
block = uint64FromBytes(bytes)
if block == 0 {
uint64ToBytes(uint64(startBlock), bytes)
break
}
}
// Add to the next available TOC space
toc.Start = startBlock
toc.End = startBlock + 1
toc.Thread = 1 // SYSTEM thread
return e.addToTOC(&toc, false)
}
// addToTOC adds a ContentsEntry structure in the next available place
func (e *EdfFile) addToTOC(c *ContentEntry, extend bool) error {
// Traverse the contents table looking for a free spot
block := uint64(2)
for {
// Get the range for this block
r := e.GetPageRange(block, block)
if r.Start.Segment != r.End.Segment {
return fmt.Errorf("Contents block split across segments")
}
bytes := e.m[r.Start.Segment]
bytes = bytes[r.Start.Byte : r.End.Byte+1]
// Get the address of the next contents block
block = uint64FromBytes(bytes)
if block != 0 {
// No point in checking this block for free space
continue
}
bytes = bytes[8:]
// Look for a blank entry in the table
cur := 0
for {
threadID := uint32FromBytes(bytes)
if threadID == 0 {
break
}
cur += 12
bytes = bytes[12:]
if len(bytes) < 12 {
if extend {
// Append a new contents block and try again
e.addNewContentsBlock()
return e.addToTOC(c, false)
}
return fmt.Errorf("Can't add to contents: no space available")
}
}
// Write the contents information into this block
uint32ToBytes(c.Thread, bytes)
bytes = bytes[4:]
uint32ToBytes(c.Start, bytes)
bytes = bytes[4:]
uint32ToBytes(c.End, bytes)
break
}
return nil
}
// AllocPages allocates a |pagesRequested| chunk of pages on the Thread
// with the given identifier. Returns an EdfRange describing the result.
func (e *EdfFile) AllocPages(pagesRequested uint32, thread uint32) (EdfRange, error) {
var ret EdfRange
var toc ContentEntry
// Parameter check
if pagesRequested == 0 {
return ret, fmt.Errorf("Must request some pages")
}
if thread == 0 {
return ret, fmt.Errorf("Need a valid page identifier")
}
// Find the next available offset
startBlock, err := e.getContiguousOffset(pagesRequested)
if startBlock == 0 && err == nil {
// Increase the size of the file if necessary
e.extend(pagesRequested)
return e.AllocPages(pagesRequested, thread)
} else if err != nil {
return ret, err
}
// Add to the table of contents
toc.Thread = thread
toc.Start = startBlock
toc.End = startBlock + pagesRequested
err = e.addToTOC(&toc, true)
// Compute the range
ret = e.GetPageRange(uint64(startBlock), uint64(startBlock+pagesRequested))
return ret, err
}
// GetThreadBlocks returns EdfRanges containing blocks assigned to a given thread.
func (e *EdfFile) GetThreadBlocks(thread uint32) ([]EdfRange, error) {
var ret []EdfRange
// Traverse the contents table
block := uint64(2)
for {
// Get the range for this block
r := e.GetPageRange(block, block)
if r.Start.Segment != r.End.Segment {
return nil, fmt.Errorf("Contents block split across segments")
}
bytes := e.m[r.Start.Segment]
bytes = bytes[r.Start.Byte : r.End.Byte+1]
// Get the address of the next contents block
block = uint64FromBytes(bytes)
bytes = bytes[8:]
// Look for matching contents entries
for {
threadID := uint32FromBytes(bytes)
if threadID == thread {
blockStart := uint32FromBytes(bytes[4:])
blockEnd := uint32FromBytes(bytes[8:])
r = e.GetPageRange(uint64(blockStart), uint64(blockEnd))
ret = append(ret, r)
}
bytes = bytes[12:]
if len(bytes) < 12 {
break
}
}
// Time to stop
if block == 0 {
break
}
}
return ret, nil
}

70
base/edf/alloc_test.go Normal file
View File

@ -0,0 +1,70 @@
package edf
import (
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"os"
"testing"
)
func TestAllocFixed(t *testing.T) {
Convey("Creating a non-existent file should succeed", t, func() {
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
So(err, ShouldEqual, nil)
Convey("Mapping the file should suceed", func() {
mapping, err := EdfMap(tempFile, EDF_CREATE)
So(err, ShouldEqual, nil)
Convey("Allocation should suceed", func() {
r, err := mapping.AllocPages(1, 2)
So(err, ShouldEqual, nil)
So(r.Start.Byte, ShouldEqual, 4*os.Getpagesize())
So(r.Start.Segment, ShouldEqual, 0)
Convey("Unmapping the file should suceed", func() {
err = mapping.Unmap(EDF_UNMAP_SYNC)
So(err, ShouldEqual, nil)
Convey("Remapping the file should suceed", func() {
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
Convey("Should get the same allocations back", func() {
rr, err := mapping.GetThreadBlocks(2)
So(err, ShouldEqual, nil)
So(len(rr), ShouldEqual, 1)
So(rr[0], ShouldResemble, r)
})
})
})
})
})
})
}
func TestAllocWithExtraContentsBlock(t *testing.T) {
Convey("Creating a non-existent file should succeed", t, func() {
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
So(err, ShouldEqual, nil)
Convey("Mapping the file should suceed", func() {
mapping, err := EdfMap(tempFile, EDF_CREATE)
So(err, ShouldEqual, nil)
Convey("Allocation of 10 pages should suceed", func() {
allocated := make([]EdfRange, 10)
for i := 0; i < 10; i++ {
r, err := mapping.AllocPages(1, 2)
So(err, ShouldEqual, nil)
allocated[i] = r
}
Convey("Unmapping the file should suceed", func() {
err = mapping.Unmap(EDF_UNMAP_SYNC)
So(err, ShouldEqual, nil)
Convey("Remapping the file should suceed", func() {
mapping, err = EdfMap(tempFile, EDF_READ_ONLY)
Convey("Should get the same allocations back", func() {
rr, err := mapping.GetThreadBlocks(2)
So(err, ShouldEqual, nil)
So(len(rr), ShouldEqual, 10)
So(rr, ShouldResemble, allocated)
})
})
})
})
})
})
}

40
base/edf/edf.go Normal file
View File

@ -0,0 +1,40 @@
package edf
// map.go: handles mmaping, truncation, header creation, verification,
// creation of initial thread contents block (todo)
// creation of initial thread metadata block (todo)
// thread.go: handles extending thread contents block (todo)
// extending thread metadata block (todo), adding threads (todo),
// retrieving the segments and offsets relevant to a thread (todo)
// resolution of threads by name (todo)
// appending data to a thread (todo)
// deleting threads (todo)
const (
// EDF_VERSION is the file format version
EDF_VERSION = 1
// EDF_LENGTH is th number of OS pages in each slice
EDF_LENGTH = 32
// EDF_SIZE sets the maximum size of the mapping, represented with
// EDF_LENGTH segments
// Currently set arbitrarily to 128 MiB
EDF_SIZE = 128 * (1024 * 1024)
)
const (
// EDF_READ_ONLY means the file will only be read, modifications fail
EDF_READ_ONLY = iota
// EDF_READ_WRITE specifies that the file will be read and written
EDF_READ_WRITE
// EDF_CREATE means the file will be created and opened with EDF_READ_WRITE
EDF_CREATE
)
const (
// EDF_UNMAP_NOSYNC means the file won't be
// Sync'd to disk before unmapping
EDF_UNMAP_NOSYNC = iota
// EDF_UNMAP_SYNC synchronises the EDF file to disk
// during unmapping
EDF_UNMAP_SYNC
)

386
base/edf/map.go Normal file
View File

@ -0,0 +1,386 @@
package edf
import (
"fmt"
mmap "github.com/riobard/go-mmap"
"os"
"runtime"
)
// EdfFile represents a mapped file on disk or
// and anonymous mapping for instance storage
type EdfFile struct {
f *os.File
m []mmap.Mmap
segmentSize uint64
pageSize uint64
}
// GetPageSize returns the pageSize of an EdfFile
func (e *EdfFile) GetPageSize() uint64 {
return e.pageSize
}
// GetSegmentSize returns the segmentSize of an EdfFile
func (e *EdfFile) GetSegmentSize() uint64 {
return e.segmentSize
}
// EdfPosition represents a start and finish point
// within the mapping
type EdfPosition struct {
Segment uint64
Byte uint64
}
// EdfRange represents a start and an end segment
// mapped in an EdfFile and also the byte offsets
// within that segment
type EdfRange struct {
Start EdfPosition
End EdfPosition
segmentSize uint64
}
// Size returns the size (in bytes) of a given EdfRange
func (r *EdfRange) Size() uint64 {
ret := uint64(r.End.Segment-r.Start.Segment) * r.segmentSize
ret += uint64(r.End.Byte - r.Start.Byte)
return ret
}
// edfCallFree is a half-baked finalizer called on garbage
// collection to ensure that the mapping gets freed
func edfCallFree(e *EdfFile) {
e.Unmap(EDF_UNMAP_NOSYNC)
}
// EdfAnonMap maps the EdfFile structure into RAM
// IMPORTANT: everything's lost if unmapped
func EdfAnonMap() (*EdfFile, error) {
var err error
ret := new(EdfFile)
// Figure out the flags
protFlags := mmap.PROT_READ | mmap.PROT_WRITE
mapFlags := mmap.MAP_FILE | mmap.MAP_SHARED
// Create mapping references
ret.m = make([]mmap.Mmap, 0)
// Get the page size
pageSize := int64(os.Getpagesize())
// Segment size is the size of each mapped region
ret.pageSize = uint64(pageSize)
ret.segmentSize = uint64(EDF_LENGTH) * uint64(os.Getpagesize())
// Map the memory
for i := int64(0); i < EDF_SIZE; i += int64(EDF_LENGTH) * pageSize {
thisMapping, err := mmap.AnonMap(int(ret.segmentSize), protFlags, mapFlags)
if err != nil {
// TODO: cleanup
return nil, err
}
ret.m = append(ret.m, thisMapping)
}
// Generate the header
ret.createHeader()
err = ret.writeInitialData()
// Make sure this gets unmapped on garbage collection
runtime.SetFinalizer(ret, edfCallFree)
return ret, err
}
// EdfMap takes an os.File and returns an EdfMappedFile
// structure, which represents the mmap'd underlying file
//
// The `mode` parameter takes the following values
// EDF_CREATE: EdfMap will truncate the file to the right length and write the correct header information
// EDF_READ_WRITE: EdfMap will verify header information
// EDF_READ_ONLY: EdfMap will verify header information
// IMPORTANT: EDF_LENGTH (edf.go) controls the size of the address
// space mapping. This means that the file can be truncated to the
// correct size without remapping. On 32-bit systems, this
// is set to 2GiB.
func EdfMap(f *os.File, mode int) (*EdfFile, error) {
var err error
// Set up various things
ret := new(EdfFile)
ret.f = f
ret.m = make([]mmap.Mmap, 0)
// Figure out the flags
protFlags := mmap.PROT_READ
if mode == EDF_READ_WRITE || mode == EDF_CREATE {
protFlags |= mmap.PROT_WRITE
}
mapFlags := mmap.MAP_FILE | mmap.MAP_SHARED
// Get the page size
pageSize := int64(os.Getpagesize())
// Segment size is the size of each mapped region
ret.pageSize = uint64(pageSize)
ret.segmentSize = uint64(EDF_LENGTH) * uint64(os.Getpagesize())
// Map the file
for i := int64(0); i < EDF_SIZE; i += int64(EDF_LENGTH) * pageSize {
thisMapping, err := mmap.Map(f, i*pageSize, int(int64(EDF_LENGTH)*pageSize), protFlags, mapFlags)
if err != nil {
// TODO: cleanup
return nil, err
}
ret.m = append(ret.m, thisMapping)
}
// Verify or generate the header
if mode == EDF_READ_WRITE || mode == EDF_READ_ONLY {
err = ret.VerifyHeader()
if err != nil {
return nil, err
}
} else if mode == EDF_CREATE {
err = ret.truncate(4)
if err != nil {
return nil, err
}
ret.createHeader()
err = ret.writeInitialData()
} else {
err = fmt.Errorf("Unrecognised flags")
}
// Make sure this gets unmapped on garbage collection
runtime.SetFinalizer(ret, edfCallFree)
return ret, err
}
// Range returns the segment offset and range of
// two positions in the file.
func (e *EdfFile) Range(byteStart uint64, byteEnd uint64) EdfRange {
var ret EdfRange
ret.Start.Segment = byteStart / e.segmentSize
ret.End.Segment = byteEnd / e.segmentSize
ret.Start.Byte = byteStart % e.segmentSize
ret.End.Byte = byteEnd % e.segmentSize
ret.segmentSize = e.segmentSize
return ret
}
// GetPageRange returns the segment offset and range of
// two pages in the file.
func (e *EdfFile) GetPageRange(pageStart uint64, pageEnd uint64) EdfRange {
return e.Range(pageStart*e.pageSize, pageEnd*e.pageSize+e.pageSize-1)
}
// VerifyHeader checks that this version of Golearn can
// read the file presented.
func (e *EdfFile) VerifyHeader() error {
// Check the magic bytes
diff := (e.m[0][0] ^ byte('G')) | (e.m[0][1] ^ byte('O'))
diff |= (e.m[0][2] ^ byte('L')) | (e.m[0][3] ^ byte('N'))
if diff != 0 {
return fmt.Errorf("Invalid magic bytes")
}
// Check the file version
version := uint32FromBytes(e.m[0][4:8])
if version != EDF_VERSION {
return fmt.Errorf("Unsupported version: %d", version)
}
// Check the page size
pageSize := uint32FromBytes(e.m[0][8:12])
if pageSize != uint32(os.Getpagesize()) {
return fmt.Errorf("Unsupported page size: (file: %d, system: %d", pageSize, os.Getpagesize())
}
return nil
}
// createHeader writes a valid header file into the file.
// Unexported since it can cause data loss.
func (e *EdfFile) createHeader() {
e.m[0][0] = byte('G')
e.m[0][1] = byte('O')
e.m[0][2] = byte('L')
e.m[0][3] = byte('N')
uint32ToBytes(EDF_VERSION, e.m[0][4:8])
uint32ToBytes(uint32(os.Getpagesize()), e.m[0][8:12])
e.Sync()
}
// writeInitialData writes system thread information
func (e *EdfFile) writeInitialData() error {
var t Thread
t.name = "SYSTEM"
t.id = 1
err := e.WriteThread(&t)
if err != nil {
return err
}
t.name = "FIXED"
t.id = 2
err = e.WriteThread(&t)
return err
}
// GetThreadCount returns the number of threads in this file.
func (e *EdfFile) GetThreadCount() uint32 {
// The number of threads is stored in bytes 12-16 in the header
return uint32FromBytes(e.m[0][12:])
}
// incrementThreadCount increments the record of the number
// of threads in this file
func (e *EdfFile) incrementThreadCount() uint32 {
cur := e.GetThreadCount()
cur++
uint32ToBytes(cur, e.m[0][12:])
return cur
}
// GetThreads returns the thread identifier -> name map.
func (e *EdfFile) GetThreads() (map[uint32]string, error) {
ret := make(map[uint32]string)
count := e.GetThreadCount()
// The starting block
block := uint64(1)
for {
// Decode the block offset
r := e.GetPageRange(block, block)
if r.Start.Segment != r.End.Segment {
return nil, fmt.Errorf("Thread range split across segments")
}
bytes := e.m[r.Start.Segment]
bytes = bytes[r.Start.Byte : r.End.Byte+1]
// The first 8 bytes say where to go next
block = uint64FromBytes(bytes)
bytes = bytes[8:]
for {
length := uint32FromBytes(bytes)
if length == 0 {
break
}
t := &Thread{}
size := t.Deserialize(bytes)
bytes = bytes[size:]
ret[t.id] = t.name[0:len(t.name)]
}
// If next block offset is zero, no more threads to read
if block == 0 {
break
}
}
// Hey? What's wrong with you!
if len(ret) != int(count) {
return ret, fmt.Errorf("Thread mismatch: %d/%d, indicates possible corruption", len(ret), count)
}
return ret, nil
}
// Sync writes information to physical storage.
func (e *EdfFile) Sync() error {
for _, m := range e.m {
err := m.Sync(mmap.MS_SYNC)
if err != nil {
return err
}
}
return nil
}
// truncate changes the size of the underlying file
// The size of the address space doesn't change.
func (e *EdfFile) truncate(size int64) error {
pageSize := int64(os.Getpagesize())
newSize := pageSize * size
// Synchronise
// e.Sync()
// Double-check that we're not reducing file size
fileInfo, err := e.f.Stat()
if err != nil {
return err
}
if fileInfo.Size() > newSize {
return fmt.Errorf("Can't reduce file size!")
}
// Truncate the file
err = e.f.Truncate(newSize)
if err != nil {
return err
}
// Verify that the file is larger now than it was
fileInfo, err = e.f.Stat()
if err != nil {
return err
}
if fileInfo.Size() != newSize {
return fmt.Errorf("Truncation failed: %d, %d", fileInfo.Size(), newSize)
}
return err
}
// Unmap unlinks the EdfFile from the address space.
// EDF_UNMAP_NOSYNC skips calling Sync() on the underlying
// file before this happens.
// IMPORTANT: attempts to use this mapping after Unmap() is
// called will result in crashes.
func (e *EdfFile) Unmap(flags int) error {
// Sync the file
e.Sync()
if flags != EDF_UNMAP_NOSYNC {
e.Sync()
}
// Unmap the file
for _, m := range e.m {
err := m.Unmap()
if err != nil {
return err
}
}
return nil
}
// ResolveRange returns a slice of byte slices representing
// the underlying memory referenced by EdfRange.
//
// WARNING: slow.
func (e *EdfFile) ResolveRange(r EdfRange) [][]byte {
var ret [][]byte
segCounter := 0
for segment := r.Start.Segment; segment <= r.End.Segment; segment++ {
if segment == r.Start.Segment {
ret = append(ret, e.m[segment][r.Start.Byte:])
} else if segment == r.End.Segment {
ret = append(ret, e.m[segment][:r.End.Byte+1])
} else {
ret = append(ret, e.m[segment])
}
segCounter++
}
return ret
}
// IResolveRange returns a byte slice representing the current EdfRange
// and returns a value saying whether there's more. Subsequent calls to IncrementallyResolveRange
// should use the value returned by the previous one until no more ranges are available.
func (e *EdfFile) IResolveRange(r EdfRange, prev uint64) ([]byte, uint64) {
segment := r.Start.Segment + prev
if segment > r.End.Segment {
return nil, 0
}
if segment == r.Start.Segment {
return e.m[segment][r.Start.Byte:], prev + 1
}
if segment == r.End.Segment {
return e.m[segment][:r.End.Byte+1], 0
}
return e.m[segment], prev + 1
}

118
base/edf/map_test.go Normal file
View File

@ -0,0 +1,118 @@
package edf
import (
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"os"
"testing"
)
func TestAnonMap(t *testing.T) {
Convey("Anonymous mapping should suceed", t, func() {
mapping, err := EdfAnonMap()
So(err, ShouldEqual, nil)
bytes := mapping.m[0]
// Read the magic bytes
magic := bytes[0:4]
Convey("Magic bytes should be correct", func() {
So(magic[0], ShouldEqual, byte('G'))
So(magic[1], ShouldEqual, byte('O'))
So(magic[2], ShouldEqual, byte('L'))
So(magic[3], ShouldEqual, byte('N'))
})
// Read the file version
versionBytes := bytes[4:8]
Convey("Version should be correct", func() {
version := uint32FromBytes(versionBytes)
So(version, ShouldEqual, EDF_VERSION)
})
// Read the block size
blockBytes := bytes[8:12]
Convey("Page size should be correct", func() {
pageSize := uint32FromBytes(blockBytes)
So(pageSize, ShouldEqual, os.Getpagesize())
})
})
}
func TestFileCreate(t *testing.T) {
Convey("Creating a non-existent file should succeed", t, func() {
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
So(err, ShouldEqual, nil)
Convey("Mapping the file should suceed", func() {
mapping, err := EdfMap(tempFile, EDF_CREATE)
So(err, ShouldEqual, nil)
Convey("Unmapping the file should suceed", func() {
err = mapping.Unmap(EDF_UNMAP_SYNC)
So(err, ShouldEqual, nil)
})
// Read the magic bytes
magic := make([]byte, 4)
read, err := tempFile.ReadAt(magic, 0)
Convey("Magic bytes should be correct", func() {
So(err, ShouldEqual, nil)
So(read, ShouldEqual, 4)
So(magic[0], ShouldEqual, byte('G'))
So(magic[1], ShouldEqual, byte('O'))
So(magic[2], ShouldEqual, byte('L'))
So(magic[3], ShouldEqual, byte('N'))
})
// Read the file version
versionBytes := make([]byte, 4)
read, err = tempFile.ReadAt(versionBytes, 4)
Convey("Version should be correct", func() {
So(err, ShouldEqual, nil)
So(read, ShouldEqual, 4)
version := uint32FromBytes(versionBytes)
So(version, ShouldEqual, EDF_VERSION)
})
// Read the block size
blockBytes := make([]byte, 4)
read, err = tempFile.ReadAt(blockBytes, 8)
Convey("Page size should be correct", func() {
So(err, ShouldEqual, nil)
So(read, ShouldEqual, 4)
pageSize := uint32FromBytes(blockBytes)
So(pageSize, ShouldEqual, os.Getpagesize())
})
// Check the file size is at least four * page size
info, err := tempFile.Stat()
Convey("File should be the right size", func() {
So(err, ShouldEqual, nil)
So(info.Size(), ShouldBeGreaterThanOrEqualTo, 4*os.Getpagesize())
})
})
})
}
func TestFileThreadCounter(t *testing.T) {
Convey("Creating a non-existent file should succeed", t, func() {
tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate")
So(err, ShouldEqual, nil)
Convey("Mapping the file should suceed", func() {
mapping, err := EdfMap(tempFile, EDF_CREATE)
So(err, ShouldEqual, nil)
Convey("The file should have two threads to start with", func() {
count := mapping.GetThreadCount()
So(count, ShouldEqual, 2)
Convey("They should be SYSTEM and FIXED", func() {
threads, err := mapping.GetThreads()
So(err, ShouldEqual, nil)
So(len(threads), ShouldEqual, 2)
So(threads[1], ShouldEqual, "SYSTEM")
So(threads[2], ShouldEqual, "FIXED")
})
})
Convey("Incrementing the threadcount should result in three threads", func() {
mapping.incrementThreadCount()
count := mapping.GetThreadCount()
So(count, ShouldEqual, 3)
Convey("Thread information should indicate corruption", func() {
_, err := mapping.GetThreads()
So(err, ShouldNotEqual, nil)
})
})
})
})
}

137
base/edf/thread.go Normal file
View File

@ -0,0 +1,137 @@
package edf
import (
"fmt"
)
// Threads are streams of data encapsulated within the file.
type Thread struct {
name string
id uint32
}
// NewThread returns a new thread.
func NewThread(e *EdfFile, name string) *Thread {
return &Thread{name, e.GetThreadCount() + 1}
}
// GetSpaceNeeded the number of bytes needed to serialize this
// Thread.
func (t *Thread) GetSpaceNeeded() int {
return 8 + len(t.name)
}
// Serialize copies this thread to the output byte slice
// Returns the number of bytes used.
func (t *Thread) Serialize(out []byte) int {
// ret keeps track of written bytes
ret := 0
// Write the length of the name first
nameLength := len(t.name)
uint32ToBytes(uint32(nameLength), out)
out = out[4:]
ret += 4
// Then write the string
copy(out, t.name)
out = out[nameLength:]
ret += nameLength
// Then the thread number
uint32ToBytes(t.id, out)
ret += 4
return ret
}
// Deserialize copies the input byte slice into a thread.
func (t *Thread) Deserialize(out []byte) int {
ret := 0
// Read the length of the thread's name
nameLength := uint32FromBytes(out)
ret += 4
out = out[4:]
// Copy out the string
t.name = string(out[:nameLength])
ret += int(nameLength)
out = out[nameLength:]
// Read the identifier
t.id = uint32FromBytes(out)
ret += 4
return ret
}
// FindThread obtains the index of a thread in the EdfFile.
func (e *EdfFile) FindThread(targetName string) (uint32, error) {
var offset uint32
var counter uint32
// Resolve the initial thread block
blockRange := e.GetPageRange(1, 1)
if blockRange.Start.Segment != blockRange.End.Segment {
return 0, fmt.Errorf("Thread block split across segments!")
}
bytes := e.m[blockRange.Start.Segment][blockRange.Start.Byte:blockRange.End.Byte]
// Skip the first 8 bytes, since we don't support multiple thread blocks yet
// TODO: fix that
bytes = bytes[8:]
counter = 1
for {
length := uint32FromBytes(bytes)
if length == 0 {
return 0, fmt.Errorf("No matching threads")
}
name := string(bytes[4 : 4+length])
if name == targetName {
offset = counter
break
}
bytes = bytes[8+length:]
counter++
}
return offset, nil
}
// WriteThread inserts a new thread into the EdfFile.
func (e *EdfFile) WriteThread(t *Thread) error {
offset, _ := e.FindThread(t.name)
if offset != 0 {
return fmt.Errorf("Writing a duplicate thread")
}
// Resolve the initial Thread block
blockRange := e.GetPageRange(1, 1)
if blockRange.Start.Segment != blockRange.End.Segment {
return fmt.Errorf("Thread block split across segments!")
}
bytes := e.m[blockRange.Start.Segment][blockRange.Start.Byte:blockRange.End.Byte]
// Skip the first 8 bytes, since we don't support multiple thread blocks yet
// TODO: fix that
bytes = bytes[8:]
cur := 0
for {
length := uint32FromBytes(bytes)
if length == 0 {
break
}
cur += 8 + int(length)
bytes = bytes[8+length:]
}
// cur should have now found an empty offset
// Check that we have enough room left to insert
roomLeft := len(bytes)
roomNeeded := t.GetSpaceNeeded()
if roomLeft < roomNeeded {
return fmt.Errorf("Not enough space available")
}
// If everything's fine, serialise
t.Serialize(bytes)
// Increment thread count
e.incrementThreadCount()
return nil
}
// GetId returns this Thread's identifier.
func (t *Thread) GetId() uint32 {
return t.id
}

59
base/edf/thread_test.go Normal file
View File

@ -0,0 +1,59 @@
package edf
import (
. "github.com/smartystreets/goconvey/convey"
"testing"
"os"
)
func TestThreadDeserialize(T *testing.T) {
bytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
Convey("Given a byte slice", T, func() {
var t Thread
size := t.Deserialize(bytes)
Convey("Decoded name should be SYSTEM", func() {
So(t.name, ShouldEqual, "SYSTEM")
})
Convey("Size should be the same as the array", func() {
So(size, ShouldEqual, len(bytes))
})
})
}
func TestThreadSerialize(T *testing.T) {
var t Thread
refBytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1}
t.name = "SYSTEM"
t.id = 1
toBytes := make([]byte, len(refBytes))
Convey("Should serialize correctly", T, func() {
t.Serialize(toBytes)
So(toBytes, ShouldResemble, refBytes)
})
}
func TestThreadFindAndWrite(T *testing.T) {
Convey("Creating a non-existent file should succeed", T, func() {
tempFile, err := os.OpenFile("hello.db", os.O_RDWR | os.O_TRUNC | os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate")
So(err, ShouldEqual, nil)
Convey("Mapping the file should suceed", func() {
mapping, err := EdfMap(tempFile, EDF_CREATE)
So(err, ShouldEqual, nil)
Convey("Writing the thread should succeed", func () {
t := NewThread(mapping, "MyNameISWhat")
Convey("Thread number should be 3", func () {
So(t.id, ShouldEqual, 3)
})
Convey("Writing the thread should succeed", func() {
err := mapping.WriteThread(t)
So(err, ShouldEqual, nil)
Convey("Should be able to find the thread again later", func() {
id, err := mapping.FindThread("MyNameISWhat")
So(err, ShouldEqual, nil)
So(id, ShouldEqual, 3)
})
})
})
})
})
}

32
base/edf/util.go Normal file
View File

@ -0,0 +1,32 @@
package edf
func uint64ToBytes(in uint64, out []byte) {
var i uint64
for i = 0; i < 8; i++ {
out[7-i] = byte((in & (0xFF << i * 8) >> i * 8))
}
}
func uint64FromBytes(in []byte) uint64 {
var i uint64
out := uint64(0)
for i = 0; i < 8; i++ {
out |= uint64(in[7-i] << uint64(i*0x8))
}
return out
}
func uint32FromBytes(in []byte) uint32 {
ret := uint32(0)
ret |= uint32(in[0]) << 24
ret |= uint32(in[1]) << 16
ret |= uint32(in[2]) << 8
ret |= uint32(in[3])
return ret
}
func uint32ToBytes(in uint32, out []byte) {
out[0] = byte(in & (0xFF << 24) >> 24)
out[1] = byte(in & (0xFF << 16) >> 16)
out[2] = byte(in & (0xFF << 8) >> 8)
out[3] = byte(in & 0xFF)
}

20
base/edf/util_test.go Normal file
View File

@ -0,0 +1,20 @@
package edf
// Utility function tests
import (
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestInt32Conversion(t *testing.T) {
Convey("Given deadbeef", t, func() {
buf := make([]byte, 4)
original := uint32(0xDEAD)
uint32ToBytes(original, buf)
converted := uint32FromBytes(buf)
Convey("Decoded value should be the original...", func() {
So(converted, ShouldEqual, original)
})
})
}

258
base/filtered.go Normal file
View File

@ -0,0 +1,258 @@
package base
import (
"bytes"
"fmt"
)
// Maybe included a TransformedAttribute struct
// so we can map from ClassAttribute to ClassAttribute
// LazilyFilteredInstances map a Filter over an underlying
// FixedDataGrid and are a memory-efficient way of applying them.
type LazilyFilteredInstances struct {
filter Filter
src FixedDataGrid
attrs []FilteredAttribute
classAttrs map[Attribute]bool
unfilteredMap map[Attribute]bool
}
// NewLazilyFitleredInstances returns a new FixedDataGrid after
// applying the given Filter to the Attributes it includes. Unfiltered
// Attributes are passed through without modification.
func NewLazilyFilteredInstances(src FixedDataGrid, f Filter) *LazilyFilteredInstances {
// Get the Attributes after filtering
attrs := f.GetAttributesAfterFiltering()
// Build a set of Attributes which have undergone filtering
unFilteredMap := make(map[Attribute]bool)
for _, a := range src.AllAttributes() {
unFilteredMap[a] = true
}
for _, a := range attrs {
unFilteredMap[a.Old] = false
}
// Create the return structure
ret := &LazilyFilteredInstances{
f,
src,
attrs,
make(map[Attribute]bool),
unFilteredMap,
}
// Transfer class Attributes
for _, a := range src.AllClassAttributes() {
ret.AddClassAttribute(a)
}
return ret
}
// GetAttribute returns an AttributeSpecification for a given Attribute
func (l *LazilyFilteredInstances) GetAttribute(target Attribute) (AttributeSpec, error) {
if l.unfilteredMap[target] {
return l.src.GetAttribute(target)
}
var ret AttributeSpec
ret.pond = -1
for i, a := range l.attrs {
if a.New.Equals(target) {
ret.position = i
ret.attr = target
return ret, nil
}
}
return ret, fmt.Errorf("Couldn't resolve %s", target)
}
// AllAttributes returns every Attribute defined in the source datagrid,
// in addition to the revised Attributes created by the filter.
func (l *LazilyFilteredInstances) AllAttributes() []Attribute {
ret := make([]Attribute, 0)
for _, a := range l.src.AllAttributes() {
if l.unfilteredMap[a] {
ret = append(ret, a)
} else {
for _, b := range l.attrs {
if a.Equals(b.Old) {
ret = append(ret, b.New)
}
}
}
}
return ret
}
// AddClassAttribute adds a given Attribute (either before or after filtering)
// to the set of defined class Attributes.
func (l *LazilyFilteredInstances) AddClassAttribute(cls Attribute) error {
if l.unfilteredMap[cls] {
l.classAttrs[cls] = true
return nil
}
for _, a := range l.attrs {
if a.Old.Equals(cls) || a.New.Equals(cls) {
l.classAttrs[a.New] = true
return nil
}
}
return fmt.Errorf("Attribute %s could not be resolved", cls)
}
// RemoveClassAttribute removes a given Attribute (either before or
// after filtering) from the set of defined class Attributes.
func (l *LazilyFilteredInstances) RemoveClassAttribute(cls Attribute) error {
if l.unfilteredMap[cls] {
l.classAttrs[cls] = false
return nil
}
for _, a := range l.attrs {
if a.Old.Equals(cls) || a.New.Equals(cls) {
l.classAttrs[a.New] = false
return nil
}
}
return fmt.Errorf("Attribute %s could not be resolved", cls)
}
// AllClassAttributes returns details of all Attributes currently specified
// as being class Attributes.
//
// If applicable, the Attributes returned are those after modification
// by the Filter.
func (l *LazilyFilteredInstances) AllClassAttributes() []Attribute {
ret := make([]Attribute, 0)
for a := range l.classAttrs {
if l.classAttrs[a] {
ret = append(ret, a)
}
}
return ret
}
func (l *LazilyFilteredInstances) transformNewToOldAttribute(as AttributeSpec) (AttributeSpec, error) {
if l.unfilteredMap[as.GetAttribute()] {
return as, nil
}
for _, a := range l.attrs {
if a.Old.Equals(as.attr) || a.New.Equals(as.attr) {
as, err := l.src.GetAttribute(a.Old)
if err != nil {
return AttributeSpec{}, fmt.Errorf("Internal error in Attribute resolution: '%s'", err)
}
return as, nil
}
}
return AttributeSpec{}, fmt.Errorf("No matching Attribute")
}
// Get returns a transformed byte slice stored at a given AttributeSpec and row.
func (l *LazilyFilteredInstances) Get(as AttributeSpec, row int) []byte {
asOld, err := l.transformNewToOldAttribute(as)
if err != nil {
panic(fmt.Sprintf("Attribute %s could not be resolved. (Error: %s)", as, err))
}
byteSeq := l.src.Get(asOld, row)
if l.unfilteredMap[as.attr] {
return byteSeq
}
newByteSeq := l.filter.Transform(asOld.attr, as.attr, byteSeq)
return newByteSeq
}
// MapOverRows maps an iteration mapFunc over the bytes contained in the source
// FixedDataGrid, after modification by the filter.
func (l *LazilyFilteredInstances) MapOverRows(asv []AttributeSpec, mapFunc func([][]byte, int) (bool, error)) error {
// Have to transform each item of asv into an
// AttributeSpec in the original
oldAsv := make([]AttributeSpec, len(asv))
for i, a := range asv {
old, err := l.transformNewToOldAttribute(a)
if err != nil {
return fmt.Errorf("Couldn't fetch old Attribute: '%s'", a)
}
oldAsv[i] = old
}
// Then map over each row in the original
newRowBuf := make([][]byte, len(asv))
return l.src.MapOverRows(oldAsv, func(oldRow [][]byte, oldRowNo int) (bool, error) {
for i, b := range oldRow {
newField := l.filter.Transform(oldAsv[i].attr, asv[i].attr, b)
newRowBuf[i] = newField
}
return mapFunc(newRowBuf, oldRowNo)
})
}
// RowString returns a string representation of a given row
// after filtering.
func (l *LazilyFilteredInstances) RowString(row int) string {
var buffer bytes.Buffer
as := ResolveAllAttributes(l) // Retrieve all Attribute data
first := true // Decide whether to prefix
for _, a := range as {
prefix := " " // What to print before value
if first {
first = false // Don't print space on first value
prefix = ""
}
val := l.Get(a, row) // Retrieve filtered value
buffer.WriteString(fmt.Sprintf("%s%s", prefix, a.attr.GetStringFromSysVal(val)))
}
return buffer.String() // Return the result
}
// Size returns the number of Attributes and rows of the underlying
// FixedDataGrid.
func (l *LazilyFilteredInstances) Size() (int, int) {
return l.src.Size()
}
// String returns a human-readable summary of this FixedDataGrid
// after filtering.
func (l *LazilyFilteredInstances) String() string {
var buffer bytes.Buffer
// Decide on rows to print
_, rows := l.Size()
maxRows := 5
if rows < maxRows {
maxRows = rows
}
// Get all Attribute information
as := ResolveAllAttributes(l)
// Print header
buffer.WriteString("Lazily filtered instances using ")
buffer.WriteString(fmt.Sprintf("%s\n", l.filter))
buffer.WriteString(fmt.Sprintf("Attributes: \n"))
for _, a := range as {
prefix := "\t"
if l.classAttrs[a.attr] {
prefix = "*\t"
}
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a.attr))
}
buffer.WriteString("\nData:\n")
for i := 0; i < maxRows; i++ {
buffer.WriteString("\t")
for _, a := range as {
val := l.Get(a, i)
buffer.WriteString(fmt.Sprintf("%s ", a.attr.GetStringFromSysVal(val)))
}
buffer.WriteString("\n")
}
return buffer.String()
}

23
base/filters.go Normal file
View File

@ -0,0 +1,23 @@
package base
// FilteredAttributes represent a mapping from the output
// generated by a filter to the original value.
type FilteredAttribute struct {
Old Attribute
New Attribute
}
// Filters transform the byte sequences stored in DataGrid
// implementations.
type Filter interface {
// Adds an Attribute to the filter
AddAttribute(Attribute) error
// Allows mapping old to new Attributes
GetAttributesAfterFiltering() []FilteredAttribute
// Gets a string for printing
String() string
// Accepts an old Attribute, the new one and returns a sequence
Transform(Attribute, Attribute, []byte) []byte
// Builds the filter
Train() error
}

103
base/float.go Normal file
View File

@ -0,0 +1,103 @@
package base
import (
"fmt"
"strconv"
)
// FloatAttribute is an implementation which stores floating point
// representations of numbers.
type FloatAttribute struct {
Name string
Precision int
}
// NewFloatAttribute returns a new FloatAttribute with a default
// precision of 2 decimal places
func NewFloatAttribute() *FloatAttribute {
return &FloatAttribute{"", 2}
}
// Compatable checks whether this FloatAttribute can be ponded with another
// Attribute (checks if they're both FloatAttributes)
func (Attr *FloatAttribute) Compatable(other Attribute) bool {
_, ok := other.(*FloatAttribute)
return ok
}
// Equals tests a FloatAttribute for equality with another Attribute.
//
// Returns false if the other Attribute has a different name
// or if the other Attribute is not a FloatAttribute.
func (Attr *FloatAttribute) Equals(other Attribute) bool {
// Check whether this FloatAttribute is equal to another
_, ok := other.(*FloatAttribute)
if !ok {
// Not the same type, so can't be equal
return false
}
if Attr.GetName() != other.GetName() {
return false
}
return true
}
// GetName returns this FloatAttribute's human-readable name.
func (Attr *FloatAttribute) GetName() string {
return Attr.Name
}
// SetName sets this FloatAttribute's human-readable name.
func (Attr *FloatAttribute) SetName(name string) {
Attr.Name = name
}
// GetType returns Float64Type.
func (Attr *FloatAttribute) GetType() int {
return Float64Type
}
// String returns a human-readable summary of this Attribute.
// e.g. "FloatAttribute(Sepal Width)"
func (Attr *FloatAttribute) String() string {
return fmt.Sprintf("FloatAttribute(%s)", Attr.Name)
}
// CheckSysValFromString confirms whether a given rawVal can
// be converted into a valid system representation. If it can't,
// the returned value is nil.
func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) ([]byte, error) {
f, err := strconv.ParseFloat(rawVal, 64)
if err != nil {
return nil, err
}
ret := PackFloatToBytes(f)
return ret, nil
}
// GetSysValFromString parses the given rawVal string to a float64 and returns it.
//
// float64 happens to be a 1-to-1 mapping to the system representation.
// IMPORTANT: This function panic()s if rawVal is not a valid float.
// Use CheckSysValFromString to confirm.
func (Attr *FloatAttribute) GetSysValFromString(rawVal string) []byte {
f, err := Attr.CheckSysValFromString(rawVal)
if err != nil {
panic(err)
}
return f
}
// GetFloatFromSysVal converts a given system value to a float
func (Attr *FloatAttribute) GetFloatFromSysVal(rawVal []byte) float64 {
return UnpackBytesToFloat(rawVal)
}
// GetStringFromSysVal converts a given system value to to a string with two decimal
// places of precision.
func (Attr *FloatAttribute) GetStringFromSysVal(rawVal []byte) string {
f := UnpackBytesToFloat(rawVal)
formatString := fmt.Sprintf("%%.%df", Attr.Precision)
return fmt.Sprintf(formatString, f)
}

View File

@ -1,561 +0,0 @@
package base
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math/rand"
"github.com/gonum/matrix/mat64"
)
// SortDirection specifies sorting direction...
type SortDirection int
const (
// Descending says that Instances should be sorted high to low...
Descending SortDirection = 1
// Ascending states that Instances should be sorted low to high...
Ascending SortDirection = 2
)
const highBit int64 = -1 << 63
// Instances represents a grid of numbers (typed by Attributes)
// stored internally in mat.DenseMatrix as float64's.
// See docs/instances.md for more information.
type Instances struct {
storage *mat64.Dense
attributes []Attribute
Rows int
Cols int
ClassIndex int
}
func xorFloatOp(item float64) float64 {
var ret float64
var tmp int64
buf := bytes.NewBuffer(nil)
binary.Write(buf, binary.LittleEndian, item)
binary.Read(buf, binary.LittleEndian, &tmp)
tmp ^= -1 << 63
binary.Write(buf, binary.LittleEndian, tmp)
binary.Read(buf, binary.LittleEndian, &ret)
return ret
}
func printFloatByteArr(arr [][]byte) {
buf := bytes.NewBuffer(nil)
var f float64
for _, b := range arr {
buf.Write(b)
binary.Read(buf, binary.LittleEndian, &f)
f = xorFloatOp(f)
fmt.Println(f)
}
}
// Sort does an in-place radix sort of Instances, using SortDirection
// direction (Ascending or Descending) with attrs as a slice of Attribute
// indices that you want to sort by.
//
// IMPORTANT: Radix sort is not stable, so ordering outside
// the attributes used for sorting is arbitrary.
func (inst *Instances) Sort(direction SortDirection, attrs []int) {
// Create a buffer
buf := bytes.NewBuffer(nil)
ds := make([][]byte, inst.Rows)
rs := make([]int, inst.Rows)
for i := 0; i < inst.Rows; i++ {
byteBuf := make([]byte, 8*len(attrs))
for _, a := range attrs {
x := inst.storage.At(i, a)
binary.Write(buf, binary.LittleEndian, xorFloatOp(x))
}
buf.Read(byteBuf)
ds[i] = byteBuf
rs[i] = i
}
// Sort viua
valueBins := make([][][]byte, 256)
rowBins := make([][]int, 256)
for i := 0; i < 8*len(attrs); i++ {
for j := 0; j < len(ds); j++ {
// Address each row value by it's ith byte
b := ds[j]
valueBins[b[i]] = append(valueBins[b[i]], b)
rowBins[b[i]] = append(rowBins[b[i]], rs[j])
}
j := 0
for k := 0; k < 256; k++ {
bs := valueBins[k]
rc := rowBins[k]
copy(ds[j:], bs)
copy(rs[j:], rc)
j += len(bs)
valueBins[k] = bs[:0]
rowBins[k] = rc[:0]
}
}
for _, b := range ds {
var v float64
buf.Write(b)
binary.Read(buf, binary.LittleEndian, &v)
}
done := make([]bool, inst.Rows)
for index := range rs {
if done[index] {
continue
}
j := index
for {
done[j] = true
if rs[j] != index {
inst.swapRows(j, rs[j])
j = rs[j]
} else {
break
}
}
}
if direction == Descending {
// Reverse the matrix
for i, j := 0, inst.Rows-1; i < j; i, j = i+1, j-1 {
inst.swapRows(i, j)
}
}
}
// NewInstances returns a preallocated Instances structure
// with some helful values pre-filled.
func NewInstances(attrs []Attribute, rows int) *Instances {
rawStorage := make([]float64, rows*len(attrs))
return NewInstancesFromRaw(attrs, rows, rawStorage)
}
// CheckNewInstancesFromRaw checks whether a call to NewInstancesFromRaw
// is likely to produce an error-free result.
func CheckNewInstancesFromRaw(attrs []Attribute, rows int, data []float64) error {
size := rows * len(attrs)
if size < len(data) {
return errors.New("base: data length is larger than the rows * attribute space")
} else if size > len(data) {
return errors.New("base: data is smaller than the rows * attribute space")
}
return nil
}
// NewInstancesFromRaw wraps a slice of float64 numbers in a
// mat64.Dense structure, reshaping it with the given number of rows
// and representing it with the given attrs (Attribute slice)
//
// IMPORTANT: if the |attrs| * |rows| value doesn't equal len(data)
// then panic()s may occur. Use CheckNewInstancesFromRaw to confirm.
func NewInstancesFromRaw(attrs []Attribute, rows int, data []float64) *Instances {
rawStorage := mat64.NewDense(rows, len(attrs), data)
return NewInstancesFromDense(attrs, rows, rawStorage)
}
// NewInstancesFromDense creates a set of Instances from a mat64.Dense
// matrix
func NewInstancesFromDense(attrs []Attribute, rows int, mat *mat64.Dense) *Instances {
return &Instances{mat, attrs, rows, len(attrs), len(attrs) - 1}
}
// InstancesTrainTestSplit takes a given Instances (src) and a train-test fraction
// (prop) and returns an array of two new Instances, one containing approximately
// that fraction and the other containing what's left.
//
// IMPORTANT: this function is only meaningful when prop is between 0.0 and 1.0.
// Using any other values may result in odd behaviour.
func InstancesTrainTestSplit(src *Instances, prop float64) (*Instances, *Instances) {
trainingRows := make([]int, 0)
testingRows := make([]int, 0)
numAttrs := len(src.attributes)
src.Shuffle()
for i := 0; i < src.Rows; i++ {
trainOrTest := rand.Intn(101)
if trainOrTest > int(100*prop) {
trainingRows = append(trainingRows, i)
} else {
testingRows = append(testingRows, i)
}
}
rawTrainMatrix := mat64.NewDense(len(trainingRows), numAttrs, make([]float64, len(trainingRows)*numAttrs))
rawTestMatrix := mat64.NewDense(len(testingRows), numAttrs, make([]float64, len(testingRows)*numAttrs))
for i, row := range trainingRows {
rowDat := src.storage.RowView(row)
rawTrainMatrix.SetRow(i, rowDat)
}
for i, row := range testingRows {
rowDat := src.storage.RowView(row)
rawTestMatrix.SetRow(i, rowDat)
}
trainingRet := NewInstancesFromDense(src.attributes, len(trainingRows), rawTrainMatrix)
testRet := NewInstancesFromDense(src.attributes, len(testingRows), rawTestMatrix)
return trainingRet, testRet
}
// CountAttrValues returns the distribution of values of a given
// Attribute.
// IMPORTANT: calls panic() if the attribute index of a cannot be
// determined. Call GetAttrIndex(a) and check for a -1 return value.
func (inst *Instances) CountAttrValues(a Attribute) map[string]int {
ret := make(map[string]int)
attrIndex := inst.GetAttrIndex(a)
if attrIndex == -1 {
panic("Invalid attribute")
}
for i := 0; i < inst.Rows; i++ {
sysVal := inst.Get(i, attrIndex)
stringVal := a.GetStringFromSysVal(sysVal)
ret[stringVal]++
}
return ret
}
// CountClassValues returns the class distribution of this
// Instances set
func (inst *Instances) CountClassValues() map[string]int {
a := inst.GetAttr(inst.ClassIndex)
return inst.CountAttrValues(a)
}
// DecomposeOnAttributeValues divides the instance set depending on the
// value of a given Attribute, constructs child instances, and returns
// them in a map keyed on the string value of that Attribute.
// IMPORTANT: calls panic() if the attribute index of at cannot be determined.
// Use GetAttrIndex(at) and check for a non-zero return value.
func (inst *Instances) DecomposeOnAttributeValues(at Attribute) map[string]*Instances {
// Find the attribute we're decomposing on
attrIndex := inst.GetAttrIndex(at)
if attrIndex == -1 {
panic("Invalid attribute index")
}
// Construct the new attribute set
newAttrs := make([]Attribute, 0)
for i := range inst.attributes {
a := inst.attributes[i]
if a.Equals(at) {
continue
}
newAttrs = append(newAttrs, a)
}
// Create the return map, several counting maps
ret := make(map[string]*Instances)
counts := inst.CountAttrValues(at) // So we know what to allocate
rows := make(map[string]int)
for k := range counts {
tmp := NewInstances(newAttrs, counts[k])
ret[k] = tmp
}
for i := 0; i < inst.Rows; i++ {
newAttrCounter := 0
classVar := at.GetStringFromSysVal(inst.Get(i, attrIndex))
dest := ret[classVar]
destRow := rows[classVar]
for j := 0; j < inst.Cols; j++ {
a := inst.attributes[j]
if a.Equals(at) {
continue
}
dest.Set(destRow, newAttrCounter, inst.Get(i, j))
newAttrCounter++
}
rows[classVar]++
}
return ret
}
func (inst *Instances) GetClassDistributionAfterSplit(at Attribute) map[string]map[string]int {
ret := make(map[string]map[string]int)
// Find the attribute we're decomposing on
attrIndex := inst.GetAttrIndex(at)
if attrIndex == -1 {
panic("Invalid attribute index")
}
// Get the class index
classAttr := inst.GetAttr(inst.ClassIndex)
for i := 0; i < inst.Rows; i++ {
splitVar := at.GetStringFromSysVal(inst.Get(i, attrIndex))
classVar := classAttr.GetStringFromSysVal(inst.Get(i, inst.ClassIndex))
if _, ok := ret[splitVar]; !ok {
ret[splitVar] = make(map[string]int)
i--
continue
}
ret[splitVar][classVar]++
}
return ret
}
// Get returns the system representation (float64) of the value
// stored at the given row and col coordinate.
func (inst *Instances) Get(row int, col int) float64 {
return inst.storage.At(row, col)
}
// Set sets the system representation (float64) to val at the
// given row and column coordinate.
func (inst *Instances) Set(row int, col int, val float64) {
inst.storage.Set(row, col, val)
}
// GetRowVector returns a row of system representation
// values at the given row index.
func (inst *Instances) GetRowVector(row int) []float64 {
return inst.storage.RowView(row)
}
// GetRowVectorWithoutClass returns a row of system representation
// values at the given row index, excluding the class attribute
func (inst *Instances) GetRowVectorWithoutClass(row int) []float64 {
rawRow := make([]float64, inst.Cols)
copy(rawRow, inst.GetRowVector(row))
return append(rawRow[0:inst.ClassIndex], rawRow[inst.ClassIndex+1:inst.Cols]...)
}
// GetClass returns the string representation of the given
// row's class, as determined by the Attribute at the ClassIndex
// position from GetAttr
func (inst *Instances) GetClass(row int) string {
attr := inst.GetAttr(inst.ClassIndex)
val := inst.Get(row, inst.ClassIndex)
return attr.GetStringFromSysVal(val)
}
// GetClassDistribution returns a map containing the count of each
// class type (indexed by the class' string representation)
func (inst *Instances) GetClassDistribution() map[string]int {
ret := make(map[string]int)
attr := inst.GetAttr(inst.ClassIndex)
for i := 0; i < inst.Rows; i++ {
val := inst.Get(i, inst.ClassIndex)
cls := attr.GetStringFromSysVal(val)
ret[cls]++
}
return ret
}
func (inst *Instances) GetClassAttrPtr() *Attribute {
attr := inst.GetAttr(inst.ClassIndex)
return &attr
}
func (inst *Instances) GetClassAttr() Attribute {
return inst.GetAttr(inst.ClassIndex)
}
//
// Attribute functions
//
// GetAttributeCount returns the number of attributes represented.
func (inst *Instances) GetAttributeCount() int {
// Return the number of attributes attached to this Instance set
return len(inst.attributes)
}
// SetAttrStr sets the system-representation value of row in column attr
// to value val, implicitly converting the string to system-representation
// via the appropriate Attribute function.
func (inst *Instances) SetAttrStr(row int, attr int, val string) {
// Set an attribute on a particular row from a string value
a := inst.attributes[attr]
sysVal := a.GetSysValFromString(val)
inst.storage.Set(row, attr, sysVal)
}
// GetAttrStr returns a human-readable string value stored in column `attr'
// and row `row', as determined by the appropriate Attribute function.
func (inst *Instances) GetAttrStr(row int, attr int) string {
// Get a human-readable value from a particular row
a := inst.attributes[attr]
usrVal := a.GetStringFromSysVal(inst.Get(row, attr))
return usrVal
}
// GetAttr returns information about an attribute at given index
// in the attributes slice.
func (inst *Instances) GetAttr(attrIndex int) Attribute {
// Return a copy of an attribute attached to this Instance set
return inst.attributes[attrIndex]
}
// GetAttrIndex returns the offset of a given Attribute `a' to an
// index in the attributes slice
func (inst *Instances) GetAttrIndex(of Attribute) int {
// Finds the offset of an Attribute in this instance set
// Returns -1 if no Attribute matches
for i, a := range inst.attributes {
if a.Equals(of) {
return i
}
}
return -1
}
// ReplaceAttr overwrites the attribute at `index' with `a'
func (inst *Instances) ReplaceAttr(index int, a Attribute) {
// Replace an Attribute at index with another
// DOESN'T CONVERT ANY EXISTING VALUES
inst.attributes[index] = a
}
//
// Printing functions
//
// RowStr returns a human-readable representation of a given row.
func (inst *Instances) RowStr(row int) string {
// Prints a given row
var buffer bytes.Buffer
for j := 0; j < inst.Cols; j++ {
val := inst.storage.At(row, j)
a := inst.attributes[j]
postfix := " "
if j == inst.Cols-1 {
postfix = ""
}
buffer.WriteString(fmt.Sprintf("%s%s", a.GetStringFromSysVal(val), postfix))
}
return buffer.String()
}
func (inst *Instances) String() string {
var buffer bytes.Buffer
buffer.WriteString("Instances with ")
buffer.WriteString(fmt.Sprintf("%d row(s) ", inst.Rows))
buffer.WriteString(fmt.Sprintf("%d attribute(s)\n", inst.Cols))
buffer.WriteString(fmt.Sprintf("Attributes: \n"))
for i, a := range inst.attributes {
prefix := "\t"
if i == inst.ClassIndex {
prefix = "*\t"
}
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a))
}
buffer.WriteString("\nData:\n")
maxRows := 30
if inst.Rows < maxRows {
maxRows = inst.Rows
}
for i := 0; i < maxRows; i++ {
buffer.WriteString("\t")
for j := 0; j < inst.Cols; j++ {
val := inst.storage.At(i, j)
a := inst.attributes[j]
buffer.WriteString(fmt.Sprintf("%s ", a.GetStringFromSysVal(val)))
}
buffer.WriteString("\n")
}
missingRows := inst.Rows - maxRows
if missingRows != 0 {
buffer.WriteString(fmt.Sprintf("\t...\n%d row(s) undisplayed", missingRows))
} else {
buffer.WriteString("All rows displayed")
}
return buffer.String()
}
// SelectAttributes returns a new instance set containing
// the values from this one with only the Attributes specified
func (inst *Instances) SelectAttributes(attrs []Attribute) *Instances {
ret := NewInstances(attrs, inst.Rows)
attrIndices := make([]int, 0)
for _, a := range attrs {
attrIndex := inst.GetAttrIndex(a)
attrIndices = append(attrIndices, attrIndex)
}
for i := 0; i < inst.Rows; i++ {
for j, a := range attrIndices {
ret.Set(i, j, inst.Get(i, a))
}
}
return ret
}
// GeneratePredictionVector generates a new set of Instances
// with the same number of rows, but only this Instance set's
// class Attribute.
func (inst *Instances) GeneratePredictionVector() *Instances {
attrs := make([]Attribute, 1)
attrs[0] = inst.GetClassAttr()
ret := NewInstances(attrs, inst.Rows)
return ret
}
// Shuffle randomizes the row order in place
func (inst *Instances) Shuffle() {
for i := 0; i < inst.Rows; i++ {
j := rand.Intn(i + 1)
inst.swapRows(i, j)
}
}
// SampleWithReplacement returns a new set of Instances of size `size'
// containing random rows from this set of Instances.
//
// IMPORTANT: There's a high chance of seeing duplicate rows
// whenever size is close to the row count.
func (inst *Instances) SampleWithReplacement(size int) *Instances {
ret := NewInstances(inst.attributes, size)
for i := 0; i < size; i++ {
srcRow := rand.Intn(inst.Rows)
for j := 0; j < inst.Cols; j++ {
ret.Set(i, j, inst.Get(srcRow, j))
}
}
return ret
}
// Equal checks whether a given Instance set is exactly the same
// as another: same size and same values (as determined by the Attributes)
//
// IMPORTANT: does not explicitly check if the Attributes are considered equal.
func (inst *Instances) Equal(other *Instances) bool {
if inst.Rows != other.Rows {
return false
}
if inst.Cols != other.Cols {
return false
}
for i := 0; i < inst.Rows; i++ {
for j := 0; j < inst.Cols; j++ {
if inst.GetAttrStr(i, j) != other.GetAttrStr(i, j) {
return false
}
}
}
return true
}
func (inst *Instances) swapRows(r1 int, r2 int) {
row1buf := make([]float64, inst.Cols)
row2buf := make([]float64, inst.Cols)
row1 := inst.storage.RowView(r1)
row2 := inst.storage.RowView(r2)
copy(row1buf, row1)
copy(row2buf, row2)
inst.storage.SetRow(r1, row2buf)
inst.storage.SetRow(r2, row1buf)
}

87
base/lazy_sort_test.go Normal file
View File

@ -0,0 +1,87 @@
package base
import (
"fmt"
"testing"
)
func TestLazySortDesc(testEnv *testing.T) {
inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
testEnv.Error(err)
return
}
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
if err != nil {
testEnv.Error(err)
return
}
as1 := ResolveAllAttributes(inst1)
as2 := ResolveAllAttributes(inst2)
if isSortedDesc(inst1, as1[0]) {
testEnv.Error("Can't test descending sort order")
}
if !isSortedDesc(inst2, as2[0]) {
testEnv.Error("Reference data not sorted in descending order!")
}
inst, err := LazySort(inst1, Descending, as1[0:len(as1)-1])
if err != nil {
testEnv.Error(err)
}
if !isSortedDesc(inst, as1[0]) {
testEnv.Error("Instances are not sorted in descending order")
testEnv.Error(inst1)
}
if !inst2.Equal(inst) {
testEnv.Error("Instances don't match")
testEnv.Error(inst)
testEnv.Error(inst2)
}
}
func TestLazySortAsc(testEnv *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
as1 := ResolveAllAttributes(inst)
if isSortedAsc(inst, as1[0]) {
testEnv.Error("Can't test ascending sort on something ascending already")
}
if err != nil {
testEnv.Error(err)
return
}
insts, err := LazySort(inst, Ascending, as1)
if err != nil {
testEnv.Error(err)
return
}
if !isSortedAsc(insts, as1[0]) {
testEnv.Error("Instances are not sorted in ascending order")
testEnv.Error(insts)
}
inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
if err != nil {
testEnv.Error(err)
return
}
as2 := ResolveAllAttributes(inst2)
if !isSortedAsc(inst2, as2[0]) {
testEnv.Error("This file should be sorted in ascending order")
}
if !inst2.Equal(insts) {
testEnv.Error("Instances don't match")
testEnv.Error(inst)
testEnv.Error(inst2)
}
rowStr := insts.RowString(0)
ref := "4.30 3.00 1.10 0.10 Iris-setosa"
if rowStr != ref {
panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref))
}
}

122
base/pond.go Normal file
View File

@ -0,0 +1,122 @@
package base
import (
"bytes"
"fmt"
)
// Ponds contain a particular number of rows of
// a particular number of Attributes, all of a given type.
type Pond struct {
threadNo uint32
parent DataGrid
attributes []Attribute
size int
alloc [][]byte
maxRow int
}
func (p *Pond) String() string {
if len(p.alloc) > 1 {
return fmt.Sprintf("Pond(%d attributes\n thread: %d\n size: %d\n)", len(p.attributes), p.threadNo, p.size)
}
return fmt.Sprintf("Pond(%d attributes\n thread: %d\n size: %d\n %d \n)", len(p.attributes), p.threadNo, p.size, p.alloc[0][0:60])
}
// PondStorageRef is a reference to a particular set
// of allocated rows within a Pond
type PondStorageRef struct {
Storage []byte
Rows int
}
// RowSize returns the size of each row in bytes
func (p *Pond) RowSize() int {
return len(p.attributes) * p.size
}
// Attributes returns a slice of Attributes in this Pond
func (p *Pond) Attributes() []Attribute {
return p.attributes
}
// Storage returns a slice of PondStorageRefs which can
// be used to access the memory in this pond.
func (p *Pond) Storage() []PondStorageRef {
ret := make([]PondStorageRef, len(p.alloc))
rowSize := p.RowSize()
for i, b := range p.alloc {
ret[i] = PondStorageRef{b, len(b) / rowSize}
}
return ret
}
func (p *Pond) resolveBlock(col int, row int) (int, int) {
if len(p.alloc) == 0 {
panic("No blocks to resolve")
}
// Find where in the pond the byte is
byteOffset := row*p.RowSize() + col*p.size
curOffset := 0
curBlock := 0
blockOffset := 0
for {
if curBlock >= len(p.alloc) {
panic("Don't have enough blocks to fulfill")
}
// Rows are not allowed to span blocks
blockAdd := len(p.alloc[curBlock])
blockAdd -= blockAdd % p.RowSize()
// Case 1: we need to skip this allocation
if curOffset+blockAdd < byteOffset {
curOffset += blockAdd
curBlock++
} else {
blockOffset = byteOffset - curOffset
break
}
}
return curBlock, blockOffset
}
func (p *Pond) set(col int, row int, val []byte) {
// Double-check the length
if len(val) != p.size {
panic(fmt.Sprintf("Tried to call set() with %d bytes, should be %d", len(val), p.size))
}
// Find where in the pond the byte is
curBlock, blockOffset := p.resolveBlock(col, row)
// Copy the value in
copied := copy(p.alloc[curBlock][blockOffset:], val)
if copied != p.size {
panic(fmt.Sprintf("set() terminated by only copying %d bytes into the current block (should be %d). Check EDF allocation", copied, p.size))
}
row++
if row > p.maxRow {
p.maxRow = row
}
}
func (p *Pond) get(col int, row int) []byte {
curBlock, blockOffset := p.resolveBlock(col, row)
return p.alloc[curBlock][blockOffset : blockOffset+p.size]
}
func (p *Pond) appendToRowBuf(row int, buffer *bytes.Buffer) {
for i, a := range p.attributes {
postfix := " "
if i == len(p.attributes)-1 {
postfix = ""
}
buffer.WriteString(fmt.Sprintf("%s%s", a.GetStringFromSysVal(p.get(i, row)), postfix))
}
}

168
base/sort.go Normal file
View File

@ -0,0 +1,168 @@
package base
import (
"bytes"
"encoding/binary"
)
func sortXorOp(b []byte) []byte {
ret := make([]byte, len(b))
copy(ret, b)
ret[0] ^= 0x80
return ret
}
type sortSpec struct {
r1 int
r2 int
}
// Returns sortSpecs for inst in ascending order
func createSortSpec(inst FixedDataGrid, attrsArg []AttributeSpec) []sortSpec {
attrs := make([]AttributeSpec, len(attrsArg))
copy(attrs, attrsArg)
// Reverse attribute order to be more intuitive
for i, j := 0, len(attrs)-1; i < j; i, j = i+1, j-1 {
attrs[i], attrs[j] = attrs[j], attrs[i]
}
_, rows := inst.Size()
ret := make([]sortSpec, 0)
// Create a buffer
buf := bytes.NewBuffer(nil)
ds := make([][]byte, rows)
rs := make([]int, rows)
rowSize := 0
inst.MapOverRows(attrs, func(row [][]byte, rowNo int) (bool, error) {
if rowSize == 0 {
// Allocate a row buffer
for _, r := range row {
rowSize += len(r)
}
}
byteBuf := make([]byte, rowSize)
for i, r := range row {
if i == 0 {
binary.Write(buf, binary.LittleEndian, sortXorOp(r))
} else {
binary.Write(buf, binary.LittleEndian, r)
}
}
buf.Read(byteBuf)
ds[rowNo] = byteBuf
rs[rowNo] = rowNo
return true, nil
})
// Sort values
valueBins := make([][][]byte, 256)
rowBins := make([][]int, 256)
for i := 0; i < rowSize; i++ {
for j := 0; j < len(ds); j++ {
// Address each row value by it's ith byte
b := ds[j]
valueBins[b[i]] = append(valueBins[b[i]], b)
rowBins[b[i]] = append(rowBins[b[i]], rs[j])
}
j := 0
for k := 0; k < 256; k++ {
bs := valueBins[k]
rc := rowBins[k]
copy(ds[j:], bs)
copy(rs[j:], rc)
j += len(bs)
valueBins[k] = bs[:0]
rowBins[k] = rc[:0]
}
}
done := make([]bool, rows)
for index := range rs {
if done[index] {
continue
}
j := index
for {
done[j] = true
if rs[j] != index {
ret = append(ret, sortSpec{j, rs[j]})
j = rs[j]
} else {
break
}
}
}
return ret
}
// Sort does a radix sort of DenseInstances, using SortDirection
// direction (Ascending or Descending) with attrs as a slice of Attribute
// indices that you want to sort by.
//
// IMPORTANT: Radix sort is not stable, so ordering outside
// the attributes used for sorting is arbitrary.
func Sort(inst FixedDataGrid, direction SortDirection, attrs []AttributeSpec) (FixedDataGrid, error) {
sortInstructions := createSortSpec(inst, attrs)
instUpdatable, ok := inst.(*DenseInstances)
if ok {
for _, i := range sortInstructions {
instUpdatable.swapRows(i.r1, i.r2)
}
if direction == Descending {
// Reverse the matrix
_, rows := inst.Size()
for i, j := 0, rows-1; i < j; i, j = i+1, j-1 {
instUpdatable.swapRows(i, j)
}
}
} else {
panic("Sort is not supported for this yet!")
}
return instUpdatable, nil
}
// LazySort also does a sort, but returns an InstanceView and doesn't actually
// reorder the rows, just makes it look like they've been reordered
// See also: Sort
func LazySort(inst FixedDataGrid, direction SortDirection, attrs []AttributeSpec) (FixedDataGrid, error) {
// Run the sort operation
sortInstructions := createSortSpec(inst, attrs)
// Build the row -> row mapping
_, rows := inst.Size() // Get the total row count
rowArr := make([]int, rows) // Create an array of positions
for i := 0; i < len(rowArr); i++ {
rowArr[i] = i
}
for i := range sortInstructions {
r1 := rowArr[sortInstructions[i].r1]
r2 := rowArr[sortInstructions[i].r2]
// Swap
rowArr[sortInstructions[i].r1] = r2
rowArr[sortInstructions[i].r2] = r1
}
if direction == Descending {
for i, j := 0, rows-1; i < j; i, j = i+1, j-1 {
tmp := rowArr[i]
rowArr[i] = rowArr[j]
rowArr[j] = tmp
}
}
// Create a mapping dictionary
rowMap := make(map[int]int)
for i, a := range rowArr {
if i == a {
continue
}
rowMap[i] = a
}
// Create the return structure
ret := NewInstancesViewFromRows(inst, rowMap)
return ret, nil
}

View File

@ -2,10 +2,11 @@ package base
import "testing"
func isSortedAsc(inst *Instances, attrIndex int) bool {
func isSortedAsc(inst FixedDataGrid, attr AttributeSpec) bool {
valPrev := 0.0
for i := 0; i < inst.Rows; i++ {
cur := inst.Get(i, attrIndex)
_, rows := inst.Size()
for i := 0; i < rows; i++ {
cur := UnpackBytesToFloat(inst.Get(attr, i))
if i > 0 {
if valPrev > cur {
return false
@ -16,10 +17,11 @@ func isSortedAsc(inst *Instances, attrIndex int) bool {
return true
}
func isSortedDesc(inst *Instances, attrIndex int) bool {
func isSortedDesc(inst FixedDataGrid, attr AttributeSpec) bool {
valPrev := 0.0
for i := 0; i < inst.Rows; i++ {
cur := inst.Get(i, attrIndex)
_, rows := inst.Size()
for i := 0; i < rows; i++ {
cur := UnpackBytesToFloat(inst.Get(attr, i))
if i > 0 {
if valPrev < cur {
return false
@ -42,25 +44,25 @@ func TestSortDesc(testEnv *testing.T) {
return
}
if isSortedDesc(inst1, 0) {
as1 := ResolveAllAttributes(inst1)
as2 := ResolveAllAttributes(inst2)
if isSortedDesc(inst1, as1[0]) {
testEnv.Error("Can't test descending sort order")
}
if !isSortedDesc(inst2, 0) {
if !isSortedDesc(inst2, as2[0]) {
testEnv.Error("Reference data not sorted in descending order!")
}
attrs := make([]int, 4)
attrs[0] = 3
attrs[1] = 2
attrs[2] = 1
attrs[3] = 0
inst1.Sort(Descending, attrs)
if !isSortedDesc(inst1, 0) {
Sort(inst1, Descending, as1[0:len(as1)-1])
if err != nil {
testEnv.Error(err)
}
if !isSortedDesc(inst1, as1[0]) {
testEnv.Error("Instances are not sorted in descending order")
testEnv.Error(inst1)
}
if !inst2.Equal(inst1) {
inst1.storage.Sub(inst1.storage, inst2.storage)
testEnv.Error(inst1.storage)
testEnv.Error("Instances don't match")
testEnv.Error(inst1)
testEnv.Error(inst2)
@ -69,20 +71,16 @@ func TestSortDesc(testEnv *testing.T) {
func TestSortAsc(testEnv *testing.T) {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if isSortedAsc(inst, 0) {
as1 := ResolveAllAttributes(inst)
if isSortedAsc(inst, as1[0]) {
testEnv.Error("Can't test ascending sort on something ascending already")
}
if err != nil {
testEnv.Error(err)
return
}
attrs := make([]int, 4)
attrs[0] = 3
attrs[1] = 2
attrs[2] = 1
attrs[3] = 0
inst.Sort(Ascending, attrs)
if !isSortedAsc(inst, 0) {
Sort(inst, Ascending, as1[0:1])
if !isSortedAsc(inst, as1[0]) {
testEnv.Error("Instances are not sorted in ascending order")
testEnv.Error(inst)
}
@ -92,13 +90,12 @@ func TestSortAsc(testEnv *testing.T) {
testEnv.Error(err)
return
}
if !isSortedAsc(inst2, 0) {
as2 := ResolveAllAttributes(inst2)
if !isSortedAsc(inst2, as2[0]) {
testEnv.Error("This file should be sorted in ascending order")
}
if !inst2.Equal(inst) {
inst.storage.Sub(inst.storage, inst2.storage)
testEnv.Error(inst.storage)
testEnv.Error("Instances don't match")
testEnv.Error(inst)
testEnv.Error(inst2)

25
base/spec.go Normal file
View File

@ -0,0 +1,25 @@
package base
import (
"fmt"
)
// AttributeSpec is a pointer to a particular Attribute
// within a particular Instance structure and encodes position
// and storage information associated with that Attribute.
type AttributeSpec struct {
pond int
position int
attr Attribute
}
// GetAttribute returns an AttributeSpec which matches a given
// Attribute.
func (a *AttributeSpec) GetAttribute() Attribute {
return a.attr
}
// String returns a human-readable description of this AttributeSpec.
func (a *AttributeSpec) String() string {
return fmt.Sprintf("AttributeSpec(Attribute: '%s', Pond: %d/%d)", a.attr, a.pond, a.position)
}

98
base/util.go Normal file
View File

@ -0,0 +1,98 @@
package base
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"unsafe"
)
// PackU64ToBytesInline fills ret with the byte values of
// val. Ret must have length at least 8.
func PackU64ToBytesInline(val uint64, ret []byte) {
ret[7] = byte(val & (0xFF << 56) >> 56)
ret[6] = byte(val & (0xFF << 48) >> 48)
ret[5] = byte(val & (0xFF << 40) >> 40)
ret[4] = byte(val & (0xFF << 32) >> 32)
ret[3] = byte(val & (0xFF << 24) >> 24)
ret[2] = byte(val & (0xFF << 16) >> 16)
ret[1] = byte(val & (0xFF << 8) >> 8)
ret[0] = byte(val & (0xFF << 0) >> 0)
}
// PackFloatToBytesInline fills ret with the byte values of
// the float64 argument. ret must be at least 8 bytes in size.
func PackFloatToBytesInline(val float64, ret []byte) {
PackU64ToBytesInline(math.Float64bits(val), ret)
}
// PackU64ToBytes allocates a return value of appropriate length
// and fills it with the values of val.
func PackU64ToBytes(val uint64) []byte {
ret := make([]byte, 8)
ret[7] = byte(val & (0xFF << 56) >> 56)
ret[6] = byte(val & (0xFF << 48) >> 48)
ret[5] = byte(val & (0xFF << 40) >> 40)
ret[4] = byte(val & (0xFF << 32) >> 32)
ret[3] = byte(val & (0xFF << 24) >> 24)
ret[2] = byte(val & (0xFF << 16) >> 16)
ret[1] = byte(val & (0xFF << 8) >> 8)
ret[0] = byte(val & (0xFF << 0) >> 0)
return ret
}
// UnpackBytesToU64 converst a given byte slice into
// a uint64 value.
func UnpackBytesToU64(val []byte) uint64 {
pb := unsafe.Pointer(&val[0])
return *(*uint64)(pb)
}
// PackFloatToBytes returns a 8-byte slice containing
// the byte values of a float64.
func PackFloatToBytes(val float64) []byte {
return PackU64ToBytes(math.Float64bits(val))
}
// UnpackBytesToFloat converts a given byte slice into an
// equivalent float64.
func UnpackBytesToFloat(val []byte) float64 {
pb := unsafe.Pointer(&val[0])
return *(*float64)(pb)
}
func xorFloatOp(item float64) float64 {
var ret float64
var tmp int64
buf := bytes.NewBuffer(nil)
binary.Write(buf, binary.LittleEndian, item)
binary.Read(buf, binary.LittleEndian, &tmp)
tmp ^= -1 << 63
binary.Write(buf, binary.LittleEndian, tmp)
binary.Read(buf, binary.LittleEndian, &ret)
return ret
}
func printFloatByteArr(arr [][]byte) {
buf := bytes.NewBuffer(nil)
var f float64
for _, b := range arr {
buf.Write(b)
binary.Read(buf, binary.LittleEndian, &f)
f = xorFloatOp(f)
fmt.Println(f)
}
}
func byteSeqEqual(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}

148
base/util_attributes.go Normal file
View File

@ -0,0 +1,148 @@
package base
import (
"fmt"
)
// This file contains utility functions relating to Attributes and Attribute specifications.
// NonClassFloatAttributes returns all FloatAttributes which
// aren't designated as a class Attribute.
func NonClassFloatAttributes(d DataGrid) []Attribute {
classAttrs := d.AllClassAttributes()
allAttrs := d.AllAttributes()
ret := make([]Attribute, 0)
for _, a := range allAttrs {
matched := false
if _, ok := a.(*FloatAttribute); !ok {
continue
}
for _, b := range classAttrs {
if a.Equals(b) {
matched = true
break
}
}
if !matched {
ret = append(ret, a)
}
}
return ret
}
// NonClassAttrs returns all Attributes which aren't designated as a
// class Attribute.
func NonClassAttributes(d DataGrid) []Attribute {
classAttrs := d.AllClassAttributes()
allAttrs := d.AllAttributes()
return AttributeDifferenceReferences(allAttrs, classAttrs)
}
// ResolveAttributes returns AttributeSpecs describing
// all of the Attributes.
func ResolveAttributes(d DataGrid, attrs []Attribute) []AttributeSpec {
ret := make([]AttributeSpec, len(attrs))
for i, a := range attrs {
spec, err := d.GetAttribute(a)
if err != nil {
panic(fmt.Errorf("Error resolving Attribute %s: %s", a, err))
}
ret[i] = spec
}
return ret
}
// ResolveAllAttributes returns every AttributeSpec
func ResolveAllAttributes(d DataGrid) []AttributeSpec {
return ResolveAttributes(d, d.AllAttributes())
}
func buildAttrSet(a []Attribute) map[Attribute]bool {
ret := make(map[Attribute]bool)
for _, a := range a {
ret[a] = true
}
return ret
}
// AttributeIntersect returns the intersection of two Attribute slices.
//
// IMPORTANT: result is ordered in order of the first []Attribute argument.
//
// IMPORTANT: result contains only Attributes from a1.
func AttributeIntersect(a1, a2 []Attribute) []Attribute {
ret := make([]Attribute, 0)
for _, a := range a1 {
matched := false
for _, b := range a2 {
if a.Equals(b) {
matched = true
break
}
}
if matched {
ret = append(ret, a)
}
}
return ret
}
// AttributeIntersectReferences returns the intersection of two Attribute slices.
//
// IMPORTANT: result is not guaranteed to be ordered.
//
// IMPORTANT: done using pointers for speed, use AttributeDifference
// if the Attributes originate from different DataGrids.
func AttributeIntersectReferences(a1, a2 []Attribute) []Attribute {
a1b := buildAttrSet(a1)
a2b := buildAttrSet(a2)
ret := make([]Attribute, 0)
for a := range a1b {
if _, ok := a2b[a]; ok {
ret = append(ret, a)
}
}
return ret
}
// AttributeDifference returns the difference between two Attribute
// slices: i.e. all the values in a1 which do not occur in a2.
//
// IMPORTANT: result is ordered the same as a1.
//
// IMPORTANT: result only contains values from a1.
func AttributeDifference(a1, a2 []Attribute) []Attribute {
ret := make([]Attribute, 0)
for _, a := range a1 {
matched := false
for _, b := range a2 {
if a.Equals(b) {
matched = true
break
}
}
if !matched {
ret = append(ret, a)
}
}
return ret
}
// AttributeDifferenceReferences returns the difference between two Attribute
// slices: i.e. all the values in a1 which do not occur in a2.
//
// IMPORTANT: result is not guaranteed to be ordered.
//
// IMPORTANT: done using pointers for speed, use AttributeDifference
// if the Attributes originate from different DataGrids.
func AttributeDifferenceReferences(a1, a2 []Attribute) []Attribute {
a1b := buildAttrSet(a1)
a2b := buildAttrSet(a2)
ret := make([]Attribute, 0)
for a := range a1b {
if _, ok := a2b[a]; !ok {
ret = append(ret, a)
}
}
return ret
}

254
base/util_instances.go Normal file
View File

@ -0,0 +1,254 @@
package base
import (
"fmt"
"math/rand"
)
// This file contains utility functions relating to efficiently
// generating predictions and instantiating DataGrid implementations.
// GeneratePredictionVector selects the class Attributes from a given
// FixedDataGrid and returns something which can hold the predictions.
func GeneratePredictionVector(from FixedDataGrid) UpdatableDataGrid {
classAttrs := from.AllClassAttributes()
_, rowCount := from.Size()
ret := NewDenseInstances()
for _, a := range classAttrs {
ret.AddAttribute(a)
ret.AddClassAttribute(a)
}
ret.Extend(rowCount)
return ret
}
// GetClass is a shortcut for returning the string value of the current
// class on a given row.
//
// IMPORTANT: GetClass will panic if the number of ClassAttributes is
// set to anything other than one.
func GetClass(from DataGrid, row int) string {
// Get the Attribute
classAttrs := from.AllClassAttributes()
if len(classAttrs) > 1 {
panic("More than one class defined")
} else if len(classAttrs) == 0 {
panic("No class defined!")
}
classAttr := classAttrs[0]
// Fetch and convert the class value
classAttrSpec, err := from.GetAttribute(classAttr)
if err != nil {
panic(fmt.Errorf("Can't resolve class Attribute %s", err))
}
classVal := from.Get(classAttrSpec, row)
if classVal == nil {
panic("Class values shouldn't be missing")
}
return classAttr.GetStringFromSysVal(classVal)
}
// SetClass is a shortcut for updating the given class of a row.
//
// IMPORTANT: SetClass will panic if the number of class Attributes
// is anything other than one.
func SetClass(at UpdatableDataGrid, row int, class string) {
// Get the Attribute
classAttrs := at.AllClassAttributes()
if len(classAttrs) > 1 {
panic("More than one class defined")
} else if len(classAttrs) == 0 {
panic("No class Attributes are defined")
}
classAttr := classAttrs[0]
// Fetch and convert the class value
classAttrSpec, err := at.GetAttribute(classAttr)
if err != nil {
panic(fmt.Errorf("Can't resolve class Attribute %s", err))
}
classBytes := classAttr.GetSysValFromString(class)
at.Set(classAttrSpec, row, classBytes)
}
// GetClassDistribution returns a map containing the count of each
// class type (indexed by the class' string representation).
func GetClassDistribution(inst FixedDataGrid) map[string]int {
ret := make(map[string]int)
_, rows := inst.Size()
for i := 0; i < rows; i++ {
cls := GetClass(inst, i)
ret[cls]++
}
return ret
}
// GetClassDistributionAfterSplit returns the class distribution
// after a speculative split on a given Attribute.
func GetClassDistributionAfterSplit(inst FixedDataGrid, at Attribute) map[string]map[string]int {
ret := make(map[string]map[string]int)
// Find the attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid attribute %s (%s)", at, err))
}
_, rows := inst.Size()
for i := 0; i < rows; i++ {
splitVar := at.GetStringFromSysVal(inst.Get(attrSpec, i))
classVar := GetClass(inst, i)
if _, ok := ret[splitVar]; !ok {
ret[splitVar] = make(map[string]int)
i--
continue
}
ret[splitVar][classVar]++
}
return ret
}
// DecomposeOnAttributeValues divides the instance set depending on the
// value of a given Attribute, constructs child instances, and returns
// them in a map keyed on the string value of that Attribute.
//
// IMPORTANT: calls panic() if the AttributeSpec of at cannot be determined.
func DecomposeOnAttributeValues(inst FixedDataGrid, at Attribute) map[string]FixedDataGrid {
// Find the Attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid Attribute index %s", at))
}
// Construct the new Attribute set
newAttrs := make([]Attribute, 0)
for _, a := range inst.AllAttributes() {
if a.Equals(at) {
continue
}
newAttrs = append(newAttrs, a)
}
// Create the return map
ret := make(map[string]FixedDataGrid)
// Create the return row mapping
rowMaps := make(map[string][]int)
// Build full Attribute set
fullAttrSpec := ResolveAttributes(inst, newAttrs)
fullAttrSpec = append(fullAttrSpec, attrSpec)
// Decompose
inst.MapOverRows(fullAttrSpec, func(row [][]byte, rowNo int) (bool, error) {
// Find the output instance set
targetBytes := row[len(row)-1]
targetAttr := fullAttrSpec[len(fullAttrSpec)-1].attr
targetSet := targetAttr.GetStringFromSysVal(targetBytes)
if _, ok := rowMaps[targetSet]; !ok {
rowMaps[targetSet] = make([]int, 0)
}
rowMap := rowMaps[targetSet]
rowMaps[targetSet] = append(rowMap, rowNo)
return true, nil
})
for a := range rowMaps {
ret[a] = NewInstancesViewFromVisible(inst, rowMaps[a], newAttrs)
}
return ret
}
// InstancesTrainTestSplit takes a given Instances (src) and a train-test fraction
// (prop) and returns an array of two new Instances, one containing approximately
// that fraction and the other containing what's left.
//
// IMPORTANT: this function is only meaningful when prop is between 0.0 and 1.0.
// Using any other values may result in odd behaviour.
func InstancesTrainTestSplit(src FixedDataGrid, prop float64) (FixedDataGrid, FixedDataGrid) {
trainingRows := make([]int, 0)
testingRows := make([]int, 0)
src = Shuffle(src)
// Create the return structure
_, rows := src.Size()
for i := 0; i < rows; i++ {
trainOrTest := rand.Intn(101)
if trainOrTest > int(100*prop) {
trainingRows = append(trainingRows, i)
} else {
testingRows = append(testingRows, i)
}
}
allAttrs := src.AllAttributes()
return NewInstancesViewFromVisible(src, trainingRows, allAttrs), NewInstancesViewFromVisible(src, testingRows, allAttrs)
}
// LazyShuffle randomizes the row order without re-ordering the rows
// via an InstancesView.
func LazyShuffle(from FixedDataGrid) FixedDataGrid {
_, rows := from.Size()
rowMap := make(map[int]int)
for i := 0; i < rows; i++ {
j := rand.Intn(i + 1)
rowMap[i] = j
rowMap[j] = i
}
return NewInstancesViewFromRows(from, rowMap)
}
// Shuffle randomizes the row order either in place (if DenseInstances)
// or using LazyShuffle.
func Shuffle(from FixedDataGrid) FixedDataGrid {
_, rows := from.Size()
if inst, ok := from.(*DenseInstances); ok {
for i := 0; i < rows; i++ {
j := rand.Intn(i + 1)
inst.swapRows(i, j)
}
return inst
} else {
return LazyShuffle(from)
}
}
// SampleWithReplacement returns a new FixedDataGrid containing
// an equal number of random rows drawn from the original FixedDataGrid
//
// IMPORTANT: There's a high chance of seeing duplicate rows
// whenever size is close to the row count.
func SampleWithReplacement(from FixedDataGrid, size int) FixedDataGrid {
rowMap := make(map[int]int)
_, rows := from.Size()
for i := 0; i < size; i++ {
srcRow := rand.Intn(rows)
rowMap[i] = srcRow
}
return NewInstancesViewFromRows(from, rowMap)
}
// CheckCompatable checks whether two DataGrids have the same Attributes
// and if they do, it returns them.
func CheckCompatable(s1 FixedDataGrid, s2 FixedDataGrid) []Attribute {
s1Attrs := s1.AllAttributes()
s2Attrs := s2.AllAttributes()
interAttrs := AttributeIntersect(s1Attrs, s2Attrs)
if len(interAttrs) != len(s1Attrs) {
return nil
} else if len(interAttrs) != len(s2Attrs) {
return nil
}
return interAttrs
}

66
base/util_test.go Normal file
View File

@ -0,0 +1,66 @@
package base
import (
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestClassDistributionAfterSplit(t *testing.T) {
Convey("Given the PlayTennis dataset", t, func() {
inst, err := ParseCSVToInstances("../examples/datasets/tennis.csv", true)
So(err, ShouldEqual, nil)
Convey("Splitting on Sunny should give the right result...", func() {
result := GetClassDistributionAfterSplit(inst, inst.AllAttributes()[0])
So(result["sunny"]["no"], ShouldEqual, 3)
So(result["sunny"]["yes"], ShouldEqual, 2)
So(result["overcast"]["yes"], ShouldEqual, 4)
So(result["rainy"]["yes"], ShouldEqual, 3)
So(result["rainy"]["no"], ShouldEqual, 2)
})
})
}
func TestPackAndUnpack(t *testing.T) {
Convey("Given some uint64", t, func() {
x := uint64(0xDEADBEEF)
Convey("When the integer is packed", func() {
packed := PackU64ToBytes(x)
Convey("And then unpacked", func() {
unpacked := UnpackBytesToU64(packed)
Convey("The unpacked version should be the same", func() {
So(x, ShouldEqual, unpacked)
})
})
})
})
Convey("Given another uint64", t, func() {
x := uint64(1)
Convey("When the integer is packed", func() {
packed := PackU64ToBytes(x)
Convey("And then unpacked", func() {
unpacked := UnpackBytesToU64(packed)
Convey("The unpacked version should be the same", func() {
So(x, ShouldEqual, unpacked)
})
})
})
})
}
func TestPackAndUnpackFloat(t *testing.T) {
Convey("Given some float", t, func() {
x := 1.2011
Convey("When the float gets packed", func() {
packed := PackFloatToBytes(x)
Convey("And then unpacked", func() {
unpacked := UnpackBytesToFloat(packed)
Convey("The unpacked version should be the same", func() {
So(unpacked, ShouldEqual, x)
})
})
})
})
}

320
base/view.go Normal file
View File

@ -0,0 +1,320 @@
package base
import (
"bytes"
"fmt"
)
// InstancesViews hide or re-order Attributes and rows from
// a given DataGrid to make it appear that they've been deleted.
type InstancesView struct {
src FixedDataGrid
attrs []AttributeSpec
rows map[int]int
classAttrs map[Attribute]bool
maskRows bool
}
func (v *InstancesView) addClassAttrsFromSrc(src FixedDataGrid) {
for _, a := range src.AllClassAttributes() {
matched := true
if v.attrs != nil {
matched = false
for _, b := range v.attrs {
if b.attr.Equals(a) {
matched = true
}
}
}
if matched {
v.classAttrs[a] = true
}
}
}
func (v *InstancesView) resolveRow(origRow int) int {
if v.rows != nil {
if newRow, ok := v.rows[origRow]; !ok {
if v.maskRows {
return -1
}
} else {
return newRow
}
}
return origRow
}
// NewInstancesViewFromRows creates a new InstancesView from a source
// FixedDataGrid and row -> row mapping. The key of the rows map is the
// row as it exists within this mapping: for example an entry like 5 -> 1
// means that row 1 in src will appear at row 5 in the Instancesview.
//
// Rows are not masked in this implementation, meaning that all rows which
// are left unspecified appear as normal.
func NewInstancesViewFromRows(src FixedDataGrid, rows map[int]int) *InstancesView {
ret := &InstancesView{
src,
nil,
rows,
make(map[Attribute]bool),
false,
}
ret.addClassAttrsFromSrc(src)
return ret
}
// NewInstancesViewFromVisible creates a new InstancesView from a source
// FixedDataGrid, a slice of row numbers and a slice of Attributes.
//
// Only the rows specified will appear in this InstancesView, and they will
// appear in the same order they appear within the rows array.
//
// Only the Attributes specified will appear in this InstancesView. Retrieving
// Attribute specifications from this InstancesView will maintain their order.
func NewInstancesViewFromVisible(src FixedDataGrid, rows []int, attrs []Attribute) *InstancesView {
ret := &InstancesView{
src,
ResolveAttributes(src, attrs),
make(map[int]int),
make(map[Attribute]bool),
true,
}
for i, a := range rows {
ret.rows[i] = a
}
ret.addClassAttrsFromSrc(src)
return ret
}
// NewInstancesViewFromAttrs creates a new InstancesView from a source
// FixedDataGrid and a slice of Attributes.
//
// Only the Attributes specified will appear in this InstancesView.
func NewInstancesViewFromAttrs(src FixedDataGrid, attrs []Attribute) *InstancesView {
ret := &InstancesView{
src,
ResolveAttributes(src, attrs),
nil,
make(map[Attribute]bool),
false,
}
ret.addClassAttrsFromSrc(src)
return ret
}
// GetAttribute returns an Attribute specification matching an Attribute
// if it has not been filtered.
//
// The AttributeSpecs returned are the same as those returned by the
// source FixedDataGrid.
func (v *InstancesView) GetAttribute(a Attribute) (AttributeSpec, error) {
if a == nil {
return AttributeSpec{}, fmt.Errorf("Attribute can't be nil")
}
// Pass-through on nil
if v.attrs == nil {
return v.src.GetAttribute(a)
}
// Otherwise
for _, r := range v.attrs {
// If the attribute matches...
if r.GetAttribute().Equals(a) {
return r, nil
}
}
return AttributeSpec{}, fmt.Errorf("Requested Attribute has been filtered")
}
// AllAttributes returns every Attribute which hasn't been filtered.
func (v *InstancesView) AllAttributes() []Attribute {
if v.attrs == nil {
return v.src.AllAttributes()
}
ret := make([]Attribute, len(v.attrs))
for i, a := range v.attrs {
ret[i] = a.GetAttribute()
}
return ret
}
// AddClassAttribute adds the given Attribute to the set of defined
// class Attributes, if it hasn't been filtered.
func (v *InstancesView) AddClassAttribute(a Attribute) error {
// Check that this Attribute is defined
matched := false
for _, r := range v.AllAttributes() {
if r.Equals(a) {
matched = true
}
}
if !matched {
return fmt.Errorf("Attribute has been filtered")
}
v.classAttrs[a] = true
return nil
}
// RemoveClassAttribute removes the given Attribute from the set of
// class Attributes.
func (v *InstancesView) RemoveClassAttribute(a Attribute) error {
v.classAttrs[a] = false
return nil
}
// AllClassAttributes returns all the Attributes currently defined
// as being class Attributes.
func (v *InstancesView) AllClassAttributes() []Attribute {
ret := make([]Attribute, 0)
for a := range v.classAttrs {
if v.classAttrs[a] {
ret = append(ret, a)
}
}
return ret
}
// Get returns a sequence of bytes stored under a given Attribute
// on a given row.
//
// IMPORTANT: The AttributeSpec is unverified, meaning it's possible
// to return values from Attributes filtered by this InstancesView
// if the underlying AttributeSpec is known.
func (v *InstancesView) Get(as AttributeSpec, row int) []byte {
// Change the row if necessary
row = v.resolveRow(row)
if row == -1 {
panic("Out of range")
}
return v.src.Get(as, row)
}
// MapOverRows, see DenseInstances.MapOverRows.
//
// IMPORTANT: MapOverRows is not guaranteed to be ordered, but this one
// especially so.
func (v *InstancesView) MapOverRows(as []AttributeSpec, rowFunc func([][]byte, int) (bool, error)) error {
if v.maskRows {
rowBuf := make([][]byte, len(as))
for r := range v.rows {
row := v.rows[r]
for i, a := range as {
rowBuf[i] = v.src.Get(a, row)
}
ok, err := rowFunc(rowBuf, r)
if err != nil {
return err
}
if !ok {
break
}
}
return nil
} else {
return v.src.MapOverRows(as, rowFunc)
}
}
// Size Returns the number of Attributes and rows this InstancesView
// contains.
func (v *InstancesView) Size() (int, int) {
// Get the original size
hSize, vSize := v.src.Size()
// Adjust to the number of defined Attributes
if v.attrs != nil {
hSize = len(v.attrs)
}
// Adjust to the number of defined rows
if v.rows != nil {
if v.maskRows {
vSize = len(v.rows)
} else if len(v.rows) > vSize {
vSize = len(v.rows)
}
}
return hSize, vSize
}
// String returns a human-readable summary of this InstancesView.
func (v *InstancesView) String() string {
var buffer bytes.Buffer
maxRows := 30
// Get all Attribute information
as := ResolveAllAttributes(v)
// Print header
cols, rows := v.Size()
buffer.WriteString("InstancesView with ")
buffer.WriteString(fmt.Sprintf("%d row(s) ", rows))
buffer.WriteString(fmt.Sprintf("%d attribute(s)\n", cols))
if v.attrs != nil {
buffer.WriteString(fmt.Sprintf("With defined Attribute view\n"))
}
if v.rows != nil {
buffer.WriteString(fmt.Sprintf("With defined Row view\n"))
}
if v.maskRows {
buffer.WriteString("Row masking on.\n")
}
buffer.WriteString(fmt.Sprintf("Attributes:\n"))
for _, a := range as {
prefix := "\t"
if v.classAttrs[a.attr] {
prefix = "*\t"
}
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a.attr))
}
// Print data
if rows < maxRows {
maxRows = rows
}
buffer.WriteString("Data:")
for i := 0; i < maxRows; i++ {
buffer.WriteString("\t")
for _, a := range as {
val := v.Get(a, i)
buffer.WriteString(fmt.Sprintf("%s ", a.attr.GetStringFromSysVal(val)))
}
buffer.WriteString("\n")
}
missingRows := rows - maxRows
if missingRows != 0 {
buffer.WriteString(fmt.Sprintf("\t...\n%d row(s) undisplayed", missingRows))
} else {
buffer.WriteString("All rows displayed")
}
return buffer.String()
}
// RowString returns a string representation of a given row.
func (v *InstancesView) RowString(row int) string {
var buffer bytes.Buffer
as := ResolveAllAttributes(v)
first := true
for _, a := range as {
val := v.Get(a, row)
prefix := " "
if first {
prefix = ""
first = false
}
buffer.WriteString(fmt.Sprintf("%s%s", prefix, a.attr.GetStringFromSysVal(val)))
}
return buffer.String()
}

119
base/view_test.go Normal file
View File

@ -0,0 +1,119 @@
package base
import (
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestInstancesViewRows(t *testing.T) {
Convey("Given Iris", t, func() {
instOrig, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldEqual, nil)
Convey("Given a new row map containing only row 5", func() {
rMap := make(map[int]int)
rMap[0] = 5
instView := NewInstancesViewFromRows(instOrig, rMap)
Convey("The internal structure should be right...", func() {
So(instView.rows[0], ShouldEqual, 5)
})
Convey("The reconstructed values should be correct...", func() {
str := "5.40 3.90 1.70 0.40 Iris-setosa"
row := instView.RowString(0)
So(row, ShouldEqual, str)
})
Convey("And the size should be correct...", func() {
width, height := instView.Size()
So(width, ShouldEqual, 5)
So(height, ShouldEqual, 150)
})
})
})
}
func TestInstancesViewFromVisible(t *testing.T) {
Convey("Given Iris", t, func() {
instOrig, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldEqual, nil)
Convey("Generate something that says every other row should be visible", func() {
rowVisiblex1 := make([]int, 0)
_, totalRows := instOrig.Size()
for i := 0; i < totalRows; i += 2 {
rowVisiblex1 = append(rowVisiblex1, i)
}
instViewx1 := NewInstancesViewFromVisible(instOrig, rowVisiblex1, instOrig.AllAttributes())
for i, a := range rowVisiblex1 {
rowStr1 := instViewx1.RowString(i)
rowStr2 := instOrig.RowString(a)
So(rowStr1, ShouldEqual, rowStr2)
}
Convey("And then generate something that says that every other row than that should be visible", func() {
rowVisiblex2 := make([]int, 0)
for i := 0; i < totalRows; i += 4 {
rowVisiblex2 = append(rowVisiblex1, i)
}
instViewx2 := NewInstancesViewFromVisible(instOrig, rowVisiblex2, instOrig.AllAttributes())
for i, a := range rowVisiblex2 {
rowStr1 := instViewx2.RowString(i)
rowStr2 := instOrig.RowString(a)
So(rowStr1, ShouldEqual, rowStr2)
}
})
})
})
}
func TestInstancesViewAttrs(t *testing.T) {
Convey("Given Iris", t, func() {
instOrig, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldEqual, nil)
Convey("Given a new Attribute vector with the last 4...", func() {
cMap := instOrig.AllAttributes()[1:]
instView := NewInstancesViewFromAttrs(instOrig, cMap)
Convey("The size should be correct", func() {
h, v := instView.Size()
So(h, ShouldEqual, 4)
_, vOrig := instOrig.Size()
So(v, ShouldEqual, vOrig)
})
Convey("There should be 4 Attributes...", func() {
attrs := instView.AllAttributes()
So(len(attrs), ShouldEqual, 4)
})
Convey("There should be 4 Attributes with the right headers...", func() {
attrs := instView.AllAttributes()
So(attrs[0].GetName(), ShouldEqual, "Sepal width")
So(attrs[1].GetName(), ShouldEqual, "Petal length")
So(attrs[2].GetName(), ShouldEqual, "Petal width")
So(attrs[3].GetName(), ShouldEqual, "Species")
})
Convey("There should be a class Attribute...", func() {
attrs := instView.AllClassAttributes()
So(len(attrs), ShouldEqual, 1)
})
Convey("The class Attribute should be preserved...", func() {
attrs := instView.AllClassAttributes()
So(attrs[0].GetName(), ShouldEqual, "Species")
})
Convey("Attempts to get the filtered Attribute should fail...", func() {
_, err := instView.GetAttribute(instOrig.AllAttributes()[0])
So(err, ShouldNotEqual, nil)
})
Convey("The filtered Attribute should not appear in the RowString", func() {
str := "3.90 1.70 0.40 Iris-setosa"
row := instView.RowString(5)
So(row, ShouldEqual, str)
})
Convey("The filtered Attributes should all be the same type...", func() {
attrs := instView.AllAttributes()
_, ok1 := attrs[0].(*FloatAttribute)
_, ok2 := attrs[1].(*FloatAttribute)
_, ok3 := attrs[2].(*FloatAttribute)
_, ok4 := attrs[3].(*CategoricalAttribute)
So(ok1, ShouldEqual, true)
So(ok2, ShouldEqual, true)
So(ok3, ShouldEqual, true)
So(ok4, ShouldEqual, true)
})
})
})
}

View File

@ -1,13 +1,12 @@
/*
//
//
// Ensemble contains classifiers which combine other classifiers.
//
// RandomForest:
// Generates ForestSize bagged decision trees (currently ID3-based)
// each considering a fixed number of random features.
//
// Built on meta.Bagging
//
Ensemble contains classifiers which combine other classifiers.
RandomForest:
Generates ForestSize bagged decision trees (currently ID3-based)
each considering a fixed number of random features.
Built on meta.Bagging
*/
package ensemble
package ensemble

View File

@ -8,7 +8,7 @@ import (
)
// RandomForest classifies instances using an ensemble
// of bagged random decision trees
// of bagged random decision trees.
type RandomForest struct {
base.BaseClassifier
ForestSize int
@ -18,7 +18,7 @@ type RandomForest struct {
// NewRandomForest generates and return a new random forests
// forestSize controls the number of trees that get built
// features controls the number of features used to build each tree
// features controls the number of features used to build each tree.
func NewRandomForest(forestSize int, features int) *RandomForest {
ret := &RandomForest{
base.BaseClassifier{},
@ -30,7 +30,7 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
}
// Fit builds the RandomForest on the specified instances
func (f *RandomForest) Fit(on *base.Instances) {
func (f *RandomForest) Fit(on base.FixedDataGrid) {
f.Model = new(meta.BaggedModel)
f.Model.RandomFeatures = f.Features
for i := 0; i < f.ForestSize; i++ {
@ -40,11 +40,12 @@ func (f *RandomForest) Fit(on *base.Instances) {
f.Model.Fit(on)
}
// Predict generates predictions from a trained RandomForest
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
// Predict generates predictions from a trained RandomForest.
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
return f.Model.Predict(with)
}
// String returns a human-readable representation of this tree.
func (f *RandomForest) String() string {
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
}

View File

@ -13,12 +13,16 @@ func TestRandomForest1(testEnv *testing.T) {
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
filt := filters.NewChiMergeFilter(trainData, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(testData)
filt.Run(trainData)
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)

View File

@ -11,19 +11,22 @@ type ConfusionMatrix map[string]map[string]int
// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref *base.Instances, gen *base.Instances) map[string]map[string]int {
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) map[string]map[string]int {
if ref.Rows != gen.Rows {
_, refRows := ref.Size()
_, genRows := gen.Size()
if refRows != genRows {
panic("Row counts should match")
}
ret := make(map[string]map[string]int)
for i := 0; i < ref.Rows; i++ {
referenceClass := ref.GetClass(i)
predictedClass := gen.GetClass(i)
for i := 0; i < int(refRows); i++ {
referenceClass := base.GetClass(ref, i)
predictedClass := base.GetClass(gen, i)
if _, ok := ret[referenceClass]; ok {
ret[referenceClass][predictedClass]++
ret[referenceClass][predictedClass] += 1
} else {
ret[referenceClass] = make(map[string]int)
ret[referenceClass][predictedClass] = 1

View File

@ -1,152 +1,151 @@
Sepal length,Sepal width,Petal length,Petal width,Species
2,3.5,1.4,0.2,Iris-setosa
1,3,1.4,0.2,Iris-setosa
1,3.2,1.3,0.2,Iris-setosa
0,3.1,1.5,0.2,Iris-setosa
1,3.6,1.4,0.2,Iris-setosa
3,3.9,1.7,0.4,Iris-setosa
0,3.4,1.4,0.3,Iris-setosa
1,3.4,1.5,0.2,Iris-setosa
0,2.9,1.4,0.2,Iris-setosa
1,3.1,1.5,0.1,Iris-setosa
3,3.7,1.5,0.2,Iris-setosa
1,3.4,1.6,0.2,Iris-setosa
1,3,1.4,0.1,Iris-setosa
0,3,1.1,0.1,Iris-setosa
4,4,1.2,0.2,Iris-setosa
3,4.4,1.5,0.4,Iris-setosa
3,3.9,1.3,0.4,Iris-setosa
2,3.5,1.4,0.3,Iris-setosa
3,3.8,1.7,0.3,Iris-setosa
2,3.8,1.5,0.3,Iris-setosa
3,3.4,1.7,0.2,Iris-setosa
2,3.7,1.5,0.4,Iris-setosa
0,3.6,1,0.2,Iris-setosa
2,3.3,1.7,0.5,Iris-setosa
1,3.4,1.9,0.2,Iris-setosa
1,3,1.6,0.2,Iris-setosa
1,3.4,1.6,0.4,Iris-setosa
2,3.5,1.5,0.2,Iris-setosa
2,3.4,1.4,0.2,Iris-setosa
1,3.2,1.6,0.2,Iris-setosa
1,3.1,1.6,0.2,Iris-setosa
3,3.4,1.5,0.4,Iris-setosa
2,4.1,1.5,0.1,Iris-setosa
3,4.2,1.4,0.2,Iris-setosa
1,3.1,1.5,0.1,Iris-setosa
1,3.2,1.2,0.2,Iris-setosa
3,3.5,1.3,0.2,Iris-setosa
1,3.1,1.5,0.1,Iris-setosa
0,3,1.3,0.2,Iris-setosa
2,3.4,1.5,0.2,Iris-setosa
1,3.5,1.3,0.3,Iris-setosa
0,2.3,1.3,0.3,Iris-setosa
0,3.2,1.3,0.2,Iris-setosa
1,3.5,1.6,0.6,Iris-setosa
2,3.8,1.9,0.4,Iris-setosa
1,3,1.4,0.3,Iris-setosa
2,3.8,1.6,0.2,Iris-setosa
0,3.2,1.4,0.2,Iris-setosa
2,3.7,1.5,0.2,Iris-setosa
1,3.3,1.4,0.2,Iris-setosa
7,3.2,4.7,1.4,Iris-versicolor
5,3.2,4.5,1.5,Iris-versicolor
7,3.1,4.9,1.5,Iris-versicolor
3,2.3,4,1.3,Iris-versicolor
6,2.8,4.6,1.5,Iris-versicolor
3,2.8,4.5,1.3,Iris-versicolor
5,3.3,4.7,1.6,Iris-versicolor
1,2.4,3.3,1,Iris-versicolor
6,2.9,4.6,1.3,Iris-versicolor
2,2.7,3.9,1.4,Iris-versicolor
1,2,3.5,1,Iris-versicolor
4,3,4.2,1.5,Iris-versicolor
4,2.2,4,1,Iris-versicolor
5,2.9,4.7,1.4,Iris-versicolor
3,2.9,3.6,1.3,Iris-versicolor
6,3.1,4.4,1.4,Iris-versicolor
3,3,4.5,1.5,Iris-versicolor
4,2.7,4.1,1,Iris-versicolor
5,2.2,4.5,1.5,Iris-versicolor
3,2.5,3.9,1.1,Iris-versicolor
4,3.2,4.8,1.8,Iris-versicolor
5,2.8,4,1.3,Iris-versicolor
5,2.5,4.9,1.5,Iris-versicolor
5,2.8,4.7,1.2,Iris-versicolor
5,2.9,4.3,1.3,Iris-versicolor
6,3,4.4,1.4,Iris-versicolor
6,2.8,4.8,1.4,Iris-versicolor
6,3,5,1.7,Iris-versicolor
4,2.9,4.5,1.5,Iris-versicolor
3,2.6,3.5,1,Iris-versicolor
3,2.4,3.8,1.1,Iris-versicolor
3,2.4,3.7,1,Iris-versicolor
4,2.7,3.9,1.2,Iris-versicolor
4,2.7,5.1,1.6,Iris-versicolor
3,3,4.5,1.5,Iris-versicolor
4,3.4,4.5,1.6,Iris-versicolor
6,3.1,4.7,1.5,Iris-versicolor
5,2.3,4.4,1.3,Iris-versicolor
3,3,4.1,1.3,Iris-versicolor
3,2.5,4,1.3,Iris-versicolor
3,2.6,4.4,1.2,Iris-versicolor
5,3,4.6,1.4,Iris-versicolor
4,2.6,4,1.2,Iris-versicolor
1,2.3,3.3,1,Iris-versicolor
3,2.7,4.2,1.3,Iris-versicolor
3,3,4.2,1.2,Iris-versicolor
3,2.9,4.2,1.3,Iris-versicolor
5,2.9,4.3,1.3,Iris-versicolor
2,2.5,3,1.1,Iris-versicolor
3,2.8,4.1,1.3,Iris-versicolor
5,3.3,6,2.5,Iris-virginica
4,2.7,5.1,1.9,Iris-virginica
7,3,5.9,2.1,Iris-virginica
5,2.9,5.6,1.8,Iris-virginica
6,3,5.8,2.2,Iris-virginica
9,3,6.6,2.1,Iris-virginica
1,2.5,4.5,1.7,Iris-virginica
8,2.9,6.3,1.8,Iris-virginica
6,2.5,5.8,1.8,Iris-virginica
8,3.6,6.1,2.5,Iris-virginica
6,3.2,5.1,2,Iris-virginica
5,2.7,5.3,1.9,Iris-virginica
6,3,5.5,2.1,Iris-virginica
3,2.5,5,2,Iris-virginica
4,2.8,5.1,2.4,Iris-virginica
5,3.2,5.3,2.3,Iris-virginica
6,3,5.5,1.8,Iris-virginica
9,3.8,6.7,2.2,Iris-virginica
9,2.6,6.9,2.3,Iris-virginica
4,2.2,5,1.5,Iris-virginica
7,3.2,5.7,2.3,Iris-virginica
3,2.8,4.9,2,Iris-virginica
9,2.8,6.7,2,Iris-virginica
5,2.7,4.9,1.8,Iris-virginica
6,3.3,5.7,2.1,Iris-virginica
8,3.2,6,1.8,Iris-virginica
5,2.8,4.8,1.8,Iris-virginica
5,3,4.9,1.8,Iris-virginica
5,2.8,5.6,2.1,Iris-virginica
8,3,5.8,1.6,Iris-virginica
8,2.8,6.1,1.9,Iris-virginica
9,3.8,6.4,2,Iris-virginica
5,2.8,5.6,2.2,Iris-virginica
5,2.8,5.1,1.5,Iris-virginica
5,2.6,5.6,1.4,Iris-virginica
9,3,6.1,2.3,Iris-virginica
5,3.4,5.6,2.4,Iris-virginica
5,3.1,5.5,1.8,Iris-virginica
4,3,4.8,1.8,Iris-virginica
7,3.1,5.4,2.1,Iris-virginica
6,3.1,5.6,2.4,Iris-virginica
7,3.1,5.1,2.3,Iris-virginica
4,2.7,5.1,1.9,Iris-virginica
6,3.2,5.9,2.3,Iris-virginica
6,3.3,5.7,2.5,Iris-virginica
6,3,5.2,2.3,Iris-virginica
5,2.5,5,1.9,Iris-virginica
6,3,5.2,2,Iris-virginica
5,3.4,5.4,2.3,Iris-virginica
4,3,5.1,1.8,Iris-virginica
Sepal length,Sepal width,Petal length, Petal width,Species
5.02,3.5,1.4,0.2,Iris-setosa
4.66,3,1.4,0.2,Iris-setosa
4.66,3.2,1.3,0.2,Iris-setosa
4.3,3.1,1.5,0.2,Iris-setosa
4.66,3.6,1.4,0.2,Iris-setosa
5.38,3.9,1.7,0.4,Iris-setosa
4.3,3.4,1.4,0.3,Iris-setosa
4.66,3.4,1.5,0.2,Iris-setosa
4.3,2.9,1.4,0.2,Iris-setosa
4.66,3.1,1.5,0.1,Iris-setosa
5.38,3.7,1.5,0.2,Iris-setosa
4.66,3.4,1.6,0.2,Iris-setosa
4.66,3,1.4,0.1,Iris-setosa
4.3,3,1.1,0.1,Iris-setosa
5.74,4,1.2,0.2,Iris-setosa
5.38,4.4,1.5,0.4,Iris-setosa
5.38,3.9,1.3,0.4,Iris-setosa
5.02,3.5,1.4,0.3,Iris-setosa
5.38,3.8,1.7,0.3,Iris-setosa
5.02,3.8,1.5,0.3,Iris-setosa
5.38,3.4,1.7,0.2,Iris-setosa
5.02,3.7,1.5,0.4,Iris-setosa
4.3,3.6,1,0.2,Iris-setosa
5.02,3.3,1.7,0.5,Iris-setosa
4.66,3.4,1.9,0.2,Iris-setosa
4.66,3,1.6,0.2,Iris-setosa
4.66,3.4,1.6,0.4,Iris-setosa
5.02,3.5,1.5,0.2,Iris-setosa
5.02,3.4,1.4,0.2,Iris-setosa
4.66,3.2,1.6,0.2,Iris-setosa
4.66,3.1,1.6,0.2,Iris-setosa
5.38,3.4,1.5,0.4,Iris-setosa
5.02,4.1,1.5,0.1,Iris-setosa
5.38,4.2,1.4,0.2,Iris-setosa
4.66,3.1,1.5,0.1,Iris-setosa
4.66,3.2,1.2,0.2,Iris-setosa
5.38,3.5,1.3,0.2,Iris-setosa
4.66,3.1,1.5,0.1,Iris-setosa
4.3,3,1.3,0.2,Iris-setosa
5.02,3.4,1.5,0.2,Iris-setosa
4.66,3.5,1.3,0.3,Iris-setosa
4.3,2.3,1.3,0.3,Iris-setosa
4.3,3.2,1.3,0.2,Iris-setosa
4.66,3.5,1.6,0.6,Iris-setosa
5.02,3.8,1.9,0.4,Iris-setosa
4.66,3,1.4,0.3,Iris-setosa
5.02,3.8,1.6,0.2,Iris-setosa
4.3,3.2,1.4,0.2,Iris-setosa
5.02,3.7,1.5,0.2,Iris-setosa
4.66,3.3,1.4,0.2,Iris-setosa
6.82,3.2,4.7,1.4,Iris-versicolor
6.1,3.2,4.5,1.5,Iris-versicolor
6.82,3.1,4.9,1.5,Iris-versicolor
5.38,2.3,4,1.3,Iris-versicolor
6.46,2.8,4.6,1.5,Iris-versicolor
5.38,2.8,4.5,1.3,Iris-versicolor
6.1,3.3,4.7,1.6,Iris-versicolor
4.66,2.4,3.3,1,Iris-versicolor
6.46,2.9,4.6,1.3,Iris-versicolor
5.02,2.7,3.9,1.4,Iris-versicolor
4.66,2,3.5,1,Iris-versicolor
5.74,3,4.2,1.5,Iris-versicolor
5.74,2.2,4,1,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.38,2.9,3.6,1.3,Iris-versicolor
6.46,3.1,4.4,1.4,Iris-versicolor
5.38,3,4.5,1.5,Iris-versicolor
5.74,2.7,4.1,1,Iris-versicolor
6.1,2.2,4.5,1.5,Iris-versicolor
5.38,2.5,3.9,1.1,Iris-versicolor
5.74,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4,1.3,Iris-versicolor
6.1,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.1,2.9,4.3,1.3,Iris-versicolor
6.46,3,4.4,1.4,Iris-versicolor
6.46,2.8,4.8,1.4,Iris-versicolor
6.46,3,5,1.7,Iris-versicolor
5.74,2.9,4.5,1.5,Iris-versicolor
5.38,2.6,3.5,1,Iris-versicolor
5.38,2.4,3.8,1.1,Iris-versicolor
5.38,2.4,3.7,1,Iris-versicolor
5.74,2.7,3.9,1.2,Iris-versicolor
5.74,2.7,5.1,1.6,Iris-versicolor
5.38,3,4.5,1.5,Iris-versicolor
5.74,3.4,4.5,1.6,Iris-versicolor
6.46,3.1,4.7,1.5,Iris-versicolor
6.1,2.3,4.4,1.3,Iris-versicolor
5.38,3,4.1,1.3,Iris-versicolor
5.38,2.5,4,1.3,Iris-versicolor
5.38,2.6,4.4,1.2,Iris-versicolor
6.1,3,4.6,1.4,Iris-versicolor
5.74,2.6,4,1.2,Iris-versicolor
4.66,2.3,3.3,1,Iris-versicolor
5.38,2.7,4.2,1.3,Iris-versicolor
5.38,3,4.2,1.2,Iris-versicolor
5.38,2.9,4.2,1.3,Iris-versicolor
6.1,2.9,4.3,1.3,Iris-versicolor
5.02,2.5,3,1.1,Iris-versicolor
5.38,2.8,4.1,1.3,Iris-versicolor
6.1,3.3,6,2.5,Iris-virginica
5.74,2.7,5.1,1.9,Iris-virginica
6.82,3,5.9,2.1,Iris-virginica
6.1,2.9,5.6,1.8,Iris-virginica
6.46,3,5.8,2.2,Iris-virginica
7.54,3,6.6,2.1,Iris-virginica
4.66,2.5,4.5,1.7,Iris-virginica
7.18,2.9,6.3,1.8,Iris-virginica
6.46,2.5,5.8,1.8,Iris-virginica
7.18,3.6,6.1,2.5,Iris-virginica
6.46,3.2,5.1,2,Iris-virginica
6.1,2.7,5.3,1.9,Iris-virginica
6.46,3,5.5,2.1,Iris-virginica
5.38,2.5,5,2,Iris-virginica
5.74,2.8,5.1,2.4,Iris-virginica
6.1,3.2,5.3,2.3,Iris-virginica
6.46,3,5.5,1.8,Iris-virginica
7.54,3.8,6.7,2.2,Iris-virginica
7.54,2.6,6.9,2.3,Iris-virginica
5.74,2.2,5,1.5,Iris-virginica
6.82,3.2,5.7,2.3,Iris-virginica
5.38,2.8,4.9,2,Iris-virginica
7.54,2.8,6.7,2,Iris-virginica
6.1,2.7,4.9,1.8,Iris-virginica
6.46,3.3,5.7,2.1,Iris-virginica
7.18,3.2,6,1.8,Iris-virginica
6.1,2.8,4.8,1.8,Iris-virginica
6.1,3,4.9,1.8,Iris-virginica
6.1,2.8,5.6,2.1,Iris-virginica
7.18,3,5.8,1.6,Iris-virginica
7.18,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2,Iris-virginica
6.1,2.8,5.6,2.2,Iris-virginica
6.1,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.54,3,6.1,2.3,Iris-virginica
6.1,3.4,5.6,2.4,Iris-virginica
6.1,3.1,5.5,1.8,Iris-virginica
5.74,3,4.8,1.8,Iris-virginica
6.82,3.1,5.4,2.1,Iris-virginica
6.46,3.1,5.6,2.4,Iris-virginica
6.82,3.1,5.1,2.3,Iris-virginica
5.74,2.7,5.1,1.9,Iris-virginica
6.46,3.2,5.9,2.3,Iris-virginica
6.46,3.3,5.7,2.5,Iris-virginica
6.46,3,5.2,2.3,Iris-virginica
6.1,2.5,5,1.9,Iris-virginica
6.46,3,5.2,2,Iris-virginica
6.1,3.4,5.4,2.3,Iris-virginica
5.74,3,5.1,1.8,Iris-virginica

1 Sepal length Sepal width Petal length Petal width Species
2 2 5.02 3.5 1.4 0.2 Iris-setosa
3 1 4.66 3 1.4 0.2 Iris-setosa
4 1 4.66 3.2 1.3 0.2 Iris-setosa
5 0 4.3 3.1 1.5 0.2 Iris-setosa
6 1 4.66 3.6 1.4 0.2 Iris-setosa
7 3 5.38 3.9 1.7 0.4 Iris-setosa
8 0 4.3 3.4 1.4 0.3 Iris-setosa
9 1 4.66 3.4 1.5 0.2 Iris-setosa
10 0 4.3 2.9 1.4 0.2 Iris-setosa
11 1 4.66 3.1 1.5 0.1 Iris-setosa
12 3 5.38 3.7 1.5 0.2 Iris-setosa
13 1 4.66 3.4 1.6 0.2 Iris-setosa
14 1 4.66 3 1.4 0.1 Iris-setosa
15 0 4.3 3 1.1 0.1 Iris-setosa
16 4 5.74 4 1.2 0.2 Iris-setosa
17 3 5.38 4.4 1.5 0.4 Iris-setosa
18 3 5.38 3.9 1.3 0.4 Iris-setosa
19 2 5.02 3.5 1.4 0.3 Iris-setosa
20 3 5.38 3.8 1.7 0.3 Iris-setosa
21 2 5.02 3.8 1.5 0.3 Iris-setosa
22 3 5.38 3.4 1.7 0.2 Iris-setosa
23 2 5.02 3.7 1.5 0.4 Iris-setosa
24 0 4.3 3.6 1 0.2 Iris-setosa
25 2 5.02 3.3 1.7 0.5 Iris-setosa
26 1 4.66 3.4 1.9 0.2 Iris-setosa
27 1 4.66 3 1.6 0.2 Iris-setosa
28 1 4.66 3.4 1.6 0.4 Iris-setosa
29 2 5.02 3.5 1.5 0.2 Iris-setosa
30 2 5.02 3.4 1.4 0.2 Iris-setosa
31 1 4.66 3.2 1.6 0.2 Iris-setosa
32 1 4.66 3.1 1.6 0.2 Iris-setosa
33 3 5.38 3.4 1.5 0.4 Iris-setosa
34 2 5.02 4.1 1.5 0.1 Iris-setosa
35 3 5.38 4.2 1.4 0.2 Iris-setosa
36 1 4.66 3.1 1.5 0.1 Iris-setosa
37 1 4.66 3.2 1.2 0.2 Iris-setosa
38 3 5.38 3.5 1.3 0.2 Iris-setosa
39 1 4.66 3.1 1.5 0.1 Iris-setosa
40 0 4.3 3 1.3 0.2 Iris-setosa
41 2 5.02 3.4 1.5 0.2 Iris-setosa
42 1 4.66 3.5 1.3 0.3 Iris-setosa
43 0 4.3 2.3 1.3 0.3 Iris-setosa
44 0 4.3 3.2 1.3 0.2 Iris-setosa
45 1 4.66 3.5 1.6 0.6 Iris-setosa
46 2 5.02 3.8 1.9 0.4 Iris-setosa
47 1 4.66 3 1.4 0.3 Iris-setosa
48 2 5.02 3.8 1.6 0.2 Iris-setosa
49 0 4.3 3.2 1.4 0.2 Iris-setosa
50 2 5.02 3.7 1.5 0.2 Iris-setosa
51 1 4.66 3.3 1.4 0.2 Iris-setosa
52 7 6.82 3.2 4.7 1.4 Iris-versicolor
53 5 6.1 3.2 4.5 1.5 Iris-versicolor
54 7 6.82 3.1 4.9 1.5 Iris-versicolor
55 3 5.38 2.3 4 1.3 Iris-versicolor
56 6 6.46 2.8 4.6 1.5 Iris-versicolor
57 3 5.38 2.8 4.5 1.3 Iris-versicolor
58 5 6.1 3.3 4.7 1.6 Iris-versicolor
59 1 4.66 2.4 3.3 1 Iris-versicolor
60 6 6.46 2.9 4.6 1.3 Iris-versicolor
61 2 5.02 2.7 3.9 1.4 Iris-versicolor
62 1 4.66 2 3.5 1 Iris-versicolor
63 4 5.74 3 4.2 1.5 Iris-versicolor
64 4 5.74 2.2 4 1 Iris-versicolor
65 5 6.1 2.9 4.7 1.4 Iris-versicolor
66 3 5.38 2.9 3.6 1.3 Iris-versicolor
67 6 6.46 3.1 4.4 1.4 Iris-versicolor
68 3 5.38 3 4.5 1.5 Iris-versicolor
69 4 5.74 2.7 4.1 1 Iris-versicolor
70 5 6.1 2.2 4.5 1.5 Iris-versicolor
71 3 5.38 2.5 3.9 1.1 Iris-versicolor
72 4 5.74 3.2 4.8 1.8 Iris-versicolor
73 5 6.1 2.8 4 1.3 Iris-versicolor
74 5 6.1 2.5 4.9 1.5 Iris-versicolor
75 5 6.1 2.8 4.7 1.2 Iris-versicolor
76 5 6.1 2.9 4.3 1.3 Iris-versicolor
77 6 6.46 3 4.4 1.4 Iris-versicolor
78 6 6.46 2.8 4.8 1.4 Iris-versicolor
79 6 6.46 3 5 1.7 Iris-versicolor
80 4 5.74 2.9 4.5 1.5 Iris-versicolor
81 3 5.38 2.6 3.5 1 Iris-versicolor
82 3 5.38 2.4 3.8 1.1 Iris-versicolor
83 3 5.38 2.4 3.7 1 Iris-versicolor
84 4 5.74 2.7 3.9 1.2 Iris-versicolor
85 4 5.74 2.7 5.1 1.6 Iris-versicolor
86 3 5.38 3 4.5 1.5 Iris-versicolor
87 4 5.74 3.4 4.5 1.6 Iris-versicolor
88 6 6.46 3.1 4.7 1.5 Iris-versicolor
89 5 6.1 2.3 4.4 1.3 Iris-versicolor
90 3 5.38 3 4.1 1.3 Iris-versicolor
91 3 5.38 2.5 4 1.3 Iris-versicolor
92 3 5.38 2.6 4.4 1.2 Iris-versicolor
93 5 6.1 3 4.6 1.4 Iris-versicolor
94 4 5.74 2.6 4 1.2 Iris-versicolor
95 1 4.66 2.3 3.3 1 Iris-versicolor
96 3 5.38 2.7 4.2 1.3 Iris-versicolor
97 3 5.38 3 4.2 1.2 Iris-versicolor
98 3 5.38 2.9 4.2 1.3 Iris-versicolor
99 5 6.1 2.9 4.3 1.3 Iris-versicolor
100 2 5.02 2.5 3 1.1 Iris-versicolor
101 3 5.38 2.8 4.1 1.3 Iris-versicolor
102 5 6.1 3.3 6 2.5 Iris-virginica
103 4 5.74 2.7 5.1 1.9 Iris-virginica
104 7 6.82 3 5.9 2.1 Iris-virginica
105 5 6.1 2.9 5.6 1.8 Iris-virginica
106 6 6.46 3 5.8 2.2 Iris-virginica
107 9 7.54 3 6.6 2.1 Iris-virginica
108 1 4.66 2.5 4.5 1.7 Iris-virginica
109 8 7.18 2.9 6.3 1.8 Iris-virginica
110 6 6.46 2.5 5.8 1.8 Iris-virginica
111 8 7.18 3.6 6.1 2.5 Iris-virginica
112 6 6.46 3.2 5.1 2 Iris-virginica
113 5 6.1 2.7 5.3 1.9 Iris-virginica
114 6 6.46 3 5.5 2.1 Iris-virginica
115 3 5.38 2.5 5 2 Iris-virginica
116 4 5.74 2.8 5.1 2.4 Iris-virginica
117 5 6.1 3.2 5.3 2.3 Iris-virginica
118 6 6.46 3 5.5 1.8 Iris-virginica
119 9 7.54 3.8 6.7 2.2 Iris-virginica
120 9 7.54 2.6 6.9 2.3 Iris-virginica
121 4 5.74 2.2 5 1.5 Iris-virginica
122 7 6.82 3.2 5.7 2.3 Iris-virginica
123 3 5.38 2.8 4.9 2 Iris-virginica
124 9 7.54 2.8 6.7 2 Iris-virginica
125 5 6.1 2.7 4.9 1.8 Iris-virginica
126 6 6.46 3.3 5.7 2.1 Iris-virginica
127 8 7.18 3.2 6 1.8 Iris-virginica
128 5 6.1 2.8 4.8 1.8 Iris-virginica
129 5 6.1 3 4.9 1.8 Iris-virginica
130 5 6.1 2.8 5.6 2.1 Iris-virginica
131 8 7.18 3 5.8 1.6 Iris-virginica
132 8 7.18 2.8 6.1 1.9 Iris-virginica
133 9 7.9 3.8 6.4 2 Iris-virginica
134 5 6.1 2.8 5.6 2.2 Iris-virginica
135 5 6.1 2.8 5.1 1.5 Iris-virginica
136 5 6.1 2.6 5.6 1.4 Iris-virginica
137 9 7.54 3 6.1 2.3 Iris-virginica
138 5 6.1 3.4 5.6 2.4 Iris-virginica
139 5 6.1 3.1 5.5 1.8 Iris-virginica
140 4 5.74 3 4.8 1.8 Iris-virginica
141 7 6.82 3.1 5.4 2.1 Iris-virginica
142 6 6.46 3.1 5.6 2.4 Iris-virginica
143 7 6.82 3.1 5.1 2.3 Iris-virginica
144 4 5.74 2.7 5.1 1.9 Iris-virginica
145 6 6.46 3.2 5.9 2.3 Iris-virginica
146 6 6.46 3.3 5.7 2.5 Iris-virginica
147 6 6.46 3 5.2 2.3 Iris-virginica
148 5 6.1 2.5 5 1.9 Iris-virginica
149 6 6.46 3 5.2 2 Iris-virginica
150 5 6.1 3.4 5.4 2.3 Iris-virginica
151 4 5.74 3 5.1 1.8 Iris-virginica

View File

@ -34,7 +34,7 @@ func main() {
// If two decimal places isn't enough, you can update the
// Precision field on any FloatAttribute
if attr, ok := rawData.GetAttr(0).(*base.FloatAttribute); !ok {
if attr, ok := rawData.AllAttributes()[0].(*base.FloatAttribute); !ok {
panic("Invalid cast")
} else {
attr.Precision = 4
@ -44,8 +44,15 @@ func main() {
// We can update the set of Instances, although the API
// for doing so is not very sophisticated.
rawData.SetAttrStr(0, 0, "1.00")
rawData.SetAttrStr(0, rawData.ClassIndex, "Iris-unusual")
// First, have to resolve Attribute Specifications
as := base.ResolveAttributes(rawData, rawData.AllAttributes())
// Attribute Specifications describe where a given column lives
rawData.Set(as[0], 0, as[0].GetAttribute().GetSysValFromString("1.00"))
// A SetClass method exists as a shortcut
base.SetClass(rawData, 0, "Iris-unusual")
fmt.Println(rawData)
// There is a way of creating new Instances from scratch.
@ -64,6 +71,21 @@ func main() {
attrs[1].GetSysValFromString("A")
// Now let's create the final instances set
newInst := base.NewInstancesFromRaw(attrs, 1, newData)
newInst := base.NewDenseInstances()
// Add the attributes
newSpecs := make([]base.AttributeSpec, len(attrs))
for i, a := range attrs {
newSpecs[i] = newInst.AddAttribute(a)
}
// Allocate space
newInst.Extend(1)
// Write the data
newInst.Set(newSpecs[0], 0, newSpecs[0].GetAttribute().GetSysValFromString("1.0"))
newInst.Set(newSpecs[1], 0, newSpecs[1].GetAttribute().GetSysValFromString("A"))
fmt.Println(newInst)
}

View File

@ -12,7 +12,7 @@ func main() {
if err != nil {
panic(err)
}
rawData.Shuffle()
//Initialises a new KNN classifier
cls := knn.NewKnnClassifier("euclidean", 2)

View File

@ -5,15 +5,15 @@ package main
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
ensemble "github.com/sjwhitworth/golearn/ensemble"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
ensemble "github.com/sjwhitworth/golearn/ensemble"
trees "github.com/sjwhitworth/golearn/trees"
"math/rand"
"time"
)
func main () {
func main() {
var tree base.Classifier
@ -27,12 +27,14 @@ func main () {
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(iris)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
trainData, testData := base.InstancesTrainTestSplit(iris, 0.60)
trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)
//
// First up, use ID3

151
filters/binary.go Normal file
View File

@ -0,0 +1,151 @@
package filters
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
)
// BinaryConvertFilters convert a given DataGrid into one which
// only contains BinaryAttributes.
//
// FloatAttributes are discretised into either 0 (if the value is 0)
// or 1 (if the value is not 0).
//
// CategoricalAttributes are discretised into one or more new
// BinaryAttributes.
type BinaryConvertFilter struct {
attrs []base.Attribute
converted []base.FilteredAttribute
twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes
nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute
}
// NewBinaryConvertFilter creates a blank BinaryConvertFilter
func NewBinaryConvertFilter() *BinaryConvertFilter {
ret := &BinaryConvertFilter{
make([]base.Attribute, 0),
make([]base.FilteredAttribute, 0),
make(map[base.Attribute]bool),
make(map[base.Attribute]map[uint64]base.Attribute),
}
return ret
}
// AddAttribute adds a new Attribute to this Filter
func (b *BinaryConvertFilter) AddAttribute(a base.Attribute) error {
b.attrs = append(b.attrs, a)
return nil
}
// GetAttributesAfterFiltering returns the Attributes previously computed via Train()
func (b *BinaryConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
return b.converted
}
// String gets a human-readable string
func (b *BinaryConvertFilter) String() string {
return fmt.Sprintf("BinaryConvertFilter(%d Attribute(s))", len(b.attrs))
}
// Transform converts the given byte sequence using the old Attribute into the new
// byte sequence.
//
// If the old Attribute has a categorical value of at most two items, then a zero or
// non-zero byte sequence is returned.
//
// If the old Attribute has a categorical value of at most n-items, then a non-zero
// or zero byte sequence is returned based on the value of the new Attribute passed in.
//
// If the old Attribute is a float, it's value's unpacked and we check for non-zeroness
//
// If the old Attribute is a BinaryAttribute, just return the input
func (b *BinaryConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte {
ret := make([]byte, 1)
// Check for CategoricalAttribute
if _, ok := a.(*base.CategoricalAttribute); ok {
// Unpack byte value
val := base.UnpackBytesToU64(attrBytes)
// If it's a two-valued one, check for non-zero
if b.twoValuedCategoricalAttributes[a] {
if val > 0 {
ret[0] = 1
} else {
ret[0] = 0
}
} else if an, ok := b.nValuedCategoricalAttributeMap[a]; ok {
// If it's an n-valued one, check the new Attribute maps onto
// the unpacked value
if af, ok := an[val]; ok {
if af.Equals(n) {
ret[0] = 1
} else {
ret[0] = 0
}
} else {
panic("Categorical value not defined!")
}
} else {
panic(fmt.Sprintf("Not a recognised Attribute %v", a))
}
} else if _, ok := a.(*base.BinaryAttribute); ok {
// Binary: just return the original value
ret = attrBytes
} else if _, ok := a.(*base.FloatAttribute); ok {
// Float: check for non-zero
val := base.UnpackBytesToFloat(attrBytes)
if val > 0 {
ret[0] = 1
} else {
ret[0] = 0
}
} else {
panic(fmt.Sprintf("Unrecognised Attribute: %v", a))
}
return ret
}
// Train converts the FloatAttributesinto equivalently named BinaryAttributes,
// leaves BinaryAttributes unmodified and processes
// CategoricalAttributes as follows.
//
// If the CategoricalAttribute has two values, one of them is
// designated 0 and the other 1, and a single identically-named
// binary Attribute is returned.
//
// If the CategoricalAttribute has more than two (n) values, the Filter
// generates n BinaryAttributes and sets each of them if the value's observed.
func (b *BinaryConvertFilter) Train() error {
for _, a := range b.attrs {
if ac, ok := a.(*base.CategoricalAttribute); ok {
vals := ac.GetValues()
if len(vals) <= 2 {
nAttr := base.NewBinaryAttribute(ac.GetName())
fAttr := base.FilteredAttribute{ac, nAttr}
b.converted = append(b.converted, fAttr)
b.twoValuedCategoricalAttributes[a] = true
} else {
if _, ok := b.nValuedCategoricalAttributeMap[a]; !ok {
b.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute)
}
for i := uint64(0); i < uint64(len(vals)); i++ {
v := vals[i]
newName := fmt.Sprintf("%s_%s", ac.GetName(), v)
newAttr := base.NewBinaryAttribute(newName)
fAttr := base.FilteredAttribute{ac, newAttr}
b.converted = append(b.converted, fAttr)
b.nValuedCategoricalAttributeMap[a][i] = newAttr
}
}
} else if ab, ok := a.(*base.BinaryAttribute); ok {
fAttr := base.FilteredAttribute{ab, ab}
b.converted = append(b.converted, fAttr)
} else if af, ok := a.(*base.FloatAttribute); ok {
newAttr := base.NewBinaryAttribute(af.GetName())
fAttr := base.FilteredAttribute{af, newAttr}
b.converted = append(b.converted, fAttr)
} else {
return fmt.Errorf("Unsupported Attribute type: %v", a)
}
}
return nil
}

4
filters/binary_test.csv Normal file
View File

@ -0,0 +1,4 @@
floatAttr,shouldBe1Binary,shouldBe3Binary
1.0,true,stoicism
1.0,false,heroism
0.0,false,romanticism
1 floatAttr shouldBe1Binary shouldBe3Binary
2 1.0 true stoicism
3 1.0 false heroism
4 0.0 false romanticism

View File

@ -9,113 +9,114 @@ import (
// BinningFilter does equal-width binning for numeric
// Attributes (aka "histogram binning")
type BinningFilter struct {
Attributes []int
Instances *base.Instances
BinCount int
MinVals map[int]float64
MaxVals map[int]float64
trained bool
AbstractDiscretizeFilter
bins int
minVals map[base.Attribute]float64
maxVals map[base.Attribute]float64
}
// NewBinningFilter creates a BinningFilter structure
// with some helpful default initialisations.
func NewBinningFilter(inst *base.Instances, bins int) BinningFilter {
return BinningFilter{
make([]int, 0),
inst,
func NewBinningFilter(d base.FixedDataGrid, bins int) *BinningFilter {
return &BinningFilter{
AbstractDiscretizeFilter{
make(map[base.Attribute]bool),
false,
d,
},
bins,
make(map[int]float64),
make(map[int]float64),
false,
make(map[base.Attribute]float64),
make(map[base.Attribute]float64),
}
}
// AddAttribute adds the index of the given attribute `a'
// to the BinningFilter for discretisation.
func (b *BinningFilter) AddAttribute(a base.Attribute) {
attrIndex := b.Instances.GetAttrIndex(a)
if attrIndex == -1 {
panic("invalid attribute")
}
b.Attributes = append(b.Attributes, attrIndex)
func (b *BinningFilter) String() string {
return fmt.Sprintf("BinningFilter(%d Attribute(s), %d bin(s)", b.attrs, b.bins)
}
// AddAllNumericAttributes adds every suitable attribute
// to the BinningFilter for discretiation
func (b *BinningFilter) AddAllNumericAttributes() {
for i := 0; i < b.Instances.Cols; i++ {
if i == b.Instances.ClassIndex {
continue
}
attr := b.Instances.GetAttr(i)
if attr.GetType() != base.Float64Type {
continue
}
b.Attributes = append(b.Attributes, i)
}
}
// Build computes and stores the bin values
// Train computes and stores the bin values
// for the training instances.
func (b *BinningFilter) Build() {
for _, attr := range b.Attributes {
maxVal := math.Inf(-1)
minVal := math.Inf(1)
for i := 0; i < b.Instances.Rows; i++ {
val := b.Instances.Get(i, attr)
if val > maxVal {
maxVal = val
func (b *BinningFilter) Train() error {
as := b.getAttributeSpecs()
// Set up the AttributeSpecs, and values
for attr := range b.attrs {
if !b.attrs[attr] {
continue
}
b.minVals[attr] = float64(math.Inf(1))
b.maxVals[attr] = float64(math.Inf(-1))
}
err := b.train.MapOverRows(as, func(row [][]byte, rowNo int) (bool, error) {
for i, a := range row {
attr := as[i].GetAttribute()
attrf := attr.(*base.FloatAttribute)
val := float64(attrf.GetFloatFromSysVal(a))
if val > b.maxVals[attr] {
b.maxVals[attr] = val
}
if val < minVal {
minVal = val
if val < b.minVals[attr] {
b.minVals[attr] = val
}
}
b.MaxVals[attr] = maxVal
b.MinVals[attr] = minVal
b.trained = true
return true, nil
})
if err != nil {
return fmt.Errorf("Training error: %s", err)
}
b.trained = true
return nil
}
// Run applies a trained BinningFilter to a set of Instances,
// discretising any numeric attributes added.
//
// IMPORTANT: Run discretises in-place, so make sure to take
// a copy if the original instances are still needed
//
// IMPORTANT: This function panic()s if the filter has not been
// trained. Call Build() before running this function
//
// IMPORTANT: Call Build() after adding any additional attributes.
// Otherwise, the training structure will be out of date from
// the values expected and could cause a panic.
func (b *BinningFilter) Run(on *base.Instances) {
if !b.trained {
panic("Call Build() beforehand")
// Transform takes an Attribute and byte sequence and returns
// the transformed byte sequence.
func (b *BinningFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
if !b.attrs[a] {
return field
}
for attr := range b.Attributes {
minVal := b.MinVals[attr]
maxVal := b.MaxVals[attr]
disc := 0
// Casts to float32 to replicate a floating point precision error
delta := float32(maxVal - minVal)
delta /= float32(b.BinCount)
for i := 0; i < on.Rows; i++ {
val := on.Get(i, attr)
if val <= minVal {
disc = 0
} else {
disc = int(math.Floor(float64(float32(val-minVal) / delta)))
if disc >= b.BinCount {
disc = b.BinCount - 1
}
}
on.Set(i, attr, float64(disc))
}
newAttribute := new(base.CategoricalAttribute)
newAttribute.SetName(on.GetAttr(attr).GetName())
for i := 0; i < b.BinCount; i++ {
newAttribute.GetSysValFromString(fmt.Sprintf("%d", i))
}
on.ReplaceAttr(attr, newAttribute)
af, ok := a.(*base.FloatAttribute)
if !ok {
panic("Attribute is the wrong type")
}
minVal := b.minVals[a]
maxVal := b.maxVals[a]
disc := 0
// Casts to float64 to replicate a floating point precision error
delta := float64(maxVal-minVal) / float64(b.bins)
val := float64(af.GetFloatFromSysVal(field))
if val <= minVal {
disc = 0
} else {
disc = int(math.Floor(float64(float64(val-minVal)/delta + 0.0001)))
}
return base.PackU64ToBytes(uint64(disc))
}
// GetAttributesAfterFiltering gets a list of before/after
// Attributes as base.FilteredAttributes
func (b *BinningFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
oldAttrs := b.train.AllAttributes()
ret := make([]base.FilteredAttribute, len(oldAttrs))
for i, a := range oldAttrs {
if b.attrs[a] {
retAttr := new(base.CategoricalAttribute)
minVal := b.minVals[a]
maxVal := b.maxVals[a]
delta := float64(maxVal-minVal) / float64(b.bins)
retAttr.SetName(a.GetName())
for i := 0; i <= b.bins; i++ {
floatVal := float64(i)*delta + minVal
fmtStr := fmt.Sprintf("%%.%df", a.(*base.FloatAttribute).Precision)
binVal := fmt.Sprintf(fmtStr, floatVal)
retAttr.GetSysValFromString(binVal)
}
ret[i] = base.FilteredAttribute{a, retAttr}
} else {
ret[i] = base.FilteredAttribute{a, a}
}
}
return ret
}

View File

@ -2,27 +2,55 @@ package filters
import (
base "github.com/sjwhitworth/golearn/base"
"math"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestBinning(testEnv *testing.T) {
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
inst3, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := NewBinningFilter(inst1, 10)
filt.AddAttribute(inst1.GetAttr(0))
filt.Build()
filt.Run(inst1)
for i := 0; i < inst1.Rows; i++ {
val1 := inst1.Get(i, 0)
val2 := inst2.Get(i, 0)
val3 := inst3.Get(i, 0)
if math.Abs(val1-val2) >= 1 {
testEnv.Error(val1, val2, val3, i)
func TestBinning(t *testing.T) {
Convey("Given some data and a reference", t, func() {
// Read the data
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
}
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
if err != nil {
panic(err)
}
//
// Construct the binning filter
binAttr := inst1.AllAttributes()[0]
filt := NewBinningFilter(inst1, 10)
filt.AddAttribute(binAttr)
filt.Train()
inst1f := base.NewLazilyFilteredInstances(inst1, filt)
// Retrieve the categorical version of the original Attribute
var cAttr base.Attribute
for _, a := range inst1f.AllAttributes() {
if a.GetName() == binAttr.GetName() {
cAttr = a
}
}
cAttrSpec, err := inst1f.GetAttribute(cAttr)
So(err, ShouldEqual, nil)
binAttrSpec, err := inst2.GetAttribute(binAttr)
So(err, ShouldEqual, nil)
//
// Create the LazilyFilteredInstances
// and check the values
Convey("Discretized version should match reference", func() {
_, rows := inst1.Size()
for i := 0; i < rows; i++ {
val1 := inst1f.Get(cAttrSpec, i)
val2 := inst2.Get(binAttrSpec, i)
val1s := cAttr.GetStringFromSysVal(val1)
val2s := binAttr.GetStringFromSysVal(val2)
So(val1s, ShouldEqual, val2s)
}
})
})
}

View File

@ -12,301 +12,30 @@ import (
// See Bramer, "Principles of Data Mining", 2nd Edition
// pp 105--115
type ChiMergeFilter struct {
Attributes []int
Instances *base.Instances
Tables map[int][]*FrequencyTableEntry
AbstractDiscretizeFilter
tables map[base.Attribute][]*FrequencyTableEntry
Significance float64
MinRows int
MaxRows int
_Trained bool
}
// NewChiMergeFilter creates a ChiMergeFilter with some helpful initialisations.
func NewChiMergeFilter(inst *base.Instances, significance float64) ChiMergeFilter {
return ChiMergeFilter{
make([]int, 0),
inst,
make(map[int][]*FrequencyTableEntry),
// NewChiMergeFilter creates a ChiMergeFilter with some helpful intialisations.
func NewChiMergeFilter(d base.FixedDataGrid, significance float64) *ChiMergeFilter {
_, rows := d.Size()
return &ChiMergeFilter{
AbstractDiscretizeFilter{
make(map[base.Attribute]bool),
false,
d,
},
make(map[base.Attribute][]*FrequencyTableEntry),
significance,
0,
0,
false,
}
}
// Build trains a ChiMergeFilter on the ChiMergeFilter.Instances given
func (c *ChiMergeFilter) Build() {
for _, attr := range c.Attributes {
tab := chiMerge(c.Instances, attr, c.Significance, c.MinRows, c.MaxRows)
c.Tables[attr] = tab
c._Trained = true
}
}
// AddAllNumericAttributes adds every suitable attribute
// to the ChiMergeFilter for discretisation
func (c *ChiMergeFilter) AddAllNumericAttributes() {
for i := 0; i < c.Instances.Cols; i++ {
if i == c.Instances.ClassIndex {
continue
}
attr := c.Instances.GetAttr(i)
if attr.GetType() != base.Float64Type {
continue
}
c.Attributes = append(c.Attributes, i)
}
}
// Run discretises the set of Instances `on'
//
// IMPORTANT: ChiMergeFilter discretises in place.
func (c *ChiMergeFilter) Run(on *base.Instances) {
if !c._Trained {
panic("Call Build() beforehand")
}
for attr := range c.Tables {
table := c.Tables[attr]
for i := 0; i < on.Rows; i++ {
val := on.Get(i, attr)
dis := 0
for j, k := range table {
if k.Value < val {
dis = j
continue
}
break
}
on.Set(i, attr, float64(dis))
}
newAttribute := new(base.CategoricalAttribute)
newAttribute.SetName(on.GetAttr(attr).GetName())
for _, k := range table {
newAttribute.GetSysValFromString(fmt.Sprintf("%f", k.Value))
}
on.ReplaceAttr(attr, newAttribute)
}
}
// AddAttribute add a given numeric Attribute `attr' to the
// filter.
//
// IMPORTANT: This function panic()s if it can't locate the
// attribute in the Instances set.
func (c *ChiMergeFilter) AddAttribute(attr base.Attribute) {
if attr.GetType() != base.Float64Type {
panic("ChiMerge only works on Float64Attributes")
}
attrIndex := c.Instances.GetAttrIndex(attr)
if attrIndex == -1 {
panic("Invalid attribute!")
}
c.Attributes = append(c.Attributes, attrIndex)
}
type FrequencyTableEntry struct {
Value float64
Frequency map[string]int
}
func (t *FrequencyTableEntry) String() string {
return fmt.Sprintf("%.2f %v", t.Value, t.Frequency)
}
func ChiMBuildFrequencyTable(attr int, inst *base.Instances) []*FrequencyTableEntry {
ret := make([]*FrequencyTableEntry, 0)
var attribute *base.FloatAttribute
attribute, ok := inst.GetAttr(attr).(*base.FloatAttribute)
if !ok {
panic("only use Chi-M on numeric stuff")
}
for i := 0; i < inst.Rows; i++ {
value := inst.Get(i, attr)
valueConv := attribute.GetUsrVal(value)
class := inst.GetClass(i)
// Search the frequency table for the value
found := false
for _, entry := range ret {
if entry.Value == valueConv {
found = true
entry.Frequency[class]++
}
}
if !found {
newEntry := &FrequencyTableEntry{
valueConv,
make(map[string]int),
}
newEntry.Frequency[class] = 1
ret = append(ret, newEntry)
}
}
return ret
}
func chiSquaredPdf(k float64, x float64) float64 {
if x < 0 {
return 0
}
top := math.Pow(x, (k/2)-1) * math.Exp(-x/2)
bottom := math.Pow(2, k/2) * math.Gamma(k/2)
return top / bottom
}
func chiSquaredPercentile(k int, x float64) float64 {
// Implements Yahya et al.'s "A Numerical Procedure
// for Computing Chi-Square Percentage Points"
// InterStat Journal 01/2007; April 25:page:1-8.
steps := 32
intervals := 4 * steps
w := x / (4.0 * float64(steps))
values := make([]float64, intervals+1)
for i := 0; i < intervals+1; i++ {
c := w * float64(i)
v := chiSquaredPdf(float64(k), c)
values[i] = v
}
ret1 := values[0] + values[len(values)-1]
ret2 := 0.0
ret3 := 0.0
ret4 := 0.0
for i := 2; i < intervals-1; i += 4 {
ret2 += values[i]
}
for i := 4; i < intervals-3; i += 4 {
ret3 += values[i]
}
for i := 1; i < intervals; i += 2 {
ret4 += values[i]
}
return (2.0 * w / 45) * (7*ret1 + 12*ret2 + 14*ret3 + 32*ret4)
}
func chiCountClasses(entries []*FrequencyTableEntry) map[string]int {
classCounter := make(map[string]int)
for _, e := range entries {
for k := range e.Frequency {
classCounter[k] += e.Frequency[k]
}
}
return classCounter
}
func chiComputeStatistic(entry1 *FrequencyTableEntry, entry2 *FrequencyTableEntry) float64 {
// Sum the number of things observed per class
classCounter := make(map[string]int)
for k := range entry1.Frequency {
classCounter[k] += entry1.Frequency[k]
}
for k := range entry2.Frequency {
classCounter[k] += entry2.Frequency[k]
}
// Sum the number of things observed per value
entryObservations1 := 0
entryObservations2 := 0
for k := range entry1.Frequency {
entryObservations1 += entry1.Frequency[k]
}
for k := range entry2.Frequency {
entryObservations2 += entry2.Frequency[k]
}
totalObservations := entryObservations1 + entryObservations2
// Compute the expected values per class
expectedClassValues1 := make(map[string]float64)
expectedClassValues2 := make(map[string]float64)
for k := range classCounter {
expectedClassValues1[k] = float64(classCounter[k])
expectedClassValues1[k] *= float64(entryObservations1)
expectedClassValues1[k] /= float64(totalObservations)
}
for k := range classCounter {
expectedClassValues2[k] = float64(classCounter[k])
expectedClassValues2[k] *= float64(entryObservations2)
expectedClassValues2[k] /= float64(totalObservations)
}
// Compute chi-squared value
chiSum := 0.0
for k := range expectedClassValues1 {
numerator := float64(entry1.Frequency[k])
numerator -= expectedClassValues1[k]
numerator = math.Pow(numerator, 2)
denominator := float64(expectedClassValues1[k])
if denominator < 0.5 {
denominator = 0.5
}
chiSum += numerator / denominator
}
for k := range expectedClassValues2 {
numerator := float64(entry2.Frequency[k])
numerator -= expectedClassValues2[k]
numerator = math.Pow(numerator, 2)
denominator := float64(expectedClassValues2[k])
if denominator < 0.5 {
denominator = 0.5
}
chiSum += numerator / denominator
}
return chiSum
}
func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*FrequencyTableEntry {
mergeEntry1 := freq[minIndex]
mergeEntry2 := freq[minIndex+1]
classCounter := make(map[string]int)
for k := range mergeEntry1.Frequency {
classCounter[k] += mergeEntry1.Frequency[k]
}
for k := range mergeEntry2.Frequency {
classCounter[k] += mergeEntry2.Frequency[k]
}
newVal := freq[minIndex].Value
newEntry := &FrequencyTableEntry{
newVal,
classCounter,
}
lowerSlice := freq
upperSlice := freq
if minIndex > 0 {
lowerSlice = freq[0:minIndex]
upperSlice = freq[minIndex+1:]
} else {
lowerSlice = make([]*FrequencyTableEntry, 0)
upperSlice = freq[1:]
}
upperSlice[0] = newEntry
freq = append(lowerSlice, upperSlice...)
return freq
}
func chiMergePrintTable(freq []*FrequencyTableEntry) {
classes := chiCountClasses(freq)
fmt.Printf("Attribute value\t")
for k := range classes {
fmt.Printf("\t%s", k)
}
fmt.Printf("\tTotal\n")
for _, f := range freq {
fmt.Printf("%.2f\t", f.Value)
total := 0
for k := range classes {
fmt.Printf("\t%d", f.Frequency[k])
total += f.Frequency[k]
}
fmt.Printf("\t%d\n", total)
2,
rows,
}
}
// Train computes and stores the
// Produces a value mapping table
// inst: The base.Instances which need discretising
// sig: The significance level (e.g. 0.95)
@ -316,7 +45,7 @@ func chiMergePrintTable(freq []*FrequencyTableEntry) {
// adjacent rows will be merged
// precision: internal number of decimal places to round E value to
// (useful for verification)
func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
func chiMerge(inst base.FixedDataGrid, attr base.Attribute, sig float64, minrows int, maxrows int) []*FrequencyTableEntry {
// Parameter sanity checking
if !(2 <= minrows) {
@ -329,12 +58,17 @@ func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows
sig = 10
}
// Check that the attribute is numeric
_, ok := attr.(*base.FloatAttribute)
if !ok {
panic("only use Chi-M on numeric stuff")
}
// Build a frequency table
freq := ChiMBuildFrequencyTable(attr, inst)
// Count the number of classes
classes := chiCountClasses(freq)
for {
// chiMergePrintTable(freq) DEBUG
if len(freq) <= minrows {
break
}
@ -378,3 +112,77 @@ func chiMerge(inst *base.Instances, attr int, sig float64, minrows int, maxrows
}
return freq
}
func (c *ChiMergeFilter) Train() error {
as := c.getAttributeSpecs()
for _, a := range as {
attr := a.GetAttribute()
// Skip if not set
if !c.attrs[attr] {
continue
}
// Build sort order
sortOrder := []base.AttributeSpec{a}
// Sort
sorted, err := base.LazySort(c.train, base.Ascending, sortOrder)
if err != nil {
panic(err)
}
// Perform ChiMerge
freq := chiMerge(sorted, attr, c.Significance, c.MinRows, c.MaxRows)
c.tables[attr] = freq
}
return nil
}
// GetAttributesAfterFiltering gets a list of before/after
// Attributes as base.FilteredAttributes
func (c *ChiMergeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
oldAttrs := c.train.AllAttributes()
ret := make([]base.FilteredAttribute, len(oldAttrs))
for i, a := range oldAttrs {
if c.attrs[a] {
retAttr := new(base.CategoricalAttribute)
retAttr.SetName(a.GetName())
for _, k := range c.tables[a] {
retAttr.GetSysValFromString(fmt.Sprintf("%f", k.Value))
}
ret[i] = base.FilteredAttribute{a, retAttr}
} else {
ret[i] = base.FilteredAttribute{a, a}
}
}
return ret
}
// Transform returns the byte sequence after discretisation
func (c *ChiMergeFilter) Transform(a base.Attribute, n base.Attribute, field []byte) []byte {
// Do we use this Attribute?
if !c.attrs[a] {
return field
}
// Find the Attribute value in the table
table := c.tables[a]
dis := 0
val := base.UnpackBytesToFloat(field)
for j, k := range table {
if k.Value < val {
dis = j
continue
}
break
}
return base.PackU64ToBytes(uint64(dis))
}
func (c *ChiMergeFilter) String() string {
return fmt.Sprintf("ChiMergeFilter(%d Attributes, %.2f Significance)", len(c.tables), c.Significance)
}

14
filters/chimerge_freq.go Normal file
View File

@ -0,0 +1,14 @@
package filters
import (
"fmt"
)
type FrequencyTableEntry struct {
Value float64
Frequency map[string]int
}
func (t *FrequencyTableEntry) String() string {
return fmt.Sprintf("%.2f %s", t.Value, t.Frequency)
}

205
filters/chimerge_funcs.go Normal file
View File

@ -0,0 +1,205 @@
package filters
import (
"github.com/sjwhitworth/golearn/base"
"fmt"
"math"
)
func ChiMBuildFrequencyTable(attr base.Attribute, inst base.FixedDataGrid) []*FrequencyTableEntry {
ret := make([]*FrequencyTableEntry, 0)
attribute := attr.(*base.FloatAttribute)
attrSpec, err := inst.GetAttribute(attr)
if err != nil {
panic(err)
}
attrSpecs := []base.AttributeSpec{attrSpec}
err = inst.MapOverRows(attrSpecs, func(row [][]byte, rowNo int) (bool, error) {
value := row[0]
valueConv := attribute.GetFloatFromSysVal(value)
class := base.GetClass(inst, rowNo)
// Search the frequency table for the value
found := false
for _, entry := range ret {
if entry.Value == valueConv {
found = true
entry.Frequency[class] += 1
}
}
if !found {
newEntry := &FrequencyTableEntry{
valueConv,
make(map[string]int),
}
newEntry.Frequency[class] = 1
ret = append(ret, newEntry)
}
return true, nil
})
return ret
}
func chiSquaredPdf(k float64, x float64) float64 {
if x < 0 {
return 0
}
top := math.Pow(x, (k/2)-1) * math.Exp(-x/2)
bottom := math.Pow(2, k/2) * math.Gamma(k/2)
return top / bottom
}
func chiSquaredPercentile(k int, x float64) float64 {
// Implements Yahya et al.'s "A Numerical Procedure
// for Computing Chi-Square Percentage Points"
// InterStat Journal 01/2007; April 25:page:1-8.
steps := 32
intervals := 4 * steps
w := x / (4.0 * float64(steps))
values := make([]float64, intervals+1)
for i := 0; i < intervals+1; i++ {
c := w * float64(i)
v := chiSquaredPdf(float64(k), c)
values[i] = v
}
ret1 := values[0] + values[len(values)-1]
ret2 := 0.0
ret3 := 0.0
ret4 := 0.0
for i := 2; i < intervals-1; i += 4 {
ret2 += values[i]
}
for i := 4; i < intervals-3; i += 4 {
ret3 += values[i]
}
for i := 1; i < intervals; i += 2 {
ret4 += values[i]
}
return (2.0 * w / 45) * (7*ret1 + 12*ret2 + 14*ret3 + 32*ret4)
}
func chiCountClasses(entries []*FrequencyTableEntry) map[string]int {
classCounter := make(map[string]int)
for _, e := range entries {
for k := range e.Frequency {
classCounter[k] += e.Frequency[k]
}
}
return classCounter
}
func chiComputeStatistic(entry1 *FrequencyTableEntry, entry2 *FrequencyTableEntry) float64 {
// Sum the number of things observed per class
classCounter := make(map[string]int)
for k := range entry1.Frequency {
classCounter[k] += entry1.Frequency[k]
}
for k := range entry2.Frequency {
classCounter[k] += entry2.Frequency[k]
}
// Sum the number of things observed per value
entryObservations1 := 0
entryObservations2 := 0
for k := range entry1.Frequency {
entryObservations1 += entry1.Frequency[k]
}
for k := range entry2.Frequency {
entryObservations2 += entry2.Frequency[k]
}
totalObservations := entryObservations1 + entryObservations2
// Compute the expected values per class
expectedClassValues1 := make(map[string]float64)
expectedClassValues2 := make(map[string]float64)
for k := range classCounter {
expectedClassValues1[k] = float64(classCounter[k])
expectedClassValues1[k] *= float64(entryObservations1)
expectedClassValues1[k] /= float64(totalObservations)
}
for k := range classCounter {
expectedClassValues2[k] = float64(classCounter[k])
expectedClassValues2[k] *= float64(entryObservations2)
expectedClassValues2[k] /= float64(totalObservations)
}
// Compute chi-squared value
chiSum := 0.0
for k := range expectedClassValues1 {
numerator := float64(entry1.Frequency[k])
numerator -= expectedClassValues1[k]
numerator = math.Pow(numerator, 2)
denominator := float64(expectedClassValues1[k])
if denominator < 0.5 {
denominator = 0.5
}
chiSum += numerator / denominator
}
for k := range expectedClassValues2 {
numerator := float64(entry2.Frequency[k])
numerator -= expectedClassValues2[k]
numerator = math.Pow(numerator, 2)
denominator := float64(expectedClassValues2[k])
if denominator < 0.5 {
denominator = 0.5
}
chiSum += numerator / denominator
}
return chiSum
}
func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*FrequencyTableEntry {
mergeEntry1 := freq[minIndex]
mergeEntry2 := freq[minIndex+1]
classCounter := make(map[string]int)
for k := range mergeEntry1.Frequency {
classCounter[k] += mergeEntry1.Frequency[k]
}
for k := range mergeEntry2.Frequency {
classCounter[k] += mergeEntry2.Frequency[k]
}
newVal := freq[minIndex].Value
newEntry := &FrequencyTableEntry{
newVal,
classCounter,
}
lowerSlice := freq
upperSlice := freq
if minIndex > 0 {
lowerSlice = freq[0:minIndex]
upperSlice = freq[minIndex+1:]
} else {
lowerSlice = make([]*FrequencyTableEntry, 0)
upperSlice = freq[1:]
}
upperSlice[0] = newEntry
freq = append(lowerSlice, upperSlice...)
return freq
}
func chiMergePrintTable(freq []*FrequencyTableEntry) {
classes := chiCountClasses(freq)
fmt.Printf("Attribute value\t")
for k := range classes {
fmt.Printf("\t%s", k)
}
fmt.Printf("\tTotal\n")
for _, f := range freq {
fmt.Printf("%.2f\t", f.Value)
total := 0
for k := range classes {
fmt.Printf("\t%d", f.Frequency[k])
total += f.Frequency[k]
}
fmt.Printf("\t%d\n", total)
}
}

View File

@ -14,7 +14,7 @@ func TestChiMFreqTable(testEnv *testing.T) {
panic(err)
}
freq := ChiMBuildFrequencyTable(0, inst)
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
if freq[0].Frequency["c1"] != 1 {
testEnv.Error("Wrong frequency")
@ -32,7 +32,7 @@ func TestChiClassCounter(testEnv *testing.T) {
if err != nil {
panic(err)
}
freq := ChiMBuildFrequencyTable(0, inst)
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
classes := chiCountClasses(freq)
if classes["c1"] != 27 {
testEnv.Error(classes)
@ -50,7 +50,7 @@ func TestStatisticValues(testEnv *testing.T) {
if err != nil {
panic(err)
}
freq := ChiMBuildFrequencyTable(0, inst)
freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst)
chiVal := chiComputeStatistic(freq[5], freq[6])
if math.Abs(chiVal-1.89) > 0.01 {
testEnv.Error(chiVal)
@ -78,12 +78,15 @@ func TestChiSquareDistValues(testEnv *testing.T) {
}
func TestChiMerge1(testEnv *testing.T) {
// See Bramer, Principles of Machine Learning
// Read the data
inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true)
if err != nil {
panic(err)
}
freq := chiMerge(inst, 0, 0.90, 0, inst.Rows)
_, rows := inst.Size()
freq := chiMerge(inst, inst.AllAttributes()[0], 0.90, 0, rows)
if len(freq) != 3 {
testEnv.Error("Wrong length")
}
@ -106,10 +109,18 @@ func TestChiMerge2(testEnv *testing.T) {
if err != nil {
panic(err)
}
attrs := make([]int, 1)
attrs[0] = 0
inst.Sort(base.Ascending, attrs)
freq := chiMerge(inst, 0, 0.90, 0, inst.Rows)
// Sort the instances
allAttrs := inst.AllAttributes()
sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
if err != nil {
panic(err)
}
// Perform Chi-Merge
_, rows := inst.Size()
freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows)
if len(freq) != 5 {
testEnv.Errorf("Wrong length (%d)", len(freq))
testEnv.Error(freq)
@ -131,6 +142,7 @@ func TestChiMerge2(testEnv *testing.T) {
}
}
/*
func TestChiMerge3(testEnv *testing.T) {
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
@ -138,12 +150,52 @@ func TestChiMerge3(testEnv *testing.T) {
if err != nil {
panic(err)
}
attrs := make([]int, 1)
attrs[0] = 0
inst.Sort(base.Ascending, attrs)
insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes()))
if err != nil {
testEnv.Error(err)
}
filt := NewChiMergeFilter(inst, 0.90)
filt.AddAttribute(inst.GetAttr(0))
filt.Build()
filt.Run(inst)
fmt.Println(inst)
filt.AddAttribute(inst.AllAttributes()[0])
filt.Train()
instf := base.NewLazilyFilteredInstances(insts, filt)
fmt.Println(instf)
fmt.Println(instf.String())
rowStr := instf.RowString(0)
ref := "4.300000 3.00 1.10 0.10 Iris-setosa"
if rowStr != ref {
panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref))
}
clsAttrs := instf.AllClassAttributes()
if len(clsAttrs) != 1 {
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
}
if clsAttrs[0].GetName() != "Species" {
panic("Class Attribute wrong!")
}
}
*/
func TestChiMerge4(testEnv *testing.T) {
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := NewChiMergeFilter(inst, 0.90)
filt.AddAttribute(inst.AllAttributes()[0])
filt.AddAttribute(inst.AllAttributes()[1])
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
fmt.Println(instf)
fmt.Println(instf.String())
clsAttrs := instf.AllClassAttributes()
if len(clsAttrs) != 1 {
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
}
if clsAttrs[0].GetName() != "Species" {
panic("Class Attribute wrong!")
}
}

62
filters/disc.go Normal file
View File

@ -0,0 +1,62 @@
package filters
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
)
type AbstractDiscretizeFilter struct {
attrs map[base.Attribute]bool
trained bool
train base.FixedDataGrid
}
// AddAttribute adds the AttributeSpec of the given attribute `a'
// to the AbstractFloatFilter for discretisation.
func (d *AbstractDiscretizeFilter) AddAttribute(a base.Attribute) error {
if _, ok := a.(*base.FloatAttribute); !ok {
return fmt.Errorf("%s is not a FloatAttribute", a)
}
_, err := d.train.GetAttribute(a)
if err != nil {
return fmt.Errorf("invalid attribute")
}
d.attrs[a] = true
return nil
}
// GetAttributesAfterFiltering gets a list of before/after
// Attributes as base.FilteredAttributes
func (d *AbstractDiscretizeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
oldAttrs := d.train.AllAttributes()
ret := make([]base.FilteredAttribute, len(oldAttrs))
for i, a := range oldAttrs {
if d.attrs[a] {
retAttr := new(base.CategoricalAttribute)
retAttr.SetName(a.GetName())
ret[i] = base.FilteredAttribute{a, retAttr}
} else {
ret[i] = base.FilteredAttribute{a, a}
}
}
return ret
}
func (d *AbstractDiscretizeFilter) getAttributeSpecs() []base.AttributeSpec {
as := make([]base.AttributeSpec, 0)
// Set up the AttributeSpecs, and values
for attr := range d.attrs {
// If for some reason we've un-added it...
if !d.attrs[attr] {
continue
}
// Get the AttributeSpec for the training set
a, err := d.train.GetAttribute(attr)
if err != nil {
panic(fmt.Errorf("Attribute resolution error: %s", err))
}
// Append to return set
as = append(as, a)
}
return as
}

View File

@ -14,7 +14,7 @@ import (
// The accepted distance functions at this time are 'euclidean' and 'manhattan'.
type KNNClassifier struct {
base.BaseEstimator
TrainingData *base.Instances
TrainingData base.FixedDataGrid
DistanceFunc string
NearestNeighbours int
}
@ -28,20 +28,12 @@ func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
}
// Fit stores the training data for later
func (KNN *KNNClassifier) Fit(trainingData *base.Instances) {
func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) {
KNN.TrainingData = trainingData
}
// PredictOne returns a classification for the vector, based on a vector input, using the KNN algorithm.
// See http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
func (KNN *KNNClassifier) PredictOne(vector []float64) string {
rows := KNN.TrainingData.Rows
rownumbers := make(map[int]float64)
labels := make([]string, 0)
maxmap := make(map[string]int)
convertedVector := util.FloatsToMatrix(vector)
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
// Check what distance function we are using
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
@ -52,40 +44,96 @@ func (KNN *KNNClassifier) PredictOne(vector []float64) string {
distanceFunc = pairwiseMetrics.NewManhattan()
default:
panic("unsupported distance function")
}
// Check compatability
allAttrs := base.CheckCompatable(what, KNN.TrainingData)
if allAttrs == nil {
// Don't have the same Attributes
return nil
}
for i := 0; i < rows; i++ {
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
rowMat := util.FloatsToMatrix(row)
distance := distanceFunc.Distance(rowMat, convertedVector)
rownumbers[i] = distance
}
sorted := util.SortIntMap(rownumbers)
values := sorted[:KNN.NearestNeighbours]
for _, elem := range values {
label := KNN.TrainingData.GetClass(elem)
labels = append(labels, label)
if _, ok := maxmap[label]; ok {
maxmap[label]++
} else {
maxmap[label] = 1
// Remove the Attributes which aren't numeric
allNumericAttrs := make([]base.Attribute, 0)
for _, a := range allAttrs {
if fAttr, ok := a.(*base.FloatAttribute); ok {
allNumericAttrs = append(allNumericAttrs, fAttr)
}
}
sortedlabels := util.SortStringMap(maxmap)
label := sortedlabels[0]
// Generate return vector
ret := base.GeneratePredictionVector(what)
return label
}
// Resolve Attribute specifications for both
whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)
// Reserve storage for most the most similar items
distances := make(map[int]float64)
// Reserve storage for voting map
maxmap := make(map[string]int)
// Reserve storage for row computations
trainRowBuf := make([]float64, len(allNumericAttrs))
predRowBuf := make([]float64, len(allNumericAttrs))
// Iterate over all outer rows
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
// Read the float values out
for i, _ := range allNumericAttrs {
predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
}
predMat := util.FloatsToMatrix(predRowBuf)
// Find the closest match in the training data
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
// Read the float values out
for i, _ := range allNumericAttrs {
trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
}
// Compute the distance
trainMat := util.FloatsToMatrix(trainRowBuf)
distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
return true, nil
})
sorted := util.SortIntMap(distances)
values := sorted[:KNN.NearestNeighbours]
// Reset maxMap
for a := range maxmap {
maxmap[a] = 0
}
// Refresh maxMap
for _, elem := range values {
label := base.GetClass(KNN.TrainingData, elem)
if _, ok := maxmap[label]; ok {
maxmap[label]++
} else {
maxmap[label] = 1
}
}
// Sort the maxMap
var maxClass string
maxVal := -1
for a := range maxmap {
if maxmap[a] > maxVal {
maxVal = maxmap[a]
maxClass = a
}
}
base.SetClass(ret, predRowNo, maxClass)
return true, nil
})
func (KNN *KNNClassifier) Predict(what *base.Instances) *base.Instances {
ret := what.GeneratePredictionVector()
for i := 0; i < what.Rows; i++ {
ret.SetAttrStr(i, 0, KNN.PredictOne(what.GetRowVectorWithoutClass(i)))
}
return ret
}

View File

@ -24,16 +24,17 @@ func TestKnnClassifier(t *testing.T) {
cls := NewKnnClassifier("euclidean", 2)
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
result := predictions.GetClass(0)
result := base.GetClass(predictions, 0)
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our first vector", func() {
result2 := predictions.GetClass(1)
result2 := base.GetClass(predictions, 1)
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})

View File

@ -8,23 +8,26 @@ import (
func TestLogisticRegression(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
// Load data
X, err := base.ParseCSVToInstances("train.csv", false)
So(err, ShouldEqual, nil)
Y, err := base.ParseCSVToInstances("test.csv", false)
So(err, ShouldEqual, nil)
// Setup the problem
lr := NewLogisticRegression("l2", 1.0, 1e-6)
lr.Fit(X)
Convey("When predicting the label of first vector", func() {
Z := lr.Predict(Y)
Convey("The result should be 1", func() {
So(Z.Get(0, 0), ShouldEqual, 1.0)
So(Z.RowString(0), ShouldEqual, "1.00")
})
})
Convey("When predicting the label of second vector", func() {
Z := lr.Predict(Y)
Convey("The result should be -1", func() {
So(Z.Get(1, 0), ShouldEqual, -1.0)
So(Z.RowString(1), ShouldEqual, "-1.00")
})
})
})

View File

@ -5,6 +5,7 @@ import (
"github.com/sjwhitworth/golearn/base"
"fmt"
_ "github.com/gonum/blas"
"github.com/gonum/blas/cblas"
"github.com/gonum/matrix/mat64"
@ -19,6 +20,8 @@ type LinearRegression struct {
fitted bool
disturbance float64
regressionCoefficients []float64
attrs []base.Attribute
cls base.Attribute
}
func init() {
@ -29,31 +32,59 @@ func NewLinearRegression() *LinearRegression {
return &LinearRegression{fitted: false}
}
func (lr *LinearRegression) Fit(inst *base.Instances) error {
if inst.Rows < inst.GetAttributeCount() {
return NotEnoughDataError
func (lr *LinearRegression) Fit(inst base.FixedDataGrid) error {
// Retrieve row size
_, rows := inst.Size()
// Validate class Attribute count
classAttrs := inst.AllClassAttributes()
if len(classAttrs) != 1 {
return fmt.Errorf("Only 1 class variable is permitted")
}
classAttrSpecs := base.ResolveAttributes(inst, classAttrs)
// Split into two matrices, observed results (dependent variable y)
// and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
observed := mat64.NewDense(inst.Rows, 1, nil)
explVariables := mat64.NewDense(inst.Rows, inst.GetAttributeCount(), nil)
for i := 0; i < inst.Rows; i++ {
observed.Set(i, 0, inst.Get(i, inst.ClassIndex)) // Set observed data
for j := 0; j < inst.GetAttributeCount(); j++ {
if j == 0 {
// Set intercepts to 1.0
// Could / should be done better: http://www.theanalysisfactor.com/interpret-the-intercept/
explVariables.Set(i, 0, 1.0)
} else {
explVariables.Set(i, j, inst.Get(i, j-1))
}
// Retrieve relevant Attributes
allAttrs := base.NonClassAttributes(inst)
attrs := make([]base.Attribute, 0)
for _, a := range allAttrs {
if _, ok := a.(*base.FloatAttribute); ok {
attrs = append(attrs, a)
}
}
n := inst.GetAttributeCount()
cols := len(attrs) + 1
if rows < cols {
return NotEnoughDataError
}
// Retrieve relevant Attribute specifications
attrSpecs := base.ResolveAttributes(inst, attrs)
// Split into two matrices, observed results (dependent variable y)
// and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
observed := mat64.NewDense(rows, 1, nil)
explVariables := mat64.NewDense(rows, cols, nil)
// Build the observed matrix
inst.MapOverRows(classAttrSpecs, func(row [][]byte, i int) (bool, error) {
val := base.UnpackBytesToFloat(row[0])
observed.Set(i, 0, val)
return true, nil
})
// Build the explainatory variables
inst.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
// Set intercepts to 1.0
explVariables.Set(i, 0, 1.0)
for j, r := range row {
explVariables.Set(i, j+1, base.UnpackBytesToFloat(r))
}
return true, nil
})
n := cols
qr := mat64.QR(explVariables)
q := qr.Q()
reg := qr.R()
@ -74,25 +105,32 @@ func (lr *LinearRegression) Fit(inst *base.Instances) error {
lr.disturbance = regressionCoefficients[0]
lr.regressionCoefficients = regressionCoefficients[1:]
lr.fitted = true
lr.attrs = attrs
lr.cls = classAttrs[0]
return nil
}
func (lr *LinearRegression) Predict(X *base.Instances) (*base.Instances, error) {
func (lr *LinearRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
if !lr.fitted {
return nil, NoTrainingDataError
}
ret := X.GeneratePredictionVector()
for i := 0; i < X.Rows; i++ {
var prediction float64 = lr.disturbance
for j := 0; j < X.Cols; j++ {
if j != X.ClassIndex {
prediction += X.Get(i, j) * lr.regressionCoefficients[j]
}
}
ret.Set(i, 0, prediction)
ret := base.GeneratePredictionVector(X)
attrSpecs := base.ResolveAttributes(X, lr.attrs)
clsSpec, err := ret.GetAttribute(lr.cls)
if err != nil {
return nil, err
}
X.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
var prediction float64 = lr.disturbance
for j, r := range row {
prediction += base.UnpackBytesToFloat(r) * lr.regressionCoefficients[j]
}
ret.Set(clsSpec, i, base.PackFloatToBytes(prediction))
return true, nil
})
return ret, nil
}

View File

@ -54,8 +54,10 @@ func TestLinearRegression(t *testing.T) {
t.Fatal(err)
}
for i := 0; i < predictions.Rows; i++ {
fmt.Printf("Expected: %f || Predicted: %f\n", testData.Get(i, testData.ClassIndex), predictions.Get(i, predictions.ClassIndex))
_, rows := predictions.Size()
for i := 0; i < rows; i++ {
fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
}
}

View File

@ -27,51 +27,85 @@ func NewLogisticRegression(penalty string, C float64, eps float64) *LogisticRegr
return &lr
}
func convertInstancesToProblemVec(X *base.Instances) [][]float64 {
problemVec := make([][]float64, X.Rows)
for i := 0; i < X.Rows; i++ {
problemVecCounter := 0
problemVec[i] = make([]float64, X.Cols-1)
for j := 0; j < X.Cols; j++ {
if j == X.ClassIndex {
continue
}
problemVec[i][problemVecCounter] = X.Get(i, j)
problemVecCounter++
func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
// Allocate problem array
_, rows := X.Size()
problemVec := make([][]float64, rows)
// Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(X)
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
// Convert each row
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
// Allocate a new row
probRow := make([]float64, len(numericAttrSpecs))
// Read out the row
for i, _ := range numericAttrSpecs {
probRow[i] = base.UnpackBytesToFloat(row[i])
}
}
base.Logger.Println(problemVec, X)
// Add the row
problemVec[rowNo] = probRow
return true, nil
})
return problemVec
}
func convertInstancesToLabelVec(X *base.Instances) []float64 {
labelVec := make([]float64, X.Rows)
for i := 0; i < X.Rows; i++ {
labelVec[i] = X.Get(i, X.ClassIndex)
func convertInstancesToLabelVec(X base.FixedDataGrid) []float64 {
// Get the class Attributes
classAttrs := X.AllClassAttributes()
// Only support 1 class Attribute
if len(classAttrs) != 1 {
panic(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
}
// ClassAttribute must be numeric
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
panic(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
}
// Allocate return structure
_, rows := X.Size()
labelVec := make([]float64, rows)
// Resolve class Attribute specification
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
return true, nil
})
return labelVec
}
func (lr *LogisticRegression) Fit(X *base.Instances) {
func (lr *LogisticRegression) Fit(X base.FixedDataGrid) {
problemVec := convertInstancesToProblemVec(X)
labelVec := convertInstancesToLabelVec(X)
prob := NewProblem(problemVec, labelVec, 0)
lr.model = Train(prob, lr.param)
}
func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances {
ret := X.GeneratePredictionVector()
row := make([]float64, X.Cols-1)
for i := 0; i < X.Rows; i++ {
rowCounter := 0
for j := 0; j < X.Cols; j++ {
if j != X.ClassIndex {
row[rowCounter] = X.Get(i, j)
rowCounter++
}
}
base.Logger.Println(Predict(lr.model, row), row)
ret.Set(i, 0, Predict(lr.model, row))
func (lr *LogisticRegression) Predict(X base.FixedDataGrid) base.FixedDataGrid {
// Only support 1 class Attribute
classAttrs := X.AllClassAttributes()
if len(classAttrs) != 1 {
panic(fmt.Sprintf("%d Wrong number of classes", len(classAttrs)))
}
// Generate return structure
ret := base.GeneratePredictionVector(X)
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
// Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(X)
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
// Allocate row storage
row := make([]float64, len(numericAttrSpecs))
X.MapOverRows(numericAttrSpecs, func(rowBytes [][]byte, rowNo int) (bool, error) {
for i, r := range rowBytes {
row[i] = base.UnpackBytesToFloat(r)
}
val := Predict(lr.model, row)
vals := base.PackFloatToBytes(val)
ret.Set(classAttrSpecs[0], rowNo, vals)
return true, nil
})
return ret
}

View File

@ -21,23 +21,18 @@ type BaggedModel struct {
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
// the provided base.Instances.
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
func (b *BaggedModel) generateTrainingAttrs(model int, from base.FixedDataGrid) []base.Attribute {
ret := make([]base.Attribute, 0)
attrs := base.NonClassAttributes(from)
if b.RandomFeatures == 0 {
for j := 0; j < from.Cols; j++ {
attr := from.GetAttr(j)
ret = append(ret, attr)
}
ret = attrs
} else {
for {
if len(ret) >= b.RandomFeatures {
break
}
attrIndex := rand.Intn(from.Cols)
if attrIndex == from.ClassIndex {
continue
}
attr := from.GetAttr(attrIndex)
attrIndex := rand.Intn(len(attrs))
attr := attrs[attrIndex]
matched := false
for _, a := range ret {
if a.Equals(attr) {
@ -50,7 +45,9 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
}
}
}
ret = append(ret, from.GetClassAttr())
for _, a := range from.AllClassAttributes() {
ret = append(ret, a)
}
b.lock.Lock()
b.selectedAttributes[model] = ret
b.lock.Unlock()
@ -60,18 +57,19 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
// generatePredictionInstances returns a modified version of the
// requested base.Instances with only the base.Attributes selected
// for training the model.
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
func (b *BaggedModel) generatePredictionInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
selected := b.selectedAttributes[model]
return from.SelectAttributes(selected)
return base.NewInstancesViewFromAttrs(from, selected)
}
// generateTrainingInstances generates RandomFeatures number of
// attributes and returns a modified version of base.Instances
// for training the model
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
insts := from.SampleWithReplacement(from.Rows)
func (b *BaggedModel) generateTrainingInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
_, rows := from.Size()
insts := base.SampleWithReplacement(from, rows)
selected := b.generateTrainingAttrs(model, from)
return insts.SelectAttributes(selected)
return base.NewInstancesViewFromAttrs(insts, selected)
}
// AddModel adds a base.Classifier to the current model
@ -81,12 +79,12 @@ func (b *BaggedModel) AddModel(m base.Classifier) {
// Fit generates and trains each model on a randomised subset of
// Instances.
func (b *BaggedModel) Fit(from *base.Instances) {
func (b *BaggedModel) Fit(from base.FixedDataGrid) {
var wait sync.WaitGroup
b.selectedAttributes = make(map[int][]base.Attribute)
for i, m := range b.Models {
wait.Add(1)
go func(c base.Classifier, f *base.Instances, model int) {
go func(c base.Classifier, f base.FixedDataGrid, model int) {
l := b.generateTrainingInstances(model, f)
c.Fit(l)
wait.Done()
@ -100,10 +98,10 @@ func (b *BaggedModel) Fit(from *base.Instances) {
//
// IMPORTANT: in the event of a tie, the first class which
// achieved the tie value is output.
func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
n := runtime.NumCPU()
// Channel to receive the results as they come in
votes := make(chan *base.Instances, n)
votes := make(chan base.DataGrid, n)
// Count the votes for each class
voting := make(map[int](map[string]int))
@ -111,21 +109,20 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
var votingwait sync.WaitGroup
votingwait.Add(1)
go func() {
for {
for { // Need to resolve the voting problem
incoming, ok := <-votes
if ok {
// Step through each prediction
for j := 0; j < incoming.Rows; j++ {
cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
// Check if we've seen this class before...
if _, ok := voting[j]; !ok {
if _, ok := voting[predRow]; !ok {
// If we haven't, create an entry
voting[j] = make(map[string]int)
voting[predRow] = make(map[string]int)
// Continue on the current row
j--
continue
}
voting[j][incoming.GetClass(j)]++
}
voting[predRow][base.GetClass(incoming, predRow)]++
return true, nil
})
} else {
votingwait.Done()
break
@ -162,7 +159,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
votingwait.Wait() // All the votes are in
// Generate the overall consensus
ret := from.GeneratePredictionVector()
ret := base.GeneratePredictionVector(from)
for i := range voting {
maxClass := ""
maxCount := 0
@ -174,7 +171,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
maxCount = votes
}
}
ret.SetAttrStr(i, 0, maxClass)
base.SetClass(ret, i, maxClass)
}
return ret
}

View File

@ -19,16 +19,18 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
testEnv.ResetTimer()
for i := 0; i < 20; i++ {
rf.Fit(inst)
rf.Fit(instf)
}
}
@ -40,17 +42,19 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(inst)
rf.Fit(instf)
testEnv.ResetTimer()
for i := 0; i < 20; i++ {
rf.Predict(inst)
rf.Predict(instf)
}
}
@ -63,19 +67,21 @@ func TestRandomForest1(testEnv *testing.T) {
rand.Seed(time.Now().UnixNano())
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(testData)
filt.Run(trainData)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainData)
rf.Fit(trainDataf)
fmt.Println(rf)
predictions := rf.Predict(testData)
predictions := rf.Predict(testDataf)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))

13
meta/meta.go Normal file
View File

@ -0,0 +1,13 @@
/*
Meta contains base.Classifier implementations which
combine the outputs of others defined elsewhere.
Bagging:
Bootstraps samples of the original training set
with a number of selected attributes, and uses
that to train an ensemble of models. Predictions
are generated via majority voting.
*/
package meta

View File

@ -1,8 +1,9 @@
package naive
import (
"math"
base "github.com/sjwhitworth/golearn/base"
"fmt"
base "github.com/sjwhitworth/golearn/base"
"math"
)
// A Bernoulli Naive Bayes Classifier. Naive Bayes classifiers assumes
@ -37,91 +38,108 @@ import (
// Information Retrieval. Cambridge University Press, pp. 234-265.
// http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
type BernoulliNBClassifier struct {
base.BaseEstimator
// Conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = condProb[c][f].
// Logarithm is used in order to avoid underflow.
condProb map[string][]float64
// Number of instances in each class. This is necessary in order to
// calculate the laplace smooth value during the Predict step.
classInstances map[string]int
// Number of instances used in training.
trainingInstances int
// Number of features in the training set
features int
base.BaseEstimator
// Conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = condProb[c][f].
// Logarithm is used in order to avoid underflow.
condProb map[string][]float64
// Number of instances in each class. This is necessary in order to
// calculate the laplace smooth value during the Predict step.
classInstances map[string]int
// Number of instances used in training.
trainingInstances int
// Number of features used in training
features int
// Attributes used to Train
attrs []base.Attribute
}
// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
// is the number of possible labels in the classification task.
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
nb := BernoulliNBClassifier{}
nb.condProb = make(map[string][]float64)
nb.features = 0
nb.trainingInstances = 0
return &nb
nb := BernoulliNBClassifier{}
nb.condProb = make(map[string][]float64)
nb.features = 0
nb.trainingInstances = 0
return &nb
}
// Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
// Number of features and instances in this training set
nb.trainingInstances = X.Rows
nb.features = 0
if X.Rows > 0 {
nb.features = len(X.GetRowVectorWithoutClass(0))
}
// Check that all Attributes are binary
classAttrs := X.AllClassAttributes()
allAttrs := X.AllAttributes()
featAttrs := base.AttributeDifference(allAttrs, classAttrs)
for i := range featAttrs {
if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
}
}
featAttrSpecs := base.ResolveAttributes(X, featAttrs)
// Number of instances in class
nb.classInstances = make(map[string]int)
// Check that only one classAttribute is defined
if len(classAttrs) != 1 {
panic("Only one class Attribute can be used")
}
// Number of documents with given term (by class)
docsContainingTerm := make(map[string][]int)
// Number of features and instances in this training set
_, nb.trainingInstances = X.Size()
nb.attrs = featAttrs
nb.features = len(featAttrs)
// This algorithm could be vectorized after binarizing the data
// matrix. Since mat64 doesn't have this function, a iterative
// version is used.
for r := 0; r < X.Rows; r++ {
class := X.GetClass(r)
docVector := X.GetRowVectorWithoutClass(r)
// Number of instances in class
nb.classInstances = make(map[string]int)
// increment number of instances in class
t, ok := nb.classInstances[class]
if !ok { t = 0 }
nb.classInstances[class] = t + 1
// Number of documents with given term (by class)
docsContainingTerm := make(map[string][]int)
// This algorithm could be vectorized after binarizing the data
// matrix. Since mat64 doesn't have this function, a iterative
// version is used.
X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
class := base.GetClass(X, r)
for feat := 0; feat < len(docVector); feat++ {
v := docVector[feat]
// In Bernoulli Naive Bayes the presence and absence of
// features are considered. All non-zero values are
// treated as presence.
if v > 0 {
// Update number of times this feature appeared within
// given label.
t, ok := docsContainingTerm[class]
if !ok {
t = make([]int, nb.features)
docsContainingTerm[class] = t
}
t[feat] += 1
}
}
}
// increment number of instances in class
t, ok := nb.classInstances[class]
if !ok {
t = 0
}
nb.classInstances[class] = t + 1
// Pre-calculate conditional probabilities for each class
for c, _ := range nb.classInstances {
nb.condProb[c] = make([]float64, nb.features)
for feat := 0; feat < nb.features; feat++ {
classTerms, _ := docsContainingTerm[c]
numDocs := classTerms[feat]
docsInClass, _ := nb.classInstances[c]
for feat := 0; feat < len(docVector); feat++ {
v := docVector[feat]
// In Bernoulli Naive Bayes the presence and absence of
// features are considered. All non-zero values are
// treated as presence.
if v[0] > 0 {
// Update number of times this feature appeared within
// given label.
t, ok := docsContainingTerm[class]
if !ok {
t = make([]int, nb.features)
docsContainingTerm[class] = t
}
t[feat] += 1
}
}
return true, nil
})
classCondProb, _ := nb.condProb[c]
// Calculate conditional probability with laplace smoothing
classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1)
}
}
// Pre-calculate conditional probabilities for each class
for c, _ := range nb.classInstances {
nb.condProb[c] = make([]float64, nb.features)
for feat := 0; feat < nb.features; feat++ {
classTerms, _ := docsContainingTerm[c]
numDocs := classTerms[feat]
docsInClass, _ := nb.classInstances[c]
classCondProb, _ := nb.condProb[c]
// Calculate conditional probability with laplace smoothing
classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
}
}
}
// Use trained model to predict test vector's class. The following
@ -133,54 +151,61 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
//
// IMPORTANT: PredictOne panics if Fit was not called or if the
// document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
if nb.features == 0 {
panic("Fit should be called before predicting")
}
func (nb *BernoulliNBClassifier) PredictOne(vector [][]byte) string {
if nb.features == 0 {
panic("Fit should be called before predicting")
}
if len(vector) != nb.features {
panic("Different dimensions in Train and Test sets")
}
if len(vector) != nb.features {
panic("Different dimensions in Train and Test sets")
}
// Currently only the predicted class is returned.
bestScore := -math.MaxFloat64
bestClass := ""
// Currently only the predicted class is returned.
bestScore := -math.MaxFloat64
bestClass := ""
for class, classCount := range nb.classInstances {
// Init classScore with log(prior)
classScore := math.Log((float64(classCount))/float64(nb.trainingInstances))
for f := 0; f < nb.features; f++ {
if vector[f] > 0 {
// Test document has feature c
classScore += math.Log(nb.condProb[class][f])
} else {
if nb.condProb[class][f] == 1.0 {
// special case when prob = 1.0, consider laplace
// smooth
classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1))
} else {
classScore += math.Log(1.0 - nb.condProb[class][f])
}
}
}
for class, classCount := range nb.classInstances {
// Init classScore with log(prior)
classScore := math.Log((float64(classCount)) / float64(nb.trainingInstances))
for f := 0; f < nb.features; f++ {
if vector[f][0] > 0 {
// Test document has feature c
classScore += math.Log(nb.condProb[class][f])
} else {
if nb.condProb[class][f] == 1.0 {
// special case when prob = 1.0, consider laplace
// smooth
classScore += math.Log(1.0 / float64(nb.classInstances[class]+1))
} else {
classScore += math.Log(1.0 - nb.condProb[class][f])
}
}
}
if classScore > bestScore {
bestScore = classScore
bestClass = class
}
}
if classScore > bestScore {
bestScore = classScore
bestClass = class
}
}
return bestClass
return bestClass
}
// Predict is just a wrapper for the PredictOne function.
//
// IMPORTANT: Predict panics if Fit was not called or if the
// document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
ret := what.GeneratePredictionVector()
for i := 0; i < what.Rows; i++ {
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
}
return ret
func (nb *BernoulliNBClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
// Generate return vector
ret := base.GeneratePredictionVector(what)
// Get the features
featAttrSpecs := base.ResolveAttributes(what, nb.attrs)
what.MapOverRows(featAttrSpecs, func(row [][]byte, i int) (bool, error) {
base.SetClass(ret, i, nb.PredictOne(row))
return true, nil
})
return ret
}

View File

@ -1,96 +1,109 @@
package naive
import (
"github.com/sjwhitworth/golearn/base"
"testing"
. "github.com/smartystreets/goconvey/convey"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/filters"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestNoFit(t *testing.T) {
Convey("Given an empty BernoulliNaiveBayes", t, func() {
nb := NewBernoulliNBClassifier()
Convey("Given an empty BernoulliNaiveBayes", t, func() {
nb := NewBernoulliNBClassifier()
Convey("PredictOne should panic if Fit was not called", func() {
testDoc := []float64{0.0, 1.0}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
})
Convey("PredictOne should panic if Fit was not called", func() {
testDoc := [][]byte{[]byte{0}, []byte{1}}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
})
}
func convertToBinary(src base.FixedDataGrid) base.FixedDataGrid {
// Convert to binary
b := filters.NewBinaryConvertFilter()
attrs := base.NonClassAttributes(src)
for _, a := range attrs {
b.AddAttribute(a)
}
b.Train()
ret := base.NewLazilyFilteredInstances(src, b)
return ret
}
func TestSimple(t *testing.T) {
Convey("Given a simple training data", t, func() {
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
if err1 != nil {
t.Error(err1)
}
Convey("Given a simple training data", t, func() {
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
if err1 != nil {
t.Error(err1)
}
nb := NewBernoulliNBClassifier()
nb.Fit(trainingData)
nb := NewBernoulliNBClassifier()
nb.Fit(convertToBinary(trainingData))
Convey("Check if Fit is working as expected", func() {
Convey("All data needed for prior should be correctly calculated", func() {
So(nb.classInstances["blue"], ShouldEqual, 2)
So(nb.classInstances["red"], ShouldEqual, 2)
So(nb.trainingInstances, ShouldEqual, 4)
})
Convey("Check if Fit is working as expected", func() {
Convey("All data needed for prior should be correctly calculated", func() {
So(nb.classInstances["blue"], ShouldEqual, 2)
So(nb.classInstances["red"], ShouldEqual, 2)
So(nb.trainingInstances, ShouldEqual, 4)
})
Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["red"][0]
logCondProbTok1 := nb.condProb["red"][1]
logCondProbTok2 := nb.condProb["red"][2]
Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["red"][0]
logCondProbTok1 := nb.condProb["red"][1]
logCondProbTok2 := nb.condProb["red"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
})
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
})
Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["blue"][0]
logCondProbTok1 := nb.condProb["blue"][1]
logCondProbTok2 := nb.condProb["blue"][2]
Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["blue"][0]
logCondProbTok1 := nb.condProb["blue"][1]
logCondProbTok2 := nb.condProb["blue"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
})
})
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
})
})
Convey("PredictOne should work as expected", func() {
Convey("Using a document with different number of cols should panic", func() {
testDoc := []float64{0.0, 2.0}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
Convey("PredictOne should work as expected", func() {
Convey("Using a document with different number of cols should panic", func() {
testDoc := [][]byte{[]byte{0}, []byte{2}}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
Convey("Token 1 should be a good predictor of the blue class", func() {
testDoc := []float64{0.0, 123.0, 0.0}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
Convey("Token 1 should be a good predictor of the blue class", func() {
testDoc := [][]byte{[]byte{0}, []byte{1}, []byte{0}}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
testDoc = []float64{120.0, 123.0, 0.0}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
})
testDoc = [][]byte{[]byte{1}, []byte{1}, []byte{0}}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
})
Convey("Token 2 should be a good predictor of the red class", func() {
testDoc := []float64{0.0, 0.0, 120.0}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
Convey("Token 2 should be a good predictor of the red class", func() {
testDoc := [][]byte{[]byte{0}, []byte{0}, []byte{1}}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
testDoc = [][]byte{[]byte{1}, []byte{0}, []byte{1}}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
})
})
testDoc = []float64{10.0, 0.0, 120.0}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
})
})
Convey("Predict should work as expected", func() {
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
if err != nil {
t.Error(err)
}
Convey("Predict should work as expected", func() {
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
if err != nil {
t.Error(err)
}
predictions := nb.Predict(testData)
predictions := nb.Predict(convertToBinary(testData))
Convey("All simple predicitions should be correct", func() {
So(predictions.GetClass(0), ShouldEqual, "blue")
So(predictions.GetClass(1), ShouldEqual, "red")
So(predictions.GetClass(2), ShouldEqual, "blue")
So(predictions.GetClass(3), ShouldEqual, "red")
})
})
})
Convey("All simple predicitions should be correct", func() {
So(base.GetClass(predictions, 0), ShouldEqual, "blue")
So(base.GetClass(predictions, 1), ShouldEqual, "red")
So(base.GetClass(predictions, 2), ShouldEqual, "blue")
So(base.GetClass(predictions, 3), ShouldEqual, "red")
})
})
})
}

View File

@ -17,35 +17,40 @@ type InformationGainRuleGenerator struct {
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
allAttributes := make([]int, 0)
for i := 0; i < f.Cols; i++ {
if i != f.ClassIndex {
allAttributes = append(allAttributes, i)
}
}
return r.GetSplitAttributeFromSelection(allAttributes, f)
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
attrs := f.AllAttributes()
classAttrs := f.AllClassAttributes()
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
return r.GetSplitAttributeFromSelection(candidates, f)
}
// GetSplitAttributeFromSelection returns the class Attribute which maximises
// the information gain amongst consideredAttributes
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute {
var selectedAttribute base.Attribute
// Parameter check
if len(consideredAttributes) == 0 {
panic("More Attributes should be considered")
}
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxGain := math.Inf(-1)
selectedAttribute := -1
// Compute the base entropy
classDist := f.GetClassDistribution()
classDist := base.GetClassDistribution(f)
baseEntropy := getBaseEntropy(classDist)
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
localEntropy := getSplitEntropy(proposedClassDist)
informationGain := baseEntropy - localEntropy
if informationGain > maxGain {
@ -55,7 +60,7 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
}
// Pick the one which maximises IG
return f.GetAttr(selectedAttribute)
return selectedAttribute
}
//

View File

@ -21,7 +21,7 @@ const (
// RuleGenerator implementations analyse instances and determine
// the best value to split on
type RuleGenerator interface {
GenerateSplitAttribute(*base.Instances) base.Attribute
GenerateSplitAttribute(base.FixedDataGrid) base.Attribute
}
// DecisionTreeNode represents a given portion of a decision tree
@ -31,14 +31,19 @@ type DecisionTreeNode struct {
SplitAttr base.Attribute
ClassDist map[string]int
Class string
ClassAttr *base.Attribute
ClassAttr base.Attribute
}
func getClassAttr(from base.FixedDataGrid) base.Attribute {
allClassAttrs := from.AllClassAttributes()
return allClassAttrs[0]
}
// InferID3Tree builds a decision tree using a RuleGenerator
// from a set of Instances (implements the ID3 algorithm)
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode {
// Count the number of classes at this node
classes := from.CountClassValues()
classes := base.GetClassDistribution(from)
// If there's only one class, return a DecisionTreeLeaf with
// the only class available
if len(classes) == 1 {
@ -52,7 +57,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
getClassAttr(from),
}
return ret
}
@ -69,28 +74,29 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
// If there are no more Attributes left to split on,
// return a DecisionTreeLeaf with the majority class
if from.GetAttributeCount() == 2 {
cols, _ := from.Size()
if cols == 2 {
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
getClassAttr(from),
}
return ret
}
// Generate a return structure
ret := &DecisionTreeNode{
RuleNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
getClassAttr(from),
}
// Generate a return structure
// Generate the splitting attribute
splitOnAttribute := with.GenerateSplitAttribute(from)
if splitOnAttribute == nil {
@ -98,7 +104,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
return ret
}
// Split the attributes based on this attribute's value
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
splitInstances := base.DecomposeOnAttributeValues(from, splitOnAttribute)
// Create new children from these attributes
ret.Children = make(map[string]*DecisionTreeNode)
for k := range splitInstances {
@ -146,13 +152,13 @@ func (d *DecisionTreeNode) String() string {
}
// computeAccuracy is a helper method for Prune()
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 {
cf := eval.GetConfusionMatrix(from, predictions)
return eval.GetAccuracy(cf)
}
// Prune eliminates branches which hurt accuracy
func (d *DecisionTreeNode) Prune(using *base.Instances) {
func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
// If you're a leaf, you're already pruned
if d.Children == nil {
return
@ -162,11 +168,15 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
}
// Recursively prune children of this node
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
sub := base.DecomposeOnAttributeValues(using, d.SplitAttr)
for k := range d.Children {
if sub[k] == nil {
continue
}
subH, subV := sub[k].Size()
if subH == 0 || subV == 0 {
continue
}
d.Children[k].Prune(sub[k])
}
@ -185,24 +195,30 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
}
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
outputAttrs := make([]base.Attribute, 1)
outputAttrs[0] = what.GetClassAttr()
predictions := base.NewInstances(outputAttrs, what.Rows)
for i := 0; i < what.Rows; i++ {
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
predictions := base.GeneratePredictionVector(what)
classAttr := getClassAttr(predictions)
classAttrSpec, err := predictions.GetAttribute(classAttr)
if err != nil {
panic(err)
}
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
predAttrSpecs := base.ResolveAttributes(what, predAttrs)
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
cur := d
for {
if cur.Children == nil {
predictions.SetAttrStr(i, 0, cur.Class)
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
break
} else {
at := cur.SplitAttr
j := what.GetAttrIndex(at)
if j == -1 {
predictions.SetAttrStr(i, 0, cur.Class)
ats, err := what.GetAttribute(at)
if err != nil {
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
break
}
classVar := at.GetStringFromSysVal(what.Get(i, j))
classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
@ -217,7 +233,8 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
}
}
}
}
return true, nil
})
return predictions
}
@ -245,7 +262,7 @@ func NewID3DecisionTree(prune float64) *ID3DecisionTree {
}
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on *base.Instances) {
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
rule := new(InformationGainRuleGenerator)
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
@ -257,7 +274,7 @@ func (t *ID3DecisionTree) Fit(on *base.Instances) {
}
// Predict outputs predictions from the ID3 decision tree
func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances {
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) base.FixedDataGrid {
return t.Root.Predict(what)
}

View File

@ -14,32 +14,32 @@ type RandomTreeRuleGenerator struct {
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
// which maximises Information Gain
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
// First step is to generate the random attributes that we'll consider
maximumAttribute := f.GetAttributeCount()
consideredAttributes := make([]int, r.Attributes)
allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
maximumAttribute := len(allAttributes)
consideredAttributes := make([]base.Attribute, 0)
attrCounter := 0
for {
if len(consideredAttributes) >= r.Attributes {
break
}
selectedAttribute := rand.Intn(maximumAttribute)
base.Logger.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
if selectedAttribute != f.ClassIndex {
matched := false
for _, a := range consideredAttributes {
if a == selectedAttribute {
matched = true
break
}
selectedAttrIndex := rand.Intn(maximumAttribute)
selectedAttribute := allAttributes[selectedAttrIndex]
matched := false
for _, a := range consideredAttributes {
if a.Equals(selectedAttribute) {
matched = true
break
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
@ -67,12 +67,12 @@ func NewRandomTree(attrs int) *RandomTree {
}
// Fit builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from *base.Instances) {
func (rt *RandomTree) Fit(from base.FixedDataGrid) {
rt.Root = InferID3Tree(from, rt.Rule)
}
// Predict returns a set of Instances containing predictions
func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
func (rt *RandomTree) Predict(from base.FixedDataGrid) base.FixedDataGrid {
return rt.Root.Predict(from)
}
@ -83,6 +83,6 @@ func (rt *RandomTree) String() string {
// Prune removes nodes from the tree which are detrimental
// to determining the accuracy of the test set (with)
func (rt *RandomTree) Prune(with *base.Instances) {
func (rt *RandomTree) Prune(with base.FixedDataGrid) {
rt.Root.Prune(with)
}

View File

@ -14,15 +14,17 @@ func TestRandomTree(testEnv *testing.T) {
if err != nil {
panic(err)
}
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(inst, r)
fmt.Println(instf)
root := InferID3Tree(instf, r)
fmt.Println(root)
}
@ -33,18 +35,20 @@ func TestRandomTreeClassification(testEnv *testing.T) {
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(trainData)
filt.Run(testData)
fmt.Println(inst)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(trainData, r)
root := InferID3Tree(trainDataF, r)
fmt.Println(root)
predictions := root.Predict(testData)
predictions := root.Predict(testDataF)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
@ -58,17 +62,19 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
root := NewRandomTree(2)
root.Fit(trainData)
root.Fit(trainDataF)
fmt.Println(root)
predictions := root.Predict(testData)
predictions := root.Predict(testDataF)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
@ -82,19 +88,21 @@ func TestPruning(testEnv *testing.T) {
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
root := NewRandomTree(2)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
root.Fit(fittrainData)
root.Prune(fittestData)
fmt.Println(root)
predictions := root.Predict(testData)
predictions := root.Predict(testDataF)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
@ -142,6 +150,7 @@ func TestID3Inference(testEnv *testing.T) {
testEnv.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
fmt.Println(rainyChild.SplitAttr)
testEnv.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
@ -156,7 +165,6 @@ func TestID3Inference(testEnv *testing.T) {
if sunnyLeafNormal.Class != "yes" {
testEnv.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
@ -176,12 +184,18 @@ func TestID3Classification(testEnv *testing.T) {
if err != nil {
panic(err)
}
filt := filters.NewBinningFilter(inst, 10)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
trainData, testData := base.InstancesTrainTestSplit(inst, 0.70)
filt := filters.NewBinningFilter(inst, 10)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
fmt.Println(filt)
instf := base.NewLazilyFilteredInstances(inst, filt)
fmt.Println("INSTFA", instf.AllAttributes())
fmt.Println("INSTF", instf)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(trainData, rule)
@ -199,6 +213,7 @@ func TestID3(testEnv *testing.T) {
// Import the "PlayTennis" dataset
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
fmt.Println(inst)
if err != nil {
panic(err)
}