Skip to content

Commit

Permalink
AutoReconnect and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
martonp authored and JoeGruffins committed Sep 26, 2024
1 parent 23f73db commit d2dc7fb
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 150 deletions.
3 changes: 1 addition & 2 deletions client/cmd/testbinance/harness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ func printThing(name string, thing interface{}) {

func newWebsocketClient(ctx context.Context, uri string, handler func([]byte)) (comms.WsConn, *dex.ConnectionMaster, error) {
wsClient, err := comms.NewWsConn(&comms.WsCfg{
URL: uri,
PingWait: pongWait,
Logger: dex.StdOutLogger("W", log.Level()),
RawHandler: handler,
Expand All @@ -106,7 +105,7 @@ func newWebsocketClient(ctx context.Context, uri string, handler func([]byte)) (
if err != nil {
return nil, nil, fmt.Errorf("Error creating websocket client: %v", err)
}
wsCM := dex.NewConnectionMaster(wsClient)
wsCM := dex.NewConnectionMaster(uri, wsClient)
if err = wsCM.ConnectOnce(ctx); err != nil {
return nil, nil, fmt.Errorf("Error connecting websocket client: %v", err)
}
Expand Down
79 changes: 54 additions & 25 deletions client/comms/wsconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ type WsConn interface {
RequestWithTimeout(msg *msgjson.Message, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) error
Connect(ctx context.Context) (*sync.WaitGroup, error)
MessageSource() <-chan *msgjson.Message
UpdateURL(string)
}

// When the DEX sends a request to the client, a responseHandler is created
Expand All @@ -110,14 +111,15 @@ type responseHandler struct {

// WsCfg is the configuration struct for initializing a WsConn.
type WsCfg struct {
// URL is the websocket endpoint URL.
URL string

// The maximum time in seconds to wait for a ping from the server. This
// should be larger than the server's ping interval to allow for network
// latency.
PingWait time.Duration

// AutoReconnect, if non-nil, will reconnect to the server after each
// interval of the amount of time specified.
AutoReconnect *time.Duration

// The server's certificate.
Cert []byte

Expand Down Expand Up @@ -161,6 +163,7 @@ type wsConn struct {
cfg *WsCfg
tlsCfg *tls.Config
readCh chan *msgjson.Message
URL atomic.Value // string

wsMtx sync.Mutex
ws *websocket.Conn
Expand All @@ -176,12 +179,12 @@ type wsConn struct {
var _ WsConn = (*wsConn)(nil)

// NewWsConn creates a client websocket connection.
func NewWsConn(cfg *WsCfg) (WsConn, error) {
func NewWsConn(rawURL string, cfg *WsCfg) (WsConn, error) {
if cfg.PingWait < 0 {
return nil, fmt.Errorf("ping wait cannot be negative")
}

uri, err := url.Parse(cfg.URL)
uri, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("error parsing URL: %w", err)
}
Expand All @@ -203,14 +206,22 @@ func NewWsConn(cfg *WsCfg) (WsConn, error) {
ServerName: uri.Hostname(),
}

return &wsConn{
conn := &wsConn{
cfg: cfg,
log: cfg.Logger,
tlsCfg: tlsConfig,
readCh: make(chan *msgjson.Message, readBuffSize),
respHandlers: make(map[uint64]*responseHandler),
reconnectCh: make(chan struct{}, 1),
}, nil
}
conn.URL.Store(rawURL)

return conn, nil
}

// UpdateURL updates the URL that the connection uses when reconnecting.
func (conn *wsConn) UpdateURL(rawURL string) {
conn.URL.Store(rawURL)
}

// IsDown indicates if the connection is known to be down.
Expand Down Expand Up @@ -240,7 +251,7 @@ func (conn *wsConn) connect(ctx context.Context) error {
dialer.Proxy = http.ProxyFromEnvironment
}

ws, _, err := dialer.DialContext(ctx, conn.cfg.URL, conn.cfg.ConnectHeaders)
ws, _, err := dialer.DialContext(ctx, conn.URL.Load().(string), conn.cfg.ConnectHeaders)
if err != nil {
if isErrorInvalidCert(err) {
conn.setConnectionStatus(InvalidCert)
Expand Down Expand Up @@ -303,9 +314,9 @@ func (conn *wsConn) connect(ctx context.Context) error {
go func() {
defer conn.wg.Done()
if conn.cfg.RawHandler != nil {
conn.readRaw(ctx)
conn.readRaw(ctx, ws)
} else {
conn.read(ctx)
conn.read(ctx, ws)
}
}()

Expand All @@ -331,7 +342,7 @@ func (conn *wsConn) handleReadError(err error) {

var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
conn.log.Errorf("Read timeout on connection to %s.", conn.cfg.URL)
conn.log.Errorf("Read timeout on connection to %s.", conn.URL.Load().(string))
reconnect()
return
}
Expand Down Expand Up @@ -372,12 +383,21 @@ func (conn *wsConn) close() {
conn.ws.Close()
}

func (conn *wsConn) readRaw(ctx context.Context) {
func (conn *wsConn) readRaw(ctx context.Context, ws *websocket.Conn) {
var reconnectTimer <-chan time.Time
if conn.cfg.AutoReconnect != nil {
reconnectTimer = time.After(*conn.cfg.AutoReconnect)
}

for {
// Lock since conn.ws may be set by connect.
conn.wsMtx.Lock()
ws := conn.ws
conn.wsMtx.Unlock()
if conn.cfg.AutoReconnect != nil {
select {
case <-reconnectTimer:
conn.reconnectCh <- struct{}{}
return
default:
}
}

// Block until a message is received or an error occurs.
_, msgBytes, err := ws.ReadMessage()
Expand All @@ -389,20 +409,30 @@ func (conn *wsConn) readRaw(ctx context.Context) {
conn.handleReadError(err)
return
}

conn.cfg.RawHandler(msgBytes)
}
}

// read fetches and parses incoming messages for processing. This should be
// run as a goroutine. Increment the wg before calling read.
func (conn *wsConn) read(ctx context.Context) {
func (conn *wsConn) read(ctx context.Context, ws *websocket.Conn) {
var reconnectTimer <-chan time.Time
if conn.cfg.AutoReconnect != nil {
reconnectTimer = time.After(*conn.cfg.AutoReconnect)
}

for {
msg := new(msgjson.Message)
if conn.cfg.AutoReconnect != nil {
select {
case <-reconnectTimer:
conn.reconnectCh <- struct{}{}
return
default:
}
}

// Lock since conn.ws may be set by connect.
conn.wsMtx.Lock()
ws := conn.ws
conn.wsMtx.Unlock()
msg := new(msgjson.Message)

// The read itself does not require locking since only this goroutine
// uses read functions that are not safe for concurrent use.
Expand Down Expand Up @@ -457,11 +487,11 @@ func (conn *wsConn) keepAlive(ctx context.Context) {
return
}

conn.log.Infof("Attempting to reconnect to %s...", conn.cfg.URL)
conn.log.Infof("Attempting to reconnect to %s...", conn.URL.Load().(string))
err := conn.connect(ctx)
if err != nil {
conn.log.Errorf("Reconnect failed. Scheduling reconnect to %s in %.1f seconds.",
conn.cfg.URL, rcInt.Seconds())
conn.URL.Load().(string), rcInt.Seconds())
time.AfterFunc(rcInt, func() {
conn.reconnectCh <- struct{}{}
})
Expand All @@ -479,7 +509,6 @@ func (conn *wsConn) keepAlive(ctx context.Context) {
if conn.cfg.ReconnectSync != nil {
conn.cfg.ReconnectSync()
}

case <-ctx.Done():
return
}
Expand Down
3 changes: 1 addition & 2 deletions client/comms/wsconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,11 @@ func TestWsConn(t *testing.T) {
const pingWait = 500 * time.Millisecond
setupWsConn := func(cert []byte) (*wsConn, error) {
cfg := &WsCfg{
URL: "wss://" + host + "/ws",
PingWait: pingWait,
Cert: cert,
Logger: tLogger,
}
conn, err := NewWsConn(cfg)
conn, err := NewWsConn("wss://"+host+"/ws", cfg)
if err != nil {
return nil, err
}
Expand Down
6 changes: 2 additions & 4 deletions client/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,7 @@ type Core struct {

seedGenerationTime uint64

wsConstructor func(*comms.WsCfg) (comms.WsConn, error)
wsConstructor func(string, *comms.WsCfg) (comms.WsConn, error)
newCrypter func([]byte) encrypt.Crypter
reCrypter func([]byte, []byte) (encrypt.Crypter, error)
latencyQ *wait.TickerQueue
Expand Down Expand Up @@ -8110,7 +8110,6 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) (
}

wsCfg := comms.WsCfg{
URL: wsURL.String(),
PingWait: 20 * time.Second, // larger than server's pingPeriod (server/comms/server.go)
Cert: acctInfo.Cert,
Logger: c.log.SubLogger(wsURL.String()),
Expand All @@ -8126,7 +8125,6 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) (
proxyAddr = c.cfg.Onion

wsURL.Scheme = "ws"
wsCfg.URL = wsURL.String()
}
proxy := &socks.Proxy{
Addr: proxyAddr,
Expand All @@ -8143,7 +8141,7 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) (
}

// Create a websocket "connection" to the server. (Don't actually connect.)
conn, err := c.wsConstructor(&wsCfg)
conn, err := c.wsConstructor(wsURL.String(), &wsCfg)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions client/core/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ func newTestRig() *testRig {
blockWaiters: make(map[string]*blockWaiter),
sentCommits: make(map[order.Commitment]chan struct{}),
tickSched: make(map[order.OrderID]*time.Timer),
wsConstructor: func(*comms.WsCfg) (comms.WsConn, error) {
wsConstructor: func(string, *comms.WsCfg) (comms.WsConn, error) {
// This is not very realistic since it doesn't start a fresh
// one, and (*Core).connectDEX always gets the same TWebsocket,
// which may have been previously "disconnected".
Expand Down Expand Up @@ -2592,7 +2592,7 @@ func TestConnectDEX(t *testing.T) {

// Constructor error.
ogConstructor := tCore.wsConstructor
tCore.wsConstructor = func(*comms.WsCfg) (comms.WsConn, error) {
tCore.wsConstructor = func(string, *comms.WsCfg) (comms.WsConn, error) {
return nil, tErr
}
_, err = tCore.connectDEX(ai)
Expand Down
Loading

0 comments on commit d2dc7fb

Please sign in to comment.