From e91c24e7b83107e3ed08ab35c97ac98c43ddb713 Mon Sep 17 00:00:00 2001 From: cela2 <1306089453@qq.com> Date: Thu, 26 May 2022 21:49:38 +0800 Subject: [PATCH] cela2 changed libliner.go and linear_regression.go --- linear_models/liblinear.go | 4 ++-- linear_models/linear_regression.go | 25 +++++++++++++------------ 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/linear_models/liblinear.go b/linear_models/liblinear.go index 47e7456..4319d1b 100644 --- a/linear_models/liblinear.go +++ b/linear_models/liblinear.go @@ -30,11 +30,11 @@ const ( L2R_LR_DUAL = C.L2R_LR_DUAL ) -func NewParameter(solver_type int, C float64, eps float64) *Parameter { +func NewParameter(solver_type int, f float64, eps float64) *Parameter { param := Parameter{} param.c_param.solver_type = C.int(solver_type) param.c_param.eps = C.double(eps) - param.c_param.C = C.double(C) + param.c_param.C = C.double(f) param.c_param.nr_weight = C.int(0) param.c_param.weight_label = nil param.c_param.weight = nil diff --git a/linear_models/linear_regression.go b/linear_models/linear_regression.go index eb52c4e..fe3daf7 100644 --- a/linear_models/linear_regression.go +++ b/linear_models/linear_regression.go @@ -6,6 +6,7 @@ import ( "github.com/sjwhitworth/golearn/base" "fmt" + _ "github.com/gonum/blas" "gonum.org/v1/gonum/mat" ) @@ -17,10 +18,10 @@ var ( type LinearRegression struct { fitted bool - disturbance float64 - regressionCoefficients []float64 - attrs []base.Attribute - cls base.Attribute + Disturbance float64 + RegressionCoefficients []float64 + Attrs []base.Attribute + Cls base.Attribute } func NewLinearRegression() *LinearRegression { @@ -99,11 +100,11 @@ func (lr *LinearRegression) Fit(inst base.FixedDataGrid) error { regressionCoefficients[i] /= reg.At(i, i) } - lr.disturbance = regressionCoefficients[0] - lr.regressionCoefficients = regressionCoefficients[1:] + lr.Disturbance = regressionCoefficients[0] + lr.RegressionCoefficients = regressionCoefficients[1:] lr.fitted = true - lr.attrs = attrs - lr.cls = classAttrs[0] + lr.Attrs = attrs + lr.Cls = classAttrs[0] return nil } @@ -113,16 +114,16 @@ func (lr *LinearRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, e } ret := base.GeneratePredictionVector(X) - attrSpecs := base.ResolveAttributes(X, lr.attrs) - clsSpec, err := ret.GetAttribute(lr.cls) + attrSpecs := base.ResolveAttributes(X, lr.Attrs) + clsSpec, err := ret.GetAttribute(lr.Cls) if err != nil { return nil, err } X.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) { - var prediction float64 = lr.disturbance + var prediction float64 = lr.Disturbance for j, r := range row { - prediction += base.UnpackBytesToFloat(r) * lr.regressionCoefficients[j] + prediction += base.UnpackBytesToFloat(r) * lr.RegressionCoefficients[j] } ret.Set(clsSpec, i, base.PackFloatToBytes(prediction))