mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
commit
8547a4335e
83
kdtree/heap.go
Normal file
83
kdtree/heap.go
Normal file
@ -0,0 +1,83 @@
|
||||
package kdtree
|
||||
|
||||
type heapNode struct {
|
||||
value []float64
|
||||
length float64
|
||||
srcRowNo int
|
||||
}
|
||||
|
||||
type heap struct {
|
||||
tree []heapNode
|
||||
}
|
||||
|
||||
// newHeap return a pointer of heap.
|
||||
func newHeap() *heap {
|
||||
h := &heap{}
|
||||
h.tree = make([]heapNode, 0)
|
||||
return &heap{}
|
||||
}
|
||||
|
||||
// maximum return the max heapNode in the heap.
|
||||
func (h *heap) maximum() heapNode {
|
||||
if len(h.tree) == 0 {
|
||||
return heapNode{}
|
||||
}
|
||||
|
||||
return h.tree[0]
|
||||
}
|
||||
|
||||
// extractMax remove the Max heapNode in the heap.
|
||||
func (h *heap) extractMax() {
|
||||
if len(h.tree) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
h.tree[0] = h.tree[len(h.tree)-1]
|
||||
h.tree = h.tree[:len(h.tree)-1]
|
||||
|
||||
target := 1
|
||||
for true {
|
||||
largest := target
|
||||
if target*2-1 >= len(h.tree) {
|
||||
break
|
||||
}
|
||||
if h.tree[target*2-1].length > h.tree[target-1].length {
|
||||
largest = target * 2
|
||||
}
|
||||
|
||||
if target*2 < len(h.tree) {
|
||||
if h.tree[target*2].length > h.tree[largest-1].length {
|
||||
largest = target*2 + 1
|
||||
}
|
||||
}
|
||||
|
||||
if largest == target {
|
||||
break
|
||||
}
|
||||
h.tree[largest-1], h.tree[target-1] = h.tree[target-1], h.tree[largest-1]
|
||||
target = largest
|
||||
}
|
||||
}
|
||||
|
||||
// insert put a new heapNode into heap.
|
||||
func (h *heap) insert(value []float64, length float64, srcRowNo int) {
|
||||
node := heapNode{}
|
||||
node.length = length
|
||||
node.srcRowNo = srcRowNo
|
||||
node.value = make([]float64, len(value))
|
||||
copy(node.value, value)
|
||||
h.tree = append(h.tree, node)
|
||||
|
||||
target := len(h.tree)
|
||||
for target != 1 {
|
||||
if h.tree[(target/2)-1].length >= h.tree[target-1].length {
|
||||
break
|
||||
}
|
||||
h.tree[target-1], h.tree[(target/2)-1] = h.tree[(target/2)-1], h.tree[target-1]
|
||||
target /= 2
|
||||
}
|
||||
}
|
||||
|
||||
func (h *heap) size() int {
|
||||
return len(h.tree)
|
||||
}
|
41
kdtree/heap_test.go
Normal file
41
kdtree/heap_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package kdtree
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestHeap(t *testing.T) {
|
||||
h := newHeap()
|
||||
|
||||
Convey("Given a heap", t, func() {
|
||||
|
||||
Convey("When heap is empty", func() {
|
||||
size := h.size()
|
||||
|
||||
Convey("The size should be 0", func() {
|
||||
So(size, ShouldEqual, 0)
|
||||
})
|
||||
})
|
||||
|
||||
Convey("When insert 10 nodes", func() {
|
||||
for i := 0; i < 10; i++ {
|
||||
h.insert([]float64{}, float64(i), i)
|
||||
}
|
||||
max1 := h.maximum()
|
||||
h.extractMax()
|
||||
h.extractMax()
|
||||
h.extractMax()
|
||||
max2 := h.maximum()
|
||||
|
||||
Convey("The max1.length should be 9", func() {
|
||||
So(max1.length, ShouldEqual, 9)
|
||||
})
|
||||
Convey("The max2.length should be 6", func() {
|
||||
So(max2.length, ShouldEqual, 6)
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
}
|
195
kdtree/kdtree.go
Normal file
195
kdtree/kdtree.go
Normal file
@ -0,0 +1,195 @@
|
||||
package kdtree
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type node struct {
|
||||
feature int
|
||||
value []float64
|
||||
srcRowNo int
|
||||
left *node
|
||||
right *node
|
||||
}
|
||||
|
||||
// Tree is a kdtree.
|
||||
type Tree struct {
|
||||
firstDiv *node
|
||||
data [][]float64
|
||||
}
|
||||
|
||||
type SortData struct {
|
||||
RowData [][]float64
|
||||
Data []int
|
||||
Feature int
|
||||
}
|
||||
|
||||
func (d SortData) Len() int { return len(d.Data) }
|
||||
func (d SortData) Less(i, j int) bool {
|
||||
return d.RowData[d.Data[i]][d.Feature] < d.RowData[d.Data[j]][d.Feature]
|
||||
}
|
||||
func (d SortData) Swap(i, j int) { d.Data[i], d.Data[j] = d.Data[j], d.Data[i] }
|
||||
|
||||
// New return a Tree pointer.
|
||||
func New() *Tree {
|
||||
return &Tree{}
|
||||
}
|
||||
|
||||
// Build builds the kdtree with specific data.
|
||||
func (t *Tree) Build(data [][]float64) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("no input data")
|
||||
}
|
||||
size := len(data[0])
|
||||
for _, v := range data {
|
||||
if len(v) != size {
|
||||
return errors.New("amounts of features are not the same")
|
||||
}
|
||||
}
|
||||
|
||||
t.data = data
|
||||
|
||||
newData := make([]int, len(data))
|
||||
for k, _ := range newData {
|
||||
newData[k] = k
|
||||
}
|
||||
|
||||
if len(data) == 1 {
|
||||
t.firstDiv = &node{feature: -1, srcRowNo: 0}
|
||||
t.firstDiv.value = make([]float64, len(data[0]))
|
||||
copy(t.firstDiv.value, data[0])
|
||||
} else {
|
||||
t.firstDiv = t.buildHandle(newData, 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildHandle builds the kdtree recursively.
|
||||
func (t *Tree) buildHandle(data []int, featureIndex int) *node {
|
||||
n := &node{feature: featureIndex}
|
||||
|
||||
tmp := SortData{RowData: t.data, Data: data, Feature: featureIndex}
|
||||
sort.Sort(tmp)
|
||||
middle := len(data) / 2
|
||||
|
||||
n.srcRowNo = data[middle]
|
||||
n.value = make([]float64, len(t.data[data[middle]]))
|
||||
copy(n.value, t.data[data[middle]])
|
||||
|
||||
divPoint := middle
|
||||
for i := middle + 1; i < len(data); i++ {
|
||||
if t.data[data[i]][featureIndex] == t.data[data[middle]][featureIndex] {
|
||||
divPoint = i
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if divPoint == 1 {
|
||||
n.left = &node{feature: -1}
|
||||
n.left.value = make([]float64, len(t.data[data[0]]))
|
||||
copy(n.left.value, t.data[data[0]])
|
||||
n.left.srcRowNo = data[0]
|
||||
} else {
|
||||
n.left = t.buildHandle(data[:divPoint], (featureIndex+1)%len(t.data[data[0]]))
|
||||
}
|
||||
|
||||
if divPoint == (len(data) - 2) {
|
||||
n.right = &node{feature: -1}
|
||||
n.right.value = make([]float64, len(t.data[data[divPoint+1]]))
|
||||
copy(n.right.value, t.data[data[divPoint+1]])
|
||||
n.left.srcRowNo = data[divPoint+1]
|
||||
} else if divPoint != (len(data) - 1) {
|
||||
n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(t.data[data[0]]))
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Search return []int contained k nearest neighbor from
|
||||
// specific distance function.
|
||||
func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, error) {
|
||||
if k > len(t.data) {
|
||||
return []int{}, errors.New("k is largerer than amount of trainData")
|
||||
}
|
||||
|
||||
if len(target) != len(t.data[0]) {
|
||||
return []int{}, errors.New("amount of features is not equal")
|
||||
}
|
||||
|
||||
h := newHeap()
|
||||
t.searchHandle(k, disType, target, h, t.firstDiv)
|
||||
|
||||
out := make([]int, k)
|
||||
i := k - 1
|
||||
for h.size() != 0 {
|
||||
out[i] = h.maximum().srcRowNo
|
||||
i--
|
||||
h.extractMax()
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
||||
if n.feature == -1 {
|
||||
vectorX := mat64.NewDense(len(target), 1, target)
|
||||
vectorY := mat64.NewDense(len(target), 1, n.value)
|
||||
length := disType.Distance(vectorX, vectorY)
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
return
|
||||
}
|
||||
|
||||
dir := true
|
||||
if target[n.feature] <= n.value[n.feature] {
|
||||
t.searchHandle(k, disType, target, h, n.left)
|
||||
} else {
|
||||
dir = false
|
||||
t.searchHandle(k, disType, target, h, n.right)
|
||||
}
|
||||
|
||||
vectorX := mat64.NewDense(len(target), 1, target)
|
||||
vectorY := mat64.NewDense(len(target), 1, n.value)
|
||||
length := disType.Distance(vectorX, vectorY)
|
||||
|
||||
if k > h.size() {
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
if dir {
|
||||
t.searchAllNode(k, disType, target, h, n.right)
|
||||
} else {
|
||||
t.searchAllNode(k, disType, target, h, n.left)
|
||||
}
|
||||
} else if h.maximum().length > length {
|
||||
h.extractMax()
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
if dir {
|
||||
t.searchAllNode(k, disType, target, h, n.right)
|
||||
} else {
|
||||
t.searchAllNode(k, disType, target, h, n.left)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tree) searchAllNode(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
||||
vectorX := mat64.NewDense(len(target), 1, target)
|
||||
vectorY := mat64.NewDense(len(target), 1, n.value)
|
||||
length := disType.Distance(vectorX, vectorY)
|
||||
|
||||
if k > h.size() {
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
} else if h.maximum().length > length {
|
||||
h.extractMax()
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
}
|
||||
|
||||
if n.left != nil {
|
||||
t.searchAllNode(k, disType, target, h, n.left)
|
||||
}
|
||||
if n.right != nil {
|
||||
t.searchAllNode(k, disType, target, h, n.right)
|
||||
}
|
||||
}
|
44
kdtree/kdtree_test.go
Normal file
44
kdtree/kdtree_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package kdtree
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestKdtree(t *testing.T) {
|
||||
kd := New()
|
||||
|
||||
Convey("Given a kdtree", t, func() {
|
||||
data := [][]float64{{2, 3}, {5, 4}, {4, 7}, {8, 1}, {7, 2}, {9, 6}}
|
||||
kd.Build(data)
|
||||
euclidean := pairwise.NewEuclidean()
|
||||
|
||||
Convey("When k is 3 with euclidean", func() {
|
||||
result, _ := kd.Search(3, euclidean, []float64{7, 3})
|
||||
|
||||
Convey("The result[0] should be 4", func() {
|
||||
So(result[0], ShouldEqual, 4)
|
||||
})
|
||||
Convey("The result[1] should be 3", func() {
|
||||
So(result[1], ShouldEqual, 3)
|
||||
})
|
||||
Convey("The result[2] should be 1", func() {
|
||||
So(result[2], ShouldEqual, 1)
|
||||
})
|
||||
})
|
||||
|
||||
Convey("When k is 2 with euclidean", func() {
|
||||
result, _ := kd.Search(2, euclidean, []float64{7, 3})
|
||||
|
||||
Convey("The result[0] should be 4", func() {
|
||||
So(result[0], ShouldEqual, 4)
|
||||
})
|
||||
Convey("The result[1] should be 1", func() {
|
||||
So(result[1], ShouldEqual, 1)
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user