Skip to content

Commit

Permalink
feat(pubsub): make batch requests provide results independently
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsos1os committed Aug 8, 2024
1 parent 765f8d5 commit 4964e0a
Showing 1 changed file with 51 additions and 28 deletions.
79 changes: 51 additions & 28 deletions pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,20 +562,37 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) {
// log.Printf("BATCH SIZE %d", batchSize)

go func() {
defer func() {
close(s.waitc)
s.waitc = nil
}()
if s.preReceiveBatchHook != nil {
s.preReceiveBatchHook(batchSize)
}
msgs, err := s.getNextBatch(batchSize)
s.mu.Lock()
defer s.mu.Unlock()
if err != nil {
// Non-retryable error from ReceiveBatch -> permanent error.
s.err = err
} else if len(msgs) > 0 {
s.q = append(s.q, msgs...)
resultChannel := s.getNextBatch(batchSize)
for {
select {
case msgs, ok := <-resultChannel.msgs:
if !ok {
// batch reception finished
return
} else if len(msgs) > 0 {
// messages received from channel
s.mu.Lock()
s.q = append(s.q, msgs...)
s.mu.Unlock()
}
case err := <-resultChannel.err:
// err can receive message only after batch group completes
if err != nil {
// Non-retryable error from ReceiveBatch -> permanent error.
s.mu.Lock()
s.err = err
s.mu.Unlock()
}
return
}
}
close(s.waitc)
s.waitc = nil
}()
}
if len(s.q) > 0 {
Expand Down Expand Up @@ -623,30 +640,33 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) {
})
return m2, nil
}
// A call to ReceiveBatch must be in flight. Wait for it.
waitc := s.waitc
s.mu.Unlock()
s.mu.Unlock() // unlock to allow message or error processing from background goroutine
select {
case <-waitc:
s.mu.Lock()
// Continue to top of loop.
case <-ctx.Done():
s.mu.Lock()
return nil, ctx.Err()
default:
// Continue to top of loop.
s.mu.Lock()
}
}
}

// getNextBatch gets the next batch of messages from the server and returns it.
func (s *Subscription) getNextBatch(nMessages int) ([]*driver.Message, error) {
var mu sync.Mutex
var q []*driver.Message
type batchChannelResult struct {
msgs chan []*driver.Message
err chan error
}

// getNextBatch gets the next batch of messages from the server and returns it.
func (s *Subscription) getNextBatch(nMessages int) *batchChannelResult {
// Split nMessages into batches based on recvBatchOpts; we'll make a
// separate ReceiveBatch call for each batch, and aggregate the results in
// msgs.
batches := batcher.Split(nMessages, s.recvBatchOpts)

result := batchChannelResult{
msgs: make(chan []*driver.Message, len(batches)),
err: make(chan error),
}
g, ctx := errgroup.WithContext(s.backgroundCtx)
for _, maxMessagesInBatch := range batches {
// Make a copy of the loop variable since it will be used by a goroutine.
Expand All @@ -663,16 +683,19 @@ func (s *Subscription) getNextBatch(nMessages int) ([]*driver.Message, error) {
if err != nil {
return wrapError(s.driver, err)
}
mu.Lock()
defer mu.Unlock()
q = append(q, msgs...)
result.msgs <- msgs
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return q, nil
go func() {
// wait on group completion on the background and proper channel closing
if err := g.Wait(); err != nil {
result.err <- err
}
close(result.err)
close(result.msgs)
}()
return &result
}

var errSubscriptionShutdown = gcerr.Newf(gcerr.FailedPrecondition, nil, "pubsub: Subscription has been Shutdown")
Expand Down

0 comments on commit 4964e0a

Please sign in to comment.