Skip to content

Commit

Permalink
Fix data race (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxu19830126 authored Nov 23, 2021
1 parent b353cad commit 280f85f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 33 deletions.
18 changes: 4 additions & 14 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ func newBaseIOWithOptions(id uint64, conn net.Conn, opts *options) IOSession {
bio := &baseIO{
id: id,
opts: opts,
in: buf.NewByteBuf(opts.readBufSize),
out: buf.NewByteBuf(opts.writeBufSize),
}

if conn != nil {
Expand Down Expand Up @@ -138,7 +136,6 @@ func (bio *baseIO) Connect(addr string, timeout time.Duration) (bool, error) {
return false, fmt.Errorf("the session is closing or connecting is other goroutine")
}

bio.resetToRead()
conn, err := net.DialTimeout("tcp", addr, timeout)
if nil != err {
atomic.StoreInt32(&bio.state, stateReadyToConnect)
Expand Down Expand Up @@ -178,10 +175,7 @@ func (bio *baseIO) Close() error {

bio.stopWriteLoop()
bio.closeConn()
if bio.disableConnect {
bio.in.Release()
bio.out.Release()
}
bio.out.Release()
atomic.StoreInt32(&bio.state, stateReadyToConnect)
return nil
}
Expand All @@ -208,7 +202,7 @@ func (bio *baseIO) Read() (interface{}, error) {
}

if nil != err {
bio.in.Clear()
bio.in.Release()
return nil, err
}

Expand Down Expand Up @@ -382,12 +376,6 @@ func (bio *baseIO) closeConn() {
}
}

func (bio *baseIO) resetToRead() {
bio.in.Clear()
bio.out.Clear()
bio.remoteAddr = ""
}

func (bio *baseIO) getState() int32 {
return atomic.LoadInt32(&bio.state)
}
Expand All @@ -396,6 +384,8 @@ func (bio *baseIO) initConn(conn net.Conn) {
bio.conn = conn
bio.remoteAddr = conn.RemoteAddr().String()
bio.localAddr = conn.LocalAddr().String()
bio.in = buf.NewByteBuf(bio.opts.readBufSize)
bio.out = buf.NewByteBuf(bio.opts.writeBufSize)

bio.logger = adjustLogger(bio.opts.logger).Named("io-session").With(zap.Uint64("id", bio.id),
zap.String("local-address", bio.localAddr),
Expand Down
5 changes: 3 additions & 2 deletions session_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goetty

import (
"sync/atomic"
"testing"
"time"

Expand All @@ -16,7 +17,7 @@ func TestNormal(t *testing.T) {
app := newTestTCPApp(t, func(rs IOSession, msg interface{}, received uint64) error {
cs = rs
rs.WriteAndFlush(msg)
cnt = received
atomic.StoreUint64(&cnt, received)
return nil
})
app.Start()
Expand All @@ -32,7 +33,7 @@ func TestNormal(t *testing.T) {
reply, err := client.Read()
assert.NoError(t, err)
assert.Equal(t, "hello", reply)
assert.Equal(t, uint64(1), cnt)
assert.Equal(t, uint64(1), atomic.LoadUint64(&cnt))

v, err := app.GetSession(cs.ID())
assert.NoError(t, err)
Expand Down
38 changes: 23 additions & 15 deletions timewheel/timewheel.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (

const (
// states of the TimeoutWheel
stopped = iota
stopped int32 = iota
stopping
running
)
Expand Down Expand Up @@ -125,7 +125,7 @@ type TimeoutWheel struct {
buckets []timeoutList
freelists []timeoutList

state int
state int32
calloutCh chan timeoutList
done chan struct{}
}
Expand Down Expand Up @@ -196,13 +196,21 @@ func NewTimeoutWheel(options ...Option) *TimeoutWheel {
return t
}

func (t *TimeoutWheel) getState() int32 {
return atomic.LoadInt32(&t.state)
}

func (t *TimeoutWheel) updateState(state int32) {
atomic.StoreInt32(&t.state, state)
}

// Start starts a stopped timeout wheel. Subsequent calls to Start panic.
func (t *TimeoutWheel) Start() {
t.lockAllBuckets()
defer t.unlockAllBuckets()

for t.state != stopped {
switch t.state {
for t.getState() != stopped {
switch t.getState() {
case stopping:
t.unlockAllBuckets()
<-t.done
Expand All @@ -212,7 +220,7 @@ func (t *TimeoutWheel) Start() {
}
}

t.state = running
t.updateState(running)
t.done = make(chan struct{})
t.calloutCh = make(chan timeoutList)

Expand All @@ -224,8 +232,8 @@ func (t *TimeoutWheel) Start() {
func (t *TimeoutWheel) Stop() {
t.lockAllBuckets()

if t.state == running {
t.state = stopping
if t.getState() == running {
t.updateState(stopping)
close(t.calloutCh)
for i := range t.buckets {
t.freeBucketLocked(t.buckets[i])
Expand All @@ -249,7 +257,7 @@ func (t *TimeoutWheel) Schedule(
deadline := atomic.LoadUint64(&t.ticks) + uint64(dTicks)
timeout := t.getTimeoutLocked(deadline)

if t.state != running {
if t.getState() != running {
t.putTimeoutLocked(timeout)
timeout.mtx.Unlock()
return Timeout{}, ErrSystemStopped
Expand Down Expand Up @@ -281,22 +289,22 @@ func (t *TimeoutWheel) doTick() {

ticker := time.NewTicker(t.tickInterval)
for range ticker.C {
atomic.AddUint64(&t.ticks, 1)
v := atomic.AddUint64(&t.ticks, 1)

mtx := t.lockBucket(t.ticks)
if t.state != running {
mtx := t.lockBucket(v)
if t.getState() != running {
mtx.Unlock()
break
}

bucket := &t.buckets[t.ticks&t.bucketMask]
bucket := &t.buckets[v&t.bucketMask]
timeout := bucket.head
bucket.lastTick = t.ticks
bucket.lastTick = v

// find all the expired timeouts in the bucket.
for timeout != nil {
next := timeout.next
if timeout.deadline <= t.ticks {
if timeout.deadline <= v {
timeout.state = timeoutExpired
timeout.removeLocked()
timeout.prependLocked(&expiredList)
Expand Down Expand Up @@ -386,7 +394,7 @@ func (t *TimeoutWheel) doExpired() {
}

t.lockAllBuckets()
t.state = stopped
t.updateState(stopped)
t.unlockAllBuckets()
close(t.done)
}
Expand Down
4 changes: 2 additions & 2 deletions timewheel/timewheel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func TestScheduleExpired(t *testing.T) {
ch := make(chan struct{})
tw := NewTimeoutWheel()
tw.Stop()
tw.ticks = 0
tw.state = running
atomic.StoreUint64(&tw.ticks, 0)
tw.updateState(running)

tw.buckets[0].lastTick = 1
timeout, _ := tw.Schedule(0, func(_ interface{}) { close(ch) }, nil)
Expand Down

0 comments on commit 280f85f

Please sign in to comment.