mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Search return length for weightedKNN
This commit is contained in:
parent
7b765a2f18
commit
3a2782ffec
@ -111,29 +111,31 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node {
|
||||
return n
|
||||
}
|
||||
|
||||
// Search return []int contained k nearest neighbor from
|
||||
// specific distance function.
|
||||
func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, error) {
|
||||
// Search return srcRowNo([]int) and length([]float64) contained
|
||||
// k nearest neighbors from specific distance function.
|
||||
func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, []float64, error) {
|
||||
if k > len(t.data) {
|
||||
return []int{}, errors.New("k is largerer than amount of trainData")
|
||||
return []int{}, []float64{}, errors.New("k is largerer than amount of trainData")
|
||||
}
|
||||
|
||||
if len(target) != len(t.data[0]) {
|
||||
return []int{}, errors.New("amount of features is not equal")
|
||||
return []int{}, []float64{}, errors.New("amount of features is not equal")
|
||||
}
|
||||
|
||||
h := newHeap()
|
||||
t.searchHandle(k, disType, target, h, t.firstDiv)
|
||||
|
||||
out := make([]int, k)
|
||||
srcRowNo := make([]int, k)
|
||||
length := make([]float64, k)
|
||||
i := k - 1
|
||||
for h.size() != 0 {
|
||||
out[i] = h.maximum().srcRowNo
|
||||
srcRowNo[i] = h.maximum().srcRowNo
|
||||
length[i] = h.maximum().length
|
||||
i--
|
||||
h.extractMax()
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return srcRowNo, length, nil
|
||||
}
|
||||
|
||||
func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
||||
|
@ -16,7 +16,7 @@ func TestKdtree(t *testing.T) {
|
||||
euclidean := pairwise.NewEuclidean()
|
||||
|
||||
Convey("When k is 3 with euclidean", func() {
|
||||
result, _ := kd.Search(3, euclidean, []float64{7, 3})
|
||||
result, _, _ := kd.Search(3, euclidean, []float64{7, 3})
|
||||
|
||||
Convey("The result[0] should be 4", func() {
|
||||
So(result[0], ShouldEqual, 4)
|
||||
@ -30,7 +30,7 @@ func TestKdtree(t *testing.T) {
|
||||
})
|
||||
|
||||
Convey("When k is 2 with euclidean", func() {
|
||||
result, _ := kd.Search(2, euclidean, []float64{7, 3})
|
||||
result, _, _ := kd.Search(2, euclidean, []float64{7, 3})
|
||||
|
||||
Convey("The result[0] should be 4", func() {
|
||||
So(result[0], ShouldEqual, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user