mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
Optimised version of KNN for Euclidean distances
This patch also: * Completes removal of the edf/ package * Corrects an erroneous print statement * Introduces two new CSV functions * ParseCSVToInstancesTemplated makes sure that reading a second CSV file maintains strict Attribute compatibility with an existing DenseInstances * ParseCSVToInstancesWithAttributeGroups gives more control over where Attributes end up in memory, important for gaining predictable control over the KNN optimisation * Decouples BinaryAttributeGroup from FixedAttributeGroup for better casting support
This commit is contained in:
parent
8f1bc62401
commit
527c6476e1
85
base/bag.go
85
base/bag.go
@ -1,33 +1,75 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// BinaryAttributeGroups contain only BinaryAttributes
|
||||
// Compact each Attribute to a bit for better storage
|
||||
type BinaryAttributeGroup struct {
|
||||
FixedAttributeGroup
|
||||
parent DataGrid
|
||||
attributes []Attribute
|
||||
size int
|
||||
alloc []byte
|
||||
maxRow int
|
||||
}
|
||||
|
||||
// String returns a human-readable summary.
|
||||
func (b *BinaryAttributeGroup) String() string {
|
||||
return "BinaryAttributeGroup"
|
||||
}
|
||||
|
||||
func (b *BinaryAttributeGroup) RowSize() int {
|
||||
// RowSizeInBytes returns the size of each row in bytes
|
||||
// (rounded up to nearest byte).
|
||||
func (b *BinaryAttributeGroup) RowSizeInBytes() int {
|
||||
return (len(b.attributes) + 7) / 8
|
||||
}
|
||||
|
||||
// Attributes returns a slice of Attributes in this BinaryAttributeGroup.
|
||||
func (b *BinaryAttributeGroup) Attributes() []Attribute {
|
||||
ret := make([]Attribute, len(b.attributes))
|
||||
for i, a := range b.attributes {
|
||||
ret[i] = a
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddAttribute adds an Attribute to this BinaryAttributeGroup
|
||||
func (b *BinaryAttributeGroup) AddAttribute(a Attribute) error {
|
||||
b.attributes = append(b.attributes, a)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Storage returns a reference to the underlying storage.
|
||||
//
|
||||
// IMPORTANT: don't modify
|
||||
func (b *BinaryAttributeGroup) Storage() []byte {
|
||||
return b.alloc
|
||||
}
|
||||
|
||||
//
|
||||
// internal methods
|
||||
//
|
||||
|
||||
func (b *BinaryAttributeGroup) setStorage(a []byte) {
|
||||
b.alloc = a
|
||||
}
|
||||
|
||||
func (b *BinaryAttributeGroup) getByteOffset(col, row int) int {
|
||||
return row*b.RowSize() + col/8
|
||||
return row*b.RowSizeInBytes() + col/8
|
||||
}
|
||||
|
||||
func (b *BinaryAttributeGroup) set(col, row int, val []byte) {
|
||||
// Resolve the block
|
||||
curBlock, blockOffset := b.resolveBlock(col, row)
|
||||
|
||||
offset := b.getByteOffset(col, row)
|
||||
|
||||
// If the value is 1, OR it
|
||||
if val[0] > 0 {
|
||||
b.alloc[curBlock][blockOffset] |= (1 << (uint(col) % 8))
|
||||
b.alloc[offset] |= (1 << (uint(col) % 8))
|
||||
} else {
|
||||
// Otherwise, AND its complement
|
||||
b.alloc[curBlock][blockOffset] &= ^(1 << (uint(col) % 8))
|
||||
b.alloc[offset] &= ^(1 << (uint(col) % 8))
|
||||
}
|
||||
|
||||
row++
|
||||
@ -36,19 +78,28 @@ func (b *BinaryAttributeGroup) set(col, row int, val []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BinaryAttributeGroup) resolveBlock(col, row int) (int, int) {
|
||||
|
||||
byteOffset := row*b.RowSize() + (col / 3)
|
||||
rowSize := b.RowSize()
|
||||
return b.FixedAttributeGroup.resolveBlockFromByteOffset(byteOffset, rowSize)
|
||||
|
||||
}
|
||||
|
||||
func (b *BinaryAttributeGroup) get(col, row int) []byte {
|
||||
curBlock, blockOffset := b.resolveBlock(col, row)
|
||||
if b.alloc[curBlock][blockOffset]&(1<<(uint(col%8))) > 0 {
|
||||
offset := b.getByteOffset(col, row)
|
||||
if b.alloc[offset]&(1<<(uint(col%8))) > 0 {
|
||||
return []byte{1}
|
||||
} else {
|
||||
return []byte{0}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BinaryAttributeGroup) appendToRowBuf(row int, buffer *bytes.Buffer) {
|
||||
for i, a := range b.attributes {
|
||||
postfix := " "
|
||||
if i == len(b.attributes)-1 {
|
||||
postfix = ""
|
||||
}
|
||||
buffer.WriteString(fmt.Sprintf("%s%s",
|
||||
a.GetStringFromSysVal(b.get(i, row)), postfix))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BinaryAttributeGroup) resize(add int) {
|
||||
newAlloc := make([]byte, len(b.alloc)+add)
|
||||
copy(newAlloc, b.alloc)
|
||||
b.alloc = newAlloc
|
||||
}
|
||||
|
218
base/csv.go
218
base/csv.go
@ -210,3 +210,221 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInst
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// ParseCSVToInstancesTemplated reads the CSV file given by filepath and returns
|
||||
// the read Instances, using another already read DenseInstances as a template.
|
||||
func ParseCSVToTemplatedInstances(filepath string, hasHeaders bool, template *DenseInstances) (instances *DenseInstances, err error) {
|
||||
|
||||
// Read the number of rows in the file
|
||||
rowCount, err := ParseCSVGetRows(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hasHeaders {
|
||||
rowCount--
|
||||
}
|
||||
|
||||
// Read the row headers
|
||||
attrs := ParseCSVGetAttributes(filepath, hasHeaders)
|
||||
templateAttrs := template.AllAttributes()
|
||||
for i, a := range attrs {
|
||||
for _, b := range templateAttrs {
|
||||
if a.Equals(b) {
|
||||
attrs[i] = b
|
||||
} else if a.GetName() == b.GetName() {
|
||||
attrs[i] = b
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
specs := make([]AttributeSpec, len(attrs))
|
||||
// Allocate the Instances to return
|
||||
instances = NewDenseInstances()
|
||||
|
||||
templateAgs := template.AllAttributeGroups()
|
||||
for ag := range templateAgs {
|
||||
agTemplate := templateAgs[ag]
|
||||
if _, ok := agTemplate.(*BinaryAttributeGroup); ok {
|
||||
instances.CreateAttributeGroup(ag, 0)
|
||||
} else {
|
||||
instances.CreateAttributeGroup(ag, 8)
|
||||
}
|
||||
}
|
||||
|
||||
for i, a := range templateAttrs {
|
||||
s, err := template.GetAttribute(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if ag, ok := template.agRevMap[s.pond]; !ok {
|
||||
panic(ag)
|
||||
} else {
|
||||
spec, err := instances.AddAttributeToAttributeGroup(a, ag)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
specs[i] = spec
|
||||
}
|
||||
}
|
||||
|
||||
instances.Extend(rowCount)
|
||||
|
||||
// Read the input
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
reader := csv.NewReader(file)
|
||||
|
||||
rowCounter := 0
|
||||
|
||||
for {
|
||||
record, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rowCounter == 0 {
|
||||
if hasHeaders {
|
||||
hasHeaders = false
|
||||
continue
|
||||
}
|
||||
}
|
||||
for i, v := range record {
|
||||
v = strings.Trim(v, " ")
|
||||
instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v))
|
||||
}
|
||||
rowCounter++
|
||||
}
|
||||
|
||||
for _, a := range template.AllClassAttributes() {
|
||||
instances.AddClassAttribute(a)
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// ParseCSVToInstancesWithAttributeGroups reads the CSV file given by filepath,
|
||||
// and returns the read DenseInstances, but also makes sure to group any Attributes
|
||||
// specified in the first argument and also any class Attributes specified in the second
|
||||
func ParseCSVToInstancesWithAttributeGroups(filepath string, attrGroups, classAttrGroups map[string]string, attrOverrides map[int]Attribute, hasHeaders bool) (instances *DenseInstances, err error) {
|
||||
|
||||
// Read row count
|
||||
rowCount, err := ParseCSVGetRows(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the row headers
|
||||
attrs := ParseCSVGetAttributes(filepath, hasHeaders)
|
||||
for i := range attrs {
|
||||
if a, ok := attrOverrides[i]; ok {
|
||||
attrs[i] = a
|
||||
}
|
||||
}
|
||||
|
||||
specs := make([]AttributeSpec, len(attrs))
|
||||
// Allocate the Instances to return
|
||||
instances = NewDenseInstances()
|
||||
|
||||
//
|
||||
// Create all AttributeGroups
|
||||
agsToCreate := make(map[string]int)
|
||||
combinedAgs := make(map[string]string)
|
||||
for a := range attrGroups {
|
||||
agsToCreate[attrGroups[a]] = 0
|
||||
combinedAgs[a] = attrGroups[a]
|
||||
}
|
||||
for a := range classAttrGroups {
|
||||
agsToCreate[classAttrGroups[a]] = 8
|
||||
combinedAgs[a] = classAttrGroups[a]
|
||||
}
|
||||
|
||||
// Decide the sizes
|
||||
for _, a := range attrs {
|
||||
if ag, ok := combinedAgs[a.GetName()]; ok {
|
||||
if _, ok := a.(*BinaryAttribute); ok {
|
||||
agsToCreate[ag] = 0
|
||||
} else {
|
||||
agsToCreate[ag] = 8
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create them
|
||||
for i := range agsToCreate {
|
||||
size := agsToCreate[i]
|
||||
err = instances.CreateAttributeGroup(i, size)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the Attributes to them
|
||||
for i, a := range attrs {
|
||||
var spec AttributeSpec
|
||||
if ag, ok := combinedAgs[a.GetName()]; ok {
|
||||
spec, err = instances.AddAttributeToAttributeGroup(a, ag)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
specs[i] = spec
|
||||
} else {
|
||||
spec = instances.AddAttribute(a)
|
||||
}
|
||||
specs[i] = spec
|
||||
if _, ok := classAttrGroups[a.GetName()]; ok {
|
||||
err = instances.AddClassAttribute(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Allocate
|
||||
instances.Extend(rowCount)
|
||||
|
||||
// Read the input
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
reader := csv.NewReader(file)
|
||||
|
||||
rowCounter := 0
|
||||
|
||||
for {
|
||||
record, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rowCounter == 0 {
|
||||
// Skip header row
|
||||
rowCounter++
|
||||
continue
|
||||
}
|
||||
for i, v := range record {
|
||||
v = strings.Trim(v, " ")
|
||||
instances.Set(specs[i], rowCounter, attrs[i].GetSysValFromString(v))
|
||||
}
|
||||
rowCounter++
|
||||
}
|
||||
|
||||
// Add class Attributes
|
||||
for _, a := range instances.AllAttributes() {
|
||||
name := a.GetName() // classAttrGroups
|
||||
if _, ok := classAttrGroups[name]; ok {
|
||||
err = instances.AddClassAttribute(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return instances, nil
|
||||
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
// in a large grid.
|
||||
type DenseInstances struct {
|
||||
agMap map[string]int
|
||||
agRevMap map[int]string
|
||||
ags []AttributeGroup
|
||||
lock sync.Mutex
|
||||
fixed bool
|
||||
@ -29,6 +30,7 @@ type DenseInstances struct {
|
||||
func NewDenseInstances() *DenseInstances {
|
||||
return &DenseInstances{
|
||||
make(map[string]int),
|
||||
make(map[int]string),
|
||||
make([]AttributeGroup, 0),
|
||||
sync.Mutex{},
|
||||
false,
|
||||
@ -62,17 +64,18 @@ func (inst *DenseInstances) createAttributeGroup(name string, size int) {
|
||||
ag.parent = inst
|
||||
ag.attributes = make([]Attribute, 0)
|
||||
ag.size = size
|
||||
ag.alloc = make([][]byte, 0)
|
||||
ag.alloc = make([]byte, 0)
|
||||
agAdd = ag
|
||||
} else {
|
||||
ag := new(BinaryAttributeGroup)
|
||||
ag.parent = inst
|
||||
ag.attributes = make([]Attribute, 0)
|
||||
ag.size = size
|
||||
ag.alloc = make([][]byte, 0)
|
||||
ag.alloc = make([]byte, 0)
|
||||
agAdd = ag
|
||||
}
|
||||
inst.agMap[name] = len(inst.ags)
|
||||
inst.agRevMap[len(inst.ags)] = name
|
||||
inst.ags = append(inst.ags, agAdd)
|
||||
}
|
||||
|
||||
@ -97,6 +100,15 @@ func (inst *DenseInstances) CreateAttributeGroup(name string, size int) (err err
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllAttributeGroups returns a copy of the available AttributeGroups
|
||||
func (inst *DenseInstances) AllAttributeGroups() map[string]AttributeGroup {
|
||||
ret := make(map[string]AttributeGroup)
|
||||
for a := range inst.agMap {
|
||||
ret[a] = inst.ags[inst.agMap[a]]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetAttributeGroup returns a reference to a AttributeGroup of a given name /
|
||||
func (inst *DenseInstances) GetAttributeGroup(name string) (AttributeGroup, error) {
|
||||
inst.lock.Lock()
|
||||
@ -167,14 +179,14 @@ func (inst *DenseInstances) AddAttribute(a Attribute) AttributeSpec {
|
||||
return AttributeSpec{id, len(p.Attributes()) - 1, a}
|
||||
}
|
||||
|
||||
// addAttributeToAttributeGroup adds an Attribute to a given ag
|
||||
func (inst *DenseInstances) addAttributeToAttributeGroup(newAttribute Attribute, ag string) (AttributeSpec, error) {
|
||||
// AddAttributeToAttributeGroup adds an Attribute to a given ag
|
||||
func (inst *DenseInstances) AddAttributeToAttributeGroup(newAttribute Attribute, ag string) (AttributeSpec, error) {
|
||||
inst.lock.Lock()
|
||||
defer inst.lock.Unlock()
|
||||
|
||||
// Check if the ag exists
|
||||
if _, ok := inst.agMap[ag]; !ok {
|
||||
return AttributeSpec{-1, 0, nil}, fmt.Errorf("Pond '%s' doesn't exist. Call CreatePond() first", ag)
|
||||
return AttributeSpec{-1, 0, nil}, fmt.Errorf("AttributeGroup '%s' doesn't exist. Call CreateAttributeGroup() first", ag)
|
||||
}
|
||||
|
||||
id := inst.agMap[ag]
|
||||
@ -341,14 +353,12 @@ func (inst *DenseInstances) Extend(rows int) error {
|
||||
for _, p := range inst.ags {
|
||||
|
||||
// Compute ag row storage requirements
|
||||
rowSize := p.RowSize()
|
||||
rowSize := p.RowSizeInBytes()
|
||||
|
||||
// How many bytes?
|
||||
allocSize := rows * rowSize
|
||||
|
||||
bytes := make([]byte, allocSize)
|
||||
|
||||
p.addStorage(bytes)
|
||||
p.resize(allocSize)
|
||||
|
||||
}
|
||||
inst.fixed = true
|
||||
|
35
base/dense_test.go
Normal file
35
base/dense_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHighDimensionalInstancesLoad(t *testing.T) {
|
||||
Convey("Given a high-dimensional dataset...", t, func() {
|
||||
_, err := ParseCSVToInstances("../examples/datasets/mnist_train.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
})
|
||||
}
|
||||
func TestHighDimensionalInstancesLoad2(t *testing.T) {
|
||||
Convey("Given a high-dimensional dataset...", t, func() {
|
||||
// Create the class Attribute
|
||||
classAttrs := make(map[int]Attribute)
|
||||
classAttrs[0] = NewCategoricalAttribute()
|
||||
classAttrs[0].SetName("Number")
|
||||
// Setup the class Attribute to be in its own group
|
||||
classAttrGroups := make(map[string]string)
|
||||
classAttrGroups["Number"] = "ClassGroup"
|
||||
// The rest can go in a default group
|
||||
attrGroups := make(map[string]string)
|
||||
|
||||
_, err := ParseCSVToInstancesWithAttributeGroups(
|
||||
"../examples/datasets/mnist_train.csv",
|
||||
attrGroups,
|
||||
classAttrGroups,
|
||||
classAttrs,
|
||||
true,
|
||||
)
|
||||
So(err, ShouldEqual, nil)
|
||||
})
|
||||
}
|
@ -11,7 +11,7 @@ type FixedAttributeGroup struct {
|
||||
parent DataGrid
|
||||
attributes []Attribute
|
||||
size int
|
||||
alloc [][]byte
|
||||
alloc []byte
|
||||
maxRow int
|
||||
}
|
||||
|
||||
@ -20,14 +20,19 @@ func (f *FixedAttributeGroup) String() string {
|
||||
return "FixedAttributeGroup"
|
||||
}
|
||||
|
||||
// RowSize returns the size of each row in bytes
|
||||
func (f *FixedAttributeGroup) RowSize() int {
|
||||
// RowSizeInBytes returns the size of each row in bytes
|
||||
func (f *FixedAttributeGroup) RowSizeInBytes() int {
|
||||
return len(f.attributes) * f.size
|
||||
}
|
||||
|
||||
// Attributes returns a slice of Attributes in this FixedAttributeGroup
|
||||
func (f *FixedAttributeGroup) Attributes() []Attribute {
|
||||
return f.attributes
|
||||
ret := make([]Attribute, len(f.attributes))
|
||||
// Add Attributes
|
||||
for i, a := range f.attributes {
|
||||
ret[i] = a
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddAttribute adds an attribute to this FixedAttributeGroup
|
||||
@ -37,56 +42,18 @@ func (f *FixedAttributeGroup) AddAttribute(a Attribute) error {
|
||||
}
|
||||
|
||||
// addStorage appends the given storage reference to this FixedAttributeGroup
|
||||
func (f *FixedAttributeGroup) addStorage(a []byte) {
|
||||
f.alloc = append(f.alloc, a)
|
||||
func (f *FixedAttributeGroup) setStorage(a []byte) {
|
||||
f.alloc = a
|
||||
}
|
||||
|
||||
// Storage returns a slice of FixedAttributeGroupStorageRefs which can
|
||||
// be used to access the memory in this pond.
|
||||
func (f *FixedAttributeGroup) Storage() []AttributeGroupStorageRef {
|
||||
ret := make([]AttributeGroupStorageRef, len(f.alloc))
|
||||
rowSize := f.RowSize()
|
||||
for i, b := range f.alloc {
|
||||
ret[i] = AttributeGroupStorageRef{b, len(b) / rowSize}
|
||||
}
|
||||
return ret
|
||||
func (f *FixedAttributeGroup) Storage() []byte {
|
||||
return f.alloc
|
||||
}
|
||||
|
||||
func (f *FixedAttributeGroup) resolveBlock(col int, row int) (int, int) {
|
||||
|
||||
if len(f.alloc) == 0 {
|
||||
panic("No blocks to resolve")
|
||||
}
|
||||
|
||||
// Find where in the pond the byte is
|
||||
byteOffset := row*f.RowSize() + col*f.size
|
||||
return f.resolveBlockFromByteOffset(byteOffset, f.RowSize())
|
||||
}
|
||||
|
||||
func (f *FixedAttributeGroup) resolveBlockFromByteOffset(byteOffset, rowSize int) (int, int) {
|
||||
curOffset := 0
|
||||
curBlock := 0
|
||||
blockOffset := 0
|
||||
for {
|
||||
if curBlock >= len(f.alloc) {
|
||||
panic("Don't have enough blocks to fulfill")
|
||||
}
|
||||
|
||||
// Rows are not allowed to span blocks
|
||||
blockAdd := len(f.alloc[curBlock])
|
||||
blockAdd -= blockAdd % rowSize
|
||||
|
||||
// Case 1: we need to skip this allocation
|
||||
if curOffset+blockAdd < byteOffset {
|
||||
curOffset += blockAdd
|
||||
curBlock++
|
||||
} else {
|
||||
blockOffset = byteOffset - curOffset
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return curBlock, blockOffset
|
||||
func (f *FixedAttributeGroup) offset(col, row int) int {
|
||||
return row*f.RowSizeInBytes() + col*f.size
|
||||
}
|
||||
|
||||
func (f *FixedAttributeGroup) set(col int, row int, val []byte) {
|
||||
@ -97,12 +64,12 @@ func (f *FixedAttributeGroup) set(col int, row int, val []byte) {
|
||||
}
|
||||
|
||||
// Find where in the pond the byte is
|
||||
curBlock, blockOffset := f.resolveBlock(col, row)
|
||||
offset := f.offset(col, row)
|
||||
|
||||
// Copy the value in
|
||||
copied := copy(f.alloc[curBlock][blockOffset:], val)
|
||||
copied := copy(f.alloc[offset:], val)
|
||||
if copied != f.size {
|
||||
panic(fmt.Sprintf("set() terminated by only copying %d bytes into the current block (should be %d). Check EDF allocation", copied, f.size))
|
||||
panic(fmt.Sprintf("set() terminated by only copying %d bytes", copied, f.size))
|
||||
}
|
||||
|
||||
row++
|
||||
@ -112,8 +79,8 @@ func (f *FixedAttributeGroup) set(col int, row int, val []byte) {
|
||||
}
|
||||
|
||||
func (f *FixedAttributeGroup) get(col int, row int) []byte {
|
||||
curBlock, blockOffset := f.resolveBlock(col, row)
|
||||
return f.alloc[curBlock][blockOffset : blockOffset+f.size]
|
||||
offset := f.offset(col, row)
|
||||
return f.alloc[offset : offset+f.size]
|
||||
}
|
||||
|
||||
func (f *FixedAttributeGroup) appendToRowBuf(row int, buffer *bytes.Buffer) {
|
||||
@ -125,3 +92,9 @@ func (f *FixedAttributeGroup) appendToRowBuf(row int, buffer *bytes.Buffer) {
|
||||
buffer.WriteString(fmt.Sprintf("%s%s", a.GetStringFromSysVal(f.get(i, row)), postfix))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FixedAttributeGroup) resize(add int) {
|
||||
newAlloc := make([]byte, len(f.alloc)+add)
|
||||
copy(newAlloc, f.alloc)
|
||||
f.alloc = newAlloc
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
// AttributeGroups store related sequences of system values
|
||||
// in memory for the DenseInstances structure.
|
||||
type AttributeGroup interface {
|
||||
addStorage(a []byte)
|
||||
// Used for printing
|
||||
appendToRowBuf(row int, buffer *bytes.Buffer)
|
||||
// Adds a new Attribute
|
||||
@ -18,17 +17,14 @@ type AttributeGroup interface {
|
||||
get(int, int) []byte
|
||||
// Stores the byte slice at a given column, row offset
|
||||
set(int, int, []byte)
|
||||
// Sets the reference to underlying memory
|
||||
setStorage([]byte)
|
||||
// Gets the size of each row in bytes (rounded up)
|
||||
RowSize() int
|
||||
// Gets references to underlying memory
|
||||
Storage() []AttributeGroupStorageRef
|
||||
RowSizeInBytes() int
|
||||
// Adds some storage to this group
|
||||
resize(int)
|
||||
// Gets a reference to underlying memory
|
||||
Storage() []byte
|
||||
// Returns a human-readable summary
|
||||
String() string
|
||||
}
|
||||
|
||||
// AttributeGroupStorageRef is a reference to a particular set
|
||||
// of allocated rows within a FixedAttributeGroup
|
||||
type AttributeGroupStorageRef struct {
|
||||
Storage []byte
|
||||
Rows int
|
||||
}
|
||||
|
@ -252,3 +252,53 @@ func CheckCompatible(s1 FixedDataGrid, s2 FixedDataGrid) []Attribute {
|
||||
}
|
||||
return interAttrs
|
||||
}
|
||||
|
||||
// CheckStrictlyCompatible checks whether two DenseInstances have
|
||||
// AttributeGroups with the same Attributes, in the same order,
|
||||
// enabling optimisations.
|
||||
func CheckStrictlyCompatible(s1 FixedDataGrid, s2 FixedDataGrid) bool {
|
||||
// Cast
|
||||
d1, ok1 := s1.(*DenseInstances)
|
||||
d2, ok2 := s2.(*DenseInstances)
|
||||
if !ok1 || !ok2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Retrieve AttributeGroups
|
||||
d1ags := d1.AllAttributeGroups()
|
||||
d2ags := d2.AllAttributeGroups()
|
||||
|
||||
// Check everything in d1 is in d2
|
||||
for a := range d1ags {
|
||||
_, ok := d2ags[a]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check everything in d2 is in d1
|
||||
for a := range d2ags {
|
||||
_, ok := d1ags[a]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check that everything has the same number
|
||||
// of equivalent Attributes, in the same order
|
||||
for a := range d1ags {
|
||||
ag1 := d1ags[a]
|
||||
ag2 := d2ags[a]
|
||||
a1 := ag1.Attributes()
|
||||
a2 := ag2.Attributes()
|
||||
for i := range a1 {
|
||||
at1 := a1[i]
|
||||
at2 := a2[i]
|
||||
if !at1.Equals(at2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
@ -64,3 +64,48 @@ func TestPackAndUnpackFloat(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStrictlyCompatable(t *testing.T) {
|
||||
Convey("Given two datasets...", t, func() {
|
||||
Convey("Given two identical datasets", func() {
|
||||
// Violates the requirement that both CategoricalAttributes
|
||||
// must have values in the same order
|
||||
d1, err := ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
d2, err := ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(CheckStrictlyCompatible(d1, d2), ShouldEqual, true)
|
||||
})
|
||||
Convey("Given two identical datasets (apart from sorting)", func() {
|
||||
// Violates the requirement that both CategoricalAttributes
|
||||
// must have values in the same order
|
||||
d1, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
d2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(CheckStrictlyCompatible(d1, d2), ShouldEqual, false)
|
||||
})
|
||||
Convey("Given two different datasets...", func() {
|
||||
// Violates verything
|
||||
d1, err := ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
d2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
So(CheckStrictlyCompatible(d1, d2), ShouldEqual, false)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCategoricalEquality(t *testing.T) {
|
||||
Convey("Given two outwardly identical class Attributes...", t, func() {
|
||||
d1, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
d2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true)
|
||||
So(err, ShouldEqual, nil)
|
||||
c1 := d1.AllClassAttributes()[0]
|
||||
c2 := d2.AllClassAttributes()[0]
|
||||
So(c1.GetName(), ShouldEqual, c2.GetName())
|
||||
So(c1.Equals(c2), ShouldBeFalse)
|
||||
So(c2.Equals(c1), ShouldBeFalse) // Violates the fact that Attributes must appear in the same order
|
||||
})
|
||||
}
|
||||
|
30
knn/euclidean.c
Normal file
30
knn/euclidean.c
Normal file
@ -0,0 +1,30 @@
|
||||
// #cgo CFLAGS: -Og -march=native -ffast-math
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "knn.h"
|
||||
|
||||
/* Works out the Euclidean distance (not square-rooted) for a given
|
||||
* AttributeGroup */
|
||||
void euclidean_distance (
|
||||
struct dist *out, /* Output distance vector, needs to be initially zero */
|
||||
int max_row, /* Size of the output vector */
|
||||
int max_col, /* Number of columns */
|
||||
int row, /* Current row */
|
||||
double *train, /* Pointer to first element of training AttributeGroup */
|
||||
double *pred /* Pointer to first element of equivalent prediction AttributeGroup */
|
||||
)
|
||||
{
|
||||
int i, j;
|
||||
for (i = 0; i < max_row; i++) {
|
||||
out[i].p = i;
|
||||
for (j = 0; j < max_col; j++) {
|
||||
double tmp;
|
||||
tmp = *(pred + row * max_col + j);
|
||||
tmp -= *(train + i * max_col + j);
|
||||
tmp *= tmp; /* Square */
|
||||
out[i].dist += tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
136
knn/knn.go
136
knn/knn.go
@ -4,6 +4,7 @@
|
||||
package knn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||
@ -12,11 +13,14 @@ import (
|
||||
|
||||
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, and a distance function.
|
||||
// The accepted distance functions at this time are 'euclidean' and 'manhattan'.
|
||||
// Optimisations only occur when things are identically group into identical
|
||||
// AttributeGroups, which don't include the class variable, in the same order.
|
||||
type KNNClassifier struct {
|
||||
base.BaseEstimator
|
||||
TrainingData base.FixedDataGrid
|
||||
DistanceFunc string
|
||||
NearestNeighbours int
|
||||
TrainingData base.FixedDataGrid
|
||||
DistanceFunc string
|
||||
NearestNeighbours int
|
||||
AllowOptimisations bool
|
||||
}
|
||||
|
||||
// NewKnnClassifier returns a new classifier
|
||||
@ -24,6 +28,7 @@ func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
|
||||
KNN := KNNClassifier{}
|
||||
KNN.DistanceFunc = distfunc
|
||||
KNN.NearestNeighbours = neighbours
|
||||
KNN.AllowOptimisations = true
|
||||
return &KNN
|
||||
}
|
||||
|
||||
@ -32,9 +37,58 @@ func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) {
|
||||
KNN.TrainingData = trainingData
|
||||
}
|
||||
|
||||
func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool {
|
||||
// Check that the two have exactly the same layout
|
||||
if !base.CheckStrictlyCompatible(what, KNN.TrainingData) {
|
||||
return false
|
||||
}
|
||||
// Check that the two are DenseInstances
|
||||
whatd, ok1 := what.(*base.DenseInstances)
|
||||
_, ok2 := KNN.TrainingData.(*base.DenseInstances)
|
||||
if !ok1 || !ok2 {
|
||||
return false
|
||||
}
|
||||
// Check that no Class Attributes are mixed in with the data
|
||||
classAttrs := whatd.AllClassAttributes()
|
||||
normalAttrs := base.NonClassAttributes(whatd)
|
||||
// Retrieve all the AGs
|
||||
ags := whatd.AllAttributeGroups()
|
||||
classAttrGroups := make([]base.AttributeGroup, 0)
|
||||
for agName := range ags {
|
||||
ag := ags[agName]
|
||||
attrs := ag.Attributes()
|
||||
matched := false
|
||||
for _, a := range attrs {
|
||||
for _, c := range classAttrs {
|
||||
if a.Equals(c) {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
classAttrGroups = append(classAttrGroups, ag)
|
||||
}
|
||||
}
|
||||
for _, cag := range classAttrGroups {
|
||||
attrs := cag.Attributes()
|
||||
common := base.AttributeIntersect(normalAttrs, attrs)
|
||||
if len(common) != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all of the Attributes are numeric
|
||||
for _, a := range normalAttrs {
|
||||
if _, ok := a.(*base.FloatAttribute); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// If that's fine, return true
|
||||
return true
|
||||
}
|
||||
|
||||
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
||||
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
|
||||
// Check what distance function we are using
|
||||
var distanceFunc pairwise.PairwiseDistanceFunc
|
||||
switch KNN.DistanceFunc {
|
||||
@ -44,7 +98,6 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
distanceFunc = pairwise.NewManhattan()
|
||||
default:
|
||||
panic("unsupported distance function")
|
||||
|
||||
}
|
||||
// Check Compatibility
|
||||
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
|
||||
@ -53,6 +106,16 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use optimised version if permitted
|
||||
if KNN.AllowOptimisations {
|
||||
if KNN.DistanceFunc == "euclidean" {
|
||||
if KNN.canUseOptimisations(what) {
|
||||
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances))
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("Optimisations are switched off")
|
||||
|
||||
// Remove the Attributes which aren't numeric
|
||||
allNumericAttrs := make([]base.Attribute, 0)
|
||||
for _, a := range allAttrs {
|
||||
@ -78,8 +141,17 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
trainRowBuf := make([]float64, len(allNumericAttrs))
|
||||
predRowBuf := make([]float64, len(allNumericAttrs))
|
||||
|
||||
_, maxRow := what.Size()
|
||||
curRow := 0
|
||||
|
||||
// Iterate over all outer rows
|
||||
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
|
||||
|
||||
if (curRow%1) == 0 && curRow > 0 {
|
||||
fmt.Printf("KNN: %.2f %% done\n", float64(curRow)*100.0/float64(maxRow))
|
||||
}
|
||||
curRow++
|
||||
|
||||
// Read the float values out
|
||||
for i, _ := range allNumericAttrs {
|
||||
predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
|
||||
@ -89,7 +161,6 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
|
||||
// Find the closest match in the training data
|
||||
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
||||
|
||||
// Read the float values out
|
||||
for i, _ := range allNumericAttrs {
|
||||
trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
|
||||
@ -104,30 +175,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
sorted := utilities.SortIntMap(distances)
|
||||
values := sorted[:KNN.NearestNeighbours]
|
||||
|
||||
// Reset maxMap
|
||||
for a := range maxmap {
|
||||
maxmap[a] = 0
|
||||
}
|
||||
|
||||
// Refresh maxMap
|
||||
for _, elem := range values {
|
||||
label := base.GetClass(KNN.TrainingData, elem)
|
||||
if _, ok := maxmap[label]; ok {
|
||||
maxmap[label]++
|
||||
} else {
|
||||
maxmap[label] = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the maxMap
|
||||
var maxClass string
|
||||
maxVal := -1
|
||||
for a := range maxmap {
|
||||
if maxmap[a] > maxVal {
|
||||
maxVal = maxmap[a]
|
||||
maxClass = a
|
||||
}
|
||||
}
|
||||
maxClass := KNN.vote(maxmap, values)
|
||||
|
||||
base.SetClass(ret, predRowNo, maxClass)
|
||||
return true, nil
|
||||
@ -137,6 +185,34 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
return ret
|
||||
}
|
||||
|
||||
func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
|
||||
// Reset maxMap
|
||||
for a := range maxmap {
|
||||
maxmap[a] = 0
|
||||
}
|
||||
|
||||
// Refresh maxMap
|
||||
for _, elem := range values {
|
||||
label := base.GetClass(KNN.TrainingData, elem)
|
||||
if _, ok := maxmap[label]; ok {
|
||||
maxmap[label]++
|
||||
} else {
|
||||
maxmap[label] = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the maxMap
|
||||
var maxClass string
|
||||
maxVal := -1
|
||||
for a := range maxmap {
|
||||
if maxmap[a] > maxVal {
|
||||
maxVal = maxmap[a]
|
||||
maxClass = a
|
||||
}
|
||||
}
|
||||
return maxClass
|
||||
}
|
||||
|
||||
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
|
||||
type KNNRegressor struct {
|
||||
base.BaseEstimator
|
||||
|
21
knn/knn.h
Normal file
21
knn/knn.h
Normal file
@ -0,0 +1,21 @@
|
||||
#ifndef _H_FUNCS
|
||||
#define _H_FUNCS
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
struct dist {
|
||||
float dist;
|
||||
uint32_t p;
|
||||
};
|
||||
|
||||
/* Works out the Euclidean distance (not square-rooted) for a given
|
||||
* AttributeGroup */
|
||||
void euclidean_distance (
|
||||
struct dist *out, /* Output distance vector, needs to be initially zero */
|
||||
int max_row, /* Size of the output vector */
|
||||
int max_col, /* Number of columns */
|
||||
int row, /* Current prediction row */
|
||||
double *train, /* Pointer to first element of training AttributeGroup */
|
||||
double *pred /* Pointer to first element of equivalent prediction AttributeGroup */
|
||||
);
|
||||
#endif
|
70
knn/knn_bench_test.go
Normal file
70
knn/knn_bench_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
package knn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func readMnist() (*base.DenseInstances, *base.DenseInstances) {
|
||||
// Create the class Attribute
|
||||
classAttrs := make(map[int]base.Attribute)
|
||||
classAttrs[0] = base.NewCategoricalAttribute()
|
||||
classAttrs[0].SetName("label")
|
||||
// Setup the class Attribute to be in its own group
|
||||
classAttrGroups := make(map[string]string)
|
||||
classAttrGroups["label"] = "ClassGroup"
|
||||
// The rest can go in a default group
|
||||
attrGroups := make(map[string]string)
|
||||
|
||||
inst1, err := base.ParseCSVToInstancesWithAttributeGroups(
|
||||
"../examples/datasets/mnist_train.csv",
|
||||
attrGroups,
|
||||
classAttrGroups,
|
||||
classAttrs,
|
||||
true,
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
inst2, err := base.ParseCSVToTemplatedInstances(
|
||||
"../examples/datasets/mnist_test.csv",
|
||||
true,
|
||||
inst1,
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return inst1, inst2
|
||||
}
|
||||
|
||||
func BenchmarkKNNWithOpts(b *testing.B) {
|
||||
// Load
|
||||
train, test := readMnist()
|
||||
cls := NewKnnClassifier("euclidean", 1)
|
||||
cls.AllowOptimisations = true
|
||||
cls.Fit(train)
|
||||
predictions := cls.Predict(test)
|
||||
c, err := evaluation.GetConfusionMatrix(test, predictions)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(evaluation.GetSummary(c))
|
||||
fmt.Println(evaluation.GetAccuracy(c))
|
||||
}
|
||||
|
||||
func BenchmarkKNNWithNoOpts(b *testing.B) {
|
||||
// Load
|
||||
train, test := readMnist()
|
||||
cls := NewKnnClassifier("euclidean", 1)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(train)
|
||||
predictions := cls.Predict(test)
|
||||
c, err := evaluation.GetConfusionMatrix(test, predictions)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(evaluation.GetSummary(c))
|
||||
fmt.Println(evaluation.GetAccuracy(c))
|
||||
}
|
86
knn/knn_opt_euclidean.go
Normal file
86
knn/knn_opt_euclidean.go
Normal file
@ -0,0 +1,86 @@
|
||||
package knn
|
||||
|
||||
// #include "knn.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type dist _Ctype_struct_dist
|
||||
|
||||
type distanceRecs []_Ctype_struct_dist
|
||||
|
||||
func (d distanceRecs) Len() int { return len(d) }
|
||||
func (d distanceRecs) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
|
||||
func (d distanceRecs) Less(i, j int) bool { return d[i].dist < d[j].dist }
|
||||
|
||||
func (KNN *KNNClassifier) optimisedEuclideanPredict(d *base.DenseInstances) base.FixedDataGrid {
|
||||
|
||||
// Create return vector
|
||||
ret := base.GeneratePredictionVector(d)
|
||||
// Type-assert training data
|
||||
tr := KNN.TrainingData.(*base.DenseInstances)
|
||||
// Enumeration of AttributeGroups
|
||||
agPos := make(map[string]int)
|
||||
agTrain := tr.AllAttributeGroups()
|
||||
agPred := d.AllAttributeGroups()
|
||||
classAttrs := tr.AllClassAttributes()
|
||||
counter := 0
|
||||
for ag := range agTrain {
|
||||
// Detect whether the AttributeGroup has any classes in it
|
||||
attrs := agTrain[ag].Attributes()
|
||||
//matched := false
|
||||
if len(base.AttributeIntersect(classAttrs, attrs)) == 0 {
|
||||
agPos[ag] = counter
|
||||
}
|
||||
counter++
|
||||
}
|
||||
// Pointers to the start of each prediction row
|
||||
rowPointers := make([]*C.double, len(agPred))
|
||||
trainPointers := make([]*C.double, len(agPred))
|
||||
rowSizes := make([]int, len(agPred))
|
||||
for ag := range agPred {
|
||||
if ap, ok := agPos[ag]; ok {
|
||||
|
||||
rowPointers[ap] = (*C.double)(unsafe.Pointer(&(agPred[ag].Storage()[0])))
|
||||
trainPointers[ap] = (*C.double)(unsafe.Pointer(&(agTrain[ag].Storage()[0])))
|
||||
rowSizes[ap] = agPred[ag].RowSizeInBytes() / 8
|
||||
}
|
||||
}
|
||||
_, predRows := d.Size()
|
||||
_, trainRows := tr.Size()
|
||||
// Crete the distance vector
|
||||
distanceVec := distanceRecs(make([]_Ctype_struct_dist, trainRows))
|
||||
// Additional datastructures
|
||||
voteVec := make([]int, KNN.NearestNeighbours)
|
||||
maxMap := make(map[string]int)
|
||||
|
||||
for row := 0; row < predRows; row++ {
|
||||
for i := 0; i < trainRows; i++ {
|
||||
distanceVec[i].dist = 0
|
||||
}
|
||||
for ag := range agPred {
|
||||
if ap, ok := agPos[ag]; ok {
|
||||
C.euclidean_distance(
|
||||
&(distanceVec[0]),
|
||||
C.int(trainRows),
|
||||
C.int(len(agPred[ag].Attributes())),
|
||||
C.int(row),
|
||||
trainPointers[ap],
|
||||
rowPointers[ap],
|
||||
)
|
||||
}
|
||||
}
|
||||
sort.Sort(distanceVec)
|
||||
votes := distanceVec[:KNN.NearestNeighbours]
|
||||
for i, v := range votes {
|
||||
voteVec[i] = int(v.p)
|
||||
}
|
||||
maxClass := KNN.vote(maxMap, voteVec)
|
||||
base.SetClass(ret, row, maxClass)
|
||||
}
|
||||
return ret
|
||||
}
|
@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestKnnClassifier(t *testing.T) {
|
||||
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
||||
Convey("Given labels, a classifier and data", t, func() {
|
||||
trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
@ -15,6 +15,37 @@ func TestKnnClassifier(t *testing.T) {
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(trainingData)
|
||||
predictions := cls.Predict(testingData)
|
||||
So(predictions, ShouldNotEqual, nil)
|
||||
|
||||
Convey("When predicting the label for our first vector", func() {
|
||||
result := base.GetClass(predictions, 0)
|
||||
Convey("The result should be 'blue", func() {
|
||||
So(result, ShouldEqual, "blue")
|
||||
})
|
||||
})
|
||||
|
||||
Convey("When predicting the label for our second vector", func() {
|
||||
result2 := base.GetClass(predictions, 1)
|
||||
Convey("The result should be 'red", func() {
|
||||
So(result2, ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnnClassifierWithOptimisations(t *testing.T) {
|
||||
Convey("Given labels, a classifier and data", t, func() {
|
||||
trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls.AllowOptimisations = true
|
||||
cls.Fit(trainingData)
|
||||
predictions := cls.Predict(testingData)
|
||||
So(predictions, ShouldNotEqual, nil)
|
||||
|
Loading…
x
Reference in New Issue
Block a user