1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/meta/one_v_all.go
2017-09-10 19:30:02 +01:00

392 lines
12 KiB
Go

package meta
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"log"
)
// OneVsAllModel replaces class Attributes with numeric versions
// and trains n wrapped classifiers. The actual class is chosen
// by whichever is most confident. Only one CategoricalAttribute
// class variable is supported.
type OneVsAllModel struct {
NewClassifierFunction func(string) base.Classifier
filters []*oneVsAllFilter
classifiers []base.Classifier
maxClassVal uint64
fitOn base.FixedDataGrid
classValues []string
}
// NewOneVsAllModel creates a new OneVsAllModel. The argument
// must be a function which returns a base.Classifier ready for training.
func NewOneVsAllModel(f func(string) base.Classifier) *OneVsAllModel {
return &OneVsAllModel{
f,
nil,
nil,
0,
nil,
nil,
}
}
// Fit creates n filtered datasets (where n is the number of values
// a CategoricalAttribute can take) and uses them to train the
// underlying classifiers.
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
var classAttr *base.CategoricalAttribute
// Do some validation
classAttrs := using.AllClassAttributes()
for _, a := range classAttrs {
if c, ok := a.(*base.CategoricalAttribute); !ok {
panic("Unsupported ClassAttribute type")
} else {
classAttr = c
}
}
attrs := m.generateAttributes(using)
// Find the highest stored value
val := uint64(0)
classVals := classAttr.GetValues()
for _, s := range classVals {
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
if cur > val {
val = cur
}
}
if val == 0 {
panic("Must have more than one class!")
}
m.maxClassVal = val
// If we're reloading, we may just be fitting to the structure
_, srcRows := using.Size()
fittingToStructure := srcRows == 0
// Create individual filtered instances for training
filters := make([]*oneVsAllFilter, val+1)
classifiers := make([]base.Classifier, val+1)
for i := uint64(0); i <= val; i++ {
f := &oneVsAllFilter{
attrs,
classAttr,
i,
}
filters[i] = f
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
if !fittingToStructure {
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
}
}
m.filters = filters
m.classifiers = classifiers
m.fitOn = base.NewStructuralCopy(using)
m.classValues = classVals
}
// Predict issues predictions. Each class-specific classifier is expected
// to output a value between 0 (indicating that a given instance is not
// a given class) and 1 (indicating that the given instance is definitely
// that class). For each instance, the class with the highest value is chosen.
// The result is undefined if several underlying models output the same value.
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
ret := base.GeneratePredictionVector(what)
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
specs := make([]base.AttributeSpec, m.maxClassVal+1)
if int(m.maxClassVal) > len(m.filters) || (m.maxClassVal == 0 && len(m.filters) == 0) {
return nil, base.WrapError(fmt.Errorf("Internal error: m.Filter len = %d, maxClassVal = %d", len(m.filters), m.maxClassVal))
}
for i := uint64(0); i <= m.maxClassVal; i++ {
//log.Printf("i = %d, m.Filter len = %d, maxClassVal = %d", i, len(m.filters), m.maxClassVal)
f := m.filters[i]
c := base.NewLazilyFilteredInstances(what, f)
p, err := m.classifiers[i].Predict(c)
if err != nil {
return nil, err
}
vecs[i] = p
specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
}
_, rows := ret.Size()
spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
for i := 0; i < rows; i++ {
class := uint64(0)
best := 0.0
for j := uint64(0); j <= m.maxClassVal; j++ {
val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
if val > best {
class = j
best = val
}
}
ret.Set(spec, i, base.PackU64ToBytes(class))
}
return ret, nil
}
func (m *OneVsAllModel) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
err = m.LoadWithPrefix(reader, "")
reader.Close()
return err
}
func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
pI := func(n string, i int) string {
return reader.Prefix(prefix, reader.Prefix(n, fmt.Sprintf("%d", i)))
}
// Reload the instances
fitOn, err := reader.GetInstancesForKey(reader.Prefix(prefix, "INSTANCE_STRUCTURE"))
if err != nil {
return base.DescribeError("Can't load INSTANCE_STRUCTURE", err)
}
m.Fit(fitOn)
/*if err != nil {
base.DescribeError("Could not fit reloaded classifier to the structure", err)
}*/
// Reload the filters
numFiltersU64, err := reader.GetU64ForKey(reader.Prefix(prefix, "FILTER_COUNT"))
if err != nil {
return base.DescribeError("Can't load FILTER_COUNT", err)
}
m.filters = make([]*oneVsAllFilter, 0)
numFilters := int(numFiltersU64)
for i := 0; i < numFilters; i++ {
f := oneVsAllFilter{}
mapPrefix := pI("FILTER", i)
mapCountKey := reader.Prefix(mapPrefix, "COUNT")
numAttrsInMapU64, err := reader.GetU64ForKey(mapCountKey)
if err != nil {
return base.FormatError(err, "Unable to read %s", mapCountKey)
}
attrMap := make(map[base.Attribute]base.Attribute)
for j := 0; j < int(numAttrsInMapU64); j++ {
mapTupleKey := reader.Prefix(mapPrefix, fmt.Sprintf("%d"))
mapKeyKeyKey := reader.Prefix(mapTupleKey, "KEY")
mapKeyValKey := reader.Prefix(mapTupleKey, "VAL")
keyAttrRaw, err := reader.GetAttributeForKey(mapKeyKeyKey)
if err != nil {
return base.FormatError(err, "Unable to read Attr from %s", mapKeyKeyKey)
}
valAttrRaw, err := reader.GetAttributeForKey(mapKeyValKey)
if err != nil {
return base.FormatError(err, "Unable to read Attr from %s", mapKeyValKey)
}
keyAttr, err := base.ReplaceDeserializedAttributeWithVersionFromInstances(keyAttrRaw, m.fitOn)
if err != nil {
return base.FormatError(err, "Can't resolve this attribute: %s", keyAttrRaw)
}
attrMap[keyAttr] = valAttrRaw
}
f.attrs = attrMap
mapClassKey := reader.Prefix(mapPrefix, "CLASS_ATTR")
classAttrRaw, err := reader.GetAttributeForKey(mapClassKey)
if err != nil {
return base.FormatError(err, "Can't read from: %s", mapClassKey)
}
classAttr, err := base.ReplaceDeserializedAttributeWithVersionFromInstances(classAttrRaw, m.fitOn)
if err != nil {
return base.FormatError(err, "Can't resolve: %s", classAttr)
}
f.classAttr = classAttr
classAttrValKey := reader.Prefix(mapPrefix, "CLASS_VAL")
classVal, err := reader.GetU64ForKey(classAttrValKey)
if err != nil {
return base.FormatError(err, "Can't read from: %s", classAttrValKey)
}
f.classAttrVal = classVal
m.filters = append(m.filters, &f)
}
// Reload the class values
var classVals = make([]string, 0)
err = reader.GetJSONForKey(reader.Prefix(prefix, "CLASS_VALUES"), &classVals)
if err != nil {
return base.DescribeError("Can't read CLASS_VALUES", err)
}
m.classValues = classVals
// Reload the classifiers
m.classifiers = make([]base.Classifier, 0)
for i, c := range classVals {
cls := m.NewClassifierFunction(c)
clsPrefix := pI("CLASSIFIERS", i)
err = cls.LoadWithPrefix(reader, clsPrefix)
if err != nil {
return base.FormatError(err, "Could not reload classifier at: %s", clsPrefix)
}
m.classifiers = append(m.classifiers, cls)
}
return nil
}
func (m *OneVsAllModel) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "OneVsAllModel",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}
}
func (m *OneVsAllModel) Save(filePath string) error {
writer, err := base.CreateSerializedClassifierStub(filePath, m.GetMetadata())
if err != nil {
return err
}
return m.SaveWithPrefix(writer, "")
}
func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
pI := func(n string, i int) string {
return writer.Prefix(prefix, writer.Prefix(n, fmt.Sprintf("%d", i)))
}
// Save the instances
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "INSTANCE_STRUCTURE"), m.fitOn, false)
if err != nil {
return base.DescribeError("Unable to write INSTANCE_STRUCTURE", err)
}
// Write the class values
err = writer.WriteJSONForKey(writer.Prefix(prefix, "CLASS_VALUES"), m.classValues)
if err != nil {
return base.DescribeError("Can't write CLASS_VALUES", err)
}
// Save the filters
err = writer.WriteU64ForKey(writer.Prefix(prefix, "FILTER_COUNT"), uint64(len(m.filters)))
if err != nil {
return base.DescribeError("Unable to write FILTER_COUNT", err)
}
for i, f := range m.filters {
mapPrefix := pI("FILTER", i)
mapCountKey := writer.Prefix(mapPrefix, "COUNT")
err := writer.WriteU64ForKey(mapCountKey, uint64(len(f.attrs)))
if err != nil {
return base.DescribeError("Unable to write the size of the filter map", err)
}
j := 0
for key := range f.attrs {
mapTupleKey := writer.Prefix(mapPrefix, fmt.Sprintf("%d"))
mapKeyKeyKey := writer.Prefix(mapTupleKey, "KEY")
mapKeyValKey := writer.Prefix(mapTupleKey, "VAL")
err = writer.WriteAttributeForKey(mapKeyKeyKey, key)
if err != nil {
return base.DescribeError("Unable to write filter map key", err)
}
err = writer.WriteAttributeForKey(mapKeyValKey, f.attrs[key])
if err != nil {
return base.DescribeError("Unable to write filter map value", err)
}
j++
}
mapClassKey := writer.Prefix(mapPrefix, "CLASS_ATTR")
err = writer.WriteAttributeForKey(mapClassKey, f.classAttr)
if err != nil {
return base.DescribeError("Unable to write CLASS_ATTR", err)
}
classAttrValKey := writer.Prefix(mapPrefix, "CLASS_VAL")
err = writer.WriteU64ForKey(classAttrValKey, f.classAttrVal)
if err != nil {
return base.DescribeError("Can't write CLASS_VAL", err)
}
}
// Save the classifiers
for i, c := range m.classifiers {
clsPrefix := pI("CLASSIFIERS", i)
err = c.SaveWithPrefix(writer, clsPrefix)
if err != nil {
return base.FormatError(err, "Can't save classifier for class %s", m.classValues[i])
}
}
return writer.Close()
}
func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Attribute]base.Attribute {
attrs := from.AllAttributes()
classAttrs := from.AllClassAttributes()
if len(classAttrs) != 1 {
panic(fmt.Errorf("Only 1 class Attribute is supported, had %d", len(classAttrs)))
}
ret := make(map[base.Attribute]base.Attribute)
for _, a := range attrs {
ret[a] = a
for _, b := range classAttrs {
if a.Equals(b) {
cur := base.NewFloatAttribute(b.GetName())
ret[a] = cur
}
}
}
return ret
}
//
// Filter implementation
//
type oneVsAllFilter struct {
attrs map[base.Attribute]base.Attribute
classAttr base.Attribute
classAttrVal uint64
}
func (f *oneVsAllFilter) AddAttribute(a base.Attribute) error {
return fmt.Errorf("Not supported")
}
func (f *oneVsAllFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
ret := make([]base.FilteredAttribute, len(f.attrs))
cnt := 0
for i := range f.attrs {
ret[cnt] = base.FilteredAttribute{i, f.attrs[i]}
cnt++
}
return ret
}
func (f *oneVsAllFilter) String() string {
return "oneVsAllFilter"
}
func (f *oneVsAllFilter) Transform(old, to base.Attribute, seq []byte) []byte {
if !old.Equals(f.classAttr) {
return seq
}
val := base.UnpackBytesToU64(seq)
if val == f.classAttrVal {
return base.PackFloatToBytes(1.0)
}
return base.PackFloatToBytes(0.0)
}
func (f *oneVsAllFilter) Train() error {
return fmt.Errorf("Unsupported")
}