mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-30 13:48:57 +08:00
Merge pull request #135 from Sentimentron/inline-training-data
Support the use of mat64.Dense as an instance type
This commit is contained in:
commit
855df3a7fa
170
base/mat.go
Normal file
170
base/mat.go
Normal file
@ -0,0 +1,170 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"bytes"
|
||||
)
|
||||
|
||||
type Mat64Instances struct {
|
||||
attributes []Attribute
|
||||
classAttrs map[int]bool
|
||||
Data *mat64.Dense
|
||||
rows int
|
||||
}
|
||||
|
||||
// InstancesFromMat64 returns a new Mat64Instances from a literal provided.
|
||||
func InstancesFromMat64(rows, cols int, data *mat64.Dense) *Mat64Instances {
|
||||
|
||||
var ret Mat64Instances
|
||||
for i := 0; i < cols; i++ {
|
||||
ret.attributes = append(ret.attributes, NewFloatAttribute(fmt.Sprintf("%d", i)))
|
||||
}
|
||||
|
||||
ret.classAttrs = make(map[int]bool)
|
||||
ret.Data = data
|
||||
ret.rows = rows
|
||||
|
||||
ret.AddClassAttribute(ret.attributes[len(ret.attributes)-1])
|
||||
|
||||
return &ret
|
||||
}
|
||||
|
||||
// GetAttribute returns an AttributeSpec from an Attribute field.
|
||||
func (m *Mat64Instances) GetAttribute(a Attribute) (AttributeSpec, error) {
|
||||
for i, at := range m.attributes {
|
||||
if at.Equals(a) {
|
||||
return AttributeSpec{0, i, at}, nil
|
||||
}
|
||||
}
|
||||
return AttributeSpec{}, fmt.Errorf("Couldn't find a matching attribute")
|
||||
}
|
||||
|
||||
// AllAttributes returns every defined Attribute.
|
||||
func (m *Mat64Instances) AllAttributes() []Attribute {
|
||||
ret := make([]Attribute, len(m.attributes))
|
||||
for i, a := range m.attributes {
|
||||
ret[i] = a
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddClassAttribute adds an attribute to the class set.
|
||||
func (m *Mat64Instances) AddClassAttribute(a Attribute) error {
|
||||
as, err := m.GetAttribute(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.classAttrs[as.position] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveClassAttribute removes an attribute to the class set.
|
||||
func (m *Mat64Instances) RemoveClassAttribute(a Attribute) error {
|
||||
as, err := m.GetAttribute(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.classAttrs[as.position] = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllClassAttributes returns every class attribute.
|
||||
func (m *Mat64Instances) AllClassAttributes() []Attribute {
|
||||
ret := make([]Attribute, 0)
|
||||
for i := range m.classAttrs {
|
||||
if m.classAttrs[i] {
|
||||
ret = append(ret, m.attributes[i])
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Get returns the bytes at a given position
|
||||
func (m *Mat64Instances) Get(as AttributeSpec, row int) []byte {
|
||||
val := m.Data.At(row, as.position)
|
||||
return PackFloatToBytes(val)
|
||||
}
|
||||
|
||||
// MapOverRows is a convenience function for iteration
|
||||
func (m *Mat64Instances) MapOverRows(as []AttributeSpec, f func([][]byte, int) (bool, error)) error {
|
||||
|
||||
rowData := make([][]byte, len(as))
|
||||
for j, _ := range as {
|
||||
rowData[j] = make([]byte, 8)
|
||||
}
|
||||
for i := 0; i < m.rows; i++ {
|
||||
for j, as := range as {
|
||||
PackFloatToBytesInline(m.Data.At(i, as.position), rowData[j])
|
||||
}
|
||||
stat, err := f(rowData, i)
|
||||
if !stat {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RowString: should print the values of a row
|
||||
// TODO: make this less half-assed
|
||||
func (m *Mat64Instances) RowString(row int) string {
|
||||
return fmt.Sprintf("%d", row)
|
||||
}
|
||||
|
||||
// Size returns the number of Attributes, then the number of rows
|
||||
func (m *Mat64Instances) Size() (int, int) {
|
||||
return len(m.attributes), m.rows
|
||||
}
|
||||
|
||||
// String returns a human-readable summary of this dataset.
|
||||
func (m *Mat64Instances) String() string {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
// Get all Attribute information
|
||||
as := ResolveAllAttributes(m)
|
||||
|
||||
// Print header
|
||||
cols, rows := m.Size()
|
||||
buffer.WriteString("Instances with ")
|
||||
buffer.WriteString(fmt.Sprintf("%d row(s) ", rows))
|
||||
buffer.WriteString(fmt.Sprintf("%d attribute(s)\n", cols))
|
||||
buffer.WriteString(fmt.Sprintf("Attributes: \n"))
|
||||
|
||||
cnt := 0
|
||||
for _, a := range as {
|
||||
prefix := "\t"
|
||||
if m.classAttrs[cnt] {
|
||||
prefix = "*\t"
|
||||
}
|
||||
cnt++
|
||||
buffer.WriteString(fmt.Sprintf("%s%s\n", prefix, a.attr))
|
||||
}
|
||||
|
||||
buffer.WriteString("\nData:\n")
|
||||
maxRows := 30
|
||||
if rows < maxRows {
|
||||
maxRows = rows
|
||||
}
|
||||
|
||||
for i := 0; i < maxRows; i++ {
|
||||
buffer.WriteString("\t")
|
||||
for _, a := range as {
|
||||
val := m.Get(a, i)
|
||||
buffer.WriteString(fmt.Sprintf("%s ", a.attr.GetStringFromSysVal(val)))
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
|
||||
missingRows := rows - maxRows
|
||||
if missingRows != 0 {
|
||||
buffer.WriteString(fmt.Sprintf("\t...\n%d row(s) undisplayed", missingRows))
|
||||
} else {
|
||||
buffer.WriteString("All rows displayed")
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
||||
|
39
base/mat_test.go
Normal file
39
base/mat_test.go
Normal file
@ -0,0 +1,39 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"github.com/gonum/matrix/mat64"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInlineMat64Creation(t *testing.T) {
|
||||
|
||||
Convey("Given a literal array...", t, func() {
|
||||
mat := mat64.NewDense(4, 3, []float64{
|
||||
1, 0, 1,
|
||||
0, 1, 1,
|
||||
0, 0, 0,
|
||||
1, 1, 0,
|
||||
})
|
||||
inst := InstancesFromMat64(4, 3, mat)
|
||||
attrs := inst.AllAttributes()
|
||||
Convey("Attributes should be well-defined...", func() {
|
||||
So(len(attrs), ShouldEqual, 3)
|
||||
})
|
||||
|
||||
Convey("No class variables set by default...", func() {
|
||||
classAttrs := inst.AllClassAttributes()
|
||||
So(len(classAttrs), ShouldEqual, 0)
|
||||
})
|
||||
|
||||
Convey("Getting values should work...", func() {
|
||||
as, err := inst.GetAttribute(attrs[0])
|
||||
So(err, ShouldBeNil)
|
||||
valBytes := inst.Get(as, 3)
|
||||
val := UnpackBytesToFloat(valBytes)
|
||||
So(val, ShouldAlmostEqual, 1.0)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
}
|
@ -155,3 +155,60 @@ func TestLayeredXOR(t *testing.T) {
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestLayeredXORInline(t *testing.T) {
|
||||
|
||||
Convey("Given an inline XOR dataset...", t, func() {
|
||||
|
||||
data := mat64.NewDense(4, 3, []float64{
|
||||
1, 0, 1,
|
||||
0, 1, 1,
|
||||
0, 0, 0,
|
||||
1, 1, 0,
|
||||
})
|
||||
|
||||
XORData := base.InstancesFromMat64(4, 3, data)
|
||||
classAttr := base.GetAttributeByName(XORData, "2")
|
||||
XORData.AddClassAttribute(classAttr)
|
||||
|
||||
net := NewMultiLayerNet([]int{3})
|
||||
net.MaxIterations = 20000
|
||||
net.Fit(XORData)
|
||||
|
||||
Convey("After running for 20000 iterations, should have some predictive power...", func() {
|
||||
|
||||
Convey("The right nodes should be connected in the network...", func() {
|
||||
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
|
||||
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
|
||||
|
||||
for i := 1; i <= 6; i++ {
|
||||
So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
|
||||
}
|
||||
|
||||
})
|
||||
out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
|
||||
net.network.Activate(out, 2)
|
||||
So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)
|
||||
|
||||
Convey("And Predict() should do OK too...", func() {
|
||||
|
||||
pred := net.Predict(XORData)
|
||||
|
||||
for _, a := range pred.AllAttributes() {
|
||||
af, ok := a.(*base.FloatAttribute)
|
||||
So(ok, ShouldBeTrue)
|
||||
|
||||
af.Precision = 1
|
||||
}
|
||||
|
||||
So(base.GetClass(pred, 0), ShouldEqual, "1.0")
|
||||
So(base.GetClass(pred, 1), ShouldEqual, "1.0")
|
||||
So(base.GetClass(pred, 2), ShouldEqual, "0.0")
|
||||
So(base.GetClass(pred, 3), ShouldEqual, "0.0")
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user