1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

add heap.go and heap_test.go for heap using in kdtree

This commit is contained in:
FrozenKP 2017-04-16 00:48:23 +08:00
parent 2c8b3be961
commit 759ee645c5
3 changed files with 183 additions and 56 deletions

88
kdtree/heap.go Normal file
View File

@ -0,0 +1,88 @@
package kdtree
import (
"errors"
)
type heapNode struct {
value []float64
length float64
}
type heap struct {
tree []heapNode
}
// newHeap return a pointer of heap.
func newHeap() *heap {
h := &heap{}
h.tree = make([]heapNode, 0)
return &heap{}
}
// maximum return the max heapNode in the heap.
func (h *heap) maximum() (heapNode, error) {
if len(h.tree) == 0 {
return heapNode{}, h.errEmpty()
}
return h.tree[0], nil
}
// extractMax remove the Max heapNode in the heap.
func (h *heap) extractMax() {
if len(h.tree) == 0 {
return
}
h.tree[0] = h.tree[len(h.tree)-1]
h.tree = h.tree[:len(h.tree)-1]
target := 1
for true {
largest := target
if target*2-1 >= len(h.tree) {
break
}
if h.tree[target*2-1].length > h.tree[target].length {
largest = target * 2
}
if target*2 >= len(h.tree) {
break
}
if h.tree[target*2].length > h.tree[largest-1].length {
largest = target*2 + 1
}
if largest == target {
break
}
h.tree[largest-1], h.tree[target-1] = h.tree[target-1], h.tree[largest-1]
target = largest
}
}
// insert put a new heapNode into heap.
func (h *heap) insert(value []float64, length float64) {
node := heapNode{}
node.length = length
node.value = make([]float64, len(value))
copy(node.value, value)
h.tree = append(h.tree, node)
target := len(h.tree)
for target != 1 {
if h.tree[(target/2)-1].length >= h.tree[target-1].length {
break
}
h.tree[target-1], h.tree[(target/2)-1] = h.tree[(target/2)-1], h.tree[target-1]
target /= 2
}
}
// errEmpty is return an error which is returned
// when heap is empty.
func (h *heap) errEmpty() error {
return errors.New("empty heap")
}

40
kdtree/heap_test.go Normal file
View File

@ -0,0 +1,40 @@
package kdtree
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestHeap(t *testing.T) {
h := newHeap()
Convey("Given a heap", t, func() {
Convey("When heap is empty", func() {
_, err := h.maximum()
Convey("The err should be errEmpty", func() {
So(err, ShouldEqual, h.errEmpty())
})
})
Convey("When insert 5 nodes", func() {
for i := 0; i < 5; i++ {
h.insert([]float64{}, float64(i))
}
max1, _ := h.maximum()
h.extractMax()
max2, _ := h.maximum()
Convey("The max1.value should be 4", func() {
So(max1.value, ShouldEqual, 4)
})
Convey("The max2.value should be 3", func() {
So(max2.value, ShouldEqual, 3)
})
})
})
}

View File

@ -1,81 +1,80 @@
package kdtree
import(
"sort"
"errors"
import (
"errors"
"sort"
)
type node struct{
feature int
value []float64
left *node
right *node
type node struct {
feature int
value []float64
left *node
right *node
}
// Tree is a kdtree.
type Tree struct{
firstDiv *node
type Tree struct {
firstDiv *node
}
// New return a Tree pointer.
func New()*Tree{
return &Tree{}
func New() *Tree {
return &Tree{}
}
// Build builds the kdtree with specific data.
func (t *Tree) Build(data [][]float64)err{
if len(data)==0 {
return errors.New("no input data")
}
size := len(data[0])
for _, v := range data {
if len(v) != size {
return errors.New("amounts of features are not the same")
}
}
func (t *Tree) Build(data [][]float64) error {
if len(data) == 0 {
return errors.New("no input data")
}
size := len(data[0])
for _, v := range data {
if len(v) != size {
return errors.New("amounts of features are not the same")
}
}
t.firstDiv = t.buildHandle(data, 0)
t.firstDiv = t.buildHandle(data, 0)
return nil
return nil
}
// buildHandle builds the kdtree recursively.
func (t *tree) buildHandle(data [][]float64, featureIndex int)*node{
n := &node{feature:featureIndex}
func (t *Tree) buildHandle(data [][]float64, featureIndex int) *node {
n := &node{feature: featureIndex}
sort.Slice(data, func(i, j int)bool{
return data[i][featureIndex]<data[j][featureIndex]
})
middle:= len(data)/2
sort.Slice(data, func(i, j int) bool {
return data[i][featureIndex] < data[j][featureIndex]
})
middle := len(data) / 2
n.value = make([]float64, len(data[middle]))
copy(n.value, data[middle])
n.value = make([]float64, len(data[middle]))
copy(n.value, data[middle])
divPoint := middle
for i:=middle+1 ; i<len(data) ; i++ {
if data[i][featureIndex] == data[middle][featureIndex] {
divPoint=i;
}else{
break
}
}
divPoint := middle
for i := middle + 1; i < len(data); i++ {
if data[i][featureIndex] == data[middle][featureIndex] {
divPoint = i
} else {
break
}
}
if divPoint==1 {
n.Left = &node{feature:-1}
n.Left.value = make([]float64, len(data[0]))
copy(n.Left.value, data[0])
}else{
n.Left = t.buildHandle(data[:divPoint])
}
if divPoint == 1 {
n.left = &node{feature: -1}
n.left.value = make([]float64, len(data[0]))
copy(n.left.value, data[0])
} else {
n.left = t.buildHandle(data[:divPoint], (featureIndex+1)%len(data[0]))
}
if divPoint==(len(data)-2) {
n.Right = &node{feature:-1}
n.Right.value = make([]float64, len(data[divPoint+1]))
copy(n.Right.value, data[divPoint+1])
}else if divPoint!=(len(data)-1){
n.Right = t.buildHandle(data[divPoint+1:])
}
if divPoint == (len(data) - 2) {
n.right = &node{feature: -1}
n.right.value = make([]float64, len(data[divPoint+1]))
copy(n.right.value, data[divPoint+1])
} else if divPoint != (len(data) - 1) {
n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(data[0]))
}
return n
return n
}