From 94e5843bcf5fe3023f4ad2b65bc540e1f4a64f15 Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 06:55:20 +0000 Subject: [PATCH 01/11] go fmt ./... --- base/edf/thread_test.go | 10 +-- base/float.go | 72 +++++++-------- ext/make.go | 2 +- filters/chimerge_freq.go | 2 +- filters/chimerge_funcs.go | 2 +- filters/float.go | 182 +++++++++++++++++++------------------- meta/meta.go | 2 +- 7 files changed, 136 insertions(+), 136 deletions(-) diff --git a/base/edf/thread_test.go b/base/edf/thread_test.go index e38f801..e259a2a 100644 --- a/base/edf/thread_test.go +++ b/base/edf/thread_test.go @@ -2,8 +2,8 @@ package edf import ( . "github.com/smartystreets/goconvey/convey" - "testing" "os" + "testing" ) func TestThreadDeserialize(T *testing.T) { @@ -34,20 +34,20 @@ func TestThreadSerialize(T *testing.T) { func TestThreadFindAndWrite(T *testing.T) { Convey("Creating a non-existent file should succeed", T, func() { - tempFile, err := os.OpenFile("hello.db", os.O_RDWR | os.O_TRUNC | os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate") + tempFile, err := os.OpenFile("hello.db", os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate") So(err, ShouldEqual, nil) Convey("Mapping the file should suceed", func() { mapping, err := EdfMap(tempFile, EDF_CREATE) So(err, ShouldEqual, nil) - Convey("Writing the thread should succeed", func () { + Convey("Writing the thread should succeed", func() { t := NewThread(mapping, "MyNameISWhat") - Convey("Thread number should be 3", func () { + Convey("Thread number should be 3", func() { So(t.id, ShouldEqual, 3) }) Convey("Writing the thread should succeed", func() { err := mapping.WriteThread(t) So(err, ShouldEqual, nil) - Convey("Should be able to find the thread again later", func() { + Convey("Should be able to find the thread again later", func() { id, err := mapping.FindThread("MyNameISWhat") So(err, ShouldEqual, nil) So(id, ShouldEqual, 3) diff --git a/base/float.go b/base/float.go index 879114e..fe2d66c 100644 --- a/base/float.go +++ b/base/float.go @@ -1,28 +1,28 @@ package base import ( - "fmt" - "strconv" + "fmt" + "strconv" ) // FloatAttribute is an implementation which stores floating point // representations of numbers. type FloatAttribute struct { - Name string - Precision int + Name string + Precision int } // NewFloatAttribute returns a new FloatAttribute with a default // precision of 2 decimal places func NewFloatAttribute(name string) *FloatAttribute { - return &FloatAttribute{name, 2} + return &FloatAttribute{name, 2} } // Compatable checks whether this FloatAttribute can be ponded with another // Attribute (checks if they're both FloatAttributes) func (Attr *FloatAttribute) Compatable(other Attribute) bool { - _, ok := other.(*FloatAttribute) - return ok + _, ok := other.(*FloatAttribute) + return ok } // Equals tests a FloatAttribute for equality with another Attribute. @@ -30,50 +30,50 @@ func (Attr *FloatAttribute) Compatable(other Attribute) bool { // Returns false if the other Attribute has a different name // or if the other Attribute is not a FloatAttribute. func (Attr *FloatAttribute) Equals(other Attribute) bool { - // Check whether this FloatAttribute is equal to another - _, ok := other.(*FloatAttribute) - if !ok { - // Not the same type, so can't be equal - return false - } - if Attr.GetName() != other.GetName() { - return false - } - return true + // Check whether this FloatAttribute is equal to another + _, ok := other.(*FloatAttribute) + if !ok { + // Not the same type, so can't be equal + return false + } + if Attr.GetName() != other.GetName() { + return false + } + return true } // GetName returns this FloatAttribute's human-readable name. func (Attr *FloatAttribute) GetName() string { - return Attr.Name + return Attr.Name } // SetName sets this FloatAttribute's human-readable name. func (Attr *FloatAttribute) SetName(name string) { - Attr.Name = name + Attr.Name = name } // GetType returns Float64Type. func (Attr *FloatAttribute) GetType() int { - return Float64Type + return Float64Type } // String returns a human-readable summary of this Attribute. // e.g. "FloatAttribute(Sepal Width)" func (Attr *FloatAttribute) String() string { - return fmt.Sprintf("FloatAttribute(%s)", Attr.Name) + return fmt.Sprintf("FloatAttribute(%s)", Attr.Name) } // CheckSysValFromString confirms whether a given rawVal can // be converted into a valid system representation. If it can't, // the returned value is nil. func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) ([]byte, error) { - f, err := strconv.ParseFloat(rawVal, 64) - if err != nil { - return nil, err - } + f, err := strconv.ParseFloat(rawVal, 64) + if err != nil { + return nil, err + } - ret := PackFloatToBytes(f) - return ret, nil + ret := PackFloatToBytes(f) + return ret, nil } // GetSysValFromString parses the given rawVal string to a float64 and returns it. @@ -82,22 +82,22 @@ func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) ([]byte, error) // IMPORTANT: This function panic()s if rawVal is not a valid float. // Use CheckSysValFromString to confirm. func (Attr *FloatAttribute) GetSysValFromString(rawVal string) []byte { - f, err := Attr.CheckSysValFromString(rawVal) - if err != nil { - panic(err) - } - return f + f, err := Attr.CheckSysValFromString(rawVal) + if err != nil { + panic(err) + } + return f } // GetFloatFromSysVal converts a given system value to a float func (Attr *FloatAttribute) GetFloatFromSysVal(rawVal []byte) float64 { - return UnpackBytesToFloat(rawVal) + return UnpackBytesToFloat(rawVal) } // GetStringFromSysVal converts a given system value to to a string with two decimal // places of precision. func (Attr *FloatAttribute) GetStringFromSysVal(rawVal []byte) string { - f := UnpackBytesToFloat(rawVal) - formatString := fmt.Sprintf("%%.%df", Attr.Precision) - return fmt.Sprintf(formatString, f) + f := UnpackBytesToFloat(rawVal) + formatString := fmt.Sprintf("%%.%df", Attr.Precision) + return fmt.Sprintf(formatString, f) } diff --git a/ext/make.go b/ext/make.go index 87ec96b..c45e5de 100644 --- a/ext/make.go +++ b/ext/make.go @@ -45,7 +45,7 @@ func main() { return } - os.Mkdir("lib", os.ModeDir | 0777) + os.Mkdir("lib", os.ModeDir|0777) log.Println("Installing libs") if runtime.GOOS == "windows" { diff --git a/filters/chimerge_freq.go b/filters/chimerge_freq.go index 1f1dfb8..553c959 100644 --- a/filters/chimerge_freq.go +++ b/filters/chimerge_freq.go @@ -11,4 +11,4 @@ type FrequencyTableEntry struct { func (t *FrequencyTableEntry) String() string { return fmt.Sprintf("%.2f %s", t.Value, t.Frequency) -} \ No newline at end of file +} diff --git a/filters/chimerge_funcs.go b/filters/chimerge_funcs.go index 3265037..4b0e3f3 100644 --- a/filters/chimerge_funcs.go +++ b/filters/chimerge_funcs.go @@ -1,8 +1,8 @@ package filters import ( - "github.com/sjwhitworth/golearn/base" "fmt" + "github.com/sjwhitworth/golearn/base" "math" ) diff --git a/filters/float.go b/filters/float.go index 64910bb..f80bf0b 100644 --- a/filters/float.go +++ b/filters/float.go @@ -1,8 +1,8 @@ package filters import ( - "fmt" - "github.com/sjwhitworth/golearn/base" + "fmt" + "github.com/sjwhitworth/golearn/base" ) // FloatConvertFilters convert a given DataGrid into one which @@ -14,84 +14,84 @@ import ( // CategoricalAttributes are discretised into one or more new // BinaryAttributes. type FloatConvertFilter struct { - attrs []base.Attribute - converted []base.FilteredAttribute - twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes - nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute + attrs []base.Attribute + converted []base.FilteredAttribute + twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes + nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute } // NewFloatConvertFilter creates a blank FloatConvertFilter func NewFloatConvertFilter() *FloatConvertFilter { - ret := &FloatConvertFilter{ - make([]base.Attribute, 0), - make([]base.FilteredAttribute, 0), - make(map[base.Attribute]bool), - make(map[base.Attribute]map[uint64]base.Attribute), - } - return ret + ret := &FloatConvertFilter{ + make([]base.Attribute, 0), + make([]base.FilteredAttribute, 0), + make(map[base.Attribute]bool), + make(map[base.Attribute]map[uint64]base.Attribute), + } + return ret } // AddAttribute adds a new Attribute to this Filter func (f *FloatConvertFilter) AddAttribute(a base.Attribute) error { - f.attrs = append(f.attrs, a) - return nil + f.attrs = append(f.attrs, a) + return nil } // GetAttributesAfterFiltering returns the Attributes previously computed via Train() func (f *FloatConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute { - return f.converted + return f.converted } // String gets a human-readable string func (f *FloatConvertFilter) String() string { - return fmt.Sprintf("FloatConvertFilter(%d Attribute(s))", len(f.attrs)) + return fmt.Sprintf("FloatConvertFilter(%d Attribute(s))", len(f.attrs)) } // Transform converts the given byte sequence using the old Attribute into the new // byte sequence. func (f *FloatConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte { - ret := make([]byte, 8) - // Check for CategoricalAttribute - if _, ok := a.(*base.CategoricalAttribute); ok { - // Unpack byte value - val := base.UnpackBytesToU64(attrBytes) - // If it's a two-valued one, check for non-zero - if f.twoValuedCategoricalAttributes[a] { - if val > 0 { - ret = base.PackFloatToBytes(1.0) - } else { - ret = base.PackFloatToBytes(0.0) - } - } else if an, ok := f.nValuedCategoricalAttributeMap[a]; ok { - // If it's an n-valued one, check the new Attribute maps onto - // the unpacked value - if af, ok := an[val]; ok { - if af.Equals(n) { - ret = base.PackFloatToBytes(1.0) - } else { - ret = base.PackFloatToBytes(0.0) - } - } else { - panic("Categorical value not defined!") - } - } else { - panic(fmt.Sprintf("Not a recognised Attribute %v", a)) - } - } else if _, ok := a.(*base.FloatAttribute); ok { - // Binary: just return the original value - ret = attrBytes - } else if _, ok := a.(*base.BinaryAttribute); ok { - // Float: check for non-zero - if attrBytes[0] > 0 { - ret = base.PackFloatToBytes(1.0) - } else { - ret = base.PackFloatToBytes(0.0) - } - } else { - panic(fmt.Sprintf("Unrecognised Attribute: %v", a)) - } - return ret + ret := make([]byte, 8) + // Check for CategoricalAttribute + if _, ok := a.(*base.CategoricalAttribute); ok { + // Unpack byte value + val := base.UnpackBytesToU64(attrBytes) + // If it's a two-valued one, check for non-zero + if f.twoValuedCategoricalAttributes[a] { + if val > 0 { + ret = base.PackFloatToBytes(1.0) + } else { + ret = base.PackFloatToBytes(0.0) + } + } else if an, ok := f.nValuedCategoricalAttributeMap[a]; ok { + // If it's an n-valued one, check the new Attribute maps onto + // the unpacked value + if af, ok := an[val]; ok { + if af.Equals(n) { + ret = base.PackFloatToBytes(1.0) + } else { + ret = base.PackFloatToBytes(0.0) + } + } else { + panic("Categorical value not defined!") + } + } else { + panic(fmt.Sprintf("Not a recognised Attribute %v", a)) + } + } else if _, ok := a.(*base.FloatAttribute); ok { + // Binary: just return the original value + ret = attrBytes + } else if _, ok := a.(*base.BinaryAttribute); ok { + // Float: check for non-zero + if attrBytes[0] > 0 { + ret = base.PackFloatToBytes(1.0) + } else { + ret = base.PackFloatToBytes(0.0) + } + } else { + panic(fmt.Sprintf("Unrecognised Attribute: %v", a)) + } + return ret } // Train converts the Attributes into equivalently named FloatAttributes, @@ -105,37 +105,37 @@ func (f *FloatConvertFilter) Transform(a base.Attribute, n base.Attribute, attrB // If the CategoricalAttribute has more than two (n) values, the Filter // generates n FloatAttributes and sets each of them if the value's observed. func (f *FloatConvertFilter) Train() error { - for _, a := range f.attrs { - if ac, ok := a.(*base.CategoricalAttribute); ok { - vals := ac.GetValues() - if len(vals) <= 2 { - nAttr := base.NewFloatAttribute(ac.GetName()) - fAttr := base.FilteredAttribute{ac, nAttr} - f.converted = append(f.converted, fAttr) - f.twoValuedCategoricalAttributes[a] = true - } else { - if _, ok := f.nValuedCategoricalAttributeMap[a]; !ok { - f.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute) - } - for i := uint64(0); i < uint64(len(vals)); i++ { - v := vals[i] - newName := fmt.Sprintf("%s_%s", ac.GetName(), v) - newAttr := base.NewFloatAttribute(newName) - fAttr := base.FilteredAttribute{ac, newAttr} - f.converted = append(f.converted, fAttr) - f.nValuedCategoricalAttributeMap[a][i] = newAttr - } - } - } else if ab, ok := a.(*base.FloatAttribute); ok { - fAttr := base.FilteredAttribute{ab, ab} - f.converted = append(f.converted, fAttr) - } else if af, ok := a.(*base.BinaryAttribute); ok { - newAttr := base.NewFloatAttribute(af.GetName()) - fAttr := base.FilteredAttribute{af, newAttr} - f.converted = append(f.converted, fAttr) - } else { - return fmt.Errorf("Unsupported Attribute type: %v", a) - } - } - return nil + for _, a := range f.attrs { + if ac, ok := a.(*base.CategoricalAttribute); ok { + vals := ac.GetValues() + if len(vals) <= 2 { + nAttr := base.NewFloatAttribute(ac.GetName()) + fAttr := base.FilteredAttribute{ac, nAttr} + f.converted = append(f.converted, fAttr) + f.twoValuedCategoricalAttributes[a] = true + } else { + if _, ok := f.nValuedCategoricalAttributeMap[a]; !ok { + f.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute) + } + for i := uint64(0); i < uint64(len(vals)); i++ { + v := vals[i] + newName := fmt.Sprintf("%s_%s", ac.GetName(), v) + newAttr := base.NewFloatAttribute(newName) + fAttr := base.FilteredAttribute{ac, newAttr} + f.converted = append(f.converted, fAttr) + f.nValuedCategoricalAttributeMap[a][i] = newAttr + } + } + } else if ab, ok := a.(*base.FloatAttribute); ok { + fAttr := base.FilteredAttribute{ab, ab} + f.converted = append(f.converted, fAttr) + } else if af, ok := a.(*base.BinaryAttribute); ok { + newAttr := base.NewFloatAttribute(af.GetName()) + fAttr := base.FilteredAttribute{af, newAttr} + f.converted = append(f.converted, fAttr) + } else { + return fmt.Errorf("Unsupported Attribute type: %v", a) + } + } + return nil } diff --git a/meta/meta.go b/meta/meta.go index 2cb7938..72cc426 100644 --- a/meta/meta.go +++ b/meta/meta.go @@ -10,4 +10,4 @@ are generated via majority voting. */ -package meta \ No newline at end of file +package meta From 688a82babb5611f7c280faa974a2f673a6522125 Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:00:39 +0000 Subject: [PATCH 02/11] fix typo suceed -> succeed --- base/edf/alloc_test.go | 16 ++++++++-------- base/edf/map_test.go | 8 ++++---- base/edf/thread_test.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/base/edf/alloc_test.go b/base/edf/alloc_test.go index 2efdfea..1cf3859 100644 --- a/base/edf/alloc_test.go +++ b/base/edf/alloc_test.go @@ -11,18 +11,18 @@ func TestAllocFixed(t *testing.T) { Convey("Creating a non-existent file should succeed", t, func() { tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate") So(err, ShouldEqual, nil) - Convey("Mapping the file should suceed", func() { + Convey("Mapping the file should succeed", func() { mapping, err := EdfMap(tempFile, EDF_CREATE) So(err, ShouldEqual, nil) - Convey("Allocation should suceed", func() { + Convey("Allocation should succeed", func() { r, err := mapping.AllocPages(1, 2) So(err, ShouldEqual, nil) So(r.Start.Byte, ShouldEqual, 4*os.Getpagesize()) So(r.Start.Segment, ShouldEqual, 0) - Convey("Unmapping the file should suceed", func() { + Convey("Unmapping the file should succeed", func() { err = mapping.Unmap(EDF_UNMAP_SYNC) So(err, ShouldEqual, nil) - Convey("Remapping the file should suceed", func() { + Convey("Remapping the file should succeed", func() { mapping, err = EdfMap(tempFile, EDF_READ_ONLY) Convey("Should get the same allocations back", func() { rr, err := mapping.GetThreadBlocks(2) @@ -41,20 +41,20 @@ func TestAllocWithExtraContentsBlock(t *testing.T) { Convey("Creating a non-existent file should succeed", t, func() { tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate") So(err, ShouldEqual, nil) - Convey("Mapping the file should suceed", func() { + Convey("Mapping the file should succeed", func() { mapping, err := EdfMap(tempFile, EDF_CREATE) So(err, ShouldEqual, nil) - Convey("Allocation of 10 pages should suceed", func() { + Convey("Allocation of 10 pages should succeed", func() { allocated := make([]EdfRange, 10) for i := 0; i < 10; i++ { r, err := mapping.AllocPages(1, 2) So(err, ShouldEqual, nil) allocated[i] = r } - Convey("Unmapping the file should suceed", func() { + Convey("Unmapping the file should succeed", func() { err = mapping.Unmap(EDF_UNMAP_SYNC) So(err, ShouldEqual, nil) - Convey("Remapping the file should suceed", func() { + Convey("Remapping the file should succeed", func() { mapping, err = EdfMap(tempFile, EDF_READ_ONLY) Convey("Should get the same allocations back", func() { rr, err := mapping.GetThreadBlocks(2) diff --git a/base/edf/map_test.go b/base/edf/map_test.go index cfb35ec..c5c30c6 100644 --- a/base/edf/map_test.go +++ b/base/edf/map_test.go @@ -8,7 +8,7 @@ import ( ) func TestAnonMap(t *testing.T) { - Convey("Anonymous mapping should suceed", t, func() { + Convey("Anonymous mapping should succeed", t, func() { mapping, err := EdfAnonMap() So(err, ShouldEqual, nil) bytes := mapping.m[0] @@ -39,10 +39,10 @@ func TestFileCreate(t *testing.T) { Convey("Creating a non-existent file should succeed", t, func() { tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate") So(err, ShouldEqual, nil) - Convey("Mapping the file should suceed", func() { + Convey("Mapping the file should succeed", func() { mapping, err := EdfMap(tempFile, EDF_CREATE) So(err, ShouldEqual, nil) - Convey("Unmapping the file should suceed", func() { + Convey("Unmapping the file should succeed", func() { err = mapping.Unmap(EDF_UNMAP_SYNC) So(err, ShouldEqual, nil) }) @@ -90,7 +90,7 @@ func TestFileThreadCounter(t *testing.T) { Convey("Creating a non-existent file should succeed", t, func() { tempFile, err := ioutil.TempFile(os.TempDir(), "TestFileCreate") So(err, ShouldEqual, nil) - Convey("Mapping the file should suceed", func() { + Convey("Mapping the file should succeed", func() { mapping, err := EdfMap(tempFile, EDF_CREATE) So(err, ShouldEqual, nil) Convey("The file should have two threads to start with", func() { diff --git a/base/edf/thread_test.go b/base/edf/thread_test.go index e259a2a..3671dc6 100644 --- a/base/edf/thread_test.go +++ b/base/edf/thread_test.go @@ -36,7 +36,7 @@ func TestThreadFindAndWrite(T *testing.T) { Convey("Creating a non-existent file should succeed", T, func() { tempFile, err := os.OpenFile("hello.db", os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate") So(err, ShouldEqual, nil) - Convey("Mapping the file should suceed", func() { + Convey("Mapping the file should succeed", func() { mapping, err := EdfMap(tempFile, EDF_CREATE) So(err, ShouldEqual, nil) Convey("Writing the thread should succeed", func() { From b8e0a36f73ecb501d885768c85f936ed95bb0b3d Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:03:49 +0000 Subject: [PATCH 03/11] Remove unused untested private function from chimerge_funcs --- filters/chimerge_funcs.go | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/filters/chimerge_funcs.go b/filters/chimerge_funcs.go index 4b0e3f3..6b649c1 100644 --- a/filters/chimerge_funcs.go +++ b/filters/chimerge_funcs.go @@ -1,7 +1,6 @@ package filters import ( - "fmt" "github.com/sjwhitworth/golearn/base" "math" ) @@ -185,21 +184,3 @@ func chiMergeMergeZipAdjacent(freq []*FrequencyTableEntry, minIndex int) []*Freq freq = append(lowerSlice, upperSlice...) return freq } - -func chiMergePrintTable(freq []*FrequencyTableEntry) { - classes := chiCountClasses(freq) - fmt.Printf("Attribute value\t") - for k := range classes { - fmt.Printf("\t%s", k) - } - fmt.Printf("\tTotal\n") - for _, f := range freq { - fmt.Printf("%.2f\t", f.Value) - total := 0 - for k := range classes { - fmt.Printf("\t%d", f.Frequency[k]) - total += f.Frequency[k] - } - fmt.Printf("\t%d\n", total) - } -} From 25f59e2d6bee53d13ede123c35d8f877d360ceba Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:07:13 +0000 Subject: [PATCH 04/11] Remove unused private functions from base/util --- base/util.go | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/base/util.go b/base/util.go index fcf0da3..dc1c661 100644 --- a/base/util.go +++ b/base/util.go @@ -1,9 +1,6 @@ package base import ( - "bytes" - "encoding/binary" - "fmt" "math" "unsafe" ) @@ -62,29 +59,6 @@ func UnpackBytesToFloat(val []byte) float64 { return *(*float64)(pb) } -func xorFloatOp(item float64) float64 { - var ret float64 - var tmp int64 - buf := bytes.NewBuffer(nil) - binary.Write(buf, binary.LittleEndian, item) - binary.Read(buf, binary.LittleEndian, &tmp) - tmp ^= -1 << 63 - binary.Write(buf, binary.LittleEndian, tmp) - binary.Read(buf, binary.LittleEndian, &ret) - return ret -} - -func printFloatByteArr(arr [][]byte) { - buf := bytes.NewBuffer(nil) - var f float64 - for _, b := range arr { - buf.Write(b) - binary.Read(buf, binary.LittleEndian, &f) - f = xorFloatOp(f) - fmt.Println(f) - } -} - func byteSeqEqual(a, b []byte) bool { if len(a) != len(b) { return false From 8f5a5f49623c813140aa65b88e5c69bc81e1ec11 Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:09:16 +0000 Subject: [PATCH 05/11] Add headings and improve formatting of ConfusionMatrix GetSummary --- evaluation/confusion.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/evaluation/confusion.go b/evaluation/confusion.go index 87fe2b1..89eb880 100644 --- a/evaluation/confusion.go +++ b/evaluation/confusion.go @@ -3,6 +3,8 @@ package evaluation import ( "bytes" "fmt" + "text/tabwriter" + "github.com/sjwhitworth/golearn/base" ) @@ -176,18 +178,22 @@ func GetMacroRecall(c ConfusionMatrix) float64 { // ConfusionMatrix func GetSummary(c ConfusionMatrix) string { var buffer bytes.Buffer + w := new(tabwriter.Writer) + w.Init(&buffer, 0, 8, 0, '\t', 0) + + fmt.Fprintln(w, "Reference Class\tTrue Positives\tFalse Positives\tTrue Negatives\tPrecision\tRecall\tF1 Score") + fmt.Fprintln(w, "---------------\t--------------\t---------------\t--------------\t---------\t------\t--------") for k := range c { - buffer.WriteString(k) - buffer.WriteString("\t") tp := GetTruePositives(k, c) fp := GetFalsePositives(k, c) tn := GetTrueNegatives(k, c) prec := GetPrecision(k, c) rec := GetRecall(k, c) f1 := GetF1Score(k, c) - buffer.WriteString(fmt.Sprintf("%.0f\t%.0f\t%.0f\t%.4f\t%.4f\t%.4f\n", tp, fp, tn, prec, rec, f1)) - } + fmt.Fprintf(w, "%s\t%.0f\t%.0f\t%.0f\t%.4f\t%.4f\t%.4f\n", k, tp, fp, tn, prec, rec, f1) + } + w.Flush() buffer.WriteString(fmt.Sprintf("Overall accuracy: %.4f\n", GetAccuracy(c))) return buffer.String() From 21bb2fc9fa49f1dca01fcd4156f951dc872bc18a Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:21:24 +0000 Subject: [PATCH 06/11] Remove redundant import renames --- ensemble/randomforest.go | 6 +++--- ensemble/randomforest_test.go | 4 ++-- examples/instances/instances.go | 2 +- examples/knnclassifier/knnclassifier_iris.go | 6 +++--- examples/trees/trees.go | 8 ++++---- filters/binning.go | 2 +- filters/binning_test.go | 2 +- filters/chimerge.go | 2 +- filters/chimerge_test.go | 2 +- filters/disc.go | 2 +- knn/knn.go | 2 +- linear_models/logistic.go | 2 +- meta/bagging.go | 2 +- meta/bagging_test.go | 6 +++--- naive/bernoulli_nb.go | 2 +- trees/entropy.go | 2 +- trees/id3.go | 2 +- trees/random.go | 2 +- trees/tree_test.go | 4 ++-- 19 files changed, 30 insertions(+), 30 deletions(-) diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index cb3c374..3cdd596 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -2,9 +2,9 @@ package ensemble import ( "fmt" - base "github.com/sjwhitworth/golearn/base" - meta "github.com/sjwhitworth/golearn/meta" - trees "github.com/sjwhitworth/golearn/trees" + "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/meta" + "github.com/sjwhitworth/golearn/trees" ) // RandomForest classifies instances using an ensemble diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 1524cc1..32463ac 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -2,9 +2,9 @@ package ensemble import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" - filters "github.com/sjwhitworth/golearn/filters" + "github.com/sjwhitworth/golearn/filters" "testing" ) diff --git a/examples/instances/instances.go b/examples/instances/instances.go index 29b1120..c6e8627 100644 --- a/examples/instances/instances.go +++ b/examples/instances/instances.go @@ -4,7 +4,7 @@ package main import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" ) func main() { diff --git a/examples/knnclassifier/knnclassifier_iris.go b/examples/knnclassifier/knnclassifier_iris.go index 8aa6f9a..f713e9a 100644 --- a/examples/knnclassifier/knnclassifier_iris.go +++ b/examples/knnclassifier/knnclassifier_iris.go @@ -2,9 +2,9 @@ package main import ( "fmt" - base "github.com/sjwhitworth/golearn/base" - evaluation "github.com/sjwhitworth/golearn/evaluation" - knn "github.com/sjwhitworth/golearn/knn" + "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/evaluation" + "github.com/sjwhitworth/golearn/knn" ) func main() { diff --git a/examples/trees/trees.go b/examples/trees/trees.go index 893d875..73116ae 100644 --- a/examples/trees/trees.go +++ b/examples/trees/trees.go @@ -4,11 +4,11 @@ package main import ( "fmt" - base "github.com/sjwhitworth/golearn/base" - ensemble "github.com/sjwhitworth/golearn/ensemble" + "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/ensemble" eval "github.com/sjwhitworth/golearn/evaluation" - filters "github.com/sjwhitworth/golearn/filters" - trees "github.com/sjwhitworth/golearn/trees" + "github.com/sjwhitworth/golearn/filters" + "github.com/sjwhitworth/golearn/trees" "math/rand" "time" ) diff --git a/filters/binning.go b/filters/binning.go index ee2a650..d11b996 100644 --- a/filters/binning.go +++ b/filters/binning.go @@ -2,7 +2,7 @@ package filters import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" "math" ) diff --git a/filters/binning_test.go b/filters/binning_test.go index b706a4c..f655717 100644 --- a/filters/binning_test.go +++ b/filters/binning_test.go @@ -1,7 +1,7 @@ package filters import ( - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" "testing" ) diff --git a/filters/chimerge.go b/filters/chimerge.go index 7f75ffc..826e8d0 100644 --- a/filters/chimerge.go +++ b/filters/chimerge.go @@ -2,7 +2,7 @@ package filters import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" "math" ) diff --git a/filters/chimerge_test.go b/filters/chimerge_test.go index 94a4615..eeb4f7a 100644 --- a/filters/chimerge_test.go +++ b/filters/chimerge_test.go @@ -2,7 +2,7 @@ package filters import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" "math" "testing" ) diff --git a/filters/disc.go b/filters/disc.go index fa3f599..940829e 100644 --- a/filters/disc.go +++ b/filters/disc.go @@ -2,7 +2,7 @@ package filters import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" ) type AbstractDiscretizeFilter struct { diff --git a/knn/knn.go b/knn/knn.go index 2b7a2b5..16863b8 100644 --- a/knn/knn.go +++ b/knn/knn.go @@ -5,7 +5,7 @@ package knn import ( "github.com/gonum/matrix/mat64" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" pairwiseMetrics "github.com/sjwhitworth/golearn/metrics/pairwise" util "github.com/sjwhitworth/golearn/utilities" ) diff --git a/linear_models/logistic.go b/linear_models/logistic.go index d3af914..56199ec 100644 --- a/linear_models/logistic.go +++ b/linear_models/logistic.go @@ -2,7 +2,7 @@ package linear_models import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" ) type LogisticRegression struct { diff --git a/meta/bagging.go b/meta/bagging.go index 8d4b94f..60414ad 100644 --- a/meta/bagging.go +++ b/meta/bagging.go @@ -2,7 +2,7 @@ package meta import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" "math/rand" "runtime" "strings" diff --git a/meta/bagging_test.go b/meta/bagging_test.go index 3b3f3ed..91bc2c3 100644 --- a/meta/bagging_test.go +++ b/meta/bagging_test.go @@ -2,10 +2,10 @@ package meta import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" - filters "github.com/sjwhitworth/golearn/filters" - trees "github.com/sjwhitworth/golearn/trees" + "github.com/sjwhitworth/golearn/filters" + "github.com/sjwhitworth/golearn/trees" "math/rand" "testing" "time" diff --git a/naive/bernoulli_nb.go b/naive/bernoulli_nb.go index f995d1e..cbf7a08 100644 --- a/naive/bernoulli_nb.go +++ b/naive/bernoulli_nb.go @@ -2,7 +2,7 @@ package naive import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" "math" ) diff --git a/trees/entropy.go b/trees/entropy.go index 1d7d254..9339fcd 100644 --- a/trees/entropy.go +++ b/trees/entropy.go @@ -1,7 +1,7 @@ package trees import ( - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" "math" ) diff --git a/trees/id3.go b/trees/id3.go index bfa9904..ef7188c 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -3,7 +3,7 @@ package trees import ( "bytes" "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" "sort" ) diff --git a/trees/random.go b/trees/random.go index 9b6d1d5..2d1f92e 100644 --- a/trees/random.go +++ b/trees/random.go @@ -2,7 +2,7 @@ package trees import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" "math/rand" ) diff --git a/trees/tree_test.go b/trees/tree_test.go index 26b643f..7b4a31d 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -2,9 +2,9 @@ package trees import ( "fmt" - base "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" - filters "github.com/sjwhitworth/golearn/filters" + "github.com/sjwhitworth/golearn/filters" "math" "testing" ) From d835081de9c0bf632bb69b341b7bbfd53f13f6a2 Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:27:16 +0000 Subject: [PATCH 07/11] Favor idiomatic error return over panic when parsing non-existent CSV file --- base/csv.go | 18 +++++++++++------- base/csv_test.go | 27 ++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/base/csv.go b/base/csv.go index 868db83..307cfc2 100644 --- a/base/csv.go +++ b/base/csv.go @@ -11,10 +11,10 @@ import ( ) // ParseCSVGetRows returns the number of rows in a given file. -func ParseCSVGetRows(filepath string) int { +func ParseCSVGetRows(filepath string) (int, error) { file, err := os.Open(filepath) if err != nil { - panic(err) + return 0, err } defer file.Close() @@ -25,11 +25,11 @@ func ParseCSVGetRows(filepath string) int { if err == io.EOF { break } else if err != nil { - panic(err) + return 0, err } counter++ } - return counter + return counter, nil } // ParseCSVGetAttributes returns an ordered slice of appropriate-ly typed @@ -157,7 +157,11 @@ func ParseCSVBuildInstances(filepath string, hasHeaders bool, u UpdatableDataGri func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInstances, err error) { // Read the number of rows in the file - rowCount := ParseCSVGetRows(filepath) + rowCount, err := ParseCSVGetRows(filepath) + if err != nil { + return nil, err + } + if hasHeaders { rowCount-- } @@ -176,7 +180,7 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInst // Read the input file, err := os.Open(filepath) if err != nil { - panic(err) + return nil, err } defer file.Close() reader := csv.NewReader(file) @@ -188,7 +192,7 @@ func ParseCSVToInstances(filepath string, hasHeaders bool) (instances *DenseInst if err == io.EOF { break } else if err != nil { - panic(err) + return nil, err } if rowCounter == 0 { if hasHeaders { diff --git a/base/csv_test.go b/base/csv_test.go index 95509de..c7b6907 100644 --- a/base/csv_test.go +++ b/base/csv_test.go @@ -5,18 +5,32 @@ import ( ) func TestParseCSVGetRows(testEnv *testing.T) { - lineCount := ParseCSVGetRows("../examples/datasets/iris.csv") + lineCount, err := ParseCSVGetRows("../examples/datasets/iris.csv") + if err != nil { + testEnv.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error()) + } if lineCount != 150 { testEnv.Errorf("Should have %d lines, has %d", 150, lineCount) } - lineCount = ParseCSVGetRows("../examples/datasets/iris_headers.csv") + lineCount, err = ParseCSVGetRows("../examples/datasets/iris_headers.csv") + if err != nil { + testEnv.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error()) + } + if lineCount != 151 { testEnv.Errorf("Should have %d lines, has %d", 151, lineCount) } } +func TestParseCSVGetRowsWithMissingFile(testEnv *testing.T) { + _, err := ParseCSVGetRows("../examples/datasets/non-existent.csv") + if err == nil { + testEnv.Fatal("Expected ParseCSVGetRows to return error when given path to non-existent file") + } +} + func TestParseCCSVGetAttributes(testEnv *testing.T) { attrs := ParseCSVGetAttributes("../examples/datasets/iris_headers.csv", true) if attrs[0].GetType() != Float64Type { @@ -72,7 +86,7 @@ func TestParseCSVSniffAttributeNamesWithHeaders(testEnv *testing.T) { } } -func TestReadInstances(testEnv *testing.T) { +func TestParseCSVToInstances(testEnv *testing.T) { inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { testEnv.Error(err) @@ -93,6 +107,13 @@ func TestReadInstances(testEnv *testing.T) { } } +func TestParseCSVToInstancesWithMissingFile(testEnv *testing.T) { + _, err := ParseCSVToInstances("../examples/datasets/non-existent.csv", true) + if err == nil { + testEnv.Fatal("Expected ParseCSVToInstances to return error when given path to non-existent file") + } +} + func TestReadAwkwardInsatnces(testEnv *testing.T) { inst, err := ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { From 66ad866cb33659d46fa563293bee3276698133c1 Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:39:14 +0000 Subject: [PATCH 08/11] RandomForest panics when trying to fit data with too few features instead of just hanging forever --- ensemble/randomforest.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index 3cdd596..785e8ec 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -31,6 +31,15 @@ func NewRandomForest(forestSize int, features int) *RandomForest { // Fit builds the RandomForest on the specified instances func (f *RandomForest) Fit(on base.FixedDataGrid) { + numNonClassAttributes := len(base.NonClassAttributes(on)) + if numNonClassAttributes < f.Features { + panic(fmt.Sprintf( + "Random forest with %d features cannot fit data grid with %d non-class attributes", + f.Features, + numNonClassAttributes, + )) + } + f.Model = new(meta.BaggedModel) f.Model.RandomFeatures = f.Features for i := 0; i < f.ForestSize; i++ { From 45545d6ebd90d0c24f0367ec943b73f42d2adc45 Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 07:58:01 +0000 Subject: [PATCH 09/11] Remove Println's from automated test suite since they aren't assertions --- base/lazy_sort_test.go | 3 +- ensemble/randomforest_test.go | 5 +-- filters/binary_test.go | 6 +-- filters/chimerge_test.go | 6 +-- filters/float_test.go | 3 +- linear_models/linear_regression_test.go | 7 +--- meta/bagging_test.go | 16 +++---- neural/layered_test.go | 8 ---- neural/network_test.go | 2 - trees/tree_test.go | 55 ++++++++----------------- 10 files changed, 31 insertions(+), 80 deletions(-) diff --git a/base/lazy_sort_test.go b/base/lazy_sort_test.go index 5f5e23f..043d802 100644 --- a/base/lazy_sort_test.go +++ b/base/lazy_sort_test.go @@ -1,7 +1,6 @@ package base import ( - "fmt" "testing" ) @@ -81,7 +80,7 @@ func TestLazySortAsc(testEnv *testing.T) { rowStr := insts.RowString(0) ref := "4.30 3.00 1.10 0.10 Iris-setosa" if rowStr != ref { - panic(fmt.Sprintf("'%s' != '%s'", rowStr, ref)) + testEnv.Fatalf("'%s' != '%s'", rowStr, ref) } } diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 32463ac..762125d 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -1,7 +1,6 @@ package ensemble import ( - "fmt" "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" @@ -26,8 +25,6 @@ func TestRandomForest1(testEnv *testing.T) { rf := NewRandomForest(10, 3) rf.Fit(trainData) predictions := rf.Predict(testData) - fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testData, predictions) - fmt.Println(confusionMat) - fmt.Println(eval.GetSummary(confusionMat)) + _ = eval.GetSummary(confusionMat) } diff --git a/filters/binary_test.go b/filters/binary_test.go index 1b875c5..c70a544 100644 --- a/filters/binary_test.go +++ b/filters/binary_test.go @@ -1,7 +1,6 @@ package filters import ( - "fmt" "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" "testing" @@ -38,9 +37,6 @@ func TestBinaryFilterClassPreservation(t *testing.T) { So(attrMap["arbitraryClass_there"], ShouldEqual, true) So(attrMap["arbitraryClass_world"], ShouldEqual, true) }) - - fmt.Println(instF) - }) } @@ -91,7 +87,7 @@ func TestBinaryFilter(t *testing.T) { name := a.GetName() _, ok := origMap[name] if !ok { - t.Error(fmt.Sprintf("Weird: %s", name)) + t.Errorf("Weird: %s", name) } origMap[name] = true } diff --git a/filters/chimerge_test.go b/filters/chimerge_test.go index eeb4f7a..72ae5be 100644 --- a/filters/chimerge_test.go +++ b/filters/chimerge_test.go @@ -1,14 +1,12 @@ package filters import ( - "fmt" "github.com/sjwhitworth/golearn/base" "math" "testing" ) func TestChiMFreqTable(testEnv *testing.T) { - inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { panic(err) @@ -189,11 +187,9 @@ func TestChiMerge4(testEnv *testing.T) { filt.AddAttribute(inst.AllAttributes()[1]) filt.Train() instf := base.NewLazilyFilteredInstances(inst, filt) - fmt.Println(instf) - fmt.Println(instf.String()) clsAttrs := instf.AllClassAttributes() if len(clsAttrs) != 1 { - panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1)) + testEnv.Fatalf("%d != %d", len(clsAttrs), 1) } if clsAttrs[0].GetName() != "Species" { panic("Class Attribute wrong!") diff --git a/filters/float_test.go b/filters/float_test.go index a4906e4..3c1f67d 100644 --- a/filters/float_test.go +++ b/filters/float_test.go @@ -1,7 +1,6 @@ package filters import ( - "fmt" "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" "testing" @@ -54,7 +53,7 @@ func TestFloatFilter(t *testing.T) { name := a.GetName() _, ok := origMap[name] if !ok { - t.Error(fmt.Sprintf("Weird: %s", name)) + t.Errorf("Weird: %s", name) } origMap[name] = true } diff --git a/linear_models/linear_regression_test.go b/linear_models/linear_regression_test.go index 2c6a6dc..e302715 100644 --- a/linear_models/linear_regression_test.go +++ b/linear_models/linear_regression_test.go @@ -1,7 +1,6 @@ package linear_models import ( - "fmt" "testing" "github.com/sjwhitworth/golearn/base" @@ -54,11 +53,7 @@ func TestLinearRegression(t *testing.T) { t.Fatal(err) } - _, rows := predictions.Size() - - for i := 0; i < rows; i++ { - fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i)) - } + _, _ = predictions.Size() } func BenchmarkLinearRegressionOneRow(b *testing.B) { diff --git a/meta/bagging_test.go b/meta/bagging_test.go index 91bc2c3..651f4f8 100644 --- a/meta/bagging_test.go +++ b/meta/bagging_test.go @@ -1,7 +1,6 @@ package meta import ( - "fmt" "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" @@ -24,10 +23,12 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) { } filt.Train() instf := base.NewLazilyFilteredInstances(inst, filt) + rf := new(BaggedModel) for i := 0; i < 10; i++ { rf.AddModel(trees.NewRandomTree(2)) } + testEnv.ResetTimer() for i := 0; i < 20; i++ { rf.Fit(instf) @@ -47,10 +48,12 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) { } filt.Train() instf := base.NewLazilyFilteredInstances(inst, filt) + rf := new(BaggedModel) for i := 0; i < 10; i++ { rf.AddModel(trees.NewRandomTree(2)) } + rf.Fit(instf) testEnv.ResetTimer() for i := 0; i < 20; i++ { @@ -63,9 +66,9 @@ func TestRandomForest1(testEnv *testing.T) { if err != nil { panic(err) } + trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) rand.Seed(time.Now().UnixNano()) - trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) @@ -73,17 +76,14 @@ func TestRandomForest1(testEnv *testing.T) { filt.Train() trainDataf := base.NewLazilyFilteredInstances(trainData, filt) testDataf := base.NewLazilyFilteredInstances(testData, filt) + rf := new(BaggedModel) for i := 0; i < 10; i++ { rf.AddModel(trees.NewRandomTree(2)) } + rf.Fit(trainDataf) - fmt.Println(rf) predictions := rf.Predict(testDataf) - fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testDataf, predictions) - fmt.Println(confusionMat) - fmt.Println(eval.GetMacroPrecision(confusionMat)) - fmt.Println(eval.GetMacroRecall(confusionMat)) - fmt.Println(eval.GetSummary(confusionMat)) + _ = eval.GetSummary(confusionMat) } diff --git a/neural/layered_test.go b/neural/layered_test.go index 6fdeb72..e9c52f6 100644 --- a/neural/layered_test.go +++ b/neural/layered_test.go @@ -1,7 +1,6 @@ package neural import ( - "fmt" "github.com/gonum/matrix/mat64" "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" @@ -13,7 +12,6 @@ func TestLayerStructureNoHidden(t *testing.T) { Convey("Creating a network...", t, func() { XORData, err := base.ParseCSVToInstances("xor.csv", false) So(err, ShouldEqual, nil) - fmt.Println(XORData) Convey("Create a MultiLayerNet with no layers...", func() { net := NewMultiLayerNet(make([]int, 0)) net.MaxIterations = 0 @@ -73,8 +71,6 @@ func TestLayerStructureNoHidden(t *testing.T) { }) Convey("The right nodes should be connected in the network...", func() { - - fmt.Println(net.network) So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000) So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000) So(net.network.GetWeight(1, 3), ShouldNotAlmostEqual, 0.000) @@ -118,7 +114,6 @@ func TestLayeredXOR(t *testing.T) { XORData, err := base.ParseCSVToInstances("xor.csv", false) So(err, ShouldEqual, nil) - fmt.Println(XORData) net := NewMultiLayerNet([]int{3}) net.MaxIterations = 20000 net.Fit(XORData) @@ -126,8 +121,6 @@ func TestLayeredXOR(t *testing.T) { Convey("After running for 20000 iterations, should have some predictive power...", func() { Convey("The right nodes should be connected in the network...", func() { - - fmt.Println(net.network) So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000) So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000) @@ -138,7 +131,6 @@ func TestLayeredXOR(t *testing.T) { }) out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0}) net.network.Activate(out, 2) - fmt.Println(out) So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1) Convey("And Predict() should do OK too...", func() { diff --git a/neural/network_test.go b/neural/network_test.go index 2043f67..fee5a29 100644 --- a/neural/network_test.go +++ b/neural/network_test.go @@ -1,7 +1,6 @@ package neural import ( - "fmt" "github.com/gonum/matrix/mat64" . "github.com/smartystreets/goconvey/convey" "testing" @@ -61,7 +60,6 @@ func TestNetworkWith1Layer(t *testing.T) { for i := 1; i <= 6; i++ { for j := 1; j <= 6; j++ { v := n.GetWeight(i, j) - fmt.Println(i, j, v) switch i { case 1: switch j { diff --git a/trees/tree_test.go b/trees/tree_test.go index 7b4a31d..1912177 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -1,7 +1,6 @@ package trees import ( - "fmt" "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" @@ -14,6 +13,7 @@ func TestRandomTree(testEnv *testing.T) { if err != nil { panic(err) } + filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) @@ -23,9 +23,8 @@ func TestRandomTree(testEnv *testing.T) { r := new(RandomTreeRuleGenerator) r.Attributes = 2 - fmt.Println(instf) - root := InferID3Tree(instf, r) - fmt.Println(root) + + _ = InferID3Tree(instf, r) } func TestRandomTreeClassification(testEnv *testing.T) { @@ -34,6 +33,7 @@ func TestRandomTreeClassification(testEnv *testing.T) { panic(err) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) + filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) @@ -44,15 +44,12 @@ func TestRandomTreeClassification(testEnv *testing.T) { r := new(RandomTreeRuleGenerator) r.Attributes = 2 + root := InferID3Tree(trainDataF, r) - fmt.Println(root) + predictions := root.Predict(testDataF) - fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testDataF, predictions) - fmt.Println(confusionMat) - fmt.Println(eval.GetMacroPrecision(confusionMat)) - fmt.Println(eval.GetMacroRecall(confusionMat)) - fmt.Println(eval.GetSummary(confusionMat)) + _ = eval.GetSummary(confusionMat) } func TestRandomTreeClassification2(testEnv *testing.T) { @@ -61,6 +58,7 @@ func TestRandomTreeClassification2(testEnv *testing.T) { panic(err) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.4) + filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) @@ -71,14 +69,10 @@ func TestRandomTreeClassification2(testEnv *testing.T) { root := NewRandomTree(2) root.Fit(trainDataF) - fmt.Println(root) + predictions := root.Predict(testDataF) - fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testDataF, predictions) - fmt.Println(confusionMat) - fmt.Println(eval.GetMacroPrecision(confusionMat)) - fmt.Println(eval.GetMacroRecall(confusionMat)) - fmt.Println(eval.GetSummary(confusionMat)) + _ = eval.GetSummary(confusionMat) } func TestPruning(testEnv *testing.T) { @@ -87,6 +81,7 @@ func TestPruning(testEnv *testing.T) { panic(err) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) + filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) @@ -99,14 +94,10 @@ func TestPruning(testEnv *testing.T) { fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6) root.Fit(fittrainData) root.Prune(fittestData) - fmt.Println(root) + predictions := root.Predict(testDataF) - fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testDataF, predictions) - fmt.Println(confusionMat) - fmt.Println(eval.GetMacroPrecision(confusionMat)) - fmt.Println(eval.GetMacroRecall(confusionMat)) - fmt.Println(eval.GetSummary(confusionMat)) + _ = eval.GetSummary(confusionMat) } func TestInformationGain(testEnv *testing.T) { @@ -127,8 +118,6 @@ func TestInformationGain(testEnv *testing.T) { } func TestID3Inference(testEnv *testing.T) { - - // Import the "PlayTennis" dataset inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) if err != nil { panic(err) @@ -150,7 +139,6 @@ func TestID3Inference(testEnv *testing.T) { testEnv.Error(sunnyChild) } if rainyChild.SplitAttr.GetName() != "windy" { - fmt.Println(rainyChild.SplitAttr) testEnv.Error(rainyChild) } if overcastChild.SplitAttr != nil { @@ -184,36 +172,27 @@ func TestID3Classification(testEnv *testing.T) { if err != nil { panic(err) } - fmt.Println(inst) + filt := filters.NewBinningFilter(inst, 10) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) } filt.Train() - fmt.Println(filt) instf := base.NewLazilyFilteredInstances(inst, filt) - fmt.Println("INSTFA", instf.AllAttributes()) - fmt.Println("INSTF", instf) + trainData, testData := base.InstancesTrainTestSplit(instf, 0.70) // Build the decision tree rule := new(InformationGainRuleGenerator) root := InferID3Tree(trainData, rule) - fmt.Println(root) + predictions := root.Predict(testData) - fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testData, predictions) - fmt.Println(confusionMat) - fmt.Println(eval.GetMacroPrecision(confusionMat)) - fmt.Println(eval.GetMacroRecall(confusionMat)) - fmt.Println(eval.GetSummary(confusionMat)) + _ = eval.GetSummary(confusionMat) } func TestID3(testEnv *testing.T) { - - // Import the "PlayTennis" dataset inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) - fmt.Println(inst) if err != nil { panic(err) } From 695aec6eb68a113cd59aa9a81a2565e0073ccd39 Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 08:07:55 +0000 Subject: [PATCH 10/11] Favor idiomatic t.Fatalf over panic for test failures --- base/bag_test.go | 4 ++-- ensemble/randomforest_test.go | 2 +- filters/binning_test.go | 4 ++-- filters/chimerge_test.go | 22 +++++++++++----------- meta/bagging_test.go | 6 +++--- neural/layered_test.go | 2 +- trees/tree_test.go | 14 +++++++------- 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/base/bag_test.go b/base/bag_test.go index f5366ae..485d79c 100644 --- a/base/bag_test.go +++ b/base/bag_test.go @@ -35,7 +35,7 @@ func TestBAGSimple(t *testing.T) { } else if name == "2" { attrSpecs[2] = a } else { - panic(name) + t.Fatalf("Unexpected attribute name '%s'", name) } } @@ -102,7 +102,7 @@ func TestBAG(t *testing.T) { } else if name == "2" { attrSpecs[2] = a } else { - panic(name) + t.Fatalf("Unexpected attribute name '%s'", name) } } diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 762125d..3828e8c 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -10,7 +10,7 @@ import ( func TestRandomForest1(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := filters.NewChiMergeFilter(inst, 0.90) diff --git a/filters/binning_test.go b/filters/binning_test.go index f655717..dfa2518 100644 --- a/filters/binning_test.go +++ b/filters/binning_test.go @@ -11,12 +11,12 @@ func TestBinning(t *testing.T) { // Read the data inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true) if err != nil { - panic(err) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } // // Construct the binning filter diff --git a/filters/chimerge_test.go b/filters/chimerge_test.go index 72ae5be..677d5aa 100644 --- a/filters/chimerge_test.go +++ b/filters/chimerge_test.go @@ -9,7 +9,7 @@ import ( func TestChiMFreqTable(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst) @@ -28,7 +28,7 @@ func TestChiMFreqTable(testEnv *testing.T) { func TestChiClassCounter(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst) classes := chiCountClasses(freq) @@ -46,7 +46,7 @@ func TestChiClassCounter(testEnv *testing.T) { func TestStatisticValues(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst) chiVal := chiComputeStatistic(freq[5], freq[6]) @@ -76,11 +76,9 @@ func TestChiSquareDistValues(testEnv *testing.T) { } func TestChiMerge1(testEnv *testing.T) { - - // Read the data inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } _, rows := inst.Size() @@ -105,7 +103,7 @@ func TestChiMerge2(testEnv *testing.T) { // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } // Sort the instances @@ -113,7 +111,7 @@ func TestChiMerge2(testEnv *testing.T) { sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1] instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs) if err != nil { - panic(err) + testEnv.Fatalf("Sort failed: %s", err.Error()) } // Perform Chi-Merge @@ -179,7 +177,7 @@ func TestChiMerge4(testEnv *testing.T) { // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := NewChiMergeFilter(inst, 0.90) @@ -191,7 +189,9 @@ func TestChiMerge4(testEnv *testing.T) { if len(clsAttrs) != 1 { testEnv.Fatalf("%d != %d", len(clsAttrs), 1) } - if clsAttrs[0].GetName() != "Species" { - panic("Class Attribute wrong!") + firstClassAttributeName := clsAttrs[0].GetName() + expectedClassAttributeName := "Species" + if firstClassAttributeName != expectedClassAttributeName { + testEnv.Fatalf("Expected class attribute '%s'; actual class attribute '%s'", expectedClassAttributeName, firstClassAttributeName) } } diff --git a/meta/bagging_test.go b/meta/bagging_test.go index 651f4f8..9443cf7 100644 --- a/meta/bagging_test.go +++ b/meta/bagging_test.go @@ -13,7 +13,7 @@ import ( func BenchmarkBaggingRandomForestFit(testEnv *testing.B) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } rand.Seed(time.Now().UnixNano()) @@ -38,7 +38,7 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) { func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } rand.Seed(time.Now().UnixNano()) @@ -64,7 +64,7 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) { func TestRandomForest1(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) diff --git a/neural/layered_test.go b/neural/layered_test.go index e9c52f6..98443c3 100644 --- a/neural/layered_test.go +++ b/neural/layered_test.go @@ -140,7 +140,7 @@ func TestLayeredXOR(t *testing.T) { for _, a := range pred.AllAttributes() { af, ok := a.(*base.FloatAttribute) if !ok { - panic("All of these should be FloatAttributes!") + t.Fatalf("Expected all attributes to be FloatAttributes; actually some were not") } af.Precision = 1 } diff --git a/trees/tree_test.go b/trees/tree_test.go index 1912177..d20089f 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -11,7 +11,7 @@ import ( func TestRandomTree(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := filters.NewChiMergeFilter(inst, 0.90) @@ -30,7 +30,7 @@ func TestRandomTree(testEnv *testing.T) { func TestRandomTreeClassification(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) @@ -55,7 +55,7 @@ func TestRandomTreeClassification(testEnv *testing.T) { func TestRandomTreeClassification2(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.4) @@ -78,7 +78,7 @@ func TestRandomTreeClassification2(testEnv *testing.T) { func TestPruning(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) @@ -120,7 +120,7 @@ func TestInformationGain(testEnv *testing.T) { func TestID3Inference(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } // Build the decision tree @@ -170,7 +170,7 @@ func TestID3Inference(testEnv *testing.T) { func TestID3Classification(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := filters.NewBinningFilter(inst, 10) @@ -194,7 +194,7 @@ func TestID3Classification(testEnv *testing.T) { func TestID3(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) if err != nil { - panic(err) + testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) } // Build the decision tree From 14aad31821d8636813cb7b453f707bab8f30be5f Mon Sep 17 00:00:00 2001 From: Amit Kumar Gupta Date: Fri, 22 Aug 2014 08:13:19 +0000 Subject: [PATCH 11/11] Consistently use (t *testing.T) instead of T or testEnv --- base/csv_test.go | 70 +++++++++++++++--------------- base/edf/thread_test.go | 26 ++++++------ base/lazy_sort_test.go | 46 ++++++++++---------- base/sort_test.go | 42 +++++++++--------- ensemble/randomforest_test.go | 4 +- evaluation/confusion_test.go | 38 ++++++++--------- filters/chimerge_test.go | 80 +++++++++++++++++------------------ meta/bagging_test.go | 16 +++---- trees/tree_test.go | 68 ++++++++++++++--------------- 9 files changed, 195 insertions(+), 195 deletions(-) diff --git a/base/csv_test.go b/base/csv_test.go index c7b6907..6ed9dca 100644 --- a/base/csv_test.go +++ b/base/csv_test.go @@ -4,92 +4,92 @@ import ( "testing" ) -func TestParseCSVGetRows(testEnv *testing.T) { +func TestParseCSVGetRows(t *testing.T) { lineCount, err := ParseCSVGetRows("../examples/datasets/iris.csv") if err != nil { - testEnv.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error()) + t.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error()) } if lineCount != 150 { - testEnv.Errorf("Should have %d lines, has %d", 150, lineCount) + t.Errorf("Should have %d lines, has %d", 150, lineCount) } lineCount, err = ParseCSVGetRows("../examples/datasets/iris_headers.csv") if err != nil { - testEnv.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error()) + t.Fatalf("Unable to parse CSV to get number of rows: %s", err.Error()) } if lineCount != 151 { - testEnv.Errorf("Should have %d lines, has %d", 151, lineCount) + t.Errorf("Should have %d lines, has %d", 151, lineCount) } } -func TestParseCSVGetRowsWithMissingFile(testEnv *testing.T) { +func TestParseCSVGetRowsWithMissingFile(t *testing.T) { _, err := ParseCSVGetRows("../examples/datasets/non-existent.csv") if err == nil { - testEnv.Fatal("Expected ParseCSVGetRows to return error when given path to non-existent file") + t.Fatal("Expected ParseCSVGetRows to return error when given path to non-existent file") } } -func TestParseCCSVGetAttributes(testEnv *testing.T) { +func TestParseCCSVGetAttributes(t *testing.T) { attrs := ParseCSVGetAttributes("../examples/datasets/iris_headers.csv", true) if attrs[0].GetType() != Float64Type { - testEnv.Errorf("First attribute should be a float, %s", attrs[0]) + t.Errorf("First attribute should be a float, %s", attrs[0]) } if attrs[0].GetName() != "Sepal length" { - testEnv.Errorf(attrs[0].GetName()) + t.Errorf(attrs[0].GetName()) } if attrs[4].GetType() != CategoricalType { - testEnv.Errorf("Final attribute should be categorical, %s", attrs[4]) + t.Errorf("Final attribute should be categorical, %s", attrs[4]) } if attrs[4].GetName() != "Species" { - testEnv.Error(attrs[4]) + t.Error(attrs[4]) } } -func TestParseCsvSniffAttributeTypes(testEnv *testing.T) { +func TestParseCsvSniffAttributeTypes(t *testing.T) { attrs := ParseCSVSniffAttributeTypes("../examples/datasets/iris_headers.csv", true) if attrs[0].GetType() != Float64Type { - testEnv.Errorf("First attribute should be a float, %s", attrs[0]) + t.Errorf("First attribute should be a float, %s", attrs[0]) } if attrs[1].GetType() != Float64Type { - testEnv.Errorf("Second attribute should be a float, %s", attrs[1]) + t.Errorf("Second attribute should be a float, %s", attrs[1]) } if attrs[2].GetType() != Float64Type { - testEnv.Errorf("Third attribute should be a float, %s", attrs[2]) + t.Errorf("Third attribute should be a float, %s", attrs[2]) } if attrs[3].GetType() != Float64Type { - testEnv.Errorf("Fourth attribute should be a float, %s", attrs[3]) + t.Errorf("Fourth attribute should be a float, %s", attrs[3]) } if attrs[4].GetType() != CategoricalType { - testEnv.Errorf("Final attribute should be categorical, %s", attrs[4]) + t.Errorf("Final attribute should be categorical, %s", attrs[4]) } } -func TestParseCSVSniffAttributeNamesWithHeaders(testEnv *testing.T) { +func TestParseCSVSniffAttributeNamesWithHeaders(t *testing.T) { attrs := ParseCSVSniffAttributeNames("../examples/datasets/iris_headers.csv", true) if attrs[0] != "Sepal length" { - testEnv.Error(attrs[0]) + t.Error(attrs[0]) } if attrs[1] != "Sepal width" { - testEnv.Error(attrs[1]) + t.Error(attrs[1]) } if attrs[2] != "Petal length" { - testEnv.Error(attrs[2]) + t.Error(attrs[2]) } if attrs[3] != "Petal width" { - testEnv.Error(attrs[3]) + t.Error(attrs[3]) } if attrs[4] != "Species" { - testEnv.Error(attrs[4]) + t.Error(attrs[4]) } } -func TestParseCSVToInstances(testEnv *testing.T) { +func TestParseCSVToInstances(t *testing.T) { inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } row1 := inst.RowString(0) @@ -97,34 +97,34 @@ func TestParseCSVToInstances(testEnv *testing.T) { row3 := inst.RowString(100) if row1 != "5.10 3.50 1.40 0.20 Iris-setosa" { - testEnv.Error(row1) + t.Error(row1) } if row2 != "7.00 3.20 4.70 1.40 Iris-versicolor" { - testEnv.Error(row2) + t.Error(row2) } if row3 != "6.30 3.30 6.00 2.50 Iris-virginica" { - testEnv.Error(row3) + t.Error(row3) } } -func TestParseCSVToInstancesWithMissingFile(testEnv *testing.T) { +func TestParseCSVToInstancesWithMissingFile(t *testing.T) { _, err := ParseCSVToInstances("../examples/datasets/non-existent.csv", true) if err == nil { - testEnv.Fatal("Expected ParseCSVToInstances to return error when given path to non-existent file") + t.Fatal("Expected ParseCSVToInstances to return error when given path to non-existent file") } } -func TestReadAwkwardInsatnces(testEnv *testing.T) { +func TestReadAwkwardInsatnces(t *testing.T) { inst, err := ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } attrs := inst.AllAttributes() if attrs[0].GetType() != Float64Type { - testEnv.Error("Should be float!") + t.Error("Should be float!") } if attrs[1].GetType() != CategoricalType { - testEnv.Error("Should be discrete!") + t.Error("Should be discrete!") } } diff --git a/base/edf/thread_test.go b/base/edf/thread_test.go index 3671dc6..6649cdc 100644 --- a/base/edf/thread_test.go +++ b/base/edf/thread_test.go @@ -6,13 +6,13 @@ import ( "testing" ) -func TestThreadDeserialize(T *testing.T) { +func TestThreadDeserialize(t *testing.T) { bytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1} - Convey("Given a byte slice", T, func() { - var t Thread - size := t.Deserialize(bytes) + Convey("Given a byte slice", t, func() { + var thread Thread + size := thread.Deserialize(bytes) Convey("Decoded name should be SYSTEM", func() { - So(t.name, ShouldEqual, "SYSTEM") + So(thread.name, ShouldEqual, "SYSTEM") }) Convey("Size should be the same as the array", func() { So(size, ShouldEqual, len(bytes)) @@ -20,20 +20,20 @@ func TestThreadDeserialize(T *testing.T) { }) } -func TestThreadSerialize(T *testing.T) { - var t Thread +func TestThreadSerialize(t *testing.T) { + var thread Thread refBytes := []byte{0, 0, 0, 6, 83, 89, 83, 84, 69, 77, 0, 0, 0, 1} - t.name = "SYSTEM" - t.id = 1 + thread.name = "SYSTEM" + thread.id = 1 toBytes := make([]byte, len(refBytes)) - Convey("Should serialize correctly", T, func() { - t.Serialize(toBytes) + Convey("Should serialize correctly", t, func() { + thread.Serialize(toBytes) So(toBytes, ShouldResemble, refBytes) }) } -func TestThreadFindAndWrite(T *testing.T) { - Convey("Creating a non-existent file should succeed", T, func() { +func TestThreadFindAndWrite(t *testing.T) { + Convey("Creating a non-existent file should succeed", t, func() { tempFile, err := os.OpenFile("hello.db", os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0700) //ioutil.TempFile(os.TempDir(), "TestFileCreate") So(err, ShouldEqual, nil) Convey("Mapping the file should succeed", func() { diff --git a/base/lazy_sort_test.go b/base/lazy_sort_test.go index 043d802..47eeb42 100644 --- a/base/lazy_sort_test.go +++ b/base/lazy_sort_test.go @@ -4,15 +4,15 @@ import ( "testing" ) -func TestLazySortDesc(testEnv *testing.T) { +func TestLazySortDesc(t *testing.T) { inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } @@ -20,67 +20,67 @@ func TestLazySortDesc(testEnv *testing.T) { as2 := ResolveAllAttributes(inst2) if isSortedDesc(inst1, as1[0]) { - testEnv.Error("Can't test descending sort order") + t.Error("Can't test descending sort order") } if !isSortedDesc(inst2, as2[0]) { - testEnv.Error("Reference data not sorted in descending order!") + t.Error("Reference data not sorted in descending order!") } inst, err := LazySort(inst1, Descending, as1[0:len(as1)-1]) if err != nil { - testEnv.Error(err) + t.Error(err) } if !isSortedDesc(inst, as1[0]) { - testEnv.Error("Instances are not sorted in descending order") - testEnv.Error(inst1) + t.Error("Instances are not sorted in descending order") + t.Error(inst1) } if !inst2.Equal(inst) { - testEnv.Error("Instances don't match") - testEnv.Error(inst) - testEnv.Error(inst2) + t.Error("Instances don't match") + t.Error(inst) + t.Error(inst2) } } -func TestLazySortAsc(testEnv *testing.T) { +func TestLazySortAsc(t *testing.T) { inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) as1 := ResolveAllAttributes(inst) if isSortedAsc(inst, as1[0]) { - testEnv.Error("Can't test ascending sort on something ascending already") + t.Error("Can't test ascending sort on something ascending already") } if err != nil { - testEnv.Error(err) + t.Error(err) return } insts, err := LazySort(inst, Ascending, as1) if err != nil { - testEnv.Error(err) + t.Error(err) return } if !isSortedAsc(insts, as1[0]) { - testEnv.Error("Instances are not sorted in ascending order") - testEnv.Error(insts) + t.Error("Instances are not sorted in ascending order") + t.Error(insts) } inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } as2 := ResolveAllAttributes(inst2) if !isSortedAsc(inst2, as2[0]) { - testEnv.Error("This file should be sorted in ascending order") + t.Error("This file should be sorted in ascending order") } if !inst2.Equal(insts) { - testEnv.Error("Instances don't match") - testEnv.Error(inst) - testEnv.Error(inst2) + t.Error("Instances don't match") + t.Error(inst) + t.Error(inst2) } rowStr := insts.RowString(0) ref := "4.30 3.00 1.10 0.10 Iris-setosa" if rowStr != ref { - testEnv.Fatalf("'%s' != '%s'", rowStr, ref) + t.Fatalf("'%s' != '%s'", rowStr, ref) } } diff --git a/base/sort_test.go b/base/sort_test.go index cbddb83..ab9eeca 100644 --- a/base/sort_test.go +++ b/base/sort_test.go @@ -32,15 +32,15 @@ func isSortedDesc(inst FixedDataGrid, attr AttributeSpec) bool { return true } -func TestSortDesc(testEnv *testing.T) { +func TestSortDesc(t *testing.T) { inst1, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_desc.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } @@ -48,57 +48,57 @@ func TestSortDesc(testEnv *testing.T) { as2 := ResolveAllAttributes(inst2) if isSortedDesc(inst1, as1[0]) { - testEnv.Error("Can't test descending sort order") + t.Error("Can't test descending sort order") } if !isSortedDesc(inst2, as2[0]) { - testEnv.Error("Reference data not sorted in descending order!") + t.Error("Reference data not sorted in descending order!") } Sort(inst1, Descending, as1[0:len(as1)-1]) if err != nil { - testEnv.Error(err) + t.Error(err) } if !isSortedDesc(inst1, as1[0]) { - testEnv.Error("Instances are not sorted in descending order") - testEnv.Error(inst1) + t.Error("Instances are not sorted in descending order") + t.Error(inst1) } if !inst2.Equal(inst1) { - testEnv.Error("Instances don't match") - testEnv.Error(inst1) - testEnv.Error(inst2) + t.Error("Instances don't match") + t.Error(inst1) + t.Error(inst2) } } -func TestSortAsc(testEnv *testing.T) { +func TestSortAsc(t *testing.T) { inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) as1 := ResolveAllAttributes(inst) if isSortedAsc(inst, as1[0]) { - testEnv.Error("Can't test ascending sort on something ascending already") + t.Error("Can't test ascending sort on something ascending already") } if err != nil { - testEnv.Error(err) + t.Error(err) return } Sort(inst, Ascending, as1[0:1]) if !isSortedAsc(inst, as1[0]) { - testEnv.Error("Instances are not sorted in ascending order") - testEnv.Error(inst) + t.Error("Instances are not sorted in ascending order") + t.Error(inst) } inst2, err := ParseCSVToInstances("../examples/datasets/iris_sorted_asc.csv", true) if err != nil { - testEnv.Error(err) + t.Error(err) return } as2 := ResolveAllAttributes(inst2) if !isSortedAsc(inst2, as2[0]) { - testEnv.Error("This file should be sorted in ascending order") + t.Error("This file should be sorted in ascending order") } if !inst2.Equal(inst) { - testEnv.Error("Instances don't match") - testEnv.Error(inst) - testEnv.Error(inst2) + t.Error("Instances don't match") + t.Error(inst) + t.Error(inst2) } } diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 3828e8c..97e0e58 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -7,10 +7,10 @@ import ( "testing" ) -func TestRandomForest1(testEnv *testing.T) { +func TestRandomForest1(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := filters.NewChiMergeFilter(inst, 0.90) diff --git a/evaluation/confusion_test.go b/evaluation/confusion_test.go index 7017a43..047d01d 100644 --- a/evaluation/confusion_test.go +++ b/evaluation/confusion_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestMetrics(testEnv *testing.T) { +func TestMetrics(t *testing.T) { confusionMat := make(ConfusionMatrix) confusionMat["a"] = make(map[string]int) confusionMat["b"] = make(map[string]int) @@ -16,89 +16,89 @@ func TestMetrics(testEnv *testing.T) { tp := GetTruePositives("a", confusionMat) if math.Abs(tp-75) >= 1 { - testEnv.Error(tp) + t.Error(tp) } tp = GetTruePositives("b", confusionMat) if math.Abs(tp-10) >= 1 { - testEnv.Error(tp) + t.Error(tp) } fn := GetFalseNegatives("a", confusionMat) if math.Abs(fn-5) >= 1 { - testEnv.Error(fn) + t.Error(fn) } fn = GetFalseNegatives("b", confusionMat) if math.Abs(fn-10) >= 1 { - testEnv.Error(fn) + t.Error(fn) } tn := GetTrueNegatives("a", confusionMat) if math.Abs(tn-10) >= 1 { - testEnv.Error(tn) + t.Error(tn) } tn = GetTrueNegatives("b", confusionMat) if math.Abs(tn-75) >= 1 { - testEnv.Error(tn) + t.Error(tn) } fp := GetFalsePositives("a", confusionMat) if math.Abs(fp-10) >= 1 { - testEnv.Error(fp) + t.Error(fp) } fp = GetFalsePositives("b", confusionMat) if math.Abs(fp-5) >= 1 { - testEnv.Error(fp) + t.Error(fp) } precision := GetPrecision("a", confusionMat) recall := GetRecall("a", confusionMat) if math.Abs(precision-0.88) >= 0.01 { - testEnv.Error(precision) + t.Error(precision) } if math.Abs(recall-0.94) >= 0.01 { - testEnv.Error(recall) + t.Error(recall) } precision = GetPrecision("b", confusionMat) recall = GetRecall("b", confusionMat) if math.Abs(precision-0.666) >= 0.01 { - testEnv.Error(precision) + t.Error(precision) } if math.Abs(recall-0.50) >= 0.01 { - testEnv.Error(recall) + t.Error(recall) } precision = GetMicroPrecision(confusionMat) if math.Abs(precision-0.85) >= 0.01 { - testEnv.Error(precision) + t.Error(precision) } recall = GetMicroRecall(confusionMat) if math.Abs(recall-0.85) >= 0.01 { - testEnv.Error(recall) + t.Error(recall) } precision = GetMacroPrecision(confusionMat) if math.Abs(precision-0.775) >= 0.01 { - testEnv.Error(precision) + t.Error(precision) } recall = GetMacroRecall(confusionMat) if math.Abs(recall-0.719) > 0.01 { - testEnv.Error(recall) + t.Error(recall) } fmeasure := GetF1Score("a", confusionMat) if math.Abs(fmeasure-0.91) >= 0.1 { - testEnv.Error(fmeasure) + t.Error(fmeasure) } accuracy := GetAccuracy(confusionMat) if math.Abs(accuracy-0.85) >= 0.1 { - testEnv.Error(accuracy) + t.Error(accuracy) } } diff --git a/filters/chimerge_test.go b/filters/chimerge_test.go index 677d5aa..0727c6a 100644 --- a/filters/chimerge_test.go +++ b/filters/chimerge_test.go @@ -6,104 +6,104 @@ import ( "testing" ) -func TestChiMFreqTable(testEnv *testing.T) { +func TestChiMFreqTable(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst) if freq[0].Frequency["c1"] != 1 { - testEnv.Error("Wrong frequency") + t.Error("Wrong frequency") } if freq[0].Frequency["c3"] != 4 { - testEnv.Errorf("Wrong frequency %s", freq[1]) + t.Errorf("Wrong frequency %s", freq[1]) } if freq[10].Frequency["c2"] != 1 { - testEnv.Error("Wrong frequency") + t.Error("Wrong frequency") } } -func TestChiClassCounter(testEnv *testing.T) { +func TestChiClassCounter(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst) classes := chiCountClasses(freq) if classes["c1"] != 27 { - testEnv.Error(classes) + t.Error(classes) } if classes["c2"] != 12 { - testEnv.Error(classes) + t.Error(classes) } if classes["c3"] != 21 { - testEnv.Error(classes) + t.Error(classes) } } -func TestStatisticValues(testEnv *testing.T) { +func TestStatisticValues(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } freq := ChiMBuildFrequencyTable(inst.AllAttributes()[0], inst) chiVal := chiComputeStatistic(freq[5], freq[6]) if math.Abs(chiVal-1.89) > 0.01 { - testEnv.Error(chiVal) + t.Error(chiVal) } chiVal = chiComputeStatistic(freq[1], freq[2]) if math.Abs(chiVal-1.08) > 0.01 { - testEnv.Error(chiVal) + t.Error(chiVal) } } -func TestChiSquareDistValues(testEnv *testing.T) { +func TestChiSquareDistValues(t *testing.T) { chiVal1 := chiSquaredPercentile(2, 4.61) chiVal2 := chiSquaredPercentile(3, 7.82) chiVal3 := chiSquaredPercentile(4, 13.28) if math.Abs(chiVal1-0.90) > 0.001 { - testEnv.Error(chiVal1) + t.Error(chiVal1) } if math.Abs(chiVal2-0.95) > 0.001 { - testEnv.Error(chiVal2) + t.Error(chiVal2) } if math.Abs(chiVal3-0.99) > 0.001 { - testEnv.Error(chiVal3) + t.Error(chiVal3) } } -func TestChiMerge1(testEnv *testing.T) { +func TestChiMerge1(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/chim.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } _, rows := inst.Size() freq := chiMerge(inst, inst.AllAttributes()[0], 0.90, 0, rows) if len(freq) != 3 { - testEnv.Error("Wrong length") + t.Error("Wrong length") } if freq[0].Value != 1.3 { - testEnv.Error(freq[0]) + t.Error(freq[0]) } if freq[1].Value != 56.2 { - testEnv.Error(freq[1]) + t.Error(freq[1]) } if freq[2].Value != 87.1 { - testEnv.Error(freq[2]) + t.Error(freq[2]) } } -func TestChiMerge2(testEnv *testing.T) { +func TestChiMerge2(t *testing.T) { // // See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } // Sort the instances @@ -111,35 +111,35 @@ func TestChiMerge2(testEnv *testing.T) { sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1] instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs) if err != nil { - testEnv.Fatalf("Sort failed: %s", err.Error()) + t.Fatalf("Sort failed: %s", err.Error()) } // Perform Chi-Merge _, rows := inst.Size() freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows) if len(freq) != 5 { - testEnv.Errorf("Wrong length (%d)", len(freq)) - testEnv.Error(freq) + t.Errorf("Wrong length (%d)", len(freq)) + t.Error(freq) } if freq[0].Value != 4.3 { - testEnv.Error(freq[0]) + t.Error(freq[0]) } if freq[1].Value != 5.5 { - testEnv.Error(freq[1]) + t.Error(freq[1]) } if freq[2].Value != 5.8 { - testEnv.Error(freq[2]) + t.Error(freq[2]) } if freq[3].Value != 6.3 { - testEnv.Error(freq[3]) + t.Error(freq[3]) } if freq[4].Value != 7.1 { - testEnv.Error(freq[4]) + t.Error(freq[4]) } } /* -func TestChiMerge3(testEnv *testing.T) { +func TestChiMerge3(t *testing.T) { // See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) @@ -149,7 +149,7 @@ func TestChiMerge3(testEnv *testing.T) { insts, err := base.LazySort(inst, base.Ascending, base.ResolveAllAttributes(inst, inst.AllAttributes())) if err != nil { - testEnv.Error(err) + t.Error(err) } filt := NewChiMergeFilter(inst, 0.90) filt.AddAttribute(inst.AllAttributes()[0]) @@ -172,12 +172,12 @@ func TestChiMerge3(testEnv *testing.T) { } */ -func TestChiMerge4(testEnv *testing.T) { +func TestChiMerge4(t *testing.T) { // See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := NewChiMergeFilter(inst, 0.90) @@ -187,11 +187,11 @@ func TestChiMerge4(testEnv *testing.T) { instf := base.NewLazilyFilteredInstances(inst, filt) clsAttrs := instf.AllClassAttributes() if len(clsAttrs) != 1 { - testEnv.Fatalf("%d != %d", len(clsAttrs), 1) + t.Fatalf("%d != %d", len(clsAttrs), 1) } firstClassAttributeName := clsAttrs[0].GetName() expectedClassAttributeName := "Species" if firstClassAttributeName != expectedClassAttributeName { - testEnv.Fatalf("Expected class attribute '%s'; actual class attribute '%s'", expectedClassAttributeName, firstClassAttributeName) + t.Fatalf("Expected class attribute '%s'; actual class attribute '%s'", expectedClassAttributeName, firstClassAttributeName) } } diff --git a/meta/bagging_test.go b/meta/bagging_test.go index 9443cf7..918cbbe 100644 --- a/meta/bagging_test.go +++ b/meta/bagging_test.go @@ -10,10 +10,10 @@ import ( "time" ) -func BenchmarkBaggingRandomForestFit(testEnv *testing.B) { +func BenchmarkBaggingRandomForestFit(t *testing.B) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } rand.Seed(time.Now().UnixNano()) @@ -29,16 +29,16 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) { rf.AddModel(trees.NewRandomTree(2)) } - testEnv.ResetTimer() + t.ResetTimer() for i := 0; i < 20; i++ { rf.Fit(instf) } } -func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) { +func BenchmarkBaggingRandomForestPredict(t *testing.B) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } rand.Seed(time.Now().UnixNano()) @@ -55,16 +55,16 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) { } rf.Fit(instf) - testEnv.ResetTimer() + t.ResetTimer() for i := 0; i < 20; i++ { rf.Predict(instf) } } -func TestRandomForest1(testEnv *testing.T) { +func TestRandomForest1(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) diff --git a/trees/tree_test.go b/trees/tree_test.go index d20089f..5e8b11a 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -8,10 +8,10 @@ import ( "testing" ) -func TestRandomTree(testEnv *testing.T) { +func TestRandomTree(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := filters.NewChiMergeFilter(inst, 0.90) @@ -27,10 +27,10 @@ func TestRandomTree(testEnv *testing.T) { _ = InferID3Tree(instf, r) } -func TestRandomTreeClassification(testEnv *testing.T) { +func TestRandomTreeClassification(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) @@ -52,10 +52,10 @@ func TestRandomTreeClassification(testEnv *testing.T) { _ = eval.GetSummary(confusionMat) } -func TestRandomTreeClassification2(testEnv *testing.T) { +func TestRandomTreeClassification2(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.4) @@ -75,10 +75,10 @@ func TestRandomTreeClassification2(testEnv *testing.T) { _ = eval.GetSummary(confusionMat) } -func TestPruning(testEnv *testing.T) { +func TestPruning(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) @@ -100,7 +100,7 @@ func TestPruning(testEnv *testing.T) { _ = eval.GetSummary(confusionMat) } -func TestInformationGain(testEnv *testing.T) { +func TestInformationGain(t *testing.T) { outlook := make(map[string]map[string]int) outlook["sunny"] = make(map[string]int) outlook["overcast"] = make(map[string]int) @@ -113,14 +113,14 @@ func TestInformationGain(testEnv *testing.T) { entropy := getSplitEntropy(outlook) if math.Abs(entropy-0.694) > 0.001 { - testEnv.Error(entropy) + t.Error(entropy) } } -func TestID3Inference(testEnv *testing.T) { +func TestID3Inference(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } // Build the decision tree @@ -130,47 +130,47 @@ func TestID3Inference(testEnv *testing.T) { // Verify the tree // First attribute should be "outlook" if root.SplitAttr.GetName() != "outlook" { - testEnv.Error(root) + t.Error(root) } sunnyChild := root.Children["sunny"] overcastChild := root.Children["overcast"] rainyChild := root.Children["rainy"] if sunnyChild.SplitAttr.GetName() != "humidity" { - testEnv.Error(sunnyChild) + t.Error(sunnyChild) } if rainyChild.SplitAttr.GetName() != "windy" { - testEnv.Error(rainyChild) + t.Error(rainyChild) } if overcastChild.SplitAttr != nil { - testEnv.Error(overcastChild) + t.Error(overcastChild) } sunnyLeafHigh := sunnyChild.Children["high"] sunnyLeafNormal := sunnyChild.Children["normal"] if sunnyLeafHigh.Class != "no" { - testEnv.Error(sunnyLeafHigh) + t.Error(sunnyLeafHigh) } if sunnyLeafNormal.Class != "yes" { - testEnv.Error(sunnyLeafNormal) + t.Error(sunnyLeafNormal) } windyLeafFalse := rainyChild.Children["false"] windyLeafTrue := rainyChild.Children["true"] if windyLeafFalse.Class != "yes" { - testEnv.Error(windyLeafFalse) + t.Error(windyLeafFalse) } if windyLeafTrue.Class != "no" { - testEnv.Error(windyLeafTrue) + t.Error(windyLeafTrue) } if overcastChild.Class != "yes" { - testEnv.Error(overcastChild) + t.Error(overcastChild) } } -func TestID3Classification(testEnv *testing.T) { +func TestID3Classification(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } filt := filters.NewBinningFilter(inst, 10) @@ -191,10 +191,10 @@ func TestID3Classification(testEnv *testing.T) { _ = eval.GetSummary(confusionMat) } -func TestID3(testEnv *testing.T) { +func TestID3(t *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) if err != nil { - testEnv.Fatal("Unable to parse CSV to instances: %s", err.Error()) + t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } // Build the decision tree @@ -205,40 +205,40 @@ func TestID3(testEnv *testing.T) { // Verify the tree // First attribute should be "outlook" if root.SplitAttr.GetName() != "outlook" { - testEnv.Error(root) + t.Error(root) } sunnyChild := root.Children["sunny"] overcastChild := root.Children["overcast"] rainyChild := root.Children["rainy"] if sunnyChild.SplitAttr.GetName() != "humidity" { - testEnv.Error(sunnyChild) + t.Error(sunnyChild) } if rainyChild.SplitAttr.GetName() != "windy" { - testEnv.Error(rainyChild) + t.Error(rainyChild) } if overcastChild.SplitAttr != nil { - testEnv.Error(overcastChild) + t.Error(overcastChild) } sunnyLeafHigh := sunnyChild.Children["high"] sunnyLeafNormal := sunnyChild.Children["normal"] if sunnyLeafHigh.Class != "no" { - testEnv.Error(sunnyLeafHigh) + t.Error(sunnyLeafHigh) } if sunnyLeafNormal.Class != "yes" { - testEnv.Error(sunnyLeafNormal) + t.Error(sunnyLeafNormal) } windyLeafFalse := rainyChild.Children["false"] windyLeafTrue := rainyChild.Children["true"] if windyLeafFalse.Class != "yes" { - testEnv.Error(windyLeafFalse) + t.Error(windyLeafFalse) } if windyLeafTrue.Class != "no" { - testEnv.Error(windyLeafTrue) + t.Error(windyLeafTrue) } if overcastChild.Class != "yes" { - testEnv.Error(overcastChild) + t.Error(overcastChild) } }