mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
neural: check that the new dense instances type works...
This commit is contained in:
parent
590d7a8091
commit
6f7326b6ff
@ -21,6 +21,7 @@ func InstancesFromMat64(rows, cols int, data *mat64.Dense) *Mat64Instances {
|
|||||||
i)))
|
i)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ret.classAttrs = make(map[int]bool)
|
||||||
ret.Data = data
|
ret.Data = data
|
||||||
ret.rows = rows
|
ret.rows = rows
|
||||||
return &ret
|
return &ret
|
||||||
|
@ -155,3 +155,60 @@ func TestLayeredXOR(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLayeredXORInline(t *testing.T) {
|
||||||
|
|
||||||
|
Convey("Given an inline XOR dataset...", t, func() {
|
||||||
|
|
||||||
|
data := mat64.NewDense(4, 3, []float64{
|
||||||
|
1, 0, 1,
|
||||||
|
0, 1, 1,
|
||||||
|
0, 0, 0,
|
||||||
|
1, 1, 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
XORData := base.InstancesFromMat64(4, 3, data)
|
||||||
|
classAttr := base.GetAttributeByName(XORData, "2")
|
||||||
|
XORData.AddClassAttribute(classAttr)
|
||||||
|
|
||||||
|
net := NewMultiLayerNet([]int{3})
|
||||||
|
net.MaxIterations = 20000
|
||||||
|
net.Fit(XORData)
|
||||||
|
|
||||||
|
Convey("After running for 20000 iterations, should have some predictive power...", func() {
|
||||||
|
|
||||||
|
Convey("The right nodes should be connected in the network...", func() {
|
||||||
|
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
|
||||||
|
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
|
||||||
|
|
||||||
|
for i := 1; i <= 6; i++ {
|
||||||
|
So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
|
||||||
|
net.network.Activate(out, 2)
|
||||||
|
So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)
|
||||||
|
|
||||||
|
Convey("And Predict() should do OK too...", func() {
|
||||||
|
|
||||||
|
pred := net.Predict(XORData)
|
||||||
|
|
||||||
|
for _, a := range pred.AllAttributes() {
|
||||||
|
af, ok := a.(*base.FloatAttribute)
|
||||||
|
So(ok, ShouldBeTrue)
|
||||||
|
|
||||||
|
af.Precision = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
So(base.GetClass(pred, 0), ShouldEqual, "1.0")
|
||||||
|
So(base.GetClass(pred, 1), ShouldEqual, "1.0")
|
||||||
|
So(base.GetClass(pred, 2), ShouldEqual, "0.0")
|
||||||
|
So(base.GetClass(pred, 3), ShouldEqual, "0.0")
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user