From d2dc7fb2ce34779dfee298bdea3bcf4c626ae808 Mon Sep 17 00:00:00 2001 From: martonp Date: Wed, 25 Sep 2024 18:14:43 +0200 Subject: [PATCH] AutoReconnect and other fixes --- client/cmd/testbinance/harness_test.go | 3 +- client/comms/wsconn.go | 79 +++++++---- client/comms/wsconn_test.go | 3 +- client/core/core.go | 6 +- client/core/core_test.go | 4 +- client/mm/libxc/binance.go | 175 +++++++++--------------- server/noderelay/cmd/sourcenode/main.go | 3 +- server/noderelay/noderelay_test.go | 3 +- tatanka/tcp/client/client.go | 3 +- tatanka/tcp/server.go | 3 +- 10 files changed, 132 insertions(+), 150 deletions(-) diff --git a/client/cmd/testbinance/harness_test.go b/client/cmd/testbinance/harness_test.go index 3d13f77ec1..fbf5e12a6e 100644 --- a/client/cmd/testbinance/harness_test.go +++ b/client/cmd/testbinance/harness_test.go @@ -97,7 +97,6 @@ func printThing(name string, thing interface{}) { func newWebsocketClient(ctx context.Context, uri string, handler func([]byte)) (comms.WsConn, *dex.ConnectionMaster, error) { wsClient, err := comms.NewWsConn(&comms.WsCfg{ - URL: uri, PingWait: pongWait, Logger: dex.StdOutLogger("W", log.Level()), RawHandler: handler, @@ -106,7 +105,7 @@ func newWebsocketClient(ctx context.Context, uri string, handler func([]byte)) ( if err != nil { return nil, nil, fmt.Errorf("Error creating websocket client: %v", err) } - wsCM := dex.NewConnectionMaster(wsClient) + wsCM := dex.NewConnectionMaster(uri, wsClient) if err = wsCM.ConnectOnce(ctx); err != nil { return nil, nil, fmt.Errorf("Error connecting websocket client: %v", err) } diff --git a/client/comms/wsconn.go b/client/comms/wsconn.go index 96ffabd082..b46b6ece75 100644 --- a/client/comms/wsconn.go +++ b/client/comms/wsconn.go @@ -98,6 +98,7 @@ type WsConn interface { RequestWithTimeout(msg *msgjson.Message, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) error Connect(ctx context.Context) (*sync.WaitGroup, error) MessageSource() <-chan *msgjson.Message + UpdateURL(string) } // When the DEX sends a request to the client, a responseHandler is created @@ -110,14 +111,15 @@ type responseHandler struct { // WsCfg is the configuration struct for initializing a WsConn. type WsCfg struct { - // URL is the websocket endpoint URL. - URL string - // The maximum time in seconds to wait for a ping from the server. This // should be larger than the server's ping interval to allow for network // latency. PingWait time.Duration + // AutoReconnect, if non-nil, will reconnect to the server after each + // interval of the amount of time specified. + AutoReconnect *time.Duration + // The server's certificate. Cert []byte @@ -161,6 +163,7 @@ type wsConn struct { cfg *WsCfg tlsCfg *tls.Config readCh chan *msgjson.Message + URL atomic.Value // string wsMtx sync.Mutex ws *websocket.Conn @@ -176,12 +179,12 @@ type wsConn struct { var _ WsConn = (*wsConn)(nil) // NewWsConn creates a client websocket connection. -func NewWsConn(cfg *WsCfg) (WsConn, error) { +func NewWsConn(rawURL string, cfg *WsCfg) (WsConn, error) { if cfg.PingWait < 0 { return nil, fmt.Errorf("ping wait cannot be negative") } - uri, err := url.Parse(cfg.URL) + uri, err := url.Parse(rawURL) if err != nil { return nil, fmt.Errorf("error parsing URL: %w", err) } @@ -203,14 +206,22 @@ func NewWsConn(cfg *WsCfg) (WsConn, error) { ServerName: uri.Hostname(), } - return &wsConn{ + conn := &wsConn{ cfg: cfg, log: cfg.Logger, tlsCfg: tlsConfig, readCh: make(chan *msgjson.Message, readBuffSize), respHandlers: make(map[uint64]*responseHandler), reconnectCh: make(chan struct{}, 1), - }, nil + } + conn.URL.Store(rawURL) + + return conn, nil +} + +// UpdateURL updates the URL that the connection uses when reconnecting. +func (conn *wsConn) UpdateURL(rawURL string) { + conn.URL.Store(rawURL) } // IsDown indicates if the connection is known to be down. @@ -240,7 +251,7 @@ func (conn *wsConn) connect(ctx context.Context) error { dialer.Proxy = http.ProxyFromEnvironment } - ws, _, err := dialer.DialContext(ctx, conn.cfg.URL, conn.cfg.ConnectHeaders) + ws, _, err := dialer.DialContext(ctx, conn.URL.Load().(string), conn.cfg.ConnectHeaders) if err != nil { if isErrorInvalidCert(err) { conn.setConnectionStatus(InvalidCert) @@ -303,9 +314,9 @@ func (conn *wsConn) connect(ctx context.Context) error { go func() { defer conn.wg.Done() if conn.cfg.RawHandler != nil { - conn.readRaw(ctx) + conn.readRaw(ctx, ws) } else { - conn.read(ctx) + conn.read(ctx, ws) } }() @@ -331,7 +342,7 @@ func (conn *wsConn) handleReadError(err error) { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - conn.log.Errorf("Read timeout on connection to %s.", conn.cfg.URL) + conn.log.Errorf("Read timeout on connection to %s.", conn.URL.Load().(string)) reconnect() return } @@ -372,12 +383,21 @@ func (conn *wsConn) close() { conn.ws.Close() } -func (conn *wsConn) readRaw(ctx context.Context) { +func (conn *wsConn) readRaw(ctx context.Context, ws *websocket.Conn) { + var reconnectTimer <-chan time.Time + if conn.cfg.AutoReconnect != nil { + reconnectTimer = time.After(*conn.cfg.AutoReconnect) + } + for { - // Lock since conn.ws may be set by connect. - conn.wsMtx.Lock() - ws := conn.ws - conn.wsMtx.Unlock() + if conn.cfg.AutoReconnect != nil { + select { + case <-reconnectTimer: + conn.reconnectCh <- struct{}{} + return + default: + } + } // Block until a message is received or an error occurs. _, msgBytes, err := ws.ReadMessage() @@ -389,20 +409,30 @@ func (conn *wsConn) readRaw(ctx context.Context) { conn.handleReadError(err) return } + conn.cfg.RawHandler(msgBytes) } } // read fetches and parses incoming messages for processing. This should be // run as a goroutine. Increment the wg before calling read. -func (conn *wsConn) read(ctx context.Context) { +func (conn *wsConn) read(ctx context.Context, ws *websocket.Conn) { + var reconnectTimer <-chan time.Time + if conn.cfg.AutoReconnect != nil { + reconnectTimer = time.After(*conn.cfg.AutoReconnect) + } + for { - msg := new(msgjson.Message) + if conn.cfg.AutoReconnect != nil { + select { + case <-reconnectTimer: + conn.reconnectCh <- struct{}{} + return + default: + } + } - // Lock since conn.ws may be set by connect. - conn.wsMtx.Lock() - ws := conn.ws - conn.wsMtx.Unlock() + msg := new(msgjson.Message) // The read itself does not require locking since only this goroutine // uses read functions that are not safe for concurrent use. @@ -457,11 +487,11 @@ func (conn *wsConn) keepAlive(ctx context.Context) { return } - conn.log.Infof("Attempting to reconnect to %s...", conn.cfg.URL) + conn.log.Infof("Attempting to reconnect to %s...", conn.URL.Load().(string)) err := conn.connect(ctx) if err != nil { conn.log.Errorf("Reconnect failed. Scheduling reconnect to %s in %.1f seconds.", - conn.cfg.URL, rcInt.Seconds()) + conn.URL.Load().(string), rcInt.Seconds()) time.AfterFunc(rcInt, func() { conn.reconnectCh <- struct{}{} }) @@ -479,7 +509,6 @@ func (conn *wsConn) keepAlive(ctx context.Context) { if conn.cfg.ReconnectSync != nil { conn.cfg.ReconnectSync() } - case <-ctx.Done(): return } diff --git a/client/comms/wsconn_test.go b/client/comms/wsconn_test.go index 95c44b277f..1d6c7817ed 100644 --- a/client/comms/wsconn_test.go +++ b/client/comms/wsconn_test.go @@ -242,12 +242,11 @@ func TestWsConn(t *testing.T) { const pingWait = 500 * time.Millisecond setupWsConn := func(cert []byte) (*wsConn, error) { cfg := &WsCfg{ - URL: "wss://" + host + "/ws", PingWait: pingWait, Cert: cert, Logger: tLogger, } - conn, err := NewWsConn(cfg) + conn, err := NewWsConn("wss://"+host+"/ws", cfg) if err != nil { return nil, err } diff --git a/client/core/core.go b/client/core/core.go index 01e135e2d1..fca15ec900 100644 --- a/client/core/core.go +++ b/client/core/core.go @@ -1471,7 +1471,7 @@ type Core struct { seedGenerationTime uint64 - wsConstructor func(*comms.WsCfg) (comms.WsConn, error) + wsConstructor func(string, *comms.WsCfg) (comms.WsConn, error) newCrypter func([]byte) encrypt.Crypter reCrypter func([]byte, []byte) (encrypt.Crypter, error) latencyQ *wait.TickerQueue @@ -8110,7 +8110,6 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) ( } wsCfg := comms.WsCfg{ - URL: wsURL.String(), PingWait: 20 * time.Second, // larger than server's pingPeriod (server/comms/server.go) Cert: acctInfo.Cert, Logger: c.log.SubLogger(wsURL.String()), @@ -8126,7 +8125,6 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) ( proxyAddr = c.cfg.Onion wsURL.Scheme = "ws" - wsCfg.URL = wsURL.String() } proxy := &socks.Proxy{ Addr: proxyAddr, @@ -8143,7 +8141,7 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) ( } // Create a websocket "connection" to the server. (Don't actually connect.) - conn, err := c.wsConstructor(&wsCfg) + conn, err := c.wsConstructor(wsURL.String(), &wsCfg) if err != nil { return nil, err } diff --git a/client/core/core_test.go b/client/core/core_test.go index 8f6fa9c5b0..4b92b56d9a 100644 --- a/client/core/core_test.go +++ b/client/core/core_test.go @@ -1325,7 +1325,7 @@ func newTestRig() *testRig { blockWaiters: make(map[string]*blockWaiter), sentCommits: make(map[order.Commitment]chan struct{}), tickSched: make(map[order.OrderID]*time.Timer), - wsConstructor: func(*comms.WsCfg) (comms.WsConn, error) { + wsConstructor: func(string, *comms.WsCfg) (comms.WsConn, error) { // This is not very realistic since it doesn't start a fresh // one, and (*Core).connectDEX always gets the same TWebsocket, // which may have been previously "disconnected". @@ -2592,7 +2592,7 @@ func TestConnectDEX(t *testing.T) { // Constructor error. ogConstructor := tCore.wsConstructor - tCore.wsConstructor = func(*comms.WsCfg) (comms.WsConn, error) { + tCore.wsConstructor = func(string, *comms.WsCfg) (comms.WsConn, error) { return nil, tErr } _, err = tCore.connectDEX(ai) diff --git a/client/mm/libxc/binance.go b/client/mm/libxc/binance.go index 020d8a4818..80bd011e23 100644 --- a/client/mm/libxc/binance.go +++ b/client/mm/libxc/binance.go @@ -88,7 +88,7 @@ func newBinanceOrderBook( quoteConversionFactor: quoteConversionFactor, log: log, getSnapshot: getSnapshot, - connectedChan: make(chan bool), + connectedChan: make(chan bool, 4), } } @@ -164,7 +164,7 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error resyncChan := make(chan struct{}, 1) - desync := func() { + desync := func(resync bool) { // clear the sync cache, set the special ID, trigger a book refresh. syncMtx.Lock() defer syncMtx.Unlock() @@ -173,7 +173,9 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error if updateID != updateIDUnsynced { b.synced.Store(false) updateID = updateIDUnsynced - resyncChan <- struct{}{} + if resync { + resyncChan <- struct{}{} + } } } @@ -268,7 +270,7 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error case update := <-b.updateQueue: if !processUpdate(update) { b.log.Tracef("Bad %s update with ID %d", b.mktID, update.LastUpdateID) - desync() + desync(true) } case <-ctx.Done(): return @@ -288,13 +290,10 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error select { case <-retry: case <-resyncChan: - if retry != nil { // don't hammer - continue - } case connected := <-b.connectedChan: if !connected { - b.log.Debugf("Unsyncing %s orderbook due to disconnect.", b.mktID, retryFrequency) - desync() + b.log.Debugf("Unsyncing %s orderbook due to disconnect.", b.mktID) + desync(false) retry = nil continue } @@ -307,7 +306,7 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error retry = nil } else { b.log.Infof("Failed to sync %s orderbook. Trying again in %s", b.mktID, retryFrequency) - desync() // Clears the syncCache + desync(false) // Clears the syncCache retry = time.After(retryFrequency) } } @@ -1419,8 +1418,7 @@ func (bnc *binance) getUserDataStream(ctx context.Context) (err error) { return nil, err } - conn, err := comms.NewWsConn(&comms.WsCfg{ - URL: bnc.wsURL + "/ws/" + listenKey, + conn, err := comms.NewWsConn(bnc.wsURL+"/ws/"+listenKey, &comms.WsCfg{ PingWait: time.Minute * 4, EchoPingData: true, ReconnectSync: func() { @@ -1647,12 +1645,13 @@ func (bnc *binance) subscribeToAdditionalMarketDataStream(ctx context.Context, b bnc.books[mktID] = book book.sync(ctx) + bnc.marketStream.UpdateURL(bnc.marketStreamsURL()) + return nil } +// bnc.booksMtx MUST be read locked when calling this function. func (bnc *binance) streams() []string { - bnc.booksMtx.RLock() - defer bnc.booksMtx.RUnlock() streamNames := make([]string, 0, len(bnc.books)) for mktID := range bnc.books { streamNames = append(streamNames, marketDataStreamID(mktID)) @@ -1660,13 +1659,22 @@ func (bnc *binance) streams() []string { return streamNames } +// bnc.booksMtx MUST be read locked when calling this function. +func (bnc *binance) marketStreamsURL() string { + return fmt.Sprintf("%s/stream?streams=%s", bnc.wsURL, strings.Join(bnc.streams(), "/")) +} + // checkSubs will query binance for current market subscriptions and compare // that to what subscriptions we should have. If there is a discrepancy a // warning is logged and the market subbed or unsubbed. func (bnc *binance) checkSubs(ctx context.Context) error { bnc.marketStreamMtx.Lock() defer bnc.marketStreamMtx.Unlock() + + bnc.booksMtx.RLock() streams := bnc.streams() + bnc.booksMtx.RUnlock() + if len(streams) == 0 { return nil } @@ -1746,61 +1754,9 @@ out: } // connectToMarketDataStream is called when the first market is subscribed to. -// It creates a connection to the market data stream and starts a goroutine -// to reconnect every 12 hours, as Binance will close the stream every 24 -// hours. Additional markets are subscribed to by calling +// Additional markets are subscribed to by calling // subscribeToAdditionalMarketDataStream. func (bnc *binance) connectToMarketDataStream(ctx context.Context, baseID, quoteID uint32) error { - reconnectC := make(chan struct{}) - - newConnection := func() (comms.WsConn, *dex.ConnectionMaster, error) { - addr := fmt.Sprintf("%s/stream?streams=%s", bnc.wsURL, strings.Join(bnc.streams(), "/")) - // Need to send key but not signature - connectEventFunc := func(cs comms.ConnectionStatus) { - if cs != comms.Disconnected && cs != comms.Connected { - return - } - // If disconnected, set all books to unsynced so bots - // will not place new orders. - connected := cs == comms.Connected - bnc.booksMtx.RLock() - defer bnc.booksMtx.RLock() - for _, b := range bnc.books { - b.connectedChan <- connected - } - } - conn, err := comms.NewWsConn(&comms.WsCfg{ - URL: addr, - // Binance Docs: The websocket server will send a ping frame every 3 - // minutes. If the websocket server does not receive a pong frame - // back from the connection within a 10 minute period, the connection - // will be disconnected. Unsolicited pong frames are allowed. - PingWait: time.Minute * 4, - EchoPingData: true, - ReconnectSync: func() { - bnc.log.Debugf("Binance reconnected") - select { - case reconnectC <- struct{}{}: - default: - } - }, - ConnectEventFunc: connectEventFunc, - Logger: bnc.log.SubLogger("BNCBOOK"), - RawHandler: bnc.handleMarketDataNote, - }) - if err != nil { - return nil, nil, err - } - - bnc.marketStream = conn - cm := dex.NewConnectionMaster(conn) - if err = cm.ConnectOnce(ctx); err != nil { - return nil, nil, fmt.Errorf("websocketHandler remote connect: %v", err) - } - - return conn, cm, nil - } - // Add the initial book to the books map baseCfg, quoteCfg, err := bncAssetCfgs(baseID, quoteID) if err != nil { @@ -1813,60 +1769,64 @@ func (bnc *binance) connectToMarketDataStream(ctx context.Context, baseID, quote } book := newBinanceOrderBook(baseCfg.conversionFactor, quoteCfg.conversionFactor, mktID, getSnapshot, bnc.log) bnc.books[mktID] = book + marketStreamsURL := bnc.marketStreamsURL() bnc.booksMtx.Unlock() - // Create initial connection to the market data stream - conn, cm, err := newConnection() + // Need to send key but not signature + connectEventFunc := func(cs comms.ConnectionStatus) { + if cs != comms.Disconnected && cs != comms.Connected { + return + } + + // If disconnected, set all books to unsynced so bots + // will not place new orders. + connected := cs == comms.Connected + + bnc.booksMtx.RLock() + defer bnc.booksMtx.RUnlock() + + for _, b := range bnc.books { + select { + case b.connectedChan <- connected: + default: // don't block + } + } + } + + reconnectInterval := time.Hour * 12 + conn, err := comms.NewWsConn(marketStreamsURL, &comms.WsCfg{ + // Binance Docs: The websocket server will send a ping frame every 3 + // minutes. If the websocket server does not receive a pong frame + // back from the connection within a 10 minute period, the connection + // will be disconnected. Unsolicited pong frames are allowed. + PingWait: time.Minute * 4, + EchoPingData: true, + ReconnectSync: func() { + bnc.log.Debugf("Binance reconnected") + }, + ConnectEventFunc: connectEventFunc, + Logger: bnc.log.SubLogger("BNCBOOK"), + RawHandler: bnc.handleMarketDataNote, + AutoReconnect: &reconnectInterval, + }) if err != nil { - return fmt.Errorf("error connecting to market data stream : %v", err) + return err + } + + cm := dex.NewConnectionMaster(conn) + if err = cm.ConnectOnce(ctx); err != nil { + return fmt.Errorf("websocketHandler remote connect: %v", err) } bnc.marketStream = conn book.sync(ctx) - // Start a goroutine to reconnect every 12 hours go func() { - reconnect := func() error { - bnc.marketStreamMtx.Lock() - defer bnc.marketStreamMtx.Unlock() - - oldCm := cm - conn, cm, err = newConnection() - if err != nil { - return err - } - - if oldCm != nil { - oldCm.Disconnect() - } - - bnc.marketStream = conn - return nil - } - checkSubsInterval := time.Minute checkSubs := time.After(checkSubsInterval) - reconnectTimer := time.After(time.Hour * 12) for { select { - case <-reconnectC: - if err := reconnect(); err != nil { - bnc.log.Errorf("Error reconnecting: %v", err) - reconnectTimer = time.After(time.Second * 30) - checkSubs = make(<-chan time.Time) - continue - } - checkSubs = time.After(checkSubsInterval) - case <-reconnectTimer: - if err := reconnect(); err != nil { - bnc.log.Errorf("Error refreshing connection: %v", err) - reconnectTimer = time.After(time.Second * 30) - checkSubs = make(<-chan time.Time) - continue - } - reconnectTimer = time.After(time.Hour * 12) - checkSubs = time.After(checkSubsInterval) case <-checkSubs: if err := bnc.checkSubs(ctx); err != nil { bnc.log.Errorf("Error checking subscriptions: %v", err) @@ -1934,6 +1894,7 @@ func (bnc *binance) UnsubscribeMarket(baseID, quoteID uint32) (err error) { unsubscribe = true delete(bnc.books, mktID) closer = book.cm + bnc.marketStream.UpdateURL(bnc.marketStreamsURL()) } book.mtx.Unlock() diff --git a/server/noderelay/cmd/sourcenode/main.go b/server/noderelay/cmd/sourcenode/main.go index 50a909b874..15399d77d2 100644 --- a/server/noderelay/cmd/sourcenode/main.go +++ b/server/noderelay/cmd/sourcenode/main.go @@ -196,8 +196,7 @@ func mainErr() (err error) { } } - cl, err = comms.NewWsConn(&comms.WsCfg{ - URL: "wss://" + nexusAddr, + cl, err = comms.NewWsConn("wss://"+nexusAddr, &comms.WsCfg{ PingWait: noderelay.PingPeriod * 2, Cert: certB, // On a disconnect, wsConn will attempt to reconnect immediately. If diff --git a/server/noderelay/noderelay_test.go b/server/noderelay/noderelay_test.go index bd0d7e65b8..a20f01a8d7 100644 --- a/server/noderelay/noderelay_test.go +++ b/server/noderelay/noderelay_test.go @@ -73,8 +73,7 @@ func TestNexus(t *testing.T) { var rawHandlerErr error var cl comms.WsConn - cl, err = comms.NewWsConn(&comms.WsCfg{ - URL: "wss://" + addr, + cl, err = comms.NewWsConn("wss://"+addr, &comms.WsCfg{ PingWait: 20 * time.Second, Cert: certB, ReconnectSync: func() { diff --git a/tatanka/tcp/client/client.go b/tatanka/tcp/client/client.go index be59aef2b3..cf5e0224df 100644 --- a/tatanka/tcp/client/client.go +++ b/tatanka/tcp/client/client.go @@ -54,8 +54,7 @@ func New(cfg *Config) (*Client, error) { func (c *Client) Connect(ctx context.Context) (_ *sync.WaitGroup, err error) { c.ctx = ctx - if c.cl, err = comms.NewWsConn(&comms.WsCfg{ - URL: c.url.String(), + if c.cl, err = comms.NewWsConn(c.url.String(), &comms.WsCfg{ PingWait: 20 * time.Second, Cert: c.cert, ReconnectSync: func() { diff --git a/tatanka/tcp/server.go b/tatanka/tcp/server.go index 1f925f1a77..3953fdaa2c 100644 --- a/tatanka/tcp/server.go +++ b/tatanka/tcp/server.go @@ -161,8 +161,7 @@ func (s *Server) ConnectBootNode( uri.Path = "/ws" } - cl, err := clientcomms.NewWsConn(&clientcomms.WsCfg{ - URL: uri.String(), + cl, err := clientcomms.NewWsConn(uri.String(), &clientcomms.WsCfg{ PingWait: 20 * time.Second, Cert: n.Cert, ConnectEventFunc: func(status clientcomms.ConnectionStatus) {