diff --git a/platforms/sphero/sphero_driver.go b/platforms/sphero/sphero_driver.go index 6d018940..d528dceb 100644 --- a/platforms/sphero/sphero_driver.go +++ b/platforms/sphero/sphero_driver.go @@ -106,11 +106,16 @@ func (s *SpheroDriver) Start() bool { header := s.readHeader() if header != nil && len(header) != 0 { body := s.readBody(header[4]) - if header[1] == 0xFE { - async := append(header, body...) - s.asyncResponse = append(s.asyncResponse, async) - } else { - s.responseChannel <- append(header, body...) + data := append(header, body...) + checksum := data[len(data)-1] + if checksum != calculateChecksum(data[2:len(data)-1]) { + continue + } + switch header[1] { + case 0xFE: + s.asyncResponse = append(s.asyncResponse, data) + case 0xFF: + s.responseChannel <- data } } } @@ -230,7 +235,10 @@ func (s *SpheroDriver) write(packet *packet) { func (s *SpheroDriver) calculateChecksum(packet *packet) uint8 { buf := append(packet.header, packet.body...) - buf = buf[2:] + return calculateChecksum(buf[2:]) +} + +func calculateChecksum(buf []byte) byte { var calculatedChecksum uint16 for i := range buf { calculatedChecksum += uint16(buf[i]) diff --git a/platforms/sphero/sphero_driver_test.go b/platforms/sphero/sphero_driver_test.go index 47f34b2e..3d7678db 100644 --- a/platforms/sphero/sphero_driver_test.go +++ b/platforms/sphero/sphero_driver_test.go @@ -1,8 +1,9 @@ package sphero import ( - "github.com/hybridgroup/gobot" "testing" + + "github.com/hybridgroup/gobot" ) func initTestSpheroDriver() *SpheroDriver { @@ -20,3 +21,20 @@ func TestSpheroDriverHalt(t *testing.T) { d := initTestSpheroDriver() gobot.Assert(t, d.Halt(), true) } + +func TestCalculateChecksum(t *testing.T) { + tests := []struct { + data []byte + checksum byte + }{ + {[]byte{0x00}, 0xff}, + {[]byte{0xf0, 0x0f}, 0x00}, + } + + for _, tt := range tests { + actual := calculateChecksum(tt.data) + if actual != tt.checksum { + t.Errorf("Expected %x, got %x for data %x.", tt.checksum, actual, tt.data) + } + } +}