mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
add heap.go and heap_test.go for heap using in kdtree
This commit is contained in:
parent
2c8b3be961
commit
759ee645c5
88
kdtree/heap.go
Normal file
88
kdtree/heap.go
Normal file
@ -0,0 +1,88 @@
|
||||
package kdtree
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
type heapNode struct {
|
||||
value []float64
|
||||
length float64
|
||||
}
|
||||
|
||||
type heap struct {
|
||||
tree []heapNode
|
||||
}
|
||||
|
||||
// newHeap return a pointer of heap.
|
||||
func newHeap() *heap {
|
||||
h := &heap{}
|
||||
h.tree = make([]heapNode, 0)
|
||||
return &heap{}
|
||||
}
|
||||
|
||||
// maximum return the max heapNode in the heap.
|
||||
func (h *heap) maximum() (heapNode, error) {
|
||||
if len(h.tree) == 0 {
|
||||
return heapNode{}, h.errEmpty()
|
||||
}
|
||||
|
||||
return h.tree[0], nil
|
||||
}
|
||||
|
||||
// extractMax remove the Max heapNode in the heap.
|
||||
func (h *heap) extractMax() {
|
||||
if len(h.tree) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
h.tree[0] = h.tree[len(h.tree)-1]
|
||||
h.tree = h.tree[:len(h.tree)-1]
|
||||
|
||||
target := 1
|
||||
for true {
|
||||
largest := target
|
||||
if target*2-1 >= len(h.tree) {
|
||||
break
|
||||
}
|
||||
if h.tree[target*2-1].length > h.tree[target].length {
|
||||
largest = target * 2
|
||||
}
|
||||
|
||||
if target*2 >= len(h.tree) {
|
||||
break
|
||||
}
|
||||
if h.tree[target*2].length > h.tree[largest-1].length {
|
||||
largest = target*2 + 1
|
||||
}
|
||||
|
||||
if largest == target {
|
||||
break
|
||||
}
|
||||
h.tree[largest-1], h.tree[target-1] = h.tree[target-1], h.tree[largest-1]
|
||||
target = largest
|
||||
}
|
||||
}
|
||||
|
||||
// insert put a new heapNode into heap.
|
||||
func (h *heap) insert(value []float64, length float64) {
|
||||
node := heapNode{}
|
||||
node.length = length
|
||||
node.value = make([]float64, len(value))
|
||||
copy(node.value, value)
|
||||
h.tree = append(h.tree, node)
|
||||
|
||||
target := len(h.tree)
|
||||
for target != 1 {
|
||||
if h.tree[(target/2)-1].length >= h.tree[target-1].length {
|
||||
break
|
||||
}
|
||||
h.tree[target-1], h.tree[(target/2)-1] = h.tree[(target/2)-1], h.tree[target-1]
|
||||
target /= 2
|
||||
}
|
||||
}
|
||||
|
||||
// errEmpty is return an error which is returned
|
||||
// when heap is empty.
|
||||
func (h *heap) errEmpty() error {
|
||||
return errors.New("empty heap")
|
||||
}
|
40
kdtree/heap_test.go
Normal file
40
kdtree/heap_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package kdtree
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestHeap(t *testing.T) {
|
||||
h := newHeap()
|
||||
|
||||
Convey("Given a heap", t, func() {
|
||||
|
||||
Convey("When heap is empty", func() {
|
||||
_, err := h.maximum()
|
||||
|
||||
Convey("The err should be errEmpty", func() {
|
||||
So(err, ShouldEqual, h.errEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Convey("When insert 5 nodes", func() {
|
||||
for i := 0; i < 5; i++ {
|
||||
h.insert([]float64{}, float64(i))
|
||||
}
|
||||
max1, _ := h.maximum()
|
||||
h.extractMax()
|
||||
max2, _ := h.maximum()
|
||||
|
||||
Convey("The max1.value should be 4", func() {
|
||||
So(max1.value, ShouldEqual, 4)
|
||||
})
|
||||
Convey("The max2.value should be 3", func() {
|
||||
So(max2.value, ShouldEqual, 3)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
}
|
111
kdtree/kdtree.go
111
kdtree/kdtree.go
@ -1,81 +1,80 @@
|
||||
package kdtree
|
||||
|
||||
import(
|
||||
"sort"
|
||||
"errors"
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type node struct{
|
||||
feature int
|
||||
value []float64
|
||||
left *node
|
||||
right *node
|
||||
type node struct {
|
||||
feature int
|
||||
value []float64
|
||||
left *node
|
||||
right *node
|
||||
}
|
||||
|
||||
// Tree is a kdtree.
|
||||
type Tree struct{
|
||||
firstDiv *node
|
||||
type Tree struct {
|
||||
firstDiv *node
|
||||
}
|
||||
|
||||
// New return a Tree pointer.
|
||||
func New()*Tree{
|
||||
return &Tree{}
|
||||
func New() *Tree {
|
||||
return &Tree{}
|
||||
}
|
||||
|
||||
// Build builds the kdtree with specific data.
|
||||
func (t *Tree) Build(data [][]float64)err{
|
||||
if len(data)==0 {
|
||||
return errors.New("no input data")
|
||||
}
|
||||
size := len(data[0])
|
||||
for _, v := range data {
|
||||
if len(v) != size {
|
||||
return errors.New("amounts of features are not the same")
|
||||
}
|
||||
}
|
||||
func (t *Tree) Build(data [][]float64) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("no input data")
|
||||
}
|
||||
size := len(data[0])
|
||||
for _, v := range data {
|
||||
if len(v) != size {
|
||||
return errors.New("amounts of features are not the same")
|
||||
}
|
||||
}
|
||||
|
||||
t.firstDiv = t.buildHandle(data, 0)
|
||||
t.firstDiv = t.buildHandle(data, 0)
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildHandle builds the kdtree recursively.
|
||||
func (t *tree) buildHandle(data [][]float64, featureIndex int)*node{
|
||||
n := &node{feature:featureIndex}
|
||||
func (t *Tree) buildHandle(data [][]float64, featureIndex int) *node {
|
||||
n := &node{feature: featureIndex}
|
||||
|
||||
sort.Slice(data, func(i, j int)bool{
|
||||
return data[i][featureIndex]<data[j][featureIndex]
|
||||
})
|
||||
middle:= len(data)/2
|
||||
sort.Slice(data, func(i, j int) bool {
|
||||
return data[i][featureIndex] < data[j][featureIndex]
|
||||
})
|
||||
middle := len(data) / 2
|
||||
|
||||
n.value = make([]float64, len(data[middle]))
|
||||
copy(n.value, data[middle])
|
||||
n.value = make([]float64, len(data[middle]))
|
||||
copy(n.value, data[middle])
|
||||
|
||||
divPoint := middle
|
||||
for i:=middle+1 ; i<len(data) ; i++ {
|
||||
if data[i][featureIndex] == data[middle][featureIndex] {
|
||||
divPoint=i;
|
||||
}else{
|
||||
break
|
||||
}
|
||||
}
|
||||
divPoint := middle
|
||||
for i := middle + 1; i < len(data); i++ {
|
||||
if data[i][featureIndex] == data[middle][featureIndex] {
|
||||
divPoint = i
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if divPoint==1 {
|
||||
n.Left = &node{feature:-1}
|
||||
n.Left.value = make([]float64, len(data[0]))
|
||||
copy(n.Left.value, data[0])
|
||||
}else{
|
||||
n.Left = t.buildHandle(data[:divPoint])
|
||||
}
|
||||
if divPoint == 1 {
|
||||
n.left = &node{feature: -1}
|
||||
n.left.value = make([]float64, len(data[0]))
|
||||
copy(n.left.value, data[0])
|
||||
} else {
|
||||
n.left = t.buildHandle(data[:divPoint], (featureIndex+1)%len(data[0]))
|
||||
}
|
||||
|
||||
if divPoint==(len(data)-2) {
|
||||
n.Right = &node{feature:-1}
|
||||
n.Right.value = make([]float64, len(data[divPoint+1]))
|
||||
copy(n.Right.value, data[divPoint+1])
|
||||
}else if divPoint!=(len(data)-1){
|
||||
n.Right = t.buildHandle(data[divPoint+1:])
|
||||
}
|
||||
if divPoint == (len(data) - 2) {
|
||||
n.right = &node{feature: -1}
|
||||
n.right.value = make([]float64, len(data[divPoint+1]))
|
||||
copy(n.right.value, data[divPoint+1])
|
||||
} else if divPoint != (len(data) - 1) {
|
||||
n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(data[0]))
|
||||
}
|
||||
|
||||
return n
|
||||
return n
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user