Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug related with context propagation (CoroutineServerInterceptor) #4894

Merged
merged 14 commits into from
Jun 14, 2023
1 change: 1 addition & 0 deletions grpc-kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ dependencies {
implementation(project(":grpc"))

implementation(libs.grpc.kotlin)
implementation(libs.kotlin.reflect)
implementation(libs.kotlin.coroutines.jdk8)

testImplementation(libs.kotlin.coroutines.test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@ import com.linecorp.armeria.common.annotation.UnstableApi
import com.linecorp.armeria.internal.common.kotlin.ArmeriaRequestCoroutineContext
import com.linecorp.armeria.internal.server.grpc.AbstractServerCall
import com.linecorp.armeria.server.grpc.AsyncServerInterceptor
import io.grpc.Context
import io.grpc.Metadata
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.ServerInterceptor
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.GlobalScope
import io.grpc.kotlin.CoroutineContextServerInterceptor
import io.grpc.kotlin.GrpcContextElement
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.future.future
import java.util.concurrent.CompletableFuture
import kotlin.coroutines.CoroutineContext
import kotlin.reflect.full.companionObject
import kotlin.reflect.full.companionObjectInstance
import kotlin.reflect.full.memberProperties

/**
* A [ServerInterceptor] that is able to suspend the interceptor without blocking the
Expand All @@ -54,7 +60,6 @@ import java.util.concurrent.CompletableFuture
@UnstableApi
interface CoroutineServerInterceptor : AsyncServerInterceptor {

@OptIn(DelicateCoroutinesApi::class)
override fun <I : Any, O : Any> asyncInterceptCall(
call: ServerCall<I, O>,
headers: Metadata,
Expand All @@ -67,7 +72,17 @@ interface CoroutineServerInterceptor : AsyncServerInterceptor {
}
val executor = call.blockingExecutor() ?: call.eventLoop()

return GlobalScope.future(executor.asCoroutineDispatcher() + ArmeriaRequestCoroutineContext(call.ctx())) {
// COROUTINE_CONTEXT_KEY.get():
// It is necessary to propagate the CoroutineContext set by the previous CoroutineContextServerInterceptor.
// (The ArmeriaRequestCoroutineContext is also propagated by CoroutineContextServerInterceptor)
// GrpcContextElement.current():
// In gRPC-kotlin, the Coroutine Context is propagated using the gRPC Context.
return CoroutineScope(
trustin marked this conversation as resolved.
Show resolved Hide resolved
executor.asCoroutineDispatcher() +
ArmeriaRequestCoroutineContext(call.ctx()) +
COROUTINE_CONTEXT_KEY.get() +
GrpcContextElement.current()
).future {
suspendedInterceptCall(call, headers, next)
}
}
Expand All @@ -87,4 +102,14 @@ interface CoroutineServerInterceptor : AsyncServerInterceptor {
headers: Metadata,
next: ServerCallHandler<ReqT, RespT>
): ServerCall.Listener<ReqT>

companion object {
@Suppress("UNCHECKED_CAST")
internal val COROUTINE_CONTEXT_KEY: Context.Key<CoroutineContext> =
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used reflection... There is no other way.

ikhoon marked this conversation as resolved.
Show resolved Hide resolved
CoroutineContextServerInterceptor::class.let { kclass ->
val companionObject = checkNotNull(kclass.companionObject)
val property = companionObject.memberProperties.single { it.name == "COROUTINE_CONTEXT_KEY" }
checkNotNull(property.getter.call(kclass.companionObjectInstance)) as Context.Key<CoroutineContext>
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.linecorp.armeria.server.grpc.kotlin

import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.google.protobuf.ByteString
import com.linecorp.armeria.client.grpc.GrpcClients
import com.linecorp.armeria.common.RequestContext
Expand All @@ -33,14 +34,22 @@ import com.linecorp.armeria.internal.testing.AnticipatedException
import com.linecorp.armeria.server.ServerBuilder
import com.linecorp.armeria.server.ServiceRequestContext
import com.linecorp.armeria.server.auth.Authorizer
import com.linecorp.armeria.server.grpc.AsyncServerInterceptor
import com.linecorp.armeria.server.grpc.GrpcService
import com.linecorp.armeria.testing.junit5.server.ServerExtension
import io.grpc.Context
import io.grpc.Contexts
import io.grpc.Metadata
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.Status
import io.grpc.StatusException
import io.grpc.kotlin.CoroutineContextServerInterceptor
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
Expand All @@ -51,13 +60,16 @@ import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.future.await
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.jupiter.api.extension.RegisterExtension
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.coroutines.CoroutineContext

internal class CoroutineServerInterceptorTest {

Expand Down Expand Up @@ -205,19 +217,30 @@ internal class CoroutineServerInterceptorTest {
@RegisterExtension
val server: ServerExtension = object : ServerExtension() {
override fun configure(sb: ServerBuilder) {
val statusFunction = GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata ->
if (throwable is AnticipatedException && throwable.message == "Invalid access") {
return@GrpcStatusFunction Status.UNAUTHENTICATED
val statusFunction =
GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata ->
if (throwable is AnticipatedException && throwable.message == "Invalid access") {
return@GrpcStatusFunction Status.UNAUTHENTICATED
}
// Fallback to the default.
null
}
// Fallback to the default.
null
}
val threadLocalInterceptor = ThreadLocalInterceptor()
val authInterceptor = AuthInterceptor()
val coroutineNameInterceptor = CoroutineNameInterceptor()
sb.serviceUnder(
"/non-blocking",
GrpcService.builder()
.exceptionMapping(statusFunction)
.intercept(authInterceptor)
// applying order is "MyAsyncInterceptor -> coroutineNameInterceptor ->
// authInterceptor -> threadLocalInterceptor -> MyAsyncInterceptor"
.intercept(
MyAsyncInterceptor(),
threadLocalInterceptor,
authInterceptor,
coroutineNameInterceptor,
MyAsyncInterceptor(),
)
.addService(TestService())
.build()
)
Expand All @@ -226,7 +249,15 @@ internal class CoroutineServerInterceptorTest {
GrpcService.builder()
.addService(TestService())
.exceptionMapping(statusFunction)
.intercept(authInterceptor)
// applying order is "MyAsyncInterceptor -> coroutineNameInterceptor ->
// authInterceptor -> threadLocalInterceptor -> MyAsyncInterceptor"
.intercept(
MyAsyncInterceptor(),
threadLocalInterceptor,
authInterceptor,
coroutineNameInterceptor,
MyAsyncInterceptor(),
)
.useBlockingTaskExecutor(true)
.build()
)
Expand All @@ -236,6 +267,10 @@ internal class CoroutineServerInterceptorTest {
private const val username = "Armeria"
private const val token = "token-1234"

private val executorDispatcher = Executors.newSingleThreadExecutor(
ThreadFactoryBuilder().setNameFormat("my-executor").build()
).asCoroutineDispatcher()

private class AuthInterceptor : CoroutineServerInterceptor {
private val authorizer = Authorizer { ctx: ServiceRequestContext, _: Metadata ->
val future = CompletableFuture<Boolean>()
Expand All @@ -254,28 +289,98 @@ internal class CoroutineServerInterceptorTest {
headers: Metadata,
next: ServerCallHandler<ReqT, RespT>
): ServerCall.Listener<ReqT> {
assertContextPropagation()

delay(100)
assertContextPropagation() // OK even if resume from suspend.

withContext(executorDispatcher) {
// OK even if the dispatcher is switched
assertContextPropagation()
assertThat(Thread.currentThread().name).contains("my-executor")
}

val result = authorizer.authorize(ServiceRequestContext.current(), headers).await()

if (result) {
return next.startCall(call, headers)
val ctx = Context.current().withValue(AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY, "OK")
return Contexts.interceptCall(ctx, call, headers, next)
} else {
throw AnticipatedException("Invalid access")
}
}

private suspend fun assertContextPropagation() {
assertThat(ServiceRequestContext.currentOrNull()).isNotNull()
assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name")
}

companion object {
val AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY: Context.Key<String> =
Context.key("authorization-result")
}
}

private class CoroutineNameInterceptor : CoroutineContextServerInterceptor() {
override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext {
return CoroutineName("my-coroutine-name")
}
}

private class ThreadLocalInterceptor : CoroutineContextServerInterceptor() {
override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext {
return THREAD_LOCAL.asContextElement(value = "thread-local-value")
}

companion object {
val THREAD_LOCAL = ThreadLocal<String>()
}
}

private class MyAsyncInterceptor : AsyncServerInterceptor {
override fun <I : Any, O : Any> asyncInterceptCall(
call: ServerCall<I, O>,
headers: Metadata,
next: ServerCallHandler<I, O>
): CompletableFuture<ServerCall.Listener<I>> {
val context = Context.current();
return CompletableFuture.supplyAsync({
// NB: When the current thread invoking `startCall` is different from the thread which
// started `asyncInterceptCall`, `next.startCall()` should be wrapped with `context.call()`
// to propagate the context to the next interceptor.
context.call { next.startCall(call, headers) }
}, EXECUTOR)
}

companion object {
private val EXECUTOR = Executors.newSingleThreadExecutor()
}
}

private class TestService : TestServiceGrpcKt.TestServiceCoroutineImplBase() {
override suspend fun unaryCall(request: SimpleRequest): SimpleResponse {
assertContextPropagation()

delay(100)
assertContextPropagation() // OK even if resume from suspend.

withContext(executorDispatcher) {
// OK even if the dispatcher is switched
assertContextPropagation()
assertThat(Thread.currentThread().name).contains("my-executor")
}

if (request.fillUsername) {
return SimpleResponse.newBuilder().setUsername(username).build()
}

return SimpleResponse.getDefaultInstance()
}

override fun streamingOutputCall(request: StreamingOutputCallRequest): Flow<StreamingOutputCallResponse> {
return flow {
for (i in 1..5) {
delay(500)
assertContextPropagation()
emit(buildReply(username))
}
}
Expand All @@ -284,16 +389,27 @@ internal class CoroutineServerInterceptorTest {
override suspend fun streamingInputCall(requests: Flow<StreamingInputCallRequest>): StreamingInputCallResponse {
val names = requests.map { it.payload.body.toString() }.toList()

assertContextPropagation()

return buildReply(names)
}

override fun fullDuplexCall(requests: Flow<StreamingOutputCallRequest>): Flow<StreamingOutputCallResponse> {
return flow {
requests.collect {
delay(500)
assertContextPropagation()
emit(buildReply(username))
}
}
}

private suspend fun assertContextPropagation() {
assertThat(ServiceRequestContext.currentOrNull()).isNotNull()
assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name")
assertThat(ThreadLocalInterceptor.THREAD_LOCAL.get()).isEqualTo("thread-local-value")
assertThat(AuthInterceptor.AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY.get()).isEqualTo("OK")
}
}

private fun buildReply(message: String): StreamingOutputCallResponse =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

package com.linecorp.armeria.server.grpc;

import static com.google.common.base.Preconditions.checkState;

import java.util.concurrent.ScheduledExecutorService;

import com.linecorp.armeria.internal.server.grpc.AbstractServerCall;
import com.linecorp.armeria.server.ServiceRequestContext;

import io.grpc.Metadata;
Expand All @@ -36,7 +39,10 @@ final class ArmeriaCoroutineContextInterceptor extends CoroutineContextServerInt

@Override
public CoroutineContext coroutineContext(ServerCall<?, ?> serverCall, Metadata metadata) {
final ServiceRequestContext ctx = ServiceRequestContext.current();
checkState(serverCall instanceof AbstractServerCall,
"Cannot use %s with a non-Armeria gRPC server",
ArmeriaCoroutineContextInterceptor.class.getName());
final ServiceRequestContext ctx = ((AbstractServerCall<?, ?>) serverCall).ctx();
final ArmeriaRequestCoroutineContext coroutineContext = new ArmeriaRequestCoroutineContext(ctx);
final ScheduledExecutorService executor;
if (useBlockingTaskExecutor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
* @Override
* <I, O> CompletableFuture<Listener<I>> asyncInterceptCall(
* ServerCall<I, O> call, Metadata headers, ServerCallHandler<I, O> next) {
*
* Context grpcContext = Context.current();
* return authorizer.authorize(headers).thenApply(result -> {
* if (result) {
* return next.startCall(call, headers);
* // `next.startCall()` should be wrapped with `grpcContext.call()` if you want to propagate
* // the context to the next interceptor.
* return grpcContext.call(() -> next.startCall(call, headers));
* } else {
* throw new AuthenticationException("Invalid access");
* }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,11 @@ static GrpcStatusFunction toGrpcStatusFunction(
private ImmutableList.Builder<ServerInterceptor> interceptors() {
if (interceptors == null) {
interceptors = ImmutableList.builder();
if (USE_COROUTINE_CONTEXT_INTERCEPTOR) {
final ServerInterceptor coroutineContextInterceptor =
new ArmeriaCoroutineContextInterceptor(useBlockingTaskExecutor);
interceptors.add(coroutineContextInterceptor);
}
}
return interceptors;
}
Expand All @@ -961,11 +966,6 @@ private ImmutableList.Builder<ServerInterceptor> interceptors() {
*/
public GrpcService build() {
final HandlerRegistry handlerRegistry;
if (USE_COROUTINE_CONTEXT_INTERCEPTOR) {
final ServerInterceptor coroutineContextInterceptor =
new ArmeriaCoroutineContextInterceptor(useBlockingTaskExecutor);
interceptors().add(coroutineContextInterceptor);
}
if (!enableUnframedRequests && unframedGrpcErrorHandler != null) {
throw new IllegalStateException(
"'unframedGrpcErrorHandler' can only be set if unframed requests are enabled");
Expand Down