Skip to content

Commit

Permalink
client: handle disconnections better (#80)
Browse files Browse the repository at this point in the history
* add TRACE log level for noisy logs

* cleanly disconnect when checking for protocol version

so server logs are more readable

* gracefully handle disconnections

1. close all the channels upon a disconnection
2. don't try to write to them if we've initiated a disconnection
3. check for ErrCloseSent from gorilla/websocket

* clean up client_test, make goroutine test even more parallel

* use -count 1

---------

Co-authored-by: Andrey Kaipov <[email protected]>
  • Loading branch information
andreykaipov and andreykaipov authored Dec 18, 2023
1 parent de30064 commit c7baaaf
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 51 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Websocket server version: 5.1.0

## advanced configuration

- `GOOBS_LOG` can be set to `debug`, `info`, or `error` to better understand what our client is doing under the hood.
- `GOOBS_LOG` can be set to `trace`, `debug`, `info`, or `error` to better understand what our client is doing under the hood.

- `GOOBS_PROFILE` can be set to enable profiling.
For example, the following will help us find unreleased memory:
Expand Down
2 changes: 1 addition & 1 deletion api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (c *Client) SendRequest(requestBody Params, responseBody interface{}) error
name := requestBody.GetRequestName()
id := uid.String()

c.Log.Printf("[INFO] Sending %s Request with ID %s", name, id)
c.Log.Printf("[TRACE] Sending %s Request with ID %s", name, id)

c.mutex.Lock()
defer c.mutex.Unlock()
Expand Down
77 changes: 62 additions & 15 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Client struct {
requestHeader http.Header
eventSubscriptions int
errors chan error
disconnected chan bool
profiler *profile.Profile
}

Expand Down Expand Up @@ -89,13 +90,19 @@ close when your program terminates or interrupts. But here's a function anyways.
*/
func (c *Client) Disconnect() error {
defer func() {
close(c.errors)
close(c.Opcodes)
close(c.IncomingEvents)
close(c.IncomingResponses)

if c.profiler != nil {
c.Log.Printf("[DEBUG] Ending profiling")
c.profiler.Stop()
}
}()

c.Log.Printf("[DEBUG] Sending disconnect message")
c.disconnected <- true
return c.conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Bye"),
Expand All @@ -113,14 +120,15 @@ func New(host string, opts ...Option) (*Client, error) {
requestHeader: http.Header{"User-Agent": []string{"goobs/" + goobs_version}},
eventSubscriptions: subscriptions.All,
errors: make(chan error),
disconnected: make(chan bool, 1),
Client: &api.Client{
IncomingEvents: make(chan interface{}, 100),
IncomingResponses: make(chan *opcodes.RequestResponse),
Opcodes: make(chan opcodes.Opcode),
ResponseTimeout: 10000,
Log: log.New(
&logutils.LevelFilter{
Levels: []logutils.LogLevel{"DEBUG", "INFO", "ERROR", ""},
Levels: []logutils.LogLevel{"TRACE", "DEBUG", "INFO", "ERROR", ""},
MinLevel: logutils.LogLevel(strings.ToUpper(os.Getenv("GOOBS_LOG"))),
Writer: api.LoggerWithWrite(func(p []byte) (int, error) {
return os.Stderr.WriteString(fmt.Sprintf("\033[36m%s\033[0m", p))
Expand Down Expand Up @@ -199,7 +207,15 @@ func (c *Client) checkProtocolVersion() error {
if err != nil {
return err
}
defer conn.Close()
defer func() {
if err := conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Protocol check"),
); err != nil {
c.Log.Printf("[ERROR] Force closing initial protocol check connection", err)
_ = conn.Close()
}
}()

_ = conn.WriteMessage(
websocket.TextMessage,
Expand Down Expand Up @@ -246,23 +262,55 @@ func (c *Client) handleRawServerMessages(auth chan<- error) {
c.Log.Printf("[INFO] Closing connection: %s", t.Text)
auth <- err
default:
c.errors <- fmt.Errorf("reading raw message: closed: %w", t)
c.Log.Printf("[ERROR] Unhandled close error: %s", t.Text)
select {
case <-c.disconnected:
default:
c.errors <- fmt.Errorf("Unhandled close error: %s", t.Text)
}
}
return
default:
c.errors <- fmt.Errorf("reading raw message from websocket connection: %w", t)
continue
switch t {
case websocket.ErrCloseSent:
// this seems to only happen with highly concurrent clients reading from
// the websocket server simultaneously. but even then it's not really an
// issue, because the connection is already closed!
c.Log.Printf("[DEBUG] Tried to read from closed connection")
return
default:
select {
case <-c.disconnected:
return
default:
c.errors <- fmt.Errorf("reading raw message from websocket connection: %w", t)
continue
}
}
}
}

c.Log.Printf("[DEBUG] Raw server message: %s", raw)
c.Log.Printf("[TRACE] Raw server message: %s", raw)

select {
case <-c.disconnected:
// This might happen if the server sends messages to us
// after we've already disconnected, e.g.:
//
// 1. client sends ToggleRecordPause request
// 2. client gets the appropriate response for it
// 3. client sends disconnect message immediately after
// 4. client gets RecordStateChanged event
c.Log.Printf("[ERROR] Got %s from the server, but we've already disconnected!", raw)
return
default:
opcode, err := opcodes.ParseRawMessage(raw)
if err != nil {
c.errors <- fmt.Errorf("parse raw message: %w", err)
}

opcode, err := opcodes.ParseRawMessage(raw)
if err != nil {
c.errors <- fmt.Errorf("parse raw message: %w", err)
c.Opcodes <- opcode
}

c.Opcodes <- opcode
}
}

Expand Down Expand Up @@ -306,8 +354,7 @@ func (c *Client) handleOpcodes(auth chan<- error) {
// can't imagine we need this

case *opcodes.Event:
c.Log.Printf("[INFO] Got %s Event", val.Type)
c.Log.Printf("[DEBUG] Event Data: %s", val.Data)
c.Log.Printf("[TRACE] Got %s event: %s", val.Type, val.Data)

event := events.GetType(val.Type)

Expand All @@ -328,15 +375,15 @@ func (c *Client) handleOpcodes(auth chan<- error) {
c.writeEvent(event)

case *opcodes.Request:
c.Log.Printf("[DEBUG] Got %s Request with ID %s", val.Type, val.ID)
c.Log.Printf("[TRACE] Got %s Request with ID %s", val.Type, val.ID)

msg := opcodes.Wrap(val).Bytes()
if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
c.errors <- fmt.Errorf("sending Request to server `%s`: %w", msg, err)
}

case *opcodes.RequestResponse:
c.Log.Printf("[INFO] Got %s Response for ID %s (%d)", val.Type, val.ID, val.Status.Code)
c.Log.Printf("[TRACE] Got %s Response for ID %s (%d)", val.Type, val.ID, val.Status.Code)

c.IncomingResponses <- val

Expand Down
91 changes: 58 additions & 33 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,56 +1,81 @@
package goobs_test

import (
"fmt"
"net"
"net/http"
"os"
"sync"
"testing"
"time"

goobs "github.com/andreykaipov/goobs"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)

func Test_client(t *testing.T) {
var err error
_, err = goobs.New(
"localhost:"+os.Getenv("OBS_PORT"),
goobs.WithPassword("wrongpassword"),
goobs.WithRequestHeader(http.Header{"User-Agent": []string{"goobs-e2e/0.0.0"}}),
)
assert.Error(t, err)
assert.IsType(t, &websocket.CloseError{}, err)
assert.Equal(t, err.(*websocket.CloseError).Code, 4009)
_, err = goobs.New(
"localhost:42069",
goobs.WithPassword("wrongpassword"),
goobs.WithRequestHeader(http.Header{"User-Agent": []string{"goobs-e2e/0.0.0"}}),
)
assert.Error(t, err)
assert.IsType(t, &net.OpError{}, err)
t.Run("wrong password", func(t *testing.T) {
_, err := goobs.New(
"localhost:"+os.Getenv("OBS_PORT"),
goobs.WithPassword("wrongpassword"),
goobs.WithRequestHeader(http.Header{"User-Agent": []string{"goobs-e2e/0.0.0"}}),
)
assert.Error(t, err)
assert.IsType(t, &websocket.CloseError{}, err)
assert.Equal(t, err.(*websocket.CloseError).Code, 4009)
})

t.Run("server isn't running", func(t *testing.T) {
_, err := goobs.New(
"localhost:42069",
goobs.WithPassword("wrongpassword"),
goobs.WithRequestHeader(http.Header{"User-Agent": []string{"goobs-e2e/0.0.0"}}),
)
assert.Error(t, err)
assert.IsType(t, &net.OpError{}, err)
})

t.Run("right password", func(t *testing.T) {
client, err := goobs.New(
"localhost:"+os.Getenv("OBS_PORT"),
goobs.WithPassword("goodpassword"),
goobs.WithRequestHeader(http.Header{"User-Agent": []string{"goobs-e2e/0.0.0"}}),
)
assert.NoError(t, err)
t.Cleanup(func() {
client.Disconnect()
})
time.Sleep(1 * time.Second)
})
}

func Test_multi_goroutine(t *testing.T) {
client, err := goobs.New(
"localhost:"+os.Getenv("OBS_PORT"),
goobs.WithPassword("goodpassword"),
goobs.WithRequestHeader(http.Header{"User-Agent": []string{"goobs-e2e/0.0.0"}}),
)
assert.NoError(t, err)
t.Cleanup(func() {
client.Disconnect()
})
wg := sync.WaitGroup{}
for i := 0; i < 1000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := client.Scenes.GetSceneList()
for i := 1; i <= 10; i++ {
t.Run(fmt.Sprintf("goroutine-%d", i), func(t *testing.T) {
t.Parallel()

client, err := goobs.New(
"localhost:"+os.Getenv("OBS_PORT"),
goobs.WithPassword("goodpassword"),
goobs.WithRequestHeader(http.Header{"User-Agent": []string{"goobs-e2e/0.0.0"}}),
)
assert.NoError(t, err)
}()
t.Cleanup(func() {
client.Disconnect()
})
wg := sync.WaitGroup{}
for i := 0; i < 5_000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := client.Scenes.GetSceneList()
assert.NoError(t, err)
}()
}
wg.Wait()
})
}
wg.Wait()
}

func Test_profile(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ setup() {

gotest() {
category="$1"
go test -v -run="^Test_$category$" -coverprofile=cover.out -coverpkg=./... -covermode=$covermode ./...
go test -v -run="^Test_$category$" -count 1 -coverprofile=cover.out -coverpkg=./... -covermode=$covermode ./...
awk 'NR>1' cover.out >>coverall.out
}

Expand All @@ -42,6 +42,7 @@ main() {
categories='
client
multi_goroutine
profile
config
filters
general
Expand Down

0 comments on commit c7baaaf

Please sign in to comment.