diff --git a/kdtree/kdtree.go b/kdtree/kdtree.go index 589a857..a88a4ab 100644 --- a/kdtree/kdtree.go +++ b/kdtree/kdtree.go @@ -76,10 +76,6 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { sort.Sort(tmp) middle := len(data) / 2 - n.srcRowNo = data[middle] - n.value = make([]float64, len(t.data[data[middle]])) - copy(n.value, t.data[data[middle]]) - divPoint := middle for i := middle + 1; i < len(data); i++ { if t.data[data[i]][featureIndex] == t.data[data[middle]][featureIndex] { @@ -89,6 +85,10 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { } } + n.srcRowNo = data[divPoint] + n.value = make([]float64, len(t.data[data[divPoint]])) + copy(n.value, t.data[data[divPoint]]) + if divPoint == 1 { n.left = &node{feature: -1} n.left.value = make([]float64, len(t.data[data[0]])) @@ -102,13 +102,12 @@ func (t *Tree) buildHandle(data []int, featureIndex int) *node { n.right = &node{feature: -1} n.right.value = make([]float64, len(t.data[data[divPoint+1]])) copy(n.right.value, t.data[data[divPoint+1]]) - n.left.srcRowNo = data[divPoint+1] + n.right.srcRowNo = data[divPoint+1] } else if divPoint != (len(data) - 1) { n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(t.data[data[0]])) } else { n.right = &node{feature: -2} } - return n }