1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

Search return length for weightedKNN

This commit is contained in:
FrozenKP 2017-04-17 19:47:13 +08:00
parent 7b765a2f18
commit 3a2782ffec
2 changed files with 12 additions and 10 deletions

View File

@ -111,29 +111,31 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node {
return n return n
} }
// Search return []int contained k nearest neighbor from // Search return srcRowNo([]int) and length([]float64) contained
// specific distance function. // k nearest neighbors from specific distance function.
func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, error) { func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, []float64, error) {
if k > len(t.data) { if k > len(t.data) {
return []int{}, errors.New("k is largerer than amount of trainData") return []int{}, []float64{}, errors.New("k is largerer than amount of trainData")
} }
if len(target) != len(t.data[0]) { if len(target) != len(t.data[0]) {
return []int{}, errors.New("amount of features is not equal") return []int{}, []float64{}, errors.New("amount of features is not equal")
} }
h := newHeap() h := newHeap()
t.searchHandle(k, disType, target, h, t.firstDiv) t.searchHandle(k, disType, target, h, t.firstDiv)
out := make([]int, k) srcRowNo := make([]int, k)
length := make([]float64, k)
i := k - 1 i := k - 1
for h.size() != 0 { for h.size() != 0 {
out[i] = h.maximum().srcRowNo srcRowNo[i] = h.maximum().srcRowNo
length[i] = h.maximum().length
i-- i--
h.extractMax() h.extractMax()
} }
return out, nil return srcRowNo, length, nil
} }
func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) { func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {

View File

@ -16,7 +16,7 @@ func TestKdtree(t *testing.T) {
euclidean := pairwise.NewEuclidean() euclidean := pairwise.NewEuclidean()
Convey("When k is 3 with euclidean", func() { Convey("When k is 3 with euclidean", func() {
result, _ := kd.Search(3, euclidean, []float64{7, 3}) result, _, _ := kd.Search(3, euclidean, []float64{7, 3})
Convey("The result[0] should be 4", func() { Convey("The result[0] should be 4", func() {
So(result[0], ShouldEqual, 4) So(result[0], ShouldEqual, 4)
@ -30,7 +30,7 @@ func TestKdtree(t *testing.T) {
}) })
Convey("When k is 2 with euclidean", func() { Convey("When k is 2 with euclidean", func() {
result, _ := kd.Search(2, euclidean, []float64{7, 3}) result, _, _ := kd.Search(2, euclidean, []float64{7, 3})
Convey("The result[0] should be 4", func() { Convey("The result[0] should be 4", func() {
So(result[0], ShouldEqual, 4) So(result[0], ShouldEqual, 4)