mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Rewrite ParseCSV, add DataFrame Structure.
This commit is contained in:
parent
1e21802d2e
commit
c2111752a0
47
data/csv.go
47
data/csv.go
@ -4,33 +4,39 @@ package data
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
//Parses a CSV file, returning the number of columns and rows, the headers, the labels associated with
|
||||
//classification, and the data that will be used for training.
|
||||
func ParseCsv(filepath string, label int, columns []int) (int, int, []string, []string, []float64) {
|
||||
labels := make([]string, 0)
|
||||
data := make([]float64, 0)
|
||||
func ParseCSV(filepath string, featureCols []int, labelCols []int, header bool) *DataFrame {
|
||||
headers := make([]string, 0)
|
||||
data := make([]float64, 0)
|
||||
labels := make([]string, 0)
|
||||
rows := 0
|
||||
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
panic(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
reader := csv.NewReader(file)
|
||||
|
||||
headerrow, _ := reader.Read()
|
||||
if header {
|
||||
record, err := reader.Read()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, col := range columns {
|
||||
entry := headerrow[col]
|
||||
headers = append(headers, entry)
|
||||
for _, col := range append(featureCols, labelCols...) {
|
||||
headers = append(headers, record[col])
|
||||
}
|
||||
rows += 1
|
||||
}
|
||||
|
||||
for {
|
||||
@ -38,20 +44,29 @@ func ParseCsv(filepath string, label int, columns []int) (int, int, []string, []
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
//
|
||||
labels = append(labels, record[label])
|
||||
|
||||
//Iterate over our rows and append the values to a slice
|
||||
for _, col := range columns {
|
||||
for _, col := range featureCols {
|
||||
entry := record[col]
|
||||
number, _ := strconv.ParseFloat(entry, 64)
|
||||
data = append(data, number)
|
||||
}
|
||||
|
||||
for _, col := range labelCols {
|
||||
labels = append(labels, record[col])
|
||||
}
|
||||
|
||||
rows += 1
|
||||
}
|
||||
cols := len(columns)
|
||||
return cols, rows, headers, labels, data
|
||||
|
||||
return &DataFrame{
|
||||
Headers: headers,
|
||||
Labels: labels,
|
||||
Values: mat64.NewDense(rows, len(featureCols), data),
|
||||
NRow: rows,
|
||||
NFeature: len(featureCols),
|
||||
NLabel: len(labelCols),
|
||||
}
|
||||
}
|
||||
|
14
data/data.go
Normal file
14
data/data.go
Normal file
@ -0,0 +1,14 @@
|
||||
/* Data - consists of helper functions for parsing different data formats */
|
||||
|
||||
package data
|
||||
|
||||
import "github.com/gonum/matrix/mat64"
|
||||
|
||||
type DataFrame struct {
|
||||
Headers []string
|
||||
Labels []string
|
||||
Values *mat64.Dense // We first focus on numeric values for now
|
||||
NRow int
|
||||
NFeature int
|
||||
NLabel int
|
||||
}
|
38
data/data_test.go
Normal file
38
data/data_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"testing"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestParseCSV(t *testing.T) {
|
||||
|
||||
Convey("Parse IRIS dataset", t, func() {
|
||||
dataFrame := ParseCSV("../examples/datasets/iris.csv", []int{0, 1, 3}, []int{4}, false)
|
||||
|
||||
Convey("First row should be {5.1, 3.5, 0.2}", func() {
|
||||
So(dataFrame.Values.RowView(0), ShouldResemble, []float64{5.1, 3.5, 0.2})
|
||||
})
|
||||
|
||||
Convey("First label should be Iris-setosa", func() {
|
||||
So(dataFrame.Labels[0], ShouldEqual, "Iris-setosa")
|
||||
})
|
||||
|
||||
Convey("Headers should be empty", func() {
|
||||
So(dataFrame.Headers, ShouldResemble, []string{})
|
||||
})
|
||||
|
||||
Convey("Number of features should be 3", func() {
|
||||
So(dataFrame.NFeature, ShouldEqual, 3)
|
||||
})
|
||||
|
||||
Convey("Number of labels should be 1", func() {
|
||||
So(dataFrame.NLabel, ShouldEqual, 1)
|
||||
})
|
||||
|
||||
Convey("Number of rows should be 150", func() {
|
||||
So(dataFrame.NRow, ShouldEqual, 150)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user