mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Merge pull request #5 from sjwhitworth/master
Bringing master up to date
This commit is contained in:
commit
6aa37aca00
1460
examples/datasets/boston_house_prices.csv
Normal file
1460
examples/datasets/boston_house_prices.csv
Normal file
File diff suppressed because it is too large
Load Diff
889
examples/datasets/titanic.csv
Normal file
889
examples/datasets/titanic.csv
Normal file
@ -0,0 +1,889 @@
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,0,2,1
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
2,0,0,1
|
||||
3,0,2,1
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
3,1,1,0
|
||||
2,1,2,1
|
||||
3,0,2,0
|
||||
3,0,0,1
|
||||
2,1,2,0
|
||||
2,1,2,1
|
||||
3,0,1,1
|
||||
1,1,2,1
|
||||
3,0,2,0
|
||||
3,0,2,1
|
||||
3,1,0,0
|
||||
1,1,2,0
|
||||
3,0,1,1
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
1,0,0,1
|
||||
3,0,1,1
|
||||
2,1,2,0
|
||||
1,1,0,0
|
||||
1,1,2,0
|
||||
3,1,0,1
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,0,0,1
|
||||
3,0,2,0
|
||||
2,0,2,0
|
||||
3,1,0,0
|
||||
2,0,0,1
|
||||
3,0,1,1
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
3,0,1,1
|
||||
3,1,0,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
2,0,2,1
|
||||
1,1,0,0
|
||||
1,1,2,1
|
||||
2,0,2,1
|
||||
3,1,0,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
3,1,0,1
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
3,1,0,0
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,1,2,1
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
3,0,1,1
|
||||
1,1,2,0
|
||||
2,0,2,1
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
1,1,0,1
|
||||
2,0,2,1
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
3,0,1,1
|
||||
1,1,2,0
|
||||
3,0,0,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,0,0,0
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
2,1,2,0
|
||||
1,1,0,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
2,1,0,0
|
||||
2,0,2,1
|
||||
1,1,2,0
|
||||
3,1,0,1
|
||||
3,1,1,0
|
||||
3,1,2,1
|
||||
3,0,0,1
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
2,1,2,0
|
||||
2,1,0,0
|
||||
1,0,2,1
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
3,0,0,0
|
||||
3,0,2,1
|
||||
3,0,2,1
|
||||
3,1,1,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,1
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
3,0,1,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
1,0,2,1
|
||||
3,0,2,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
3,1,1,0
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,0,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
2,1,0,0
|
||||
3,1,2,0
|
||||
2,1,2,1
|
||||
3,0,2,1
|
||||
1,1,2,0
|
||||
3,0,1,1
|
||||
1,1,2,1
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
2,1,2,0
|
||||
3,0,2,1
|
||||
2,1,2,1
|
||||
1,0,0,1
|
||||
1,0,0,1
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
3,0,1,1
|
||||
2,0,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
3,1,2,1
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,1,0,1
|
||||
3,0,1,1
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,1,1,0
|
||||
1,0,0,1
|
||||
3,0,2,1
|
||||
2,1,2,0
|
||||
1,0,0,1
|
||||
2,1,2,0
|
||||
3,1,2,1
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,1
|
||||
3,1,2,0
|
||||
2,1,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,0,2,1
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
3,0,0,0
|
||||
3,0,1,1
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
1,1,1,0
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
1,1,2,1
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,0,0,1
|
||||
1,0,0,1
|
||||
1,0,2,1
|
||||
1,0,0,1
|
||||
2,0,2,1
|
||||
3,1,1,0
|
||||
3,1,2,1
|
||||
1,1,2,0
|
||||
1,1,2,0
|
||||
3,0,1,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
1,0,2,1
|
||||
1,0,2,1
|
||||
1,1,2,0
|
||||
3,1,2,1
|
||||
2,0,2,1
|
||||
1,1,0,0
|
||||
3,0,1,1
|
||||
1,0,2,1
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
3,1,1,0
|
||||
3,0,2,1
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
1,1,2,0
|
||||
3,1,0,0
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
2,1,2,1
|
||||
3,0,1,1
|
||||
1,0,2,1
|
||||
1,0,0,1
|
||||
2,1,0,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
3,1,0,0
|
||||
1,0,2,0
|
||||
1,1,2,1
|
||||
1,0,0,1
|
||||
3,0,1,1
|
||||
3,1,1,1
|
||||
3,1,2,0
|
||||
2,0,1,1
|
||||
3,1,2,0
|
||||
1,1,2,1
|
||||
1,0,0,1
|
||||
1,0,0,1
|
||||
2,1,0,0
|
||||
1,0,0,1
|
||||
1,0,0,1
|
||||
1,0,0,1
|
||||
2,0,2,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,0,2,1
|
||||
2,0,2,1
|
||||
2,1,2,0
|
||||
1,0,2,1
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,1,1
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,0,2,1
|
||||
1,0,0,1
|
||||
3,0,1,1
|
||||
1,1,2,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
1,0,0,1
|
||||
3,1,2,1
|
||||
1,1,2,0
|
||||
2,1,2,1
|
||||
1,0,2,1
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
2,0,2,1
|
||||
3,0,2,1
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
3,1,0,0
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
2,0,2,0
|
||||
3,0,1,1
|
||||
3,0,1,1
|
||||
3,1,2,0
|
||||
2,1,0,0
|
||||
3,0,0,0
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,0,0,1
|
||||
3,0,1,1
|
||||
1,0,0,1
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
3,0,2,0
|
||||
1,0,0,1
|
||||
3,0,2,1
|
||||
1,1,0,0
|
||||
3,1,0,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,0,0,1
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,1,0
|
||||
2,0,0,1
|
||||
1,1,2,1
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
2,1,2,1
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
1,0,1,1
|
||||
2,1,2,0
|
||||
3,1,2,1
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
2,0,2,1
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
3,1,0,0
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
2,0,2,1
|
||||
3,1,1,0
|
||||
3,1,2,1
|
||||
1,1,2,1
|
||||
3,0,2,1
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
1,0,2,1
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
1,1,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,1
|
||||
1,1,2,1
|
||||
2,0,2,1
|
||||
1,1,2,1
|
||||
3,0,0,1
|
||||
1,1,2,1
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
1,1,0,0
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
3,1,0,1
|
||||
1,1,2,0
|
||||
1,0,2,1
|
||||
2,0,2,1
|
||||
3,1,1,0
|
||||
1,1,2,1
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
1,1,2,0
|
||||
3,1,1,0
|
||||
3,0,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
2,0,0,1
|
||||
3,0,2,0
|
||||
1,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
1,1,0,1
|
||||
3,0,2,0
|
||||
1,0,2,1
|
||||
1,1,0,0
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
1,1,0,0
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
1,0,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,1,0
|
||||
3,0,1,0
|
||||
3,0,2,0
|
||||
1,0,2,1
|
||||
1,1,0,0
|
||||
2,0,2,1
|
||||
1,1,2,1
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
3,1,1,1
|
||||
3,1,2,0
|
||||
1,1,2,1
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
2,0,2,1
|
||||
3,1,1,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
1,0,0,1
|
||||
3,1,0,0
|
||||
3,1,1,0
|
||||
2,0,2,1
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
3,1,0,0
|
||||
3,1,0,0
|
||||
3,0,0,1
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
1,1,2,0
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
1,0,2,1
|
||||
3,0,2,0
|
||||
3,0,2,0
|
||||
2,1,2,1
|
||||
1,1,0,0
|
||||
1,1,2,0
|
||||
2,0,2,1
|
||||
2,1,0,1
|
||||
3,1,2,0
|
||||
2,1,2,1
|
||||
1,1,0,1
|
||||
2,1,2,0
|
||||
3,1,1,0
|
||||
3,1,0,1
|
||||
3,0,2,1
|
||||
1,1,2,0
|
||||
1,0,0,1
|
||||
1,1,0,0
|
||||
1,0,2,1
|
||||
3,0,2,1
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
3,1,0,0
|
||||
3,1,2,1
|
||||
2,1,2,1
|
||||
1,0,2,1
|
||||
1,1,2,1
|
||||
3,0,1,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
1,0,2,1
|
||||
3,0,0,0
|
||||
3,1,2,1
|
||||
2,0,2,1
|
||||
1,0,0,1
|
||||
2,1,2,0
|
||||
1,1,0,0
|
||||
3,1,0,0
|
||||
1,0,2,1
|
||||
2,1,2,0
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
3,0,1,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
1,1,0,1
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,1
|
||||
2,0,0,1
|
||||
1,0,2,1
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,0,1,1
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
2,1,2,0
|
||||
3,1,0,0
|
||||
1,1,2,1
|
||||
3,1,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
2,1,1,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
1,1,2,1
|
||||
3,1,2,0
|
||||
1,1,0,1
|
||||
1,1,2,0
|
||||
3,0,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,0,2,0
|
||||
3,1,2,1
|
||||
3,0,0,1
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,0,1,1
|
||||
3,0,1,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,1,0
|
||||
2,1,2,0
|
||||
1,1,0,0
|
||||
1,1,2,1
|
||||
3,1,0,0
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
2,0,2,1
|
||||
1,1,2,0
|
||||
2,1,2,0
|
||||
2,1,2,1
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
3,0,2,0
|
||||
1,1,0,1
|
||||
3,0,1,0
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
2,1,0,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
1,1,2,1
|
||||
3,0,0,1
|
||||
3,1,2,1
|
||||
3,1,0,0
|
||||
1,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,1,1
|
||||
1,1,0,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
1,1,2,1
|
||||
3,0,0,0
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
1,1,2,1
|
||||
1,0,2,1
|
||||
3,1,0,1
|
||||
1,0,0,1
|
||||
1,1,2,0
|
||||
1,1,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
2,0,2,1
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
1,1,2,1
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
3,0,1,1
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
1,0,2,1
|
||||
3,1,0,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,1,2,1
|
||||
1,1,2,0
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
1,1,2,0
|
||||
3,1,1,0
|
||||
2,0,2,1
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,1
|
||||
2,1,2,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,0,1
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
1,1,0,0
|
||||
3,0,1,0
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,0,2,0
|
||||
3,1,0,0
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
3,0,2,1
|
||||
3,1,1,0
|
||||
1,0,2,1
|
||||
3,0,0,1
|
||||
1,0,2,1
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,2,1
|
||||
3,1,1,0
|
||||
3,1,2,1
|
||||
1,1,0,0
|
||||
3,1,1,0
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
1,1,0,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
1,0,2,1
|
||||
3,0,2,1
|
||||
3,1,0,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
1,1,2,1
|
||||
3,1,0,1
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
3,0,2,0
|
||||
3,1,2,0
|
||||
1,1,2,0
|
||||
3,0,2,0
|
||||
2,1,0,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
3,1,2,1
|
||||
1,1,2,0
|
||||
3,0,2,1
|
||||
3,1,2,0
|
||||
3,1,1,0
|
||||
3,1,2,0
|
||||
2,1,0,1
|
||||
3,1,1,1
|
||||
3,0,0,1
|
||||
2,1,2,1
|
||||
3,1,0,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
1,1,0,1
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
1,0,0,1
|
||||
3,1,0,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,0,0
|
||||
2,1,2,0
|
||||
1,0,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,0,0,0
|
||||
1,0,2,1
|
||||
2,0,2,0
|
||||
3,0,2,1
|
||||
1,0,2,1
|
||||
1,1,2,1
|
||||
3,0,0,1
|
||||
3,1,0,0
|
||||
3,1,2,0
|
||||
2,1,2,0
|
||||
1,0,2,1
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
2,0,2,1
|
||||
2,0,0,1
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,1
|
||||
3,1,2,0
|
||||
1,0,2,1
|
||||
1,1,2,0
|
||||
3,1,2,0
|
||||
2,0,0,1
|
||||
3,0,0,1
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
3,1,2,0
|
||||
1,0,0,1
|
||||
2,0,2,1
|
||||
3,1,2,0
|
||||
3,0,2,0
|
||||
2,1,2,0
|
||||
3,1,2,0
|
||||
3,0,1,0
|
||||
2,1,2,0
|
||||
1,0,2,1
|
||||
3,0,2,0
|
||||
1,1,0,1
|
||||
3,1,1,0
|
|
92
examples/trees/cart.go
Normal file
92
examples/trees/cart.go
Normal file
@ -0,0 +1,92 @@
|
||||
// Example of how to use CART trees for both Classification and Regression
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/trees"
|
||||
|
||||
)
|
||||
|
||||
func main() {
|
||||
/* Performance of CART Algorithm:
|
||||
|
||||
Training Time for Titanic Dataset ≈ 611 µs
|
||||
Prediction Time for Titanic Datset ≈ 101 µs
|
||||
|
||||
Complexity Analysis:
|
||||
1x Dataset -- x ms
|
||||
2x Dataset -- 1.7x ms
|
||||
128x Dataset -- 74x ms
|
||||
|
||||
Complexity is sub linear
|
||||
|
||||
Sklearn:
|
||||
Training Time for Titanic Dataset ≈ 8.8 µs
|
||||
Prediction Time for Titanic Datset ≈ 7.87 µs
|
||||
|
||||
|
||||
This implementation and sci-kit learn produce the exact same tree for the exact same dataset.
|
||||
Predictions on the same test set also yield the exact same accuracy.
|
||||
|
||||
This implementation is optimized to prevent redundant iterations over the dataset, but it is not completely optimized. Also, sklearn makes use of numpy to access column easily, whereas here a complete iteration is required.
|
||||
In terms of Hyperparameters, this implmentation gives you the ability to choose the impurity function and the maxDepth.
|
||||
Many of the other hyperparameters used in sklearn are not here, but pruning and impurity is included.
|
||||
*/
|
||||
|
||||
// Load Titanic Data For classification
|
||||
classificationData, err := base.ParseCSVToInstances("../datasets/titanic.csv", false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(classificationData, 0.5)
|
||||
|
||||
// Create New Classification Tree
|
||||
// Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels
|
||||
decTree := NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
|
||||
|
||||
// Train Tree
|
||||
err = decTree.Fit(trainData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Print out tree for visualization - shows splits and feature and predictions
|
||||
fmt.Println(decTree.String())
|
||||
|
||||
// Access Predictions
|
||||
classificationPreds := decTree.Predict(testData)
|
||||
|
||||
fmt.Println("Titanic Predictions")
|
||||
fmt.Println(classificationPreds)
|
||||
|
||||
// Evaluate Accuracy on Test Data
|
||||
fmt.Println(decTree.Evaluate(testData))
|
||||
|
||||
// Load House Price Data For Regression
|
||||
regressionData, err := base.ParseCSVToInstances("../datasets/boston_house_prices.csv", false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainRegData, testRegData := base.InstancesTrainTestSplit(regressionData, 0.5)
|
||||
|
||||
// Hyperparameters - Loss function, max Depth (-1 will split until pure)
|
||||
regTree := NewDecisionTreeRegressor("mse", -1)
|
||||
|
||||
// Train Tree
|
||||
err = regTree.Fit(trainRegData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Print out tree for visualization
|
||||
fmt.Println(regTree.String())
|
||||
|
||||
// Access Predictions
|
||||
regressionPreds := regTree.Predict(testRegData)
|
||||
|
||||
fmt.Println("Boston House Price Predictions")
|
||||
fmt.Println(regressionPreds)
|
||||
|
||||
}
|
441
trees/cart_classifier.go
Normal file
441
trees/cart_classifier.go
Normal file
@ -0,0 +1,441 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
const (
|
||||
GINI string = "gini"
|
||||
ENTROPY string = "entropy"
|
||||
)
|
||||
|
||||
// CNode is Node struct for Decision Tree Classifier.
|
||||
// It holds the information for each split (which feature to use, what threshold, and which label to assign for each side of the split)
|
||||
type classifierNode struct {
|
||||
Left *classifierNode
|
||||
Right *classifierNode
|
||||
Threshold float64
|
||||
Feature int64
|
||||
LeftLabel int64
|
||||
RightLabel int64
|
||||
isNodeNeeded bool
|
||||
}
|
||||
|
||||
// CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
|
||||
// It contains the rootNode, as well as all of the hyperparameters chosen by the user.
|
||||
// It also keeps track of all splits done at the tree level.
|
||||
type CARTDecisionTreeClassifier struct {
|
||||
RootNode *classifierNode
|
||||
criterion string
|
||||
maxDepth int64
|
||||
labels []int64
|
||||
triedSplits [][]float64
|
||||
}
|
||||
|
||||
// Convert a series of labels to frequency map for efficient impurity calculation
|
||||
func convertToMap(y []int64, labels []int64) map[int64]int {
|
||||
labelCount := make(map[int64]int)
|
||||
for _, label := range labels {
|
||||
labelCount[label] = 0
|
||||
}
|
||||
for _, value := range y {
|
||||
labelCount[value]++
|
||||
}
|
||||
return labelCount
|
||||
}
|
||||
|
||||
// Calculate Gini Impurity of Target Labels
|
||||
func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
||||
nInstances := len(y)
|
||||
gini := 0.0
|
||||
var maxLabel int64 = 0
|
||||
|
||||
labelCount := convertToMap(y, labels)
|
||||
for _, label := range labels {
|
||||
if labelCount[label] > labelCount[maxLabel] {
|
||||
maxLabel = label
|
||||
}
|
||||
p := float64(labelCount[label]) / float64(nInstances)
|
||||
gini += p * (1 - p)
|
||||
}
|
||||
return gini, maxLabel
|
||||
}
|
||||
|
||||
// Calculate Entropy loss of Target Labels
|
||||
func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
||||
nInstances := len(y)
|
||||
entropy := 0.0
|
||||
var maxLabel int64 = 0
|
||||
|
||||
labelCount := convertToMap(y, labels)
|
||||
for _, label := range labels {
|
||||
if labelCount[label] > labelCount[maxLabel] {
|
||||
maxLabel = label
|
||||
}
|
||||
p := float64(labelCount[label]) / float64(nInstances)
|
||||
logP := math.Log2(p)
|
||||
if p == 0 {
|
||||
logP = 0
|
||||
}
|
||||
entropy += (-p * logP)
|
||||
}
|
||||
return entropy, maxLabel
|
||||
}
|
||||
|
||||
func calculateClassificationLoss(y []int64, labels []int64, criterion string) (float64, int64, error) {
|
||||
if len(y) == 0 {
|
||||
return 0, 0, errors.New("Need atleast 1 value to compute impurity")
|
||||
}
|
||||
if criterion == GINI {
|
||||
loss, modeLabel := computeGiniImpurityAndModeLabel(y, labels)
|
||||
return loss, modeLabel, nil
|
||||
} else if criterion == ENTROPY {
|
||||
loss, modeLabel := computeEntropyAndModeLabel(y, labels)
|
||||
return loss, modeLabel, nil
|
||||
} else {
|
||||
return 0, 0, errors.New("Invalid impurity function, choose from GINI or ENTROPY")
|
||||
}
|
||||
}
|
||||
|
||||
// Split the data into left node and right node based on feature and threshold
|
||||
func classifierCreateSplit(data [][]float64, feature int64, y []int64, threshold float64) ([][]float64, [][]float64, []int64, []int64) {
|
||||
var left [][]float64
|
||||
var right [][]float64
|
||||
var lefty []int64
|
||||
var righty []int64
|
||||
|
||||
for i := range data {
|
||||
example := data[i]
|
||||
if example[feature] < threshold {
|
||||
left = append(left, example)
|
||||
lefty = append(lefty, y[i])
|
||||
} else {
|
||||
right = append(right, example)
|
||||
righty = append(righty, y[i])
|
||||
}
|
||||
}
|
||||
|
||||
return left, right, lefty, righty
|
||||
}
|
||||
|
||||
// Function to Create New Decision Tree Classifier.
|
||||
// It assigns all of the hyperparameters by user into the tree attributes.
|
||||
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
|
||||
var tree CARTDecisionTreeClassifier
|
||||
tree.criterion = strings.ToLower(criterion)
|
||||
tree.maxDepth = maxDepth
|
||||
tree.labels = labels
|
||||
|
||||
return &tree
|
||||
}
|
||||
|
||||
// Reorder the data by feature being considered. Optimizes code by reducing the number of times we have to loop over data for splitting
|
||||
func classifierReOrderData(featureVal []float64, data [][]float64, y []int64) ([][]float64, []int64) {
|
||||
s := NewSlice(featureVal)
|
||||
sort.Sort(s)
|
||||
|
||||
indexes := s.Idx
|
||||
|
||||
var dataSorted [][]float64
|
||||
var ySorted []int64
|
||||
|
||||
for _, index := range indexes {
|
||||
dataSorted = append(dataSorted, data[index])
|
||||
ySorted = append(ySorted, y[index])
|
||||
}
|
||||
|
||||
return dataSorted, ySorted
|
||||
}
|
||||
|
||||
// Update the left and right side of the split based on the threshold.
|
||||
func classifierUpdateSplit(left [][]float64, leftY []int64, right [][]float64, rightY []int64, feature int64, threshold float64) ([][]float64, []int64, [][]float64, []int64) {
|
||||
|
||||
for right[0][feature] < threshold {
|
||||
left = append(left, right[0])
|
||||
right = right[1:]
|
||||
leftY = append(leftY, rightY[0])
|
||||
rightY = rightY[1:]
|
||||
}
|
||||
|
||||
return left, leftY, right, rightY
|
||||
}
|
||||
|
||||
// Fit - Creates an Empty Root Node
|
||||
// Trains the tree by calling recursive function classifierBestSplit
|
||||
func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) error {
|
||||
var emptyNode classifierNode
|
||||
var err error
|
||||
|
||||
data := convertInstancesToProblemVec(X)
|
||||
y, err := classifierConvertInstancesToLabelVec(X)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
emptyNode, err = classifierBestSplit(*tree, data, y, tree.labels, emptyNode, tree.criterion, tree.maxDepth, 0)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tree.RootNode = &emptyNode
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterativly find and record the best split
|
||||
// Stop If depth reaches maxDepth or nodes are pure
|
||||
func classifierBestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode classifierNode, criterion string, maxDepth int64, depth int64) (classifierNode, error) {
|
||||
|
||||
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
||||
depth++
|
||||
|
||||
if maxDepth != -1 && depth > maxDepth {
|
||||
return upperNode, nil
|
||||
}
|
||||
|
||||
numFeatures := len(data[0])
|
||||
var bestGini, origGini float64
|
||||
var err error
|
||||
// Calculate loss based on Criterion Specified by user
|
||||
origGini, upperNode.LeftLabel, err = calculateClassificationLoss(y, labels, criterion)
|
||||
if err != nil {
|
||||
return upperNode, err
|
||||
}
|
||||
|
||||
bestGini = origGini
|
||||
|
||||
bestLeft, bestRight, bestLefty, bestRighty := data, data, y, y
|
||||
|
||||
numData := len(data)
|
||||
|
||||
bestLeftGini, bestRightGini := bestGini, bestGini
|
||||
|
||||
upperNode.isNodeNeeded = true
|
||||
|
||||
var leftN, rightN classifierNode
|
||||
|
||||
// Iterate over all features
|
||||
for i := 0; i < numFeatures; i++ {
|
||||
|
||||
featureVal := getFeature(data, int64(i))
|
||||
unique := findUnique(featureVal)
|
||||
sort.Float64s(unique)
|
||||
|
||||
sortData, sortY := classifierReOrderData(featureVal, data, y)
|
||||
|
||||
firstTime := true
|
||||
|
||||
var left, right [][]float64
|
||||
var leftY, rightY []int64
|
||||
// Iterate over all possible thresholds for that feature
|
||||
for j := 0; j < len(unique)-1; j++ {
|
||||
|
||||
threshold := (unique[j] + unique[j+1]) / 2
|
||||
// Ensure that same split has not been made before
|
||||
if validate(tree.triedSplits, int64(i), threshold) {
|
||||
// We need to split data from fresh when considering new feature for the first time.
|
||||
// Otherwise, we need to update the split by moving data points from left to right.
|
||||
if firstTime {
|
||||
left, right, leftY, rightY = classifierCreateSplit(sortData, int64(i), sortY, threshold)
|
||||
firstTime = false
|
||||
} else {
|
||||
left, leftY, right, rightY = classifierUpdateSplit(left, leftY, right, rightY, int64(i), threshold)
|
||||
}
|
||||
|
||||
var leftGini, rightGini float64
|
||||
var leftLabels, rightLabels int64
|
||||
|
||||
leftGini, leftLabels, _ = calculateClassificationLoss(leftY, labels, criterion)
|
||||
rightGini, rightLabels, _ = calculateClassificationLoss(rightY, labels, criterion)
|
||||
|
||||
// Calculate weighted gini impurity of child nodes
|
||||
subGini := (leftGini * float64(len(left)) / float64(numData)) + (rightGini * float64(len(right)) / float64(numData))
|
||||
|
||||
// If we find a split that reduces impurity
|
||||
if subGini < bestGini {
|
||||
bestGini = subGini
|
||||
|
||||
bestLeft, bestRight = left, right
|
||||
|
||||
bestLefty, bestRighty = leftY, rightY
|
||||
|
||||
upperNode.Threshold, upperNode.Feature = threshold, int64(i)
|
||||
|
||||
upperNode.LeftLabel, upperNode.RightLabel = leftLabels, rightLabels
|
||||
|
||||
bestLeftGini, bestRightGini = leftGini, rightGini
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// If no split was found, we don't want to use this node, so we will flag it
|
||||
if bestGini == origGini {
|
||||
upperNode.isNodeNeeded = false
|
||||
return upperNode, nil
|
||||
}
|
||||
// Until nodes are not pure
|
||||
if bestGini > 0 {
|
||||
|
||||
// If left node is pure, no need to split on left side again
|
||||
if bestLeftGini > 0 {
|
||||
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
||||
// Recursive splitting logic
|
||||
leftN, err = classifierBestSplit(tree, bestLeft, bestLefty, labels, leftN, criterion, maxDepth, depth)
|
||||
if err != nil {
|
||||
return upperNode, err
|
||||
}
|
||||
if leftN.isNodeNeeded == true {
|
||||
upperNode.Left = &leftN
|
||||
}
|
||||
|
||||
}
|
||||
// If right node is pure, no need to split on right side again
|
||||
if bestRightGini > 0 {
|
||||
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
||||
// Recursive splitting logic
|
||||
rightN, err = classifierBestSplit(tree, bestRight, bestRighty, labels, rightN, criterion, maxDepth, depth)
|
||||
if err != nil {
|
||||
return upperNode, err
|
||||
}
|
||||
if rightN.isNodeNeeded == true {
|
||||
upperNode.Right = &rightN
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
// Return the node - contains all information regarding feature and threshold.
|
||||
return upperNode, nil
|
||||
}
|
||||
|
||||
// String : this function prints out entire tree for visualization.
|
||||
// Calls a recursive function to print the tree - classifierPrintTreeFromNode
|
||||
func (tree *CARTDecisionTreeClassifier) String() string {
|
||||
rootNode := *tree.RootNode
|
||||
return classifierPrintTreeFromNode(rootNode, "")
|
||||
}
|
||||
|
||||
func classifierPrintTreeFromNode(tree classifierNode, spacing string) string {
|
||||
returnString := ""
|
||||
returnString += spacing + "Feature "
|
||||
returnString += strconv.FormatInt(tree.Feature, 10)
|
||||
returnString += " < "
|
||||
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
||||
returnString += "\n"
|
||||
|
||||
if tree.Left == nil {
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += strconv.FormatInt(tree.LeftLabel, 10) + "\n"
|
||||
}
|
||||
if tree.Right == nil {
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += strconv.FormatInt(tree.RightLabel, 10) + "\n"
|
||||
}
|
||||
|
||||
if tree.Left != nil {
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += classifierPrintTreeFromNode(*tree.Left, spacing+" ")
|
||||
}
|
||||
|
||||
if tree.Right != nil {
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += classifierPrintTreeFromNode(*tree.Right, spacing+" ")
|
||||
}
|
||||
|
||||
return returnString
|
||||
}
|
||||
|
||||
// Predict a single data point by traversing the entire tree
|
||||
// Uses recursive logic to navigate the tree.
|
||||
func classifierPredictSingle(tree classifierNode, instance []float64) int64 {
|
||||
if instance[tree.Feature] < tree.Threshold {
|
||||
if tree.Left == nil {
|
||||
return tree.LeftLabel
|
||||
} else {
|
||||
return classifierPredictSingle(*tree.Left, instance)
|
||||
}
|
||||
} else {
|
||||
if tree.Right == nil {
|
||||
return tree.RightLabel
|
||||
} else {
|
||||
return classifierPredictSingle(*tree.Right, instance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Given test data, return predictions for every datapoint. calls classifierPredictFromNode
|
||||
func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
|
||||
root := *tree.RootNode
|
||||
test := convertInstancesToProblemVec(X_test)
|
||||
return classifierPredictFromNode(root, test)
|
||||
}
|
||||
|
||||
// This function uses the rootnode from Predict.
|
||||
// It iterates through every data point and calls the recursive function to give predictions and then summarizes them.
|
||||
func classifierPredictFromNode(tree classifierNode, test [][]float64) []int64 {
|
||||
var preds []int64
|
||||
for i := range test {
|
||||
iPred := classifierPredictSingle(tree, test[i])
|
||||
preds = append(preds, iPred)
|
||||
}
|
||||
return preds
|
||||
}
|
||||
|
||||
// Given Test data and label, return the accuracy of the classifier.
|
||||
// First it retreives predictions from the data, then compares for accuracy.
|
||||
// Calls classifierEvaluateFromNode
|
||||
func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) (float64, error) {
|
||||
rootNode := *tree.RootNode
|
||||
xTest := convertInstancesToProblemVec(test)
|
||||
yTest, err := classifierConvertInstancesToLabelVec(test)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return classifierEvaluateFromNode(rootNode, xTest, yTest), nil
|
||||
}
|
||||
|
||||
// Retrieve predictions and then calculate accuracy.
|
||||
func classifierEvaluateFromNode(tree classifierNode, xTest [][]float64, yTest []int64) float64 {
|
||||
preds := classifierPredictFromNode(tree, xTest)
|
||||
accuracy := 0.0
|
||||
for i := range preds {
|
||||
if preds[i] == yTest[i] {
|
||||
accuracy++
|
||||
}
|
||||
}
|
||||
accuracy /= float64(len(yTest))
|
||||
return accuracy
|
||||
}
|
||||
|
||||
// Helper function to convert base.FixedDataGrid into required format. Called in Fit, Predict
|
||||
func classifierConvertInstancesToLabelVec(X base.FixedDataGrid) ([]int64, error) {
|
||||
// Get the class Attributes
|
||||
classAttrs := X.AllClassAttributes()
|
||||
// Only support 1 class Attribute
|
||||
if len(classAttrs) != 1 {
|
||||
return []int64{0}, errors.New(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
|
||||
}
|
||||
// ClassAttribute must be numeric
|
||||
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
|
||||
return []int64{0}, errors.New(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
|
||||
}
|
||||
// Allocate return structure
|
||||
_, rows := X.Size()
|
||||
|
||||
labelVec := make([]int64, rows)
|
||||
// Resolve class Attribute specification
|
||||
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
||||
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
labelVec[rowNo] = int64(base.UnpackBytesToFloat(row[0]))
|
||||
return true, nil
|
||||
})
|
||||
return labelVec, nil
|
||||
}
|
412
trees/cart_regressor.go
Normal file
412
trees/cart_regressor.go
Normal file
@ -0,0 +1,412 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
const (
|
||||
MAE string = "mae"
|
||||
MSE string = "mse"
|
||||
)
|
||||
|
||||
// RNode - Node struct for Decision Tree Regressor
|
||||
// It holds the information for each split
|
||||
// Which feature to use, threshold, left prediction and right prediction
|
||||
type regressorNode struct {
|
||||
Left *regressorNode
|
||||
Right *regressorNode
|
||||
Threshold float64
|
||||
Feature int64
|
||||
LeftPred float64
|
||||
RightPred float64
|
||||
isNodeNeeded bool
|
||||
}
|
||||
|
||||
// CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
|
||||
// It contains the rootNode, as well as the hyperparameters chosen by user.
|
||||
// Also keeps track of splits used at tree level.
|
||||
type CARTDecisionTreeRegressor struct {
|
||||
RootNode *regressorNode
|
||||
criterion string
|
||||
maxDepth int64
|
||||
triedSplits [][]float64
|
||||
}
|
||||
|
||||
// Find average
|
||||
func average(y []float64) float64 {
|
||||
mean := 0.0
|
||||
for _, value := range y {
|
||||
mean += value
|
||||
}
|
||||
mean /= float64(len(y))
|
||||
return mean
|
||||
}
|
||||
|
||||
// Calculate Mean Absolute Error for a constant prediction
|
||||
func meanAbsoluteError(y []float64, yBar float64) float64 {
|
||||
error := 0.0
|
||||
for _, target := range y {
|
||||
error += math.Abs(target - yBar)
|
||||
}
|
||||
error /= float64(len(y))
|
||||
return error
|
||||
}
|
||||
|
||||
// Turn Mean Absolute Error into impurity function for decision trees.
|
||||
func computeMaeImpurityAndAverage(y []float64) (float64, float64) {
|
||||
yHat := average(y)
|
||||
return meanAbsoluteError(y, yHat), yHat
|
||||
}
|
||||
|
||||
// Calculate Mean Squared Error for constant prediction
|
||||
func meanSquaredError(y []float64, yBar float64) float64 {
|
||||
error := 0.0
|
||||
for _, target := range y {
|
||||
itemError := target - yBar
|
||||
error += math.Pow(itemError, 2)
|
||||
}
|
||||
error /= float64(len(y))
|
||||
return error
|
||||
}
|
||||
|
||||
// Convert mean squared error into impurity function for decision trees
|
||||
func computeMseImpurityAndAverage(y []float64) (float64, float64) {
|
||||
yHat := average(y)
|
||||
return meanSquaredError(y, yHat), yHat
|
||||
}
|
||||
|
||||
func calculateRegressionLoss(y []float64, criterion string) (float64, float64, error) {
|
||||
if criterion == MAE {
|
||||
loss, avg := computeMaeImpurityAndAverage(y)
|
||||
return loss, avg, nil
|
||||
} else if criterion == MSE {
|
||||
loss, avg := computeMseImpurityAndAverage(y)
|
||||
return loss, avg, nil
|
||||
} else {
|
||||
panic("Invalid impurity function, choose from MAE or MSE")
|
||||
}
|
||||
}
|
||||
|
||||
// Split the data into left and right based on trehsold and feature.
|
||||
func regressorCreateSplit(data [][]float64, feature int64, y []float64, threshold float64) ([][]float64, [][]float64, []float64, []float64) {
|
||||
var left [][]float64
|
||||
var lefty []float64
|
||||
var right [][]float64
|
||||
var righty []float64
|
||||
|
||||
for i := range data {
|
||||
example := data[i]
|
||||
if example[feature] < threshold {
|
||||
left = append(left, example)
|
||||
lefty = append(lefty, y[i])
|
||||
} else {
|
||||
right = append(right, example)
|
||||
righty = append(righty, y[i])
|
||||
}
|
||||
}
|
||||
|
||||
return left, right, lefty, righty
|
||||
}
|
||||
|
||||
// Interface for creating new Decision Tree Regressor
|
||||
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
|
||||
var tree CARTDecisionTreeRegressor
|
||||
tree.maxDepth = maxDepth
|
||||
tree.criterion = strings.ToLower(criterion)
|
||||
return &tree
|
||||
}
|
||||
|
||||
// Re order data based on a feature for optimizing code
|
||||
// Helps in updating splits without reiterating entire dataset
|
||||
func regressorReOrderData(featureVal []float64, data [][]float64, y []float64) ([][]float64, []float64) {
|
||||
s := NewSlice(featureVal)
|
||||
sort.Sort(s)
|
||||
|
||||
indexes := s.Idx
|
||||
|
||||
var dataSorted [][]float64
|
||||
var ySorted []float64
|
||||
|
||||
for _, index := range indexes {
|
||||
dataSorted = append(dataSorted, data[index])
|
||||
ySorted = append(ySorted, y[index])
|
||||
}
|
||||
|
||||
return dataSorted, ySorted
|
||||
}
|
||||
|
||||
// Update the left and right data based on change in threshold
|
||||
func regressorUpdateSplit(left [][]float64, leftY []float64, right [][]float64, rightY []float64, feature int64, threshold float64) ([][]float64, []float64, [][]float64, []float64) {
|
||||
|
||||
for right[0][feature] < threshold {
|
||||
left = append(left, right[0])
|
||||
right = right[1:]
|
||||
leftY = append(leftY, rightY[0])
|
||||
rightY = rightY[1:]
|
||||
}
|
||||
|
||||
return left, leftY, right, rightY
|
||||
}
|
||||
|
||||
// Fit - Build the tree using the data
|
||||
// Creates empty root node and builds tree by calling regressorBestSplit
|
||||
func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) error {
|
||||
var emptyNode regressorNode
|
||||
var err error
|
||||
|
||||
data := regressorConvertInstancesToProblemVec(X)
|
||||
y, err := regressorConvertInstancesToLabelVec(X)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
emptyNode, err = regressorBestSplit(*tree, data, y, emptyNode, tree.criterion, tree.maxDepth, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tree.RootNode = &emptyNode
|
||||
return nil
|
||||
}
|
||||
|
||||
// Builds the tree by iteratively finding the best split.
|
||||
// Recursive function - stops if maxDepth is reached or nodes are pure
|
||||
func regressorBestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode regressorNode, criterion string, maxDepth int64, depth int64) (regressorNode, error) {
|
||||
|
||||
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
||||
depth++
|
||||
|
||||
if depth > maxDepth && maxDepth != -1 {
|
||||
return upperNode, nil
|
||||
}
|
||||
|
||||
numFeatures := len(data[0])
|
||||
var bestLoss, origLoss float64
|
||||
var err error
|
||||
origLoss, upperNode.LeftPred, err = calculateRegressionLoss(y, criterion)
|
||||
if err != nil {
|
||||
return upperNode, err
|
||||
}
|
||||
|
||||
bestLoss = origLoss
|
||||
|
||||
bestLeft, bestRight, bestLefty, bestRighty := data, data, y, y
|
||||
|
||||
numData := len(data)
|
||||
|
||||
bestLeftLoss, bestRightLoss := bestLoss, bestLoss
|
||||
|
||||
upperNode.isNodeNeeded = true
|
||||
|
||||
var leftN, rightN regressorNode
|
||||
// Iterate over all features
|
||||
for i := 0; i < numFeatures; i++ {
|
||||
|
||||
featureVal := getFeature(data, int64(i))
|
||||
unique := findUnique(featureVal)
|
||||
sort.Float64s(unique)
|
||||
|
||||
sortData, sortY := regressorReOrderData(featureVal, data, y)
|
||||
|
||||
firstTime := true
|
||||
|
||||
var left, right [][]float64
|
||||
var leftY, rightY []float64
|
||||
|
||||
for j := 0; j < len(unique)-1; j++ {
|
||||
threshold := (unique[j] + unique[j+1]) / 2
|
||||
if validate(tree.triedSplits, int64(i), threshold) {
|
||||
if firstTime {
|
||||
left, right, leftY, rightY = regressorCreateSplit(sortData, int64(i), sortY, threshold)
|
||||
firstTime = false
|
||||
} else {
|
||||
left, leftY, right, rightY = regressorUpdateSplit(left, leftY, right, rightY, int64(i), threshold)
|
||||
}
|
||||
|
||||
var leftLoss, rightLoss float64
|
||||
var leftPred, rightPred float64
|
||||
|
||||
leftLoss, leftPred, _ = calculateRegressionLoss(leftY, criterion)
|
||||
rightLoss, rightPred, _ = calculateRegressionLoss(rightY, criterion)
|
||||
|
||||
subLoss := (leftLoss * float64(len(left)) / float64(numData)) + (rightLoss * float64(len(right)) / float64(numData))
|
||||
|
||||
if subLoss < bestLoss {
|
||||
bestLoss = subLoss
|
||||
|
||||
bestLeft, bestRight = left, right
|
||||
bestLefty, bestRighty = leftY, rightY
|
||||
|
||||
upperNode.Threshold, upperNode.Feature = threshold, int64(i)
|
||||
|
||||
upperNode.LeftPred, upperNode.RightPred = leftPred, rightPred
|
||||
|
||||
bestLeftLoss, bestRightLoss = leftLoss, rightLoss
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bestLoss == origLoss {
|
||||
upperNode.isNodeNeeded = false
|
||||
return upperNode, nil
|
||||
}
|
||||
|
||||
if bestLoss > 0 {
|
||||
|
||||
if bestLeftLoss > 0 {
|
||||
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
||||
leftN, err = regressorBestSplit(tree, bestLeft, bestLefty, leftN, criterion, maxDepth, depth)
|
||||
if err != nil {
|
||||
return upperNode, err
|
||||
}
|
||||
if leftN.isNodeNeeded == true {
|
||||
upperNode.Left = &leftN
|
||||
}
|
||||
}
|
||||
|
||||
if bestRightLoss > 0 {
|
||||
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
||||
rightN, err = regressorBestSplit(tree, bestRight, bestRighty, rightN, criterion, maxDepth, depth)
|
||||
if err != nil {
|
||||
return upperNode, err
|
||||
}
|
||||
if rightN.isNodeNeeded == true {
|
||||
upperNode.Right = &rightN
|
||||
}
|
||||
}
|
||||
}
|
||||
return upperNode, nil
|
||||
}
|
||||
|
||||
// Print Tree for Visualtion - calls regressorPrintTreeFromNode()
|
||||
func (tree *CARTDecisionTreeRegressor) String() string {
|
||||
rootNode := *tree.RootNode
|
||||
return regressorPrintTreeFromNode(rootNode, "")
|
||||
}
|
||||
|
||||
// Recursively explore the entire tree and print out all details such as threshold, feature, prediction
|
||||
func regressorPrintTreeFromNode(tree regressorNode, spacing string) string {
|
||||
returnString := ""
|
||||
returnString += spacing + "Feature "
|
||||
returnString += strconv.FormatInt(tree.Feature, 10)
|
||||
returnString += " < "
|
||||
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
||||
returnString += "\n"
|
||||
|
||||
if tree.Left == nil {
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += fmt.Sprintf("%.3f", tree.LeftPred) + "\n"
|
||||
}
|
||||
if tree.Right == nil {
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += fmt.Sprintf("%.3f", tree.RightPred) + "\n"
|
||||
}
|
||||
|
||||
if tree.Left != nil {
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += regressorPrintTreeFromNode(*tree.Left, spacing+" ")
|
||||
}
|
||||
|
||||
if tree.Right != nil {
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += regressorPrintTreeFromNode(*tree.Right, spacing+" ")
|
||||
}
|
||||
|
||||
return returnString
|
||||
}
|
||||
|
||||
// Predict a single data point by navigating to rootNodes.
|
||||
// Uses a recursive logic
|
||||
func regressorPredictSingle(tree regressorNode, instance []float64) float64 {
|
||||
if instance[tree.Feature] < tree.Threshold {
|
||||
if tree.Left == nil {
|
||||
return tree.LeftPred
|
||||
} else {
|
||||
return regressorPredictSingle(*tree.Left, instance)
|
||||
}
|
||||
} else {
|
||||
if tree.Right == nil {
|
||||
return tree.RightPred
|
||||
} else {
|
||||
return regressorPredictSingle(*tree.Right, instance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Predict method for multiple data points.
|
||||
// First converts input data into usable format, and then calls regressorPredictFromNode
|
||||
func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
|
||||
root := *tree.RootNode
|
||||
test := regressorConvertInstancesToProblemVec(X_test)
|
||||
return regressorPredictFromNode(root, test)
|
||||
}
|
||||
|
||||
// Use tree's root node to print out entire tree.
|
||||
// Iterates over all data points and calls regressorPredictSingle to predict individual datapoints.
|
||||
func regressorPredictFromNode(tree regressorNode, test [][]float64) []float64 {
|
||||
var preds []float64
|
||||
for i := range test {
|
||||
i_pred := regressorPredictSingle(tree, test[i])
|
||||
preds = append(preds, i_pred)
|
||||
}
|
||||
return preds
|
||||
}
|
||||
|
||||
// Helper function to convert base.FixedDataGrid into required format. Called in Fit, Predict
|
||||
func regressorConvertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
||||
// Allocate problem array
|
||||
_, rows := X.Size()
|
||||
problemVec := make([][]float64, rows)
|
||||
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
|
||||
// Convert each row
|
||||
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
// Allocate a new row
|
||||
probRow := make([]float64, len(numericAttrSpecs))
|
||||
// Read out the row
|
||||
for i, _ := range numericAttrSpecs {
|
||||
probRow[i] = base.UnpackBytesToFloat(row[i])
|
||||
}
|
||||
// Add the row
|
||||
problemVec[rowNo] = probRow
|
||||
return true, nil
|
||||
})
|
||||
return problemVec
|
||||
}
|
||||
|
||||
// Helper function to convert base.FixedDataGrid into required format. Called in Fit, Predict
|
||||
func regressorConvertInstancesToLabelVec(X base.FixedDataGrid) ([]float64, error) {
|
||||
// Get the class Attributes
|
||||
classAttrs := X.AllClassAttributes()
|
||||
// Only support 1 class Attribute
|
||||
if len(classAttrs) != 1 {
|
||||
return []float64{0}, errors.New(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
|
||||
}
|
||||
// ClassAttribute must be numeric
|
||||
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
|
||||
return []float64{0}, errors.New(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
|
||||
}
|
||||
// Allocate return structure
|
||||
_, rows := X.Size()
|
||||
|
||||
labelVec := make([]float64, rows)
|
||||
// Resolve class Attribute specification
|
||||
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
||||
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
|
||||
return true, nil
|
||||
})
|
||||
return labelVec, nil
|
||||
}
|
104
trees/cart_test.go
Normal file
104
trees/cart_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestRegressor(t *testing.T) {
|
||||
|
||||
Convey("Doing a CART Test", t, func() {
|
||||
// For Classification Trees:
|
||||
|
||||
// Is Gini being calculated correctly
|
||||
gini, giniMaxLabel := computeGiniImpurityAndModeLabel([]int64{1, 0, 0, 1}, []int64{0, 1})
|
||||
So(gini, ShouldEqual, 0.5)
|
||||
So(giniMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Entropy being calculated correctly
|
||||
entropy, entropyMaxLabel := computeEntropyAndModeLabel([]int64{1, 0, 0, 1}, []int64{0, 1})
|
||||
So(entropy, ShouldEqual, 1.0)
|
||||
So(entropyMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Data being split into left and right properly
|
||||
classifierData := [][]float64{[]float64{1, 3, 6},
|
||||
[]float64{1, 2, 3},
|
||||
[]float64{1, 9, 6},
|
||||
[]float64{1, 11, 1}}
|
||||
|
||||
classifiery := []int64{0, 1, 0, 0}
|
||||
|
||||
leftdata, rightdata, lefty, righty := classifierCreateSplit(classifierData, 1, classifiery, 5.0)
|
||||
|
||||
So(len(leftdata), ShouldEqual, 2)
|
||||
So(len(lefty), ShouldEqual, 2)
|
||||
So(len(rightdata), ShouldEqual, 2)
|
||||
So(len(righty), ShouldEqual, 2)
|
||||
|
||||
// Is isolating unique values working properly
|
||||
So(len(findUnique([]float64{10, 1, 1})), ShouldEqual, 2)
|
||||
|
||||
// is data reordered correctly
|
||||
orderedData, orderedY := classifierReOrderData(getFeature(classifierData, 1), classifierData, classifiery)
|
||||
|
||||
So(orderedData[1][1], ShouldEqual, 3.0)
|
||||
So(orderedY[0], ShouldEqual, 1)
|
||||
|
||||
// Is split being updated properly based on threshold
|
||||
leftdata, lefty, rightdata, righty = classifierUpdateSplit(leftdata, lefty, rightdata, righty, 1, 9.5)
|
||||
So(len(leftdata), ShouldEqual, 3)
|
||||
So(len(rightdata), ShouldEqual, 1)
|
||||
|
||||
// Is the root Node null when tree is not trained?
|
||||
tree := NewDecisionTreeClassifier("gini", -1, []int64{0, 1})
|
||||
So(tree.RootNode, ShouldBeNil)
|
||||
So(tree.triedSplits, ShouldBeEmpty)
|
||||
|
||||
// ------------------------------------------
|
||||
// For Regression Trees
|
||||
|
||||
// Is MAE being calculated correctly
|
||||
mae, maeMaxLabel := computeMaeImpurityAndAverage([]float64{1, 3, 5})
|
||||
So(mae, ShouldEqual, (4.0 / 3.0))
|
||||
So(maeMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Entropy being calculated correctly
|
||||
mse, mseMaxLabel := computeMseImpurityAndAverage([]float64{1, 3, 5})
|
||||
So(mse, ShouldEqual, (8.0 / 3.0))
|
||||
So(mseMaxLabel, ShouldNotBeNil)
|
||||
|
||||
// Is Data being split into left and right properly
|
||||
data := [][]float64{[]float64{1, 3, 6},
|
||||
[]float64{1, 2, 3},
|
||||
[]float64{1, 9, 6},
|
||||
[]float64{1, 11, 1}}
|
||||
|
||||
y := []float64{1, 2, 3, 4}
|
||||
|
||||
leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0)
|
||||
|
||||
So(len(leftData), ShouldEqual, 2)
|
||||
So(len(leftY), ShouldEqual, 2)
|
||||
So(len(rightData), ShouldEqual, 2)
|
||||
So(len(rightY), ShouldEqual, 2)
|
||||
|
||||
// is data reordered correctly
|
||||
regressorOrderedData, regressorOrderedY := regressorReOrderData(getFeature(data, 1), data, y)
|
||||
|
||||
So(regressorOrderedData[1][1], ShouldEqual, 3.0)
|
||||
So(regressorOrderedY[0], ShouldEqual, 2)
|
||||
|
||||
// Is split being updated properly based on threshold
|
||||
leftData, leftY, rightData, rightY = regressorUpdateSplit(leftData, leftY, rightData, rightY, 1, 9.5)
|
||||
So(len(leftData), ShouldEqual, 3)
|
||||
So(len(rightData), ShouldEqual, 1)
|
||||
|
||||
// Is the root Node null when tree is not trained?
|
||||
regressorTreetree := NewDecisionTreeRegressor("mae", -1)
|
||||
So(regressorTreetree.RootNode, ShouldBeNil)
|
||||
So(regressorTreetree.triedSplits, ShouldBeEmpty)
|
||||
|
||||
})
|
||||
|
||||
}
|
65
trees/cart_utils.go
Normal file
65
trees/cart_utils.go
Normal file
@ -0,0 +1,65 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
// Isolate only unique values. This way, we can try only unique splits and not redundant ones.
|
||||
func findUnique(data []float64) []float64 {
|
||||
keys := make(map[float64]bool)
|
||||
unique := []float64{}
|
||||
for _, entry := range data {
|
||||
if _, value := keys[entry]; !value {
|
||||
keys[entry] = true
|
||||
unique = append(unique, entry)
|
||||
}
|
||||
}
|
||||
return unique
|
||||
}
|
||||
|
||||
// Isolate only the feature being considered for splitting. Reduces the complexity in managing splits.
|
||||
func getFeature(data [][]float64, feature int64) []float64 {
|
||||
var featureVals []float64
|
||||
for i := range data {
|
||||
featureVals = append(featureVals, data[i][feature])
|
||||
}
|
||||
return featureVals
|
||||
}
|
||||
|
||||
// Make sure that split being considered has not been done before.
|
||||
// Else we will unnecessarily try splits that won't improve Impurity.
|
||||
func validate(triedSplits [][]float64, feature int64, threshold float64) bool {
|
||||
for i := range triedSplits {
|
||||
split := triedSplits[i]
|
||||
featureTried, thresholdTried := split[0], split[1]
|
||||
if int64(featureTried) == feature && thresholdTried == threshold {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Helper function to convert base.FixedDataGrid into required format. Called in Fit, Predict
|
||||
func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
|
||||
// Allocate problem array
|
||||
_, rows := X.Size()
|
||||
problemVec := make([][]float64, rows)
|
||||
|
||||
// Retrieve numeric non-class Attributes
|
||||
numericAttrs := base.NonClassFloatAttributes(X)
|
||||
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
|
||||
|
||||
// Convert each row
|
||||
X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
// Allocate a new row
|
||||
probRow := make([]float64, len(numericAttrSpecs))
|
||||
// Read out the row
|
||||
for i, _ := range numericAttrSpecs {
|
||||
probRow[i] = base.UnpackBytesToFloat(row[i])
|
||||
}
|
||||
// Add the row
|
||||
problemVec[rowNo] = probRow
|
||||
return true, nil
|
||||
})
|
||||
return problemVec
|
||||
}
|
24
trees/sorter.go
Normal file
24
trees/sorter.go
Normal file
@ -0,0 +1,24 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
type Slice struct {
|
||||
sort.Float64Slice
|
||||
Idx []int
|
||||
}
|
||||
|
||||
func (s Slice) Swap(i, j int) {
|
||||
s.Float64Slice.Swap(i, j)
|
||||
s.Idx[i], s.Idx[j] = s.Idx[j], s.Idx[i]
|
||||
}
|
||||
|
||||
func NewSlice(n []float64) *Slice {
|
||||
s := &Slice{Float64Slice: sort.Float64Slice(n), Idx: make([]int, len(n))}
|
||||
|
||||
for i := range s.Idx {
|
||||
s.Idx[i] = i
|
||||
}
|
||||
return s
|
||||
}
|
@ -11,6 +11,14 @@
|
||||
present, so discretise beforehand (see
|
||||
filters)
|
||||
|
||||
CART (Classification and Regression Trees):
|
||||
Builds a binary decision tree using the CART algorithm
|
||||
using a greedy approach to find the best split at each node.
|
||||
|
||||
Can be used for regression and classficiation.
|
||||
Attributes have to be FloatAttributes even for classification.
|
||||
Hence, convert to Integer Labels before hand for Classficiation.
|
||||
|
||||
RandomTree:
|
||||
Builds a decision tree using the ID3 algorithm
|
||||
by picking the Attribute amongst those
|
||||
|
Loading…
x
Reference in New Issue
Block a user