mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
375 lines
11 KiB
Go
375 lines
11 KiB
Go
package meta
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
// OneVsAllModel replaces class Attributes with numeric versions
|
|
// and trains n wrapped classifiers. The actual class is chosen
|
|
// by whichever is most confident. Only one CategoricalAttribute
|
|
// class variable is supported.
|
|
type OneVsAllModel struct {
|
|
NewClassifierFunction func(string) base.Classifier
|
|
filters []*oneVsAllFilter
|
|
classifiers []base.Classifier
|
|
maxClassVal uint64
|
|
fitOn base.FixedDataGrid
|
|
classValues []string
|
|
}
|
|
|
|
// NewOneVsAllModel creates a new OneVsAllModel. The argument
|
|
// must be a function which returns a base.Classifier ready for training.
|
|
func NewOneVsAllModel(f func(string) base.Classifier) *OneVsAllModel {
|
|
return &OneVsAllModel{
|
|
f,
|
|
nil,
|
|
nil,
|
|
0,
|
|
nil,
|
|
nil,
|
|
}
|
|
}
|
|
|
|
// Fit creates n filtered datasets (where n is the number of values
|
|
// a CategoricalAttribute can take) and uses them to train the
|
|
// underlying classifiers.
|
|
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
|
|
var classAttr *base.CategoricalAttribute
|
|
// Do some validation
|
|
classAttrs := using.AllClassAttributes()
|
|
for _, a := range classAttrs {
|
|
if c, ok := a.(*base.CategoricalAttribute); !ok {
|
|
panic("Unsupported ClassAttribute type")
|
|
} else {
|
|
classAttr = c
|
|
}
|
|
}
|
|
attrs := m.generateAttributes(using)
|
|
|
|
// Find the highest stored value
|
|
val := uint64(0)
|
|
classVals := classAttr.GetValues()
|
|
for _, s := range classVals {
|
|
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
|
|
if cur > val {
|
|
val = cur
|
|
}
|
|
}
|
|
if val == 0 {
|
|
panic("Must have more than one class!")
|
|
}
|
|
m.maxClassVal = val
|
|
|
|
// Create individual filtered instances for training
|
|
filters := make([]*oneVsAllFilter, val+1)
|
|
classifiers := make([]base.Classifier, val+1)
|
|
for i := uint64(0); i <= val; i++ {
|
|
f := &oneVsAllFilter{
|
|
attrs,
|
|
classAttr,
|
|
i,
|
|
}
|
|
filters[i] = f
|
|
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
|
|
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
|
|
}
|
|
|
|
m.filters = filters
|
|
m.classifiers = classifiers
|
|
m.fitOn = base.NewStructuralCopy(using)
|
|
m.classValues = classVals
|
|
}
|
|
|
|
// Predict issues predictions. Each class-specific classifier is expected
|
|
// to output a value between 0 (indicating that a given instance is not
|
|
// a given class) and 1 (indicating that the given instance is definitely
|
|
// that class). For each instance, the class with the highest value is chosen.
|
|
// The result is undefined if several underlying models output the same value.
|
|
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
|
ret := base.GeneratePredictionVector(what)
|
|
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
|
|
specs := make([]base.AttributeSpec, m.maxClassVal+1)
|
|
for i := uint64(0); i <= m.maxClassVal; i++ {
|
|
f := m.filters[i]
|
|
c := base.NewLazilyFilteredInstances(what, f)
|
|
p, err := m.classifiers[i].Predict(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vecs[i] = p
|
|
specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
|
|
}
|
|
_, rows := ret.Size()
|
|
spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
|
|
for i := 0; i < rows; i++ {
|
|
class := uint64(0)
|
|
best := 0.0
|
|
for j := uint64(0); j <= m.maxClassVal; j++ {
|
|
val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
|
|
if val > best {
|
|
class = j
|
|
best = val
|
|
}
|
|
}
|
|
ret.Set(spec, i, base.PackU64ToBytes(class))
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (m *OneVsAllModel) Load(filePath string) error {
|
|
reader, err := base.ReadSerializedClassifierStub(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = m.LoadWithPrefix(reader, "")
|
|
reader.Close()
|
|
return err
|
|
}
|
|
|
|
func (m *OneVsAllModel) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
|
|
|
pI := func(n string, i int) string {
|
|
return reader.Prefix(prefix, reader.Prefix(n, fmt.Sprintf("%d", i)))
|
|
}
|
|
|
|
// Reload the instances
|
|
fitOn, err := reader.GetInstancesForKey(reader.Prefix(prefix, "INSTANCE_STRUCTURE"))
|
|
if err != nil {
|
|
return base.DescribeError("Can't load INSTANCE_STRUCTURE", err)
|
|
}
|
|
m.fitOn = fitOn
|
|
|
|
// Reload the filters
|
|
numFiltersU64, err := reader.GetU64ForKey(reader.Prefix(prefix, "FILTER_COUNT"))
|
|
if err != nil {
|
|
return base.DescribeError("Can't load FILTER_COUNT", err)
|
|
}
|
|
m.filters = make([]*oneVsAllFilter, 0)
|
|
numFilters := int(numFiltersU64)
|
|
for i := 0; i < numFilters; i++ {
|
|
f := oneVsAllFilter{}
|
|
|
|
mapPrefix := pI(reader.Prefix(prefix, "FILTER"), i)
|
|
mapCountKey := reader.Prefix(mapPrefix, "COUNT")
|
|
numAttrsInMapU64, err := reader.GetU64ForKey(mapCountKey)
|
|
if err != nil {
|
|
return base.FormatError(err, "Unable to read %s", mapCountKey)
|
|
}
|
|
|
|
attrMap := make(map[base.Attribute]base.Attribute)
|
|
|
|
for j := 0; j < int(numAttrsInMapU64); j++ {
|
|
mapTupleKey := pI(mapPrefix, j)
|
|
mapKeyKeyKey := reader.Prefix(mapTupleKey, "KEY")
|
|
mapKeyValKey := reader.Prefix(mapTupleKey, "VAL")
|
|
|
|
keyAttrRaw, err := reader.GetAttributeForKey(mapKeyKeyKey)
|
|
if err != nil {
|
|
return base.FormatError(err, "Unable to read Attr from %s", mapKeyKeyKey)
|
|
}
|
|
valAttrRaw, err := reader.GetAttributeForKey(mapKeyValKey)
|
|
if err != nil {
|
|
return base.FormatError(err, "Unable to read Attr from %s", mapKeyValKey)
|
|
}
|
|
|
|
keyAttr, err := base.ReplaceDeserializedAttributeWithVersionFromInstances(keyAttrRaw, m.fitOn)
|
|
if err != nil {
|
|
return base.FormatError(err, "Can't resolve this attribute: %s", keyAttrRaw)
|
|
}
|
|
|
|
attrMap[keyAttr] = valAttrRaw
|
|
}
|
|
f.attrs = attrMap
|
|
mapClassKey := reader.Prefix(mapPrefix, "CLASS_ATTR")
|
|
classAttrRaw, err := reader.GetAttributeForKey(mapClassKey)
|
|
if err != nil {
|
|
return base.FormatError(err, "Can't read from: %s", mapClassKey)
|
|
}
|
|
classAttr, err := base.ReplaceDeserializedAttributeWithVersionFromInstances(classAttrRaw, m.fitOn)
|
|
if err != nil {
|
|
return base.FormatError(err, "Can't resolve: %s", classAttr)
|
|
}
|
|
f.classAttr = classAttr
|
|
|
|
classAttrValKey := reader.Prefix(mapPrefix, "CLASS_VAL")
|
|
classVal, err := reader.GetU64ForKey(classAttrValKey)
|
|
if err != nil {
|
|
return base.FormatError(err, "Can't read from: %s", classAttrValKey)
|
|
}
|
|
f.classAttrVal = classVal
|
|
m.filters = append(m.filters, &f)
|
|
}
|
|
// Reload the class values
|
|
var classVals = make([]string, 0)
|
|
err = reader.GetJSONForKey(reader.Prefix(prefix, "CLASS_VALUES"), &classVals)
|
|
if err != nil {
|
|
return base.DescribeError("Can't read CLASS_VALUES", err)
|
|
}
|
|
m.classValues = classVals
|
|
|
|
// Reload the classifiers
|
|
m.classifiers = make([]base.Classifier, 0)
|
|
for i, c := range classVals {
|
|
cls := m.NewClassifierFunction(c)
|
|
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
|
|
|
|
err = cls.LoadWithPrefix(reader, clsPrefix)
|
|
if err != nil {
|
|
return base.FormatError(err, "Could not reload classifier at: %s", clsPrefix)
|
|
}
|
|
m.classifiers = append(m.classifiers, cls)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *OneVsAllModel) GetMetadata() base.ClassifierMetadataV1 {
|
|
return base.ClassifierMetadataV1{
|
|
FormatVersion: 1,
|
|
ClassifierName: "OneVsAllModel",
|
|
ClassifierVersion: "1.0",
|
|
ClassifierMetadata: nil,
|
|
}
|
|
}
|
|
|
|
func (m *OneVsAllModel) Save(filePath string) error {
|
|
writer, err := base.CreateSerializedClassifierStub(filePath, m.GetMetadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return m.SaveWithPrefix(writer, "")
|
|
}
|
|
|
|
func (m *OneVsAllModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
|
|
|
pI := func(n string, i int) string {
|
|
return writer.Prefix(prefix, writer.Prefix(n, fmt.Sprintf("%d", i)))
|
|
}
|
|
|
|
// Save the instances
|
|
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "INSTANCE_STRUCTURE"), m.fitOn, false)
|
|
if err != nil {
|
|
return base.DescribeError("Unable to write INSTANCE_STRUCTURE", err)
|
|
}
|
|
|
|
// Write the class values
|
|
err = writer.WriteJSONForKey(writer.Prefix(prefix, "CLASS_VALUES"), m.classValues)
|
|
if err != nil {
|
|
return base.DescribeError("Can't write CLASS_VALUES", err)
|
|
}
|
|
|
|
// Save the filters
|
|
err = writer.WriteU64ForKey(writer.Prefix(prefix, "FILTER_COUNT"), uint64(len(m.filters)))
|
|
if err != nil {
|
|
return base.DescribeError("Unable to write FILTER_COUNT", err)
|
|
}
|
|
for i, f := range m.filters {
|
|
mapPrefix := pI(writer.Prefix(prefix, "FILTER"), i)
|
|
mapCountKey := writer.Prefix(mapPrefix, "COUNT")
|
|
err := writer.WriteU64ForKey(mapCountKey, uint64(len(f.attrs)))
|
|
if err != nil {
|
|
return base.DescribeError("Unable to write the size of the filter map", err)
|
|
}
|
|
j := 0
|
|
for key := range f.attrs {
|
|
mapTupleKey := pI(mapPrefix, j)
|
|
mapKeyKeyKey := writer.Prefix(mapTupleKey, "KEY")
|
|
mapKeyValKey := writer.Prefix(mapTupleKey, "VAL")
|
|
|
|
err = writer.WriteAttributeForKey(mapKeyKeyKey, key)
|
|
if err != nil {
|
|
return base.DescribeError("Unable to write filter map key", err)
|
|
}
|
|
err = writer.WriteAttributeForKey(mapKeyValKey, f.attrs[key])
|
|
if err != nil {
|
|
return base.DescribeError("Unable to write filter map value", err)
|
|
}
|
|
j++
|
|
}
|
|
mapClassKey := writer.Prefix(mapPrefix, "CLASS_ATTR")
|
|
err = writer.WriteAttributeForKey(mapClassKey, f.classAttr)
|
|
if err != nil {
|
|
return base.DescribeError("Unable to write CLASS_ATTR", err)
|
|
}
|
|
classAttrValKey := writer.Prefix(mapPrefix, "CLASS_VAL")
|
|
err = writer.WriteU64ForKey(classAttrValKey, f.classAttrVal)
|
|
if err != nil {
|
|
return base.DescribeError("Can't write CLASS_VAL", err)
|
|
}
|
|
}
|
|
|
|
// Save the classifiers
|
|
for i, c := range m.classifiers {
|
|
clsPrefix := pI(writer.Prefix(prefix, "CLASSIFIERS"), i)
|
|
err = c.SaveWithPrefix(writer, clsPrefix)
|
|
if err != nil {
|
|
base.FormatError(err, "Can't save classifier for class %s", m.classValues[i])
|
|
}
|
|
}
|
|
|
|
return writer.Close()
|
|
}
|
|
|
|
func (m *OneVsAllModel) generateAttributes(from base.FixedDataGrid) map[base.Attribute]base.Attribute {
|
|
attrs := from.AllAttributes()
|
|
classAttrs := from.AllClassAttributes()
|
|
if len(classAttrs) != 1 {
|
|
panic("Only 1 class Attribute is supported!")
|
|
}
|
|
ret := make(map[base.Attribute]base.Attribute)
|
|
for _, a := range attrs {
|
|
ret[a] = a
|
|
for _, b := range classAttrs {
|
|
if a.Equals(b) {
|
|
cur := base.NewFloatAttribute(b.GetName())
|
|
ret[a] = cur
|
|
}
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
|
|
//
|
|
// Filter implementation
|
|
//
|
|
type oneVsAllFilter struct {
|
|
attrs map[base.Attribute]base.Attribute
|
|
classAttr base.Attribute
|
|
classAttrVal uint64
|
|
}
|
|
|
|
func (f *oneVsAllFilter) AddAttribute(a base.Attribute) error {
|
|
return fmt.Errorf("Not supported")
|
|
}
|
|
|
|
func (f *oneVsAllFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
|
|
ret := make([]base.FilteredAttribute, len(f.attrs))
|
|
cnt := 0
|
|
for i := range f.attrs {
|
|
ret[cnt] = base.FilteredAttribute{i, f.attrs[i]}
|
|
cnt++
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (f *oneVsAllFilter) String() string {
|
|
return "oneVsAllFilter"
|
|
}
|
|
|
|
func (f *oneVsAllFilter) Transform(old, to base.Attribute, seq []byte) []byte {
|
|
if !old.Equals(f.classAttr) {
|
|
return seq
|
|
}
|
|
val := base.UnpackBytesToU64(seq)
|
|
if val == f.classAttrVal {
|
|
return base.PackFloatToBytes(1.0)
|
|
}
|
|
return base.PackFloatToBytes(0.0)
|
|
}
|
|
|
|
func (f *oneVsAllFilter) Train() error {
|
|
return fmt.Errorf("Unsupported")
|
|
} |