1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

cela2 changed libliner.go and linear_regression.go

This commit is contained in:
cela2 2022-05-26 21:49:38 +08:00
parent a8b69c276c
commit e91c24e7b8
2 changed files with 15 additions and 14 deletions

View File

@ -30,11 +30,11 @@ const (
L2R_LR_DUAL = C.L2R_LR_DUAL
)
func NewParameter(solver_type int, C float64, eps float64) *Parameter {
func NewParameter(solver_type int, f float64, eps float64) *Parameter {
param := Parameter{}
param.c_param.solver_type = C.int(solver_type)
param.c_param.eps = C.double(eps)
param.c_param.C = C.double(C)
param.c_param.C = C.double(f)
param.c_param.nr_weight = C.int(0)
param.c_param.weight_label = nil
param.c_param.weight = nil

View File

@ -6,6 +6,7 @@ import (
"github.com/sjwhitworth/golearn/base"
"fmt"
_ "github.com/gonum/blas"
"gonum.org/v1/gonum/mat"
)
@ -17,10 +18,10 @@ var (
type LinearRegression struct {
fitted bool
disturbance float64
regressionCoefficients []float64
attrs []base.Attribute
cls base.Attribute
Disturbance float64
RegressionCoefficients []float64
Attrs []base.Attribute
Cls base.Attribute
}
func NewLinearRegression() *LinearRegression {
@ -99,11 +100,11 @@ func (lr *LinearRegression) Fit(inst base.FixedDataGrid) error {
regressionCoefficients[i] /= reg.At(i, i)
}
lr.disturbance = regressionCoefficients[0]
lr.regressionCoefficients = regressionCoefficients[1:]
lr.Disturbance = regressionCoefficients[0]
lr.RegressionCoefficients = regressionCoefficients[1:]
lr.fitted = true
lr.attrs = attrs
lr.cls = classAttrs[0]
lr.Attrs = attrs
lr.Cls = classAttrs[0]
return nil
}
@ -113,16 +114,16 @@ func (lr *LinearRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, e
}
ret := base.GeneratePredictionVector(X)
attrSpecs := base.ResolveAttributes(X, lr.attrs)
clsSpec, err := ret.GetAttribute(lr.cls)
attrSpecs := base.ResolveAttributes(X, lr.Attrs)
clsSpec, err := ret.GetAttribute(lr.Cls)
if err != nil {
return nil, err
}
X.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
var prediction float64 = lr.disturbance
var prediction float64 = lr.Disturbance
for j, r := range row {
prediction += base.UnpackBytesToFloat(r) * lr.regressionCoefficients[j]
prediction += base.UnpackBytesToFloat(r) * lr.RegressionCoefficients[j]
}
ret.Set(clsSpec, i, base.PackFloatToBytes(prediction))