1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/linear_models/linear_regression.go

132 lines
3.3 KiB
Go
Raw Normal View History

2014-07-19 16:02:11 +03:00
package linear_models
import (
"errors"
"github.com/sjwhitworth/golearn/base"
2014-08-03 15:05:35 +01:00
"fmt"
2014-07-19 16:02:11 +03:00
_ "github.com/gonum/blas"
"github.com/gonum/matrix/mat64"
)
var (
NotEnoughDataError = errors.New("not enough rows to support this many variables.")
NoTrainingDataError = errors.New("you need to Fit() before you can Predict()")
)
type LinearRegression struct {
fitted bool
disturbance float64
regressionCoefficients []float64
2014-08-03 15:05:35 +01:00
attrs []base.Attribute
cls base.Attribute
2014-07-19 16:02:11 +03:00
}
func NewLinearRegression() *LinearRegression {
return &LinearRegression{fitted: false}
}
2014-08-03 15:05:35 +01:00
func (lr *LinearRegression) Fit(inst base.FixedDataGrid) error {
// Retrieve row size
_, rows := inst.Size()
// Validate class Attribute count
classAttrs := inst.AllClassAttributes()
if len(classAttrs) != 1 {
return fmt.Errorf("Only 1 class variable is permitted")
}
classAttrSpecs := base.ResolveAttributes(inst, classAttrs)
// Retrieve relevant Attributes
allAttrs := base.NonClassAttributes(inst)
attrs := make([]base.Attribute, 0)
for _, a := range allAttrs {
if _, ok := a.(*base.FloatAttribute); ok {
attrs = append(attrs, a)
}
}
cols := len(attrs) + 1
if rows < cols {
2014-07-19 16:02:11 +03:00
return NotEnoughDataError
}
2014-08-03 15:05:35 +01:00
// Retrieve relevant Attribute specifications
attrSpecs := base.ResolveAttributes(inst, attrs)
2014-07-19 16:02:11 +03:00
// Split into two matrices, observed results (dependent variable y)
// and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
2014-08-03 15:05:35 +01:00
observed := mat64.NewDense(rows, 1, nil)
explVariables := mat64.NewDense(rows, cols, nil)
// Build the observed matrix
inst.MapOverRows(classAttrSpecs, func(row [][]byte, i int) (bool, error) {
val := base.UnpackBytesToFloat(row[0])
observed.Set(i, 0, val)
return true, nil
})
// Build the explainatory variables
inst.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
// Set intercepts to 1.0
explVariables.Set(i, 0, 1.0)
for j, r := range row {
explVariables.Set(i, j+1, base.UnpackBytesToFloat(r))
2014-07-19 16:02:11 +03:00
}
2014-08-03 15:05:35 +01:00
return true, nil
})
2014-07-19 16:02:11 +03:00
2014-08-03 15:05:35 +01:00
n := cols
2014-07-19 16:02:11 +03:00
qr := mat64.QR(explVariables)
q := qr.Q()
reg := qr.R()
var transposed, qty mat64.Dense
transposed.TCopy(q)
qty.Mul(&transposed, observed)
regressionCoefficients := make([]float64, n)
for i := n - 1; i >= 0; i-- {
regressionCoefficients[i] = qty.At(i, 0)
for j := i + 1; j < n; j++ {
regressionCoefficients[i] -= regressionCoefficients[j] * reg.At(i, j)
}
regressionCoefficients[i] /= reg.At(i, i)
}
lr.disturbance = regressionCoefficients[0]
lr.regressionCoefficients = regressionCoefficients[1:]
lr.fitted = true
2014-08-03 15:05:35 +01:00
lr.attrs = attrs
lr.cls = classAttrs[0]
2014-07-19 16:02:11 +03:00
return nil
}
2014-08-03 15:05:35 +01:00
func (lr *LinearRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
2014-07-19 16:02:11 +03:00
if !lr.fitted {
return nil, NoTrainingDataError
}
2014-08-03 15:05:35 +01:00
ret := base.GeneratePredictionVector(X)
attrSpecs := base.ResolveAttributes(X, lr.attrs)
clsSpec, err := ret.GetAttribute(lr.cls)
if err != nil {
return nil, err
}
X.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
2014-07-19 16:02:11 +03:00
var prediction float64 = lr.disturbance
2014-08-03 15:05:35 +01:00
for j, r := range row {
prediction += base.UnpackBytesToFloat(r) * lr.regressionCoefficients[j]
2014-07-19 16:02:11 +03:00
}
2014-08-03 15:05:35 +01:00
ret.Set(clsSpec, i, base.PackFloatToBytes(prediction))
return true, nil
})
2014-07-19 16:02:11 +03:00
return ret, nil
}