From b07a272ca6fd664bcb15298a4928eed94075b9fe Mon Sep 17 00:00:00 2001 From: st-user Date: Tue, 25 May 2021 10:42:44 +0900 Subject: [PATCH] Dji Tello Halt does not terminate all the related goroutines and may wait forever when it is called multiple times Fix the issue. --- platforms/dji/tello/driver.go | 37 +++++++++++++++++++++++++----- platforms/dji/tello/driver_test.go | 14 +++++++++++ 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/platforms/dji/tello/driver.go b/platforms/dji/tello/driver.go index 77451b24..180d02f8 100644 --- a/platforms/dji/tello/driver.go +++ b/platforms/dji/tello/driver.go @@ -10,6 +10,7 @@ import ( "net" "strconv" "sync" + "sync/atomic" "time" "gobot.io/x/gobot" @@ -191,7 +192,8 @@ type Driver struct { throttle int bouncing bool gobot.Eventer - doneCh chan struct{} + doneCh chan struct{} + doneChReaderCount int32 } // NewDriver creates a driver for the Tello drone. Pass in the UDP port to use for the responses @@ -280,7 +282,10 @@ func (d *Driver) Start() error { d.cmdConn = cmdConn // handle responses + d.addDoneChReaderCount(1) go func() { + defer d.addDoneChReaderCount(-1) + d.On(d.Event(ConnectedEvent), func(interface{}) { d.SendDateTime() d.processVideo() @@ -304,13 +309,22 @@ func (d *Driver) Start() error { d.SendCommand(d.connectionString()) // send stick commands + d.addDoneChReaderCount(1) go func() { + defer d.addDoneChReaderCount(-1) + + stickCmdLoop: for { - err := d.SendStickCommand() - if err != nil { - fmt.Println("stick command error:", err) + select { + case <-d.doneCh: + break stickCmdLoop + default: + err := d.SendStickCommand() + if err != nil { + fmt.Println("stick command error:", err) + } + time.Sleep(20 * time.Millisecond) } - time.Sleep(20 * time.Millisecond) } }() @@ -321,7 +335,11 @@ func (d *Driver) Start() error { func (d *Driver) Halt() (err error) { // send a landing command when we disconnect, and give it 500ms to be received before we shutdown d.Land() - d.doneCh <- struct{}{} + readerCount := atomic.LoadInt32(&d.doneChReaderCount) + for i := 0; i < int(readerCount); i++ { + d.doneCh <- struct{}{} + } + time.Sleep(500 * time.Millisecond) d.cmdConn.Close() @@ -946,7 +964,10 @@ func (d *Driver) processVideo() error { return err } + d.addDoneChReaderCount(1) go func() { + defer d.addDoneChReaderCount(-1) + videoConnLoop: for { select { @@ -989,6 +1010,10 @@ func (d *Driver) connectionString() string { return res } +func (d *Driver) addDoneChReaderCount(delta int32) { + atomic.AddInt32(&d.doneChReaderCount, delta) +} + func (f *FlightData) AirSpeed() float64 { return math.Sqrt( math.Pow(float64(f.NorthSpeed), 2) + diff --git a/platforms/dji/tello/driver_test.go b/platforms/dji/tello/driver_test.go index 18a584d6..5d8cfc34 100644 --- a/platforms/dji/tello/driver_test.go +++ b/platforms/dji/tello/driver_test.go @@ -151,17 +151,29 @@ func TestHaltShouldTerminateAllTheRelatedGoroutines(t *testing.T) { var wg sync.WaitGroup wg.Add(3) + + d.addDoneChReaderCount(1) go func() { + defer d.addDoneChReaderCount(-1) + <-d.doneCh wg.Done() fmt.Println("Done routine 1.") }() + + d.addDoneChReaderCount(1) go func() { + defer d.addDoneChReaderCount(-1) + <-d.doneCh wg.Done() fmt.Println("Done routine 2.") }() + + d.addDoneChReaderCount(1) go func() { + defer d.addDoneChReaderCount(-1) + <-d.doneCh wg.Done() fmt.Println("Done routine 3.") @@ -169,6 +181,8 @@ func TestHaltShouldTerminateAllTheRelatedGoroutines(t *testing.T) { d.Halt() wg.Wait() + + gobottest.Assert(t, d.doneChReaderCount, int32(0)) } func TestHaltNotWaitForeverWhenCalledMultipleTimes(t *testing.T) {