mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
commit
222b0ab33d
@ -98,7 +98,7 @@ func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
|
||||
for _, entry := range columns {
|
||||
entry = strings.Trim(entry, " ")
|
||||
matched, err := regexp.MatchString("^[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?$", entry)
|
||||
fmt.Println(entry, matched)
|
||||
//fmt.Println(entry, matched)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
2
examples/datasets/exam.csv
Normal file
2
examples/datasets/exam.csv
Normal file
@ -0,0 +1,2 @@
|
||||
EXAM1,EXAM2,EXAM3,FINAL
|
||||
73,80,75,152
|
|
26
examples/datasets/exams.csv
Normal file
26
examples/datasets/exams.csv
Normal file
@ -0,0 +1,26 @@
|
||||
EXAM1,EXAM2,EXAM3,FINAL
|
||||
73,80,75,152
|
||||
93,88,93,185
|
||||
89,91,90,180
|
||||
96,98,100,196
|
||||
73,66,70,142
|
||||
53,46,55,101
|
||||
69,74,77,149
|
||||
47,56,60,115
|
||||
87,79,90,175
|
||||
79,70,88,164
|
||||
69,70,73,141
|
||||
70,65,74,141
|
||||
93,95,91,184
|
||||
79,80,73,152
|
||||
70,73,78,148
|
||||
93,89,96,192
|
||||
78,75,68,147
|
||||
81,90,93,183
|
||||
88,92,86,177
|
||||
78,83,77,159
|
||||
82,86,90,177
|
||||
86,82,89,175
|
||||
78,83,85,175
|
||||
76,83,71,149
|
||||
96,93,95,192
|
|
5
linear_models/doc.go
Normal file
5
linear_models/doc.go
Normal file
@ -0,0 +1,5 @@
|
||||
/*
|
||||
Package linear_models implements linear
|
||||
and logistic regression models.
|
||||
*/
|
||||
package linear_models
|
98
linear_models/linear_regression.go
Normal file
98
linear_models/linear_regression.go
Normal file
@ -0,0 +1,98 @@
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
|
||||
_ "github.com/gonum/blas"
|
||||
"github.com/gonum/blas/cblas"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
var (
|
||||
NotEnoughDataError = errors.New("not enough rows to support this many variables.")
|
||||
NoTrainingDataError = errors.New("you need to Fit() before you can Predict()")
|
||||
)
|
||||
|
||||
type LinearRegression struct {
|
||||
fitted bool
|
||||
disturbance float64
|
||||
regressionCoefficients []float64
|
||||
}
|
||||
|
||||
func init() {
|
||||
mat64.Register(cblas.Blas{})
|
||||
}
|
||||
|
||||
func NewLinearRegression() *LinearRegression {
|
||||
return &LinearRegression{fitted: false}
|
||||
}
|
||||
|
||||
func (lr *LinearRegression) Fit(inst *base.Instances) error {
|
||||
if inst.Rows < inst.GetAttributeCount() {
|
||||
return NotEnoughDataError
|
||||
}
|
||||
|
||||
// Split into two matrices, observed results (dependent variable y)
|
||||
// and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
|
||||
observed := mat64.NewDense(inst.Rows, 1, nil)
|
||||
explVariables := mat64.NewDense(inst.Rows, inst.GetAttributeCount(), nil)
|
||||
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
observed.Set(i, 0, inst.Get(i, inst.ClassIndex)) // Set observed data
|
||||
|
||||
for j := 0; j < inst.GetAttributeCount(); j++ {
|
||||
if j == 0 {
|
||||
// Set intercepts to 1.0
|
||||
// Could / should be done better: http://www.theanalysisfactor.com/interpret-the-intercept/
|
||||
explVariables.Set(i, 0, 1.0)
|
||||
} else {
|
||||
explVariables.Set(i, j, inst.Get(i, j-1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n := inst.GetAttributeCount()
|
||||
qr := mat64.QR(explVariables)
|
||||
q := qr.Q()
|
||||
reg := qr.R()
|
||||
|
||||
var transposed, qty mat64.Dense
|
||||
transposed.TCopy(q)
|
||||
qty.Mul(&transposed, observed)
|
||||
|
||||
regressionCoefficients := make([]float64, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
regressionCoefficients[i] = qty.At(i, 0)
|
||||
for j := i + 1; j < n; j++ {
|
||||
regressionCoefficients[i] -= regressionCoefficients[j] * reg.At(i, j)
|
||||
}
|
||||
regressionCoefficients[i] /= reg.At(i, i)
|
||||
}
|
||||
|
||||
lr.disturbance = regressionCoefficients[0]
|
||||
lr.regressionCoefficients = regressionCoefficients[1:]
|
||||
lr.fitted = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lr *LinearRegression) Predict(X *base.Instances) (*base.Instances, error) {
|
||||
if !lr.fitted {
|
||||
return nil, NoTrainingDataError
|
||||
}
|
||||
|
||||
ret := X.GeneratePredictionVector()
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
var prediction float64 = lr.disturbance
|
||||
for j := 0; j < X.Cols; j++ {
|
||||
if j != X.ClassIndex {
|
||||
prediction += X.Get(i, j) * lr.regressionCoefficients[j]
|
||||
}
|
||||
}
|
||||
ret.Set(i, 0, prediction)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
60
linear_models/linear_regression_test.go
Normal file
60
linear_models/linear_regression_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
func TestNoTrainingData(t *testing.T) {
|
||||
lr := NewLinearRegression()
|
||||
|
||||
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = lr.Predict(rawData)
|
||||
if err != NoTrainingDataError {
|
||||
t.Fatal("failed to error out even if no training data exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotEnoughTrainingData(t *testing.T) {
|
||||
lr := NewLinearRegression()
|
||||
|
||||
rawData, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = lr.Fit(rawData)
|
||||
if err != NotEnoughDataError {
|
||||
t.Fatal("failed to error out even though there was not enough data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearRegression(t *testing.T) {
|
||||
lr := NewLinearRegression()
|
||||
|
||||
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
|
||||
err = lr.Fit(trainData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
predictions, err := lr.Predict(testData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < predictions.Rows; i++ {
|
||||
fmt.Printf("Expected: %f || Predicted: %f\n", testData.Get(i, testData.ClassIndex), predictions.Get(i, predictions.ClassIndex))
|
||||
}
|
||||
}
|
@ -1 +0,0 @@
|
||||
package lm
|
Loading…
x
Reference in New Issue
Block a user