mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
cela2 changed libliner.go and linear_regression.go
This commit is contained in:
parent
a8b69c276c
commit
e91c24e7b8
@ -30,11 +30,11 @@ const (
|
||||
L2R_LR_DUAL = C.L2R_LR_DUAL
|
||||
)
|
||||
|
||||
func NewParameter(solver_type int, C float64, eps float64) *Parameter {
|
||||
func NewParameter(solver_type int, f float64, eps float64) *Parameter {
|
||||
param := Parameter{}
|
||||
param.c_param.solver_type = C.int(solver_type)
|
||||
param.c_param.eps = C.double(eps)
|
||||
param.c_param.C = C.double(C)
|
||||
param.c_param.C = C.double(f)
|
||||
param.c_param.nr_weight = C.int(0)
|
||||
param.c_param.weight_label = nil
|
||||
param.c_param.weight = nil
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
|
||||
"fmt"
|
||||
|
||||
_ "github.com/gonum/blas"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
)
|
||||
@ -17,10 +18,10 @@ var (
|
||||
|
||||
type LinearRegression struct {
|
||||
fitted bool
|
||||
disturbance float64
|
||||
regressionCoefficients []float64
|
||||
attrs []base.Attribute
|
||||
cls base.Attribute
|
||||
Disturbance float64
|
||||
RegressionCoefficients []float64
|
||||
Attrs []base.Attribute
|
||||
Cls base.Attribute
|
||||
}
|
||||
|
||||
func NewLinearRegression() *LinearRegression {
|
||||
@ -99,11 +100,11 @@ func (lr *LinearRegression) Fit(inst base.FixedDataGrid) error {
|
||||
regressionCoefficients[i] /= reg.At(i, i)
|
||||
}
|
||||
|
||||
lr.disturbance = regressionCoefficients[0]
|
||||
lr.regressionCoefficients = regressionCoefficients[1:]
|
||||
lr.Disturbance = regressionCoefficients[0]
|
||||
lr.RegressionCoefficients = regressionCoefficients[1:]
|
||||
lr.fitted = true
|
||||
lr.attrs = attrs
|
||||
lr.cls = classAttrs[0]
|
||||
lr.Attrs = attrs
|
||||
lr.Cls = classAttrs[0]
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -113,16 +114,16 @@ func (lr *LinearRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, e
|
||||
}
|
||||
|
||||
ret := base.GeneratePredictionVector(X)
|
||||
attrSpecs := base.ResolveAttributes(X, lr.attrs)
|
||||
clsSpec, err := ret.GetAttribute(lr.cls)
|
||||
attrSpecs := base.ResolveAttributes(X, lr.Attrs)
|
||||
clsSpec, err := ret.GetAttribute(lr.Cls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
X.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
|
||||
var prediction float64 = lr.disturbance
|
||||
var prediction float64 = lr.Disturbance
|
||||
for j, r := range row {
|
||||
prediction += base.UnpackBytesToFloat(r) * lr.regressionCoefficients[j]
|
||||
prediction += base.UnpackBytesToFloat(r) * lr.RegressionCoefficients[j]
|
||||
}
|
||||
|
||||
ret.Set(clsSpec, i, base.PackFloatToBytes(prediction))
|
||||
|
Loading…
x
Reference in New Issue
Block a user