From 6f7326b6ff5fe1682de5fee5c36b89da652939c6 Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Sat, 14 May 2016 23:07:04 +0100 Subject: [PATCH] neural: check that the new dense instances type works... --- base/mat.go | 1 + neural/layered_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/base/mat.go b/base/mat.go index 8258b07..c608558 100644 --- a/base/mat.go +++ b/base/mat.go @@ -21,6 +21,7 @@ func InstancesFromMat64(rows, cols int, data *mat64.Dense) *Mat64Instances { i))) } + ret.classAttrs = make(map[int]bool) ret.Data = data ret.rows = rows return &ret diff --git a/neural/layered_test.go b/neural/layered_test.go index 6081e17..0a4c496 100644 --- a/neural/layered_test.go +++ b/neural/layered_test.go @@ -155,3 +155,60 @@ func TestLayeredXOR(t *testing.T) { }) } + +func TestLayeredXORInline(t *testing.T) { + + Convey("Given an inline XOR dataset...", t, func() { + + data := mat64.NewDense(4, 3, []float64{ + 1, 0, 1, + 0, 1, 1, + 0, 0, 0, + 1, 1, 0, + }) + + XORData := base.InstancesFromMat64(4, 3, data) + classAttr := base.GetAttributeByName(XORData, "2") + XORData.AddClassAttribute(classAttr) + + net := NewMultiLayerNet([]int{3}) + net.MaxIterations = 20000 + net.Fit(XORData) + + Convey("After running for 20000 iterations, should have some predictive power...", func() { + + Convey("The right nodes should be connected in the network...", func() { + So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000) + So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000) + + for i := 1; i <= 6; i++ { + So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000) + } + + }) + out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0}) + net.network.Activate(out, 2) + So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1) + + Convey("And Predict() should do OK too...", func() { + + pred := net.Predict(XORData) + + for _, a := range pred.AllAttributes() { + af, ok := a.(*base.FloatAttribute) + So(ok, ShouldBeTrue) + + af.Precision = 1 + } + + So(base.GetClass(pred, 0), ShouldEqual, "1.0") + So(base.GetClass(pred, 1), ShouldEqual, "1.0") + So(base.GetClass(pred, 2), ShouldEqual, "0.0") + So(base.GetClass(pred, 3), ShouldEqual, "0.0") + + }) + }) + + }) + +}