From 280f85feb21761f16e416db12a77e30035e62ec2 Mon Sep 17 00:00:00 2001 From: fagongzi Date: Tue, 23 Nov 2021 11:37:26 +0800 Subject: [PATCH] Fix data race (#17) --- session.go | 18 ++++-------------- session_test.go | 5 +++-- timewheel/timewheel.go | 38 ++++++++++++++++++++++--------------- timewheel/timewheel_test.go | 4 ++-- 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/session.go b/session.go index 34374c7..4a7c3b8 100644 --- a/session.go +++ b/session.go @@ -95,8 +95,6 @@ func newBaseIOWithOptions(id uint64, conn net.Conn, opts *options) IOSession { bio := &baseIO{ id: id, opts: opts, - in: buf.NewByteBuf(opts.readBufSize), - out: buf.NewByteBuf(opts.writeBufSize), } if conn != nil { @@ -138,7 +136,6 @@ func (bio *baseIO) Connect(addr string, timeout time.Duration) (bool, error) { return false, fmt.Errorf("the session is closing or connecting is other goroutine") } - bio.resetToRead() conn, err := net.DialTimeout("tcp", addr, timeout) if nil != err { atomic.StoreInt32(&bio.state, stateReadyToConnect) @@ -178,10 +175,7 @@ func (bio *baseIO) Close() error { bio.stopWriteLoop() bio.closeConn() - if bio.disableConnect { - bio.in.Release() - bio.out.Release() - } + bio.out.Release() atomic.StoreInt32(&bio.state, stateReadyToConnect) return nil } @@ -208,7 +202,7 @@ func (bio *baseIO) Read() (interface{}, error) { } if nil != err { - bio.in.Clear() + bio.in.Release() return nil, err } @@ -382,12 +376,6 @@ func (bio *baseIO) closeConn() { } } -func (bio *baseIO) resetToRead() { - bio.in.Clear() - bio.out.Clear() - bio.remoteAddr = "" -} - func (bio *baseIO) getState() int32 { return atomic.LoadInt32(&bio.state) } @@ -396,6 +384,8 @@ func (bio *baseIO) initConn(conn net.Conn) { bio.conn = conn bio.remoteAddr = conn.RemoteAddr().String() bio.localAddr = conn.LocalAddr().String() + bio.in = buf.NewByteBuf(bio.opts.readBufSize) + bio.out = buf.NewByteBuf(bio.opts.writeBufSize) bio.logger = adjustLogger(bio.opts.logger).Named("io-session").With(zap.Uint64("id", bio.id), zap.String("local-address", bio.localAddr), diff --git a/session_test.go b/session_test.go index 5676cd4..cbcee72 100644 --- a/session_test.go +++ b/session_test.go @@ -1,6 +1,7 @@ package goetty import ( + "sync/atomic" "testing" "time" @@ -16,7 +17,7 @@ func TestNormal(t *testing.T) { app := newTestTCPApp(t, func(rs IOSession, msg interface{}, received uint64) error { cs = rs rs.WriteAndFlush(msg) - cnt = received + atomic.StoreUint64(&cnt, received) return nil }) app.Start() @@ -32,7 +33,7 @@ func TestNormal(t *testing.T) { reply, err := client.Read() assert.NoError(t, err) assert.Equal(t, "hello", reply) - assert.Equal(t, uint64(1), cnt) + assert.Equal(t, uint64(1), atomic.LoadUint64(&cnt)) v, err := app.GetSession(cs.ID()) assert.NoError(t, err) diff --git a/timewheel/timewheel.go b/timewheel/timewheel.go index 96e06f1..9e19907 100644 --- a/timewheel/timewheel.go +++ b/timewheel/timewheel.go @@ -22,7 +22,7 @@ const ( const ( // states of the TimeoutWheel - stopped = iota + stopped int32 = iota stopping running ) @@ -125,7 +125,7 @@ type TimeoutWheel struct { buckets []timeoutList freelists []timeoutList - state int + state int32 calloutCh chan timeoutList done chan struct{} } @@ -196,13 +196,21 @@ func NewTimeoutWheel(options ...Option) *TimeoutWheel { return t } +func (t *TimeoutWheel) getState() int32 { + return atomic.LoadInt32(&t.state) +} + +func (t *TimeoutWheel) updateState(state int32) { + atomic.StoreInt32(&t.state, state) +} + // Start starts a stopped timeout wheel. Subsequent calls to Start panic. func (t *TimeoutWheel) Start() { t.lockAllBuckets() defer t.unlockAllBuckets() - for t.state != stopped { - switch t.state { + for t.getState() != stopped { + switch t.getState() { case stopping: t.unlockAllBuckets() <-t.done @@ -212,7 +220,7 @@ func (t *TimeoutWheel) Start() { } } - t.state = running + t.updateState(running) t.done = make(chan struct{}) t.calloutCh = make(chan timeoutList) @@ -224,8 +232,8 @@ func (t *TimeoutWheel) Start() { func (t *TimeoutWheel) Stop() { t.lockAllBuckets() - if t.state == running { - t.state = stopping + if t.getState() == running { + t.updateState(stopping) close(t.calloutCh) for i := range t.buckets { t.freeBucketLocked(t.buckets[i]) @@ -249,7 +257,7 @@ func (t *TimeoutWheel) Schedule( deadline := atomic.LoadUint64(&t.ticks) + uint64(dTicks) timeout := t.getTimeoutLocked(deadline) - if t.state != running { + if t.getState() != running { t.putTimeoutLocked(timeout) timeout.mtx.Unlock() return Timeout{}, ErrSystemStopped @@ -281,22 +289,22 @@ func (t *TimeoutWheel) doTick() { ticker := time.NewTicker(t.tickInterval) for range ticker.C { - atomic.AddUint64(&t.ticks, 1) + v := atomic.AddUint64(&t.ticks, 1) - mtx := t.lockBucket(t.ticks) - if t.state != running { + mtx := t.lockBucket(v) + if t.getState() != running { mtx.Unlock() break } - bucket := &t.buckets[t.ticks&t.bucketMask] + bucket := &t.buckets[v&t.bucketMask] timeout := bucket.head - bucket.lastTick = t.ticks + bucket.lastTick = v // find all the expired timeouts in the bucket. for timeout != nil { next := timeout.next - if timeout.deadline <= t.ticks { + if timeout.deadline <= v { timeout.state = timeoutExpired timeout.removeLocked() timeout.prependLocked(&expiredList) @@ -386,7 +394,7 @@ func (t *TimeoutWheel) doExpired() { } t.lockAllBuckets() - t.state = stopped + t.updateState(stopped) t.unlockAllBuckets() close(t.done) } diff --git a/timewheel/timewheel_test.go b/timewheel/timewheel_test.go index fb65f18..a0fd93a 100644 --- a/timewheel/timewheel_test.go +++ b/timewheel/timewheel_test.go @@ -138,8 +138,8 @@ func TestScheduleExpired(t *testing.T) { ch := make(chan struct{}) tw := NewTimeoutWheel() tw.Stop() - tw.ticks = 0 - tw.state = running + atomic.StoreUint64(&tw.ticks, 0) + tw.updateState(running) tw.buckets[0].lastTick = 1 timeout, _ := tw.Schedule(0, func(_ interface{}) { close(ch) }, nil)