1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Merge pull request #52 from njern/linear_regression

Linear regression
This commit is contained in:
Stephen Whitworth 2014-07-19 09:30:58 -07:00
commit 222b0ab33d
7 changed files with 192 additions and 2 deletions

View File

@ -98,7 +98,7 @@ func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
for _, entry := range columns {
entry = strings.Trim(entry, " ")
matched, err := regexp.MatchString("^[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?$", entry)
fmt.Println(entry, matched)
//fmt.Println(entry, matched)
if err != nil {
panic(err)
}

View File

@ -0,0 +1,2 @@
EXAM1,EXAM2,EXAM3,FINAL
73,80,75,152
1 EXAM1 EXAM2 EXAM3 FINAL
2 73 80 75 152

View File

@ -0,0 +1,26 @@
EXAM1,EXAM2,EXAM3,FINAL
73,80,75,152
93,88,93,185
89,91,90,180
96,98,100,196
73,66,70,142
53,46,55,101
69,74,77,149
47,56,60,115
87,79,90,175
79,70,88,164
69,70,73,141
70,65,74,141
93,95,91,184
79,80,73,152
70,73,78,148
93,89,96,192
78,75,68,147
81,90,93,183
88,92,86,177
78,83,77,159
82,86,90,177
86,82,89,175
78,83,85,175
76,83,71,149
96,93,95,192
1 EXAM1 EXAM2 EXAM3 FINAL
2 73 80 75 152
3 93 88 93 185
4 89 91 90 180
5 96 98 100 196
6 73 66 70 142
7 53 46 55 101
8 69 74 77 149
9 47 56 60 115
10 87 79 90 175
11 79 70 88 164
12 69 70 73 141
13 70 65 74 141
14 93 95 91 184
15 79 80 73 152
16 70 73 78 148
17 93 89 96 192
18 78 75 68 147
19 81 90 93 183
20 88 92 86 177
21 78 83 77 159
22 82 86 90 177
23 86 82 89 175
24 78 83 85 175
25 76 83 71 149
26 96 93 95 192

5
linear_models/doc.go Normal file
View File

@ -0,0 +1,5 @@
/*
Package linear_models implements linear
and logistic regression models.
*/
package linear_models

View File

@ -0,0 +1,98 @@
package linear_models
import (
"errors"
"github.com/sjwhitworth/golearn/base"
_ "github.com/gonum/blas"
"github.com/gonum/blas/cblas"
"github.com/gonum/matrix/mat64"
)
var (
NotEnoughDataError = errors.New("not enough rows to support this many variables.")
NoTrainingDataError = errors.New("you need to Fit() before you can Predict()")
)
type LinearRegression struct {
fitted bool
disturbance float64
regressionCoefficients []float64
}
func init() {
mat64.Register(cblas.Blas{})
}
func NewLinearRegression() *LinearRegression {
return &LinearRegression{fitted: false}
}
func (lr *LinearRegression) Fit(inst *base.Instances) error {
if inst.Rows < inst.GetAttributeCount() {
return NotEnoughDataError
}
// Split into two matrices, observed results (dependent variable y)
// and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
observed := mat64.NewDense(inst.Rows, 1, nil)
explVariables := mat64.NewDense(inst.Rows, inst.GetAttributeCount(), nil)
for i := 0; i < inst.Rows; i++ {
observed.Set(i, 0, inst.Get(i, inst.ClassIndex)) // Set observed data
for j := 0; j < inst.GetAttributeCount(); j++ {
if j == 0 {
// Set intercepts to 1.0
// Could / should be done better: http://www.theanalysisfactor.com/interpret-the-intercept/
explVariables.Set(i, 0, 1.0)
} else {
explVariables.Set(i, j, inst.Get(i, j-1))
}
}
}
n := inst.GetAttributeCount()
qr := mat64.QR(explVariables)
q := qr.Q()
reg := qr.R()
var transposed, qty mat64.Dense
transposed.TCopy(q)
qty.Mul(&transposed, observed)
regressionCoefficients := make([]float64, n)
for i := n - 1; i >= 0; i-- {
regressionCoefficients[i] = qty.At(i, 0)
for j := i + 1; j < n; j++ {
regressionCoefficients[i] -= regressionCoefficients[j] * reg.At(i, j)
}
regressionCoefficients[i] /= reg.At(i, i)
}
lr.disturbance = regressionCoefficients[0]
lr.regressionCoefficients = regressionCoefficients[1:]
lr.fitted = true
return nil
}
func (lr *LinearRegression) Predict(X *base.Instances) (*base.Instances, error) {
if !lr.fitted {
return nil, NoTrainingDataError
}
ret := X.GeneratePredictionVector()
for i := 0; i < X.Rows; i++ {
var prediction float64 = lr.disturbance
for j := 0; j < X.Cols; j++ {
if j != X.ClassIndex {
prediction += X.Get(i, j) * lr.regressionCoefficients[j]
}
}
ret.Set(i, 0, prediction)
}
return ret, nil
}

View File

@ -0,0 +1,60 @@
package linear_models
import (
"fmt"
"testing"
"github.com/sjwhitworth/golearn/base"
)
func TestNoTrainingData(t *testing.T) {
lr := NewLinearRegression()
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
if err != nil {
t.Fatal(err)
}
_, err = lr.Predict(rawData)
if err != NoTrainingDataError {
t.Fatal("failed to error out even if no training data exists")
}
}
func TestNotEnoughTrainingData(t *testing.T) {
lr := NewLinearRegression()
rawData, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
if err != nil {
t.Fatal(err)
}
err = lr.Fit(rawData)
if err != NotEnoughDataError {
t.Fatal("failed to error out even though there was not enough data")
}
}
func TestLinearRegression(t *testing.T) {
lr := NewLinearRegression()
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
if err != nil {
t.Fatal(err)
}
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
err = lr.Fit(trainData)
if err != nil {
t.Fatal(err)
}
predictions, err := lr.Predict(testData)
if err != nil {
t.Fatal(err)
}
for i := 0; i < predictions.Rows; i++ {
fmt.Printf("Expected: %f || Predicted: %f\n", testData.Get(i, testData.ClassIndex), predictions.Get(i, predictions.ClassIndex))
}
}

View File

@ -1 +0,0 @@
package lm