mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-05-03 22:17:14 +08:00
Merge pull request #172 from FrozenKP/master
[kdtree] fixed a bug and change function name
This commit is contained in:
commit
7a9e119010
@ -76,10 +76,6 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node {
|
||||
sort.Sort(tmp)
|
||||
middle := len(data) / 2
|
||||
|
||||
n.srcRowNo = data[middle]
|
||||
n.value = make([]float64, len(t.data[data[middle]]))
|
||||
copy(n.value, t.data[data[middle]])
|
||||
|
||||
divPoint := middle
|
||||
for i := middle + 1; i < len(data); i++ {
|
||||
if t.data[data[i]][featureIndex] == t.data[data[middle]][featureIndex] {
|
||||
@ -89,6 +85,10 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node {
|
||||
}
|
||||
}
|
||||
|
||||
n.srcRowNo = data[divPoint]
|
||||
n.value = make([]float64, len(t.data[data[divPoint]]))
|
||||
copy(n.value, t.data[data[divPoint]])
|
||||
|
||||
if divPoint == 1 {
|
||||
n.left = &node{feature: -1}
|
||||
n.left.value = make([]float64, len(t.data[data[0]]))
|
||||
@ -102,37 +102,40 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node {
|
||||
n.right = &node{feature: -1}
|
||||
n.right.value = make([]float64, len(t.data[data[divPoint+1]]))
|
||||
copy(n.right.value, t.data[data[divPoint+1]])
|
||||
n.left.srcRowNo = data[divPoint+1]
|
||||
n.right.srcRowNo = data[divPoint+1]
|
||||
} else if divPoint != (len(data) - 1) {
|
||||
n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(t.data[data[0]]))
|
||||
} else {
|
||||
n.right = &node{feature: -2}
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Search return []int contained k nearest neighbor from
|
||||
// specific distance function.
|
||||
func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, error) {
|
||||
// Search return srcRowNo([]int) and length([]float64) contained
|
||||
// k nearest neighbors from specific distance function.
|
||||
func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, []float64, error) {
|
||||
if k > len(t.data) {
|
||||
return []int{}, errors.New("k is largerer than amount of trainData")
|
||||
return []int{}, []float64{}, errors.New("k is largerer than amount of trainData")
|
||||
}
|
||||
|
||||
if len(target) != len(t.data[0]) {
|
||||
return []int{}, errors.New("amount of features is not equal")
|
||||
return []int{}, []float64{}, errors.New("amount of features is not equal")
|
||||
}
|
||||
|
||||
h := newHeap()
|
||||
t.searchHandle(k, disType, target, h, t.firstDiv)
|
||||
|
||||
out := make([]int, k)
|
||||
srcRowNo := make([]int, k)
|
||||
length := make([]float64, k)
|
||||
i := k - 1
|
||||
for h.size() != 0 {
|
||||
out[i] = h.maximum().srcRowNo
|
||||
srcRowNo[i] = h.maximum().srcRowNo
|
||||
length[i] = h.maximum().length
|
||||
i--
|
||||
h.extractMax()
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return srcRowNo, length, nil
|
||||
}
|
||||
|
||||
func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
||||
@ -142,6 +145,8 @@ func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target
|
||||
length := disType.Distance(vectorX, vectorY)
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
return
|
||||
} else if n.feature == -2 {
|
||||
return
|
||||
}
|
||||
|
||||
dir := true
|
||||
@ -159,22 +164,22 @@ func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target
|
||||
if k > h.size() {
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
if dir {
|
||||
t.searchAllNode(k, disType, target, h, n.right)
|
||||
t.searchAllNodes(k, disType, target, h, n.right)
|
||||
} else {
|
||||
t.searchAllNode(k, disType, target, h, n.left)
|
||||
t.searchAllNodes(k, disType, target, h, n.left)
|
||||
}
|
||||
} else if h.maximum().length > length {
|
||||
h.extractMax()
|
||||
h.insert(n.value, length, n.srcRowNo)
|
||||
if dir {
|
||||
t.searchAllNode(k, disType, target, h, n.right)
|
||||
t.searchAllNodes(k, disType, target, h, n.right)
|
||||
} else {
|
||||
t.searchAllNode(k, disType, target, h, n.left)
|
||||
t.searchAllNodes(k, disType, target, h, n.left)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tree) searchAllNode(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
||||
func (t *Tree) searchAllNodes(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
||||
vectorX := mat64.NewDense(len(target), 1, target)
|
||||
vectorY := mat64.NewDense(len(target), 1, n.value)
|
||||
length := disType.Distance(vectorX, vectorY)
|
||||
@ -187,9 +192,9 @@ func (t *Tree) searchAllNode(k int, disType pairwise.PairwiseDistanceFunc, targe
|
||||
}
|
||||
|
||||
if n.left != nil {
|
||||
t.searchAllNode(k, disType, target, h, n.left)
|
||||
t.searchAllNodes(k, disType, target, h, n.left)
|
||||
}
|
||||
if n.right != nil {
|
||||
t.searchAllNode(k, disType, target, h, n.right)
|
||||
t.searchAllNodes(k, disType, target, h, n.right)
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ func TestKdtree(t *testing.T) {
|
||||
euclidean := pairwise.NewEuclidean()
|
||||
|
||||
Convey("When k is 3 with euclidean", func() {
|
||||
result, _ := kd.Search(3, euclidean, []float64{7, 3})
|
||||
result, _, _ := kd.Search(3, euclidean, []float64{7, 3})
|
||||
|
||||
Convey("The result[0] should be 4", func() {
|
||||
So(result[0], ShouldEqual, 4)
|
||||
@ -30,7 +30,7 @@ func TestKdtree(t *testing.T) {
|
||||
})
|
||||
|
||||
Convey("When k is 2 with euclidean", func() {
|
||||
result, _ := kd.Search(2, euclidean, []float64{7, 3})
|
||||
result, _, _ := kd.Search(2, euclidean, []float64{7, 3})
|
||||
|
||||
Convey("The result[0] should be 4", func() {
|
||||
So(result[0], ShouldEqual, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user