diff --git a/kdtree/kdtree.go b/kdtree/kdtree.go index a88a4ab..2797719 100644 --- a/kdtree/kdtree.go +++ b/kdtree/kdtree.go @@ -111,29 +111,31 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { return n } -// Search return []int contained k nearest neighbor from -// specific distance function. -func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, error) { +// Search return srcRowNo([]int) and length([]float64) contained +// k nearest neighbors from specific distance function. +func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, []float64, error) { if k > len(t.data) { - return []int{}, errors.New("k is largerer than amount of trainData") + return []int{}, []float64{}, errors.New("k is largerer than amount of trainData") } if len(target) != len(t.data[0]) { - return []int{}, errors.New("amount of features is not equal") + return []int{}, []float64{}, errors.New("amount of features is not equal") } h := newHeap() t.searchHandle(k, disType, target, h, t.firstDiv) - out := make([]int, k) + srcRowNo := make([]int, k) + length := make([]float64, k) i := k - 1 for h.size() != 0 { - out[i] = h.maximum().srcRowNo + srcRowNo[i] = h.maximum().srcRowNo + length[i] = h.maximum().length i-- h.extractMax() } - return out, nil + return srcRowNo, length, nil } func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) { diff --git a/kdtree/kdtree_test.go b/kdtree/kdtree_test.go index c8f80ba..8d2a565 100644 --- a/kdtree/kdtree_test.go +++ b/kdtree/kdtree_test.go @@ -16,7 +16,7 @@ func TestKdtree(t *testing.T) { euclidean := pairwise.NewEuclidean() Convey("When k is 3 with euclidean", func() { - result, _ := kd.Search(3, euclidean, []float64{7, 3}) + result, _, _ := kd.Search(3, euclidean, []float64{7, 3}) Convey("The result[0] should be 4", func() { So(result[0], ShouldEqual, 4) @@ -30,7 +30,7 @@ func TestKdtree(t *testing.T) { }) Convey("When k is 2 with euclidean", func() { - result, _ := kd.Search(2, euclidean, []float64{7, 3}) + result, _, _ := kd.Search(2, euclidean, []float64{7, 3}) Convey("The result[0] should be 4", func() { So(result[0], ShouldEqual, 4)