diff --git a/client.go b/client.go index eeca09a..41a69d0 100644 --- a/client.go +++ b/client.go @@ -125,9 +125,9 @@ func (disc *Client) jsonDecodeLoop(in io.Reader, outChan chan<- *discoveryMessag closeAndReportError := func(err error) { disc.statusMutex.Lock() disc.incomingMessagesError = err - disc.statusMutex.Unlock() disc.stopSync() disc.killProcess() + disc.statusMutex.Unlock() close(outChan) if err != nil { disc.logger.Errorf("Stopped decode loop: %v", err) @@ -138,11 +138,7 @@ func (disc *Client) jsonDecodeLoop(in io.Reader, outChan chan<- *discoveryMessag for { var msg discoveryMessage - if err := decoder.Decode(&msg); errors.Is(err, io.EOF) { - // This is fine :flames: we exit gracefully - closeAndReportError(nil) - return - } else if err != nil { + if err := decoder.Decode(&msg); err != nil { closeAndReportError(err) return } @@ -184,7 +180,10 @@ func (disc *Client) waitMessage(timeout time.Duration) (*discoveryMessage, error select { case msg := <-disc.incomingMessagesChan: if msg == nil { - return nil, disc.incomingMessagesError + disc.statusMutex.Lock() + err := disc.incomingMessagesError + disc.statusMutex.Unlock() + return nil, err } return msg, nil case <-time.After(timeout): @@ -239,9 +238,6 @@ func (disc *Client) runProcess() error { } func (disc *Client) killProcess() { - disc.statusMutex.Lock() - defer disc.statusMutex.Unlock() - disc.logger.Debugf("Killing discovery process") if process := disc.process; process != nil { disc.process = nil @@ -270,7 +266,9 @@ func (disc *Client) Run() (err error) { if err == nil { return } + disc.statusMutex.Lock() disc.killProcess() + disc.statusMutex.Unlock() }() if err = disc.sendCommand("HELLO 1 \"arduino-cli " + disc.userAgent + "\"\n"); err != nil { @@ -287,8 +285,6 @@ func (disc *Client) Run() (err error) { } else if msg.ProtocolVersion > 1 { return fmt.Errorf("protocol version not supported: requested 1, got %d", msg.ProtocolVersion) } - disc.statusMutex.Lock() - defer disc.statusMutex.Unlock() return nil } @@ -307,8 +303,6 @@ func (disc *Client) Start() error { } else if strings.ToUpper(msg.Message) != "OK" { return fmt.Errorf("communication out of sync, expected 'OK', received '%s'", msg.Message) } - disc.statusMutex.Lock() - defer disc.statusMutex.Unlock() return nil } @@ -348,8 +342,10 @@ func (disc *Client) Quit() { if _, err := disc.waitMessage(time.Second * 5); err != nil { disc.logger.Errorf("Quitting discovery: %s", err) } + disc.statusMutex.Lock() disc.stopSync() disc.killProcess() + disc.statusMutex.Unlock() } // List executes an enumeration of the ports and returns a list of the available @@ -377,9 +373,6 @@ func (disc *Client) List() ([]*Port, error) { // The event channel must be consumed as quickly as possible since it may block the // discovery if it becomes full. The channel size is configurable. func (disc *Client) StartSync(size int) (<-chan *Event, error) { - disc.statusMutex.Lock() - defer disc.statusMutex.Unlock() - if err := disc.sendCommand("START_SYNC\n"); err != nil { return nil, err } @@ -395,6 +388,8 @@ func (disc *Client) StartSync(size int) (<-chan *Event, error) { } // In case there is already an existing event channel in use we close it before creating a new one. + disc.statusMutex.Lock() + defer disc.statusMutex.Unlock() disc.stopSync() c := make(chan *Event, size) disc.eventChan = c diff --git a/client_test.go b/client_test.go index ce945ee..abf08ce 100644 --- a/client_test.go +++ b/client_test.go @@ -19,6 +19,7 @@ package discovery import ( "fmt" + "io" "net" "testing" "time" @@ -93,3 +94,58 @@ func TestDiscoveryStdioHandling(t *testing.T) { require.False(t, disc.Alive()) } + +func TestClient(t *testing.T) { + // Build dummy-discovery + builder, err := paths.NewProcess(nil, "go", "build") + require.NoError(t, err) + builder.SetDir("dummy-discovery") + require.NoError(t, builder.Run()) + + t.Run("WithDiscoveryCrashingOnStartup", func(t *testing.T) { + // Run client with discovery crashing on startup + cl := NewClient("1", "dummy-discovery/dummy-discovery", "--invalid") + require.ErrorIs(t, cl.Run(), io.EOF) + }) + + t.Run("WithDiscoveryCrashingWhileSendingCommands", func(t *testing.T) { + // Run client with crashing discovery after 1 second + cl := NewClient("1", "dummy-discovery/dummy-discovery", "-k") + require.NoError(t, cl.Run()) + + time.Sleep(time.Second) + + ch, err := cl.StartSync(20) + require.Error(t, err) + require.Nil(t, ch) + }) + + t.Run("WithDiscoveryCrashingWhileStreamingEvents", func(t *testing.T) { + // Run client with crashing discovery after 1 second + cl := NewClient("1", "dummy-discovery/dummy-discovery", "-k") + require.NoError(t, cl.Run()) + + ch, err := cl.StartSync(20) + require.NoError(t, err) + + time.Sleep(time.Second) + + loop: + for { + select { + case msg, ok := <-ch: + if !ok { + // Channel closed: Test passed + fmt.Println("Event channel closed") + break loop + } + fmt.Println("Recv: ", msg) + case <-time.After(time.Second): + t.Error("Crashing client did not close event channel") + break loop + } + } + + cl.Quit() + }) +} diff --git a/dummy-discovery/args/args.go b/dummy-discovery/args/args.go index 8b1423e..f7b8195 100644 --- a/dummy-discovery/args/args.go +++ b/dummy-discovery/args/args.go @@ -20,6 +20,7 @@ package args import ( "fmt" "os" + "time" ) // Tag is the current git tag @@ -38,6 +39,14 @@ func Parse() { fmt.Printf("dummy-discovery %s (build timestamp: %s)\n", Tag, Timestamp) os.Exit(0) } + if arg == "-k" { + // Emulate crashing discovery + go func() { + time.Sleep(time.Millisecond * 500) + os.Exit(1) + }() + continue + } fmt.Fprintf(os.Stderr, "invalid argument: %s\n", arg) os.Exit(1) }