mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
213 lines
5.2 KiB
Go
213 lines
5.2 KiB
Go
package kdtree
|
|
|
|
import (
|
|
"errors"
|
|
"gonum.org/v1/gonum/mat"
|
|
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
|
"sort"
|
|
)
|
|
|
|
type node struct {
|
|
feature int
|
|
value []float64
|
|
srcRowNo int
|
|
left *node
|
|
right *node
|
|
}
|
|
|
|
// Tree is a kdtree.
|
|
type Tree struct {
|
|
firstDiv *node
|
|
data [][]float64
|
|
}
|
|
|
|
type SortData struct {
|
|
RowData [][]float64
|
|
Data []int
|
|
Feature int
|
|
}
|
|
|
|
func (d SortData) Len() int { return len(d.Data) }
|
|
func (d SortData) Less(i, j int) bool {
|
|
return d.RowData[d.Data[i]][d.Feature] < d.RowData[d.Data[j]][d.Feature]
|
|
}
|
|
func (d SortData) Swap(i, j int) { d.Data[i], d.Data[j] = d.Data[j], d.Data[i] }
|
|
|
|
// New return a Tree pointer.
|
|
func New() *Tree {
|
|
return &Tree{}
|
|
}
|
|
|
|
// Build builds the kdtree with specific data.
|
|
func (t *Tree) Build(data [][]float64) error {
|
|
if len(data) == 0 {
|
|
return errors.New("no input data")
|
|
}
|
|
size := len(data[0])
|
|
for _, v := range data {
|
|
if len(v) != size {
|
|
return errors.New("amounts of features are not the same")
|
|
}
|
|
}
|
|
|
|
t.data = data
|
|
|
|
newData := make([]int, len(data))
|
|
for k, _ := range newData {
|
|
newData[k] = k
|
|
}
|
|
|
|
if len(data) == 1 {
|
|
t.firstDiv = &node{feature: -1, srcRowNo: 0}
|
|
t.firstDiv.value = make([]float64, len(data[0]))
|
|
copy(t.firstDiv.value, data[0])
|
|
} else {
|
|
t.firstDiv = t.buildHandle(newData, 0)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// buildHandle builds the kdtree recursively.
|
|
func (t *Tree) buildHandle(data []int, featureIndex int) *node {
|
|
n := &node{feature: featureIndex}
|
|
|
|
tmp := SortData{RowData: t.data, Data: data, Feature: featureIndex}
|
|
sort.Sort(tmp)
|
|
middle := len(data) / 2
|
|
|
|
divPoint := middle
|
|
for i := middle + 1; i < len(data); i++ {
|
|
if t.data[data[i]][featureIndex] == t.data[data[middle]][featureIndex] {
|
|
divPoint = i
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
n.srcRowNo = data[divPoint]
|
|
n.value = make([]float64, len(t.data[data[divPoint]]))
|
|
copy(n.value, t.data[data[divPoint]])
|
|
|
|
if divPoint == 1 {
|
|
n.left = &node{feature: -1}
|
|
n.left.value = make([]float64, len(t.data[data[0]]))
|
|
copy(n.left.value, t.data[data[0]])
|
|
n.left.srcRowNo = data[0]
|
|
} else {
|
|
n.left = t.buildHandle(data[:divPoint], (featureIndex+1)%len(t.data[data[0]]))
|
|
}
|
|
|
|
if divPoint == (len(data) - 2) {
|
|
n.right = &node{feature: -1}
|
|
n.right.value = make([]float64, len(t.data[data[divPoint+1]]))
|
|
copy(n.right.value, t.data[data[divPoint+1]])
|
|
n.right.srcRowNo = data[divPoint+1]
|
|
} else if divPoint != (len(data) - 1) {
|
|
n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(t.data[data[0]]))
|
|
} else {
|
|
n.right = &node{feature: -2}
|
|
}
|
|
return n
|
|
}
|
|
|
|
// Search return srcRowNo([]int) and length([]float64) contained
|
|
// k nearest neighbors from specific distance function.
|
|
func (t *Tree) Search(k int, disType pairwise.PairwiseDistanceFunc, target []float64) ([]int, []float64, error) {
|
|
if k > len(t.data) {
|
|
return []int{}, []float64{}, errors.New("k is largerer than amount of trainData")
|
|
}
|
|
|
|
if len(target) != len(t.data[0]) {
|
|
return []int{}, []float64{}, errors.New("amount of features is not equal")
|
|
}
|
|
|
|
h := newHeap()
|
|
t.searchHandle(k, disType, target, h, t.firstDiv)
|
|
|
|
srcRowNo := make([]int, k)
|
|
length := make([]float64, k)
|
|
i := k - 1
|
|
for h.size() != 0 {
|
|
srcRowNo[i] = h.maximum().srcRowNo
|
|
length[i] = h.maximum().length
|
|
i--
|
|
h.extractMax()
|
|
}
|
|
|
|
return srcRowNo, length, nil
|
|
}
|
|
|
|
func (t *Tree) searchHandle(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
|
if n.feature == -1 {
|
|
vectorX := mat.NewDense(len(target), 1, target)
|
|
vectorY := mat.NewDense(len(target), 1, n.value)
|
|
length := disType.Distance(vectorX, vectorY)
|
|
h.insert(n.value, length, n.srcRowNo)
|
|
return
|
|
} else if n.feature == -2 {
|
|
return
|
|
}
|
|
|
|
dir := true
|
|
if target[n.feature] <= n.value[n.feature] {
|
|
t.searchHandle(k, disType, target, h, n.left)
|
|
} else {
|
|
dir = false
|
|
t.searchHandle(k, disType, target, h, n.right)
|
|
}
|
|
|
|
vectorX := mat.NewDense(len(target), 1, target)
|
|
vectorY := mat.NewDense(len(target), 1, n.value)
|
|
length := disType.Distance(vectorX, vectorY)
|
|
|
|
if k > h.size() {
|
|
h.insert(n.value, length, n.srcRowNo)
|
|
if dir {
|
|
t.searchAllNodes(k, disType, target, h, n.right)
|
|
} else {
|
|
t.searchAllNodes(k, disType, target, h, n.left)
|
|
}
|
|
} else if h.maximum().length > length {
|
|
h.extractMax()
|
|
h.insert(n.value, length, n.srcRowNo)
|
|
if dir {
|
|
t.searchAllNodes(k, disType, target, h, n.right)
|
|
} else {
|
|
t.searchAllNodes(k, disType, target, h, n.left)
|
|
}
|
|
} else {
|
|
vectorX = mat.NewDense(1, 1, []float64{target[n.feature]})
|
|
vectorY = mat.NewDense(1, 1, []float64{n.value[n.feature]})
|
|
length = disType.Distance(vectorX, vectorY)
|
|
|
|
if h.maximum().length > length {
|
|
if dir {
|
|
t.searchAllNodes(k, disType, target, h, n.right)
|
|
} else {
|
|
t.searchAllNodes(k, disType, target, h, n.left)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *Tree) searchAllNodes(k int, disType pairwise.PairwiseDistanceFunc, target []float64, h *heap, n *node) {
|
|
vectorX := mat.NewDense(len(target), 1, target)
|
|
vectorY := mat.NewDense(len(target), 1, n.value)
|
|
length := disType.Distance(vectorX, vectorY)
|
|
|
|
if k > h.size() {
|
|
h.insert(n.value, length, n.srcRowNo)
|
|
} else if h.maximum().length > length {
|
|
h.extractMax()
|
|
h.insert(n.value, length, n.srcRowNo)
|
|
}
|
|
|
|
if n.left != nil {
|
|
t.searchAllNodes(k, disType, target, h, n.left)
|
|
}
|
|
if n.right != nil {
|
|
t.searchAllNodes(k, disType, target, h, n.right)
|
|
}
|
|
}
|