Skip to content

Commit

Permalink
feat(pubsub): refactor channel result flow
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsos1os committed Aug 8, 2024
1 parent 4964e0a commit 8c3bebc
Showing 1 changed file with 34 additions and 41 deletions.
75 changes: 34 additions & 41 deletions pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,37 +562,31 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) {
// log.Printf("BATCH SIZE %d", batchSize)

go func() {
defer func() {
close(s.waitc)
s.waitc = nil
}()
if s.preReceiveBatchHook != nil {
s.preReceiveBatchHook(batchSize)
}
resultChannel := s.getNextBatch(batchSize)
for {
select {
case msgs, ok := <-resultChannel.msgs:
if !ok {
// batch reception finished
return
} else if len(msgs) > 0 {
// messages received from channel
s.mu.Lock()
s.q = append(s.q, msgs...)
s.mu.Unlock()
}
case err := <-resultChannel.err:
for msgsOrError := range resultChannel {
if msgsOrError.msgs != nil && len(msgsOrError.msgs) > 0 {
// messages received from channel
s.mu.Lock()
s.q = append(s.q, msgsOrError.msgs...)
s.mu.Unlock()
// notify that queue should now have messages
s.waitc <- struct{}{}
} else if msgsOrError.err != nil {
// err can receive message only after batch group completes
if err != nil {
// Non-retryable error from ReceiveBatch -> permanent error.
s.mu.Lock()
s.err = err
s.mu.Unlock()
}
return
// Non-retryable error from ReceiveBatch -> permanent error
s.mu.Lock()
s.err = msgsOrError.err
s.mu.Unlock()
}
}
// batch reception finished
s.mu.Lock()
close(s.waitc)
s.waitc = nil
s.mu.Unlock()
}()
}
if len(s.q) > 0 {
Expand Down Expand Up @@ -640,33 +634,33 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) {
})
return m2, nil
}
// A call to ReceiveBatch must be in flight. Wait for it.
waitc := s.waitc
s.mu.Unlock() // unlock to allow message or error processing from background goroutine
select {
case <-waitc:
// Continue to top of loop.
s.mu.Lock()
case <-ctx.Done():
s.mu.Lock()
return nil, ctx.Err()
default:
// Continue to top of loop.
s.mu.Lock()
}
}
}

type batchChannelResult struct {
msgs chan []*driver.Message
err chan error
type msgsOrError struct {
msgs []*driver.Message
err error
}

// getNextBatch gets the next batch of messages from the server and returns it.
func (s *Subscription) getNextBatch(nMessages int) *batchChannelResult {
// getNextBatch gets the next batch of messages from the server. It will return a channel that will itself return the
// messages as they come from each independent batch, or an operation error
func (s *Subscription) getNextBatch(nMessages int) chan msgsOrError {
// Split nMessages into batches based on recvBatchOpts; we'll make a
// separate ReceiveBatch call for each batch, and aggregate the results in
// msgs.
batches := batcher.Split(nMessages, s.recvBatchOpts)
result := batchChannelResult{
msgs: make(chan []*driver.Message, len(batches)),
err: make(chan error),
}
result := make(chan msgsOrError, len(batches))
g, ctx := errgroup.WithContext(s.backgroundCtx)
for _, maxMessagesInBatch := range batches {
// Make a copy of the loop variable since it will be used by a goroutine.
Expand All @@ -683,19 +677,18 @@ func (s *Subscription) getNextBatch(nMessages int) *batchChannelResult {
if err != nil {
return wrapError(s.driver, err)
}
result.msgs <- msgs
result <- msgsOrError{msgs: msgs}
return nil
})
}
go func() {
// wait on group completion on the background and proper channel closing
if err := g.Wait(); err != nil {
result.err <- err
result <- msgsOrError{err: err}
}
close(result.err)
close(result.msgs)
close(result)
}()
return &result
return result
}

var errSubscriptionShutdown = gcerr.Newf(gcerr.FailedPrecondition, nil, "pubsub: Subscription has been Shutdown")
Expand Down

0 comments on commit 8c3bebc

Please sign in to comment.