From af0909a40e13ddec5033ac2688e5ae0a95babec8 Mon Sep 17 00:00:00 2001 From: FrozenKP Date: Mon, 17 Apr 2017 08:33:05 +0800 Subject: [PATCH 1/4] update to latest and change searchAllNode to searchAllNodes --- kdtree/kdtree.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/kdtree/kdtree.go b/kdtree/kdtree.go index 59251a5..e379979 100644 --- a/kdtree/kdtree.go +++ b/kdtree/kdtree.go @@ -159,22 +159,22 @@ func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target if k > h.size() { h.insert(n.value, length, n.srcRowNo) if dir { - t.searchAllNode(k, disType, target, h, n.right) + t.searchAllNodes(k, disType, target, h, n.right) } else { - t.searchAllNode(k, disType, target, h, n.left) + t.searchAllNodes(k, disType, target, h, n.left) } } else if h.maximum().length > length { h.extractMax() h.insert(n.value, length, n.srcRowNo) if dir { - t.searchAllNode(k, disType, target, h, n.right) + t.searchAllNodes(k, disType, target, h, n.right) } else { - t.searchAllNode(k, disType, target, h, n.left) + t.searchAllNodes(k, disType, target, h, n.left) } } } -func (t *Tree) searchAllNode(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) { +func (t *Tree) searchAllNodes(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) { vectorX := mat64.NewDense(len(target), 1, target) vectorY := mat64.NewDense(len(target), 1, n.value) length := disType.Distance(vectorX, vectorY) @@ -187,9 +187,9 @@ func (t *Tree) searchAllNode(k int, disType pairwise.PairwiseDistanceFunc, targe } if n.left != nil { - t.searchAllNode(k, disType, target, h, n.left) + t.searchAllNodes(k, disType, target, h, n.left) } if n.right != nil { - t.searchAllNode(k, disType, target, h, n.right) + t.searchAllNodes(k, disType, target, h, n.right) } } From ac0a2e1fc2e886aff39c098ca6932fd7134c5291 Mon Sep 17 00:00:00 2001 From: FrozenKP Date: Mon, 17 Apr 2017 12:01:57 +0800 Subject: [PATCH 2/4] find bug when testing knn with kdtree -> fixed --- kdtree/kdtree.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kdtree/kdtree.go b/kdtree/kdtree.go index e379979..589a857 100644 --- a/kdtree/kdtree.go +++ b/kdtree/kdtree.go @@ -105,6 +105,8 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { n.left.srcRowNo = data[divPoint+1] } else if divPoint != (len(data) - 1) { n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(t.data[data[0]])) + } else { + n.right = &node{feature: -2} } return n @@ -142,6 +144,8 @@ func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target length := disType.Distance(vectorX, vectorY) h.insert(n.value, length, n.srcRowNo) return + } else if n.feature == -2 { + return } dir := true From 041b0b25904d9f12e1f8acd457ec5814d8985501 Mon Sep 17 00:00:00 2001 From: FrozenKP Date: Mon, 17 Apr 2017 14:13:01 +0800 Subject: [PATCH 3/4] find a bug of srcRowNo and fixed --- kdtree/kdtree.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/kdtree/kdtree.go b/kdtree/kdtree.go index 589a857..a88a4ab 100644 --- a/kdtree/kdtree.go +++ b/kdtree/kdtree.go @@ -76,10 +76,6 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { sort.Sort(tmp) middle := len(data) / 2 - n.srcRowNo = data[middle] - n.value = make([]float64, len(t.data[data[middle]])) - copy(n.value, t.data[data[middle]]) - divPoint := middle for i := middle + 1; i < len(data); i++ { if t.data[data[i]][featureIndex] == t.data[data[middle]][featureIndex] { @@ -89,6 +85,10 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { } } + n.srcRowNo = data[divPoint] + n.value = make([]float64, len(t.data[data[divPoint]])) + copy(n.value, t.data[data[divPoint]]) + if divPoint == 1 { n.left = &node{feature: -1} n.left.value = make([]float64, len(t.data[data[0]])) @@ -102,13 +102,12 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { n.right = &node{feature: -1} n.right.value = make([]float64, len(t.data[data[divPoint+1]])) copy(n.right.value, t.data[data[divPoint+1]]) - n.left.srcRowNo = data[divPoint+1] + n.right.srcRowNo = data[divPoint+1] } else if divPoint != (len(data) - 1) { n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(t.data[data[0]])) } else { n.right = &node{feature: -2} } - return n } From 543fe7a48490df748f7d248e5233d26f20b8f920 Mon Sep 17 00:00:00 2001 From: FrozenKP Date: Mon, 17 Apr 2017 19:54:17 +0800 Subject: [PATCH 4/4] Search return length for weightedKNN --- kdtree/kdtree.go | 18 ++++++++++-------- kdtree/kdtree_test.go | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/kdtree/kdtree.go b/kdtree/kdtree.go index a88a4ab..2797719 100644 --- a/kdtree/kdtree.go +++ b/kdtree/kdtree.go @@ -111,29 +111,31 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { return n } -// Search return []int contained k nearest neighbor from -// specific distance function. -func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, error) { +// Search return srcRowNo([]int) and length([]float64) contained +// k nearest neighbors from specific distance function. +func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, []float64, error) { if k > len(t.data) { - return []int{}, errors.New("k is largerer than amount of trainData") + return []int{}, []float64{}, errors.New("k is largerer than amount of trainData") } if len(target) != len(t.data[0]) { - return []int{}, errors.New("amount of features is not equal") + return []int{}, []float64{}, errors.New("amount of features is not equal") } h := newHeap() t.searchHandle(k, disType, target, h, t.firstDiv) - out := make([]int, k) + srcRowNo := make([]int, k) + length := make([]float64, k) i := k - 1 for h.size() != 0 { - out[i] = h.maximum().srcRowNo + srcRowNo[i] = h.maximum().srcRowNo + length[i] = h.maximum().length i-- h.extractMax() } - return out, nil + return srcRowNo, length, nil } func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) { diff --git a/kdtree/kdtree_test.go b/kdtree/kdtree_test.go index c8f80ba..8d2a565 100644 --- a/kdtree/kdtree_test.go +++ b/kdtree/kdtree_test.go @@ -16,7 +16,7 @@ func TestKdtree(t *testing.T) { euclidean := pairwise.NewEuclidean() Convey("When k is 3 with euclidean", func() { - result, _ := kd.Search(3, euclidean, []float64{7, 3}) + result, _, _ := kd.Search(3, euclidean, []float64{7, 3}) Convey("The result[0] should be 4", func() { So(result[0], ShouldEqual, 4) @@ -30,7 +30,7 @@ func TestKdtree(t *testing.T) { }) Convey("When k is 2 with euclidean", func() { - result, _ := kd.Search(2, euclidean, []float64{7, 3}) + result, _, _ := kd.Search(2, euclidean, []float64{7, 3}) Convey("The result[0] should be 4", func() { So(result[0], ShouldEqual, 4)