1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00

meta: tests passing

This commit is contained in:
Richard Townsend 2017-09-10 17:43:17 +01:00
parent 768d2cd19f
commit c18d50d217
5 changed files with 153 additions and 42 deletions

View File

@ -293,10 +293,6 @@ func SerializeInstancesToTarWriter(inst FixedDataGrid, tw *tar.Writer, prefix st
return fmt.Errorf("Could not write ATTRS: %s", err)
}
if !includeData {
return nil
}
// Data must be written out in the same order as the Attributes
allAttrs := make([]Attribute, attrCount)
normCount := copy(allAttrs, normalAttrs)
@ -323,6 +319,11 @@ func SerializeInstancesToTarWriter(inst FixedDataGrid, tw *tar.Writer, prefix st
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("Could not write DATA: %s", err)
}
tw.Flush()
if !includeData {
return nil
}
// Then write the actual data
writtenLength := int64(0)

View File

@ -4,6 +4,8 @@ package linear_models
#include "linear.h"
*/
import "C"
import "fmt"
import "unsafe"
type Problem struct {
c_prob C.struct_problem
@ -14,7 +16,7 @@ type Parameter struct {
}
type Model struct {
c_model *C.struct_model
c_model unsafe.Pointer
}
const (
@ -58,12 +60,30 @@ func NewProblem(X [][]float64, y []float64, bias float64) *Problem {
func Train(prob *Problem, param *Parameter) *Model {
libLinearHookPrintFunc() // Sets up logging
return &Model{C.train(&prob.c_prob, &param.c_param)}
tmpCProb := &prob.c_prob
tmpCParam := &param.c_param
return &Model{unsafe.Pointer(C.train(tmpCProb, tmpCParam))}
}
func Export(model *Model, filePath string) error {
status := C.save_model(C.CString(filePath), (*C.struct_model)(model.c_model))
if status != 0 {
return fmt.Errorf("Problem occured during export to %s (status was %d)", filePath, status)
}
return nil
}
func Load(model *Model, filePath string) error {
model.c_model = unsafe.Pointer(C.load_model(C.CString(filePath)))
if model.c_model == nil {
return fmt.Errorf("Something went wrong")
}
return nil
}
func Predict(model *Model, x []float64) float64 {
c_x := convert_vector(x, 0)
c_y := C.predict(model.c_model, c_x)
c_y := C.predict((*C.struct_model)(model.c_model), c_x)
y := float64(c_y)
return y
}

View File

@ -1,9 +1,9 @@
package linear_models
import (
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestLogisticRegression(t *testing.T) {

View File

@ -6,6 +6,9 @@ import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"unsafe"
"encoding/json"
"io/ioutil"
"os"
)
// LinearSVCParams represnts all available LinearSVC options.
@ -211,6 +214,119 @@ func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
return ret, nil
}
func (lr *LinearSVC) GetMetadata() base.ClassifierMetadataV1 {
params, err := json.Marshal(lr.Param)
if err != nil {
panic(err)
}
classifierParams := make(map[string]interface{})
err = json.Unmarshal(params, &classifierParams)
if err != nil {
panic(err)
}
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "LinearSVC",
ClassifierVersion: "1.0",
ClassifierMetadata: classifierParams,
}
}
// Save outputs this classifier
func (lr *LinearSVC) Save(filePath string) error {
writer, err := base.CreateSerializedClassifierStub(filePath, lr.GetMetadata())
if err != nil {
return err
}
defer func() {
writer.Close()
}()
fmt.Printf("writer: %v", writer)
return lr.SaveWithPrefix(writer, "")
}
func (lr *LinearSVC) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
params, err := json.Marshal(lr.Param)
if err != nil {
return base.DescribeError("Error marshalling parameters", err)
}
classifierParams := make(map[string]interface{})
err = json.Unmarshal(params, &classifierParams)
if err != nil {
return base.DescribeError("Error marshalling parameters", err)
}
f, err := ioutil.TempFile(os.TempDir(), "liblinear")
defer func() {
f.Close()
}()
err = Export(lr.model, f.Name())
if err != nil {
return base.DescribeError("Error exporting model", err)
}
f.Seek(0, os.SEEK_SET)
bytes, err := ioutil.ReadAll(f)
if err != nil {
return base.DescribeError("Error reading model in again", err)
}
err = writer.WriteBytesForKey(writer.Prefix(prefix, "PARAMETERS"), params)
if err != nil {
return base.DescribeError("Error writing model parameters", err)
}
writer.WriteBytesForKey(writer.Prefix(prefix, "MODEL"), bytes)
if err != nil {
return base.DescribeError("Error writing model", err)
}
return nil
}
func (lr *LinearSVC) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
defer func() {
reader.Close()
}()
return lr.LoadWithPrefix(reader, "")
}
func (lr *LinearSVC) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
err := reader.GetJSONForKey(reader.Prefix(prefix, "PARAMETERS"), &lr.Param)
if err != nil {
return base.DescribeError("Error reading PARAMETERS", err)
}
modelBytes, err := reader.GetBytesForKey(reader.Prefix(prefix, "MODEL"))
if err != nil {
return base.DescribeError("Error reading MODEL", err)
}
f, err := ioutil.TempFile(os.TempDir(), "linear")
defer func() {
f.Close()
}()
f.WriteAt(modelBytes, 0)
lr.model = &Model{}
err = Load(lr.model, f.Name())
if err != nil {
return base.DescribeError("Unable to reload the model", err)
}
return nil
}
// String return a humaan-readable version.
func (lr *LinearSVC) String() string {
return "LogisticSVC"

View File

@ -178,11 +178,8 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
if err != nil {
return base.FormatError(err, "Can't resolve this attribute: %s", keyAttrRaw)
}
valAttr, err := base.ReplaceDeserializedAttributeWithVersionFromInstances(valAttrRaw, m.fitOn)
if err != nil {
return base.FormatError(err, "Can't resolve this attribute: %s", valAttrRaw)
}
attrMap[keyAttr] = valAttr
attrMap[keyAttr] = valAttrRaw
}
f.attrs = attrMap
mapClassKey := reader.Prefix(mapPrefix, "CLASS_ATTR")
@ -206,7 +203,7 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
}
// Reload the class values
var classVals = make([]string, 0)
err = reader.GetJSONForKey(reader.Prefix(prefix, "CLASS_VALUES"), classVals)
err = reader.GetJSONForKey(reader.Prefix(prefix, "CLASS_VALUES"), &classVals)
if err != nil {
return base.DescribeError("Can't read CLASS_VALUES", err)
}
@ -217,34 +214,10 @@ func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, pref
for i, c := range classVals {
cls := m.NewClassifierFunction(c)
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
clsBodyPrefix := reader.Prefix(clsPrefix, "CLS")
err = cls.LoadWithPrefix(reader, clsBodyPrefix)
err = cls.LoadWithPrefix(reader, clsPrefix)
if err != nil {
return base.FormatError(err, "Could not reload classifier at: %s", clsBodyPrefix)
}
m.classifiers = append(m.classifiers, cls)
}
numClassifiersU64, err := reader.GetU64ForKey(reader.Prefix(prefix, "CLASSIFIER_COUNT"))
if err != nil {
return base.DescribeError("Can't load CLASSIFIER_COUNT", err)
}
numClassifiers := int(numClassifiersU64)
for i := 0; i < numClassifiers; i++ {
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
clsStringPrefix := reader.Prefix(clsPrefix, "STRING")
clsBodyPrefix := reader.Prefix(clsPrefix, "CLS")
str, err := reader.GetStringForKey(clsStringPrefix)
if err != nil {
return base.FormatError(err, "Could not read class from: %s", clsStringPrefix)
}
cls := m.NewClassifierFunction(str)
err = cls.LoadWithPrefix(reader, clsBodyPrefix)
if err != nil {
return base.FormatError(err, "Could not reload classifier at: %s", clsBodyPrefix)
return base.FormatError(err, "Could not reload classifier at: %s", clsPrefix)
}
m.classifiers = append(m.classifiers, cls)
}
@ -313,6 +286,7 @@ func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix
if err != nil {
return base.DescribeError("Unable to write filter map value", err)
}
j++
}
mapClassKey := writer.Prefix(mapPrefix, "CLASS_ATTR")
err = writer.WriteAttributeForKey(mapClassKey, f.classAttr)