mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
meta: tests passing
This commit is contained in:
parent
768d2cd19f
commit
c18d50d217
@ -293,10 +293,6 @@ func SerializeInstancesToTarWriter(inst FixedDataGrid, tw *tar.Writer, prefix st
|
||||
return fmt.Errorf("Could not write ATTRS: %s", err)
|
||||
}
|
||||
|
||||
if !includeData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data must be written out in the same order as the Attributes
|
||||
allAttrs := make([]Attribute, attrCount)
|
||||
normCount := copy(allAttrs, normalAttrs)
|
||||
@ -323,6 +319,11 @@ func SerializeInstancesToTarWriter(inst FixedDataGrid, tw *tar.Writer, prefix st
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("Could not write DATA: %s", err)
|
||||
}
|
||||
tw.Flush()
|
||||
|
||||
if !includeData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Then write the actual data
|
||||
writtenLength := int64(0)
|
||||
|
@ -4,6 +4,8 @@ package linear_models
|
||||
#include "linear.h"
|
||||
*/
|
||||
import "C"
|
||||
import "fmt"
|
||||
import "unsafe"
|
||||
|
||||
type Problem struct {
|
||||
c_prob C.struct_problem
|
||||
@ -14,7 +16,7 @@ type Parameter struct {
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
c_model *C.struct_model
|
||||
c_model unsafe.Pointer
|
||||
}
|
||||
|
||||
const (
|
||||
@ -58,12 +60,30 @@ func NewProblem(X [][]float64, y []float64, bias float64) *Problem {
|
||||
|
||||
func Train(prob *Problem, param *Parameter) *Model {
|
||||
libLinearHookPrintFunc() // Sets up logging
|
||||
return &Model{C.train(&prob.c_prob, ¶m.c_param)}
|
||||
tmpCProb := &prob.c_prob
|
||||
tmpCParam := ¶m.c_param
|
||||
return &Model{unsafe.Pointer(C.train(tmpCProb, tmpCParam))}
|
||||
}
|
||||
|
||||
func Export(model *Model, filePath string) error {
|
||||
status := C.save_model(C.CString(filePath), (*C.struct_model)(model.c_model))
|
||||
if status != 0 {
|
||||
return fmt.Errorf("Problem occured during export to %s (status was %d)", filePath, status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Load(model *Model, filePath string) error {
|
||||
model.c_model = unsafe.Pointer(C.load_model(C.CString(filePath)))
|
||||
if model.c_model == nil {
|
||||
return fmt.Errorf("Something went wrong")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Predict(model *Model, x []float64) float64 {
|
||||
c_x := convert_vector(x, 0)
|
||||
c_y := C.predict(model.c_model, c_x)
|
||||
c_y := C.predict((*C.struct_model)(model.c_model), c_x)
|
||||
y := float64(c_y)
|
||||
return y
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLogisticRegression(t *testing.T) {
|
||||
|
@ -6,6 +6,9 @@ import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"unsafe"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// LinearSVCParams represnts all available LinearSVC options.
|
||||
@ -211,6 +214,119 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (lr *LinearSVC) GetMetadata() base.ClassifierMetadataV1 {
|
||||
|
||||
params, err := json.Marshal(lr.Param)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
classifierParams := make(map[string]interface{})
|
||||
err = json.Unmarshal(params, &classifierParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return base.ClassifierMetadataV1{
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "LinearSVC",
|
||||
ClassifierVersion: "1.0",
|
||||
ClassifierMetadata: classifierParams,
|
||||
}
|
||||
}
|
||||
|
||||
// Save outputs this classifier
|
||||
func (lr *LinearSVC) Save(filePath string) error {
|
||||
writer, err := base.CreateSerializedClassifierStub(filePath, lr.GetMetadata())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
writer.Close()
|
||||
}()
|
||||
fmt.Printf("writer: %v", writer)
|
||||
return lr.SaveWithPrefix(writer, "")
|
||||
}
|
||||
|
||||
func (lr *LinearSVC) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
||||
params, err := json.Marshal(lr.Param)
|
||||
if err != nil {
|
||||
return base.DescribeError("Error marshalling parameters", err)
|
||||
}
|
||||
|
||||
classifierParams := make(map[string]interface{})
|
||||
err = json.Unmarshal(params, &classifierParams)
|
||||
if err != nil {
|
||||
return base.DescribeError("Error marshalling parameters", err)
|
||||
}
|
||||
|
||||
f, err := ioutil.TempFile(os.TempDir(), "liblinear")
|
||||
defer func() {
|
||||
f.Close()
|
||||
}()
|
||||
|
||||
err = Export(lr.model, f.Name())
|
||||
if err != nil {
|
||||
return base.DescribeError("Error exporting model", err)
|
||||
}
|
||||
|
||||
f.Seek(0, os.SEEK_SET)
|
||||
bytes, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return base.DescribeError("Error reading model in again", err)
|
||||
}
|
||||
|
||||
err = writer.WriteBytesForKey(writer.Prefix(prefix, "PARAMETERS"), params)
|
||||
if err != nil {
|
||||
return base.DescribeError("Error writing model parameters", err)
|
||||
}
|
||||
writer.WriteBytesForKey(writer.Prefix(prefix, "MODEL"), bytes)
|
||||
if err != nil {
|
||||
return base.DescribeError("Error writing model", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lr *LinearSVC) Load(filePath string) error {
|
||||
reader, err := base.ReadSerializedClassifierStub(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
reader.Close()
|
||||
}()
|
||||
|
||||
return lr.LoadWithPrefix(reader, "")
|
||||
}
|
||||
|
||||
func (lr *LinearSVC) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
err := reader.GetJSONForKey(reader.Prefix(prefix, "PARAMETERS"), &lr.Param)
|
||||
if err != nil {
|
||||
return base.DescribeError("Error reading PARAMETERS", err)
|
||||
}
|
||||
|
||||
modelBytes, err := reader.GetBytesForKey(reader.Prefix(prefix, "MODEL"))
|
||||
if err != nil {
|
||||
return base.DescribeError("Error reading MODEL", err)
|
||||
}
|
||||
|
||||
f, err := ioutil.TempFile(os.TempDir(), "linear")
|
||||
defer func() {
|
||||
f.Close()
|
||||
}()
|
||||
|
||||
f.WriteAt(modelBytes, 0)
|
||||
|
||||
lr.model = &Model{}
|
||||
|
||||
err = Load(lr.model, f.Name())
|
||||
if err != nil {
|
||||
return base.DescribeError("Unable to reload the model", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// String return a humaan-readable version.
|
||||
func (lr *LinearSVC) String() string {
|
||||
return "LogisticSVC"
|
||||
|
@ -178,11 +178,8 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
|
||||
if err != nil {
|
||||
return base.FormatError(err, "Can't resolve this attribute: %s", keyAttrRaw)
|
||||
}
|
||||
valAttr, err := base.ReplaceDeserializedAttributeWithVersionFromInstances(valAttrRaw, m.fitOn)
|
||||
if err != nil {
|
||||
return base.FormatError(err, "Can't resolve this attribute: %s", valAttrRaw)
|
||||
}
|
||||
attrMap[keyAttr] = valAttr
|
||||
|
||||
attrMap[keyAttr] = valAttrRaw
|
||||
}
|
||||
f.attrs = attrMap
|
||||
mapClassKey := reader.Prefix(mapPrefix, "CLASS_ATTR")
|
||||
@ -206,7 +203,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
|
||||
}
|
||||
// Reload the class values
|
||||
var classVals = make([]string, 0)
|
||||
err = reader.GetJSONForKey(reader.Prefix(prefix, "CLASS_VALUES"), classVals)
|
||||
err = reader.GetJSONForKey(reader.Prefix(prefix, "CLASS_VALUES"), &classVals)
|
||||
if err != nil {
|
||||
return base.DescribeError("Can't read CLASS_VALUES", err)
|
||||
}
|
||||
@ -217,34 +214,10 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
|
||||
for i, c := range classVals {
|
||||
cls := m.NewClassifierFunction(c)
|
||||
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
|
||||
clsBodyPrefix := reader.Prefix(clsPrefix, "CLS")
|
||||
|
||||
err = cls.LoadWithPrefix(reader, clsBodyPrefix)
|
||||
err = cls.LoadWithPrefix(reader, clsPrefix)
|
||||
if err != nil {
|
||||
return base.FormatError(err, "Could not reload classifier at: %s", clsBodyPrefix)
|
||||
}
|
||||
m.classifiers = append(m.classifiers, cls)
|
||||
}
|
||||
|
||||
numClassifiersU64, err := reader.GetU64ForKey(reader.Prefix(prefix, "CLASSIFIER_COUNT"))
|
||||
if err != nil {
|
||||
return base.DescribeError("Can't load CLASSIFIER_COUNT", err)
|
||||
}
|
||||
numClassifiers := int(numClassifiersU64)
|
||||
|
||||
for i := 0; i < numClassifiers; i++ {
|
||||
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
|
||||
clsStringPrefix := reader.Prefix(clsPrefix, "STRING")
|
||||
clsBodyPrefix := reader.Prefix(clsPrefix, "CLS")
|
||||
str, err := reader.GetStringForKey(clsStringPrefix)
|
||||
if err != nil {
|
||||
return base.FormatError(err, "Could not read class from: %s", clsStringPrefix)
|
||||
}
|
||||
|
||||
cls := m.NewClassifierFunction(str)
|
||||
err = cls.LoadWithPrefix(reader, clsBodyPrefix)
|
||||
if err != nil {
|
||||
return base.FormatError(err, "Could not reload classifier at: %s", clsBodyPrefix)
|
||||
return base.FormatError(err, "Could not reload classifier at: %s", clsPrefix)
|
||||
}
|
||||
m.classifiers = append(m.classifiers, cls)
|
||||
}
|
||||
@ -313,6 +286,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
|
||||
if err != nil {
|
||||
return base.DescribeError("Unable to write filter map value", err)
|
||||
}
|
||||
j++
|
||||
}
|
||||
mapClassKey := writer.Prefix(mapPrefix, "CLASS_ATTR")
|
||||
err = writer.WriteAttributeForKey(mapClassKey, f.classAttr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user