diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index f68b17e03e..077637fb4e 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -565,17 +565,28 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) { if s.preReceiveBatchHook != nil { s.preReceiveBatchHook(batchSize) } - msgs, err := s.getNextBatch(batchSize) - s.mu.Lock() - defer s.mu.Unlock() - if err != nil { - // Non-retryable error from ReceiveBatch -> permanent error. - s.err = err - } else if len(msgs) > 0 { - s.q = append(s.q, msgs...) + resultChannel := s.getNextBatch(batchSize) + for msgsOrError := range resultChannel { + if msgsOrError.msgs != nil && len(msgsOrError.msgs) > 0 { + // messages received from channel + s.mu.Lock() + s.q = append(s.q, msgsOrError.msgs...) + s.mu.Unlock() + // notify that queue should now have messages + s.waitc <- struct{}{} + } else if msgsOrError.err != nil { + // err can receive message only after batch group completes + // Non-retryable error from ReceiveBatch -> permanent error + s.mu.Lock() + s.err = msgsOrError.err + s.mu.Unlock() + } } + // batch reception finished + s.mu.Lock() close(s.waitc) s.waitc = nil + s.mu.Unlock() }() } if len(s.q) > 0 { @@ -625,11 +636,11 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) { } // A call to ReceiveBatch must be in flight. Wait for it. waitc := s.waitc - s.mu.Unlock() + s.mu.Unlock() // unlock to allow message or error processing from background goroutine select { case <-waitc: - s.mu.Lock() // Continue to top of loop. + s.mu.Lock() case <-ctx.Done(): s.mu.Lock() return nil, ctx.Err() @@ -637,16 +648,19 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) { } } -// getNextBatch gets the next batch of messages from the server and returns it. -func (s *Subscription) getNextBatch(nMessages int) ([]*driver.Message, error) { - var mu sync.Mutex - var q []*driver.Message +type msgsOrError struct { + msgs []*driver.Message + err error +} +// getNextBatch gets the next batch of messages from the server. It will return a channel that will itself return the +// messages as they come from each independent batch, or an operation error +func (s *Subscription) getNextBatch(nMessages int) chan msgsOrError { // Split nMessages into batches based on recvBatchOpts; we'll make a // separate ReceiveBatch call for each batch, and aggregate the results in // msgs. batches := batcher.Split(nMessages, s.recvBatchOpts) - + result := make(chan msgsOrError, len(batches)) g, ctx := errgroup.WithContext(s.backgroundCtx) for _, maxMessagesInBatch := range batches { // Make a copy of the loop variable since it will be used by a goroutine. @@ -663,16 +677,18 @@ func (s *Subscription) getNextBatch(nMessages int) ([]*driver.Message, error) { if err != nil { return wrapError(s.driver, err) } - mu.Lock() - defer mu.Unlock() - q = append(q, msgs...) + result <- msgsOrError{msgs: msgs} return nil }) } - if err := g.Wait(); err != nil { - return nil, err - } - return q, nil + go func() { + // wait on group completion on the background and proper channel closing + if err := g.Wait(); err != nil { + result <- msgsOrError{err: err} + } + close(result) + }() + return result } var errSubscriptionShutdown = gcerr.Newf(gcerr.FailedPrecondition, nil, "pubsub: Subscription has been Shutdown") diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 3c42442c68..ddce3147b9 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -20,6 +20,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "testing" "time" @@ -281,6 +282,43 @@ func TestCancelTwoReceives(t *testing.T) { } } +type secondReceiveBlockedDriverSub struct { + driver.Subscription + receiveCounter atomic.Uint64 +} + +func (s *secondReceiveBlockedDriverSub) ReceiveBatch(ctx context.Context, _ int) ([]*driver.Message, error) { + s.receiveCounter.Add(1) + if s.receiveCounter.Load() > 1 { + // wait after 1st request for the context to finish before returning the batch result + <-ctx.Done() + } + msg := &driver.Message{Body: []byte(fmt.Sprintf("message #%d", s.receiveCounter.Load()))} + return []*driver.Message{msg}, nil +} +func (*secondReceiveBlockedDriverSub) CanNack() bool { return false } +func (*secondReceiveBlockedDriverSub) IsRetryable(error) bool { return false } +func (*secondReceiveBlockedDriverSub) Close() error { return nil } + +func TestIndependentBatchReturn(t *testing.T) { + // We want to test the scenario when multiple batch requests are sent, as long as one of them succeeds, it should + // not block the Subscription.Receive result + s := NewSubscription( + &secondReceiveBlockedDriverSub{}, + &batcher.Options{MaxBatchSize: 1, MaxHandlers: 2}, // force 2 batches, by allowing 2 handlers and 1 msg per batch + nil, + ) + // set the false calculated subscription batch size to force 2 batches to be called + s.runningBatchSize = 2 + ctx := context.Background() + defer s.Shutdown(ctx) + _, err := s.Receive(ctx) + if err != nil { + t.Fatal("Receive should not fail", err) + return + } +} + func TestRetryTopic(t *testing.T) { // Test that Send is retried if the driver returns a retryable error. ctx := context.Background()