1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Rewrite ParseCSV, add DataFrame Structure.

This commit is contained in:
Bert Chang 2014-05-05 23:14:31 +08:00
parent 1e21802d2e
commit c2111752a0
3 changed files with 83 additions and 16 deletions

View File

@ -4,33 +4,39 @@ package data
import (
"encoding/csv"
"fmt"
"io"
"os"
"strconv"
"github.com/gonum/matrix/mat64"
)
//Parses a CSV file, returning the number of columns and rows, the headers, the labels associated with
//classification, and the data that will be used for training.
func ParseCsv(filepath string, label int, columns []int) (int, int, []string, []string, []float64) {
labels := make([]string, 0)
data := make([]float64, 0)
func ParseCSV(filepath string, featureCols []int, labelCols []int, header bool) *DataFrame {
headers := make([]string, 0)
data := make([]float64, 0)
labels := make([]string, 0)
rows := 0
file, err := os.Open(filepath)
if err != nil {
fmt.Println("Error:", err)
panic(err)
}
defer file.Close()
reader := csv.NewReader(file)
headerrow, _ := reader.Read()
if header {
record, err := reader.Read()
if err != nil {
panic(err)
}
for _, col := range columns {
entry := headerrow[col]
headers = append(headers, entry)
for _, col := range append(featureCols, labelCols...) {
headers = append(headers, record[col])
}
rows += 1
}
for {
@ -38,20 +44,29 @@ func ParseCsv(filepath string, label int, columns []int) (int, int, []string, []
if err == io.EOF {
break
} else if err != nil {
fmt.Println("Error:", err)
panic(err)
}
//
labels = append(labels, record[label])
//Iterate over our rows and append the values to a slice
for _, col := range columns {
for _, col := range featureCols {
entry := record[col]
number, _ := strconv.ParseFloat(entry, 64)
data = append(data, number)
}
for _, col := range labelCols {
labels = append(labels, record[col])
}
rows += 1
}
cols := len(columns)
return cols, rows, headers, labels, data
return &DataFrame{
Headers: headers,
Labels: labels,
Values: mat64.NewDense(rows, len(featureCols), data),
NRow: rows,
NFeature: len(featureCols),
NLabel: len(labelCols),
}
}

14
data/data.go Normal file
View File

@ -0,0 +1,14 @@
/* Data - consists of helper functions for parsing different data formats */
package data
import "github.com/gonum/matrix/mat64"
type DataFrame struct {
Headers []string
Labels []string
Values *mat64.Dense // We first focus on numeric values for now
NRow int
NFeature int
NLabel int
}

38
data/data_test.go Normal file
View File

@ -0,0 +1,38 @@
package data
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestParseCSV(t *testing.T) {
Convey("Parse IRIS dataset", t, func() {
dataFrame := ParseCSV("../examples/datasets/iris.csv", []int{0, 1, 3}, []int{4}, false)
Convey("First row should be {5.1, 3.5, 0.2}", func() {
So(dataFrame.Values.RowView(0), ShouldResemble, []float64{5.1, 3.5, 0.2})
})
Convey("First label should be Iris-setosa", func() {
So(dataFrame.Labels[0], ShouldEqual, "Iris-setosa")
})
Convey("Headers should be empty", func() {
So(dataFrame.Headers, ShouldResemble, []string{})
})
Convey("Number of features should be 3", func() {
So(dataFrame.NFeature, ShouldEqual, 3)
})
Convey("Number of labels should be 1", func() {
So(dataFrame.NLabel, ShouldEqual, 1)
})
Convey("Number of rows should be 150", func() {
So(dataFrame.NRow, ShouldEqual, 150)
})
})
}