diff --git a/grpc-kotlin/build.gradle.kts b/grpc-kotlin/build.gradle.kts index bdfd375c5d1..d43aa525b05 100644 --- a/grpc-kotlin/build.gradle.kts +++ b/grpc-kotlin/build.gradle.kts @@ -3,6 +3,7 @@ dependencies { implementation(project(":grpc")) implementation(libs.grpc.kotlin) + implementation(libs.kotlin.reflect) implementation(libs.kotlin.coroutines.jdk8) testImplementation(libs.kotlin.coroutines.test) diff --git a/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt b/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt index 6952725efa5..4e2496d6731 100644 --- a/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt +++ b/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt @@ -17,18 +17,21 @@ package com.linecorp.armeria.server.grpc.kotlin import com.linecorp.armeria.common.annotation.UnstableApi -import com.linecorp.armeria.internal.common.kotlin.ArmeriaRequestCoroutineContext -import com.linecorp.armeria.internal.server.grpc.AbstractServerCall import com.linecorp.armeria.server.grpc.AsyncServerInterceptor +import io.grpc.Context import io.grpc.Metadata import io.grpc.ServerCall import io.grpc.ServerCallHandler import io.grpc.ServerInterceptor -import kotlinx.coroutines.DelicateCoroutinesApi -import kotlinx.coroutines.GlobalScope -import kotlinx.coroutines.asCoroutineDispatcher +import io.grpc.kotlin.CoroutineContextServerInterceptor +import io.grpc.kotlin.GrpcContextElement +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.future.future import java.util.concurrent.CompletableFuture +import kotlin.coroutines.CoroutineContext +import kotlin.reflect.full.companionObject +import kotlin.reflect.full.companionObjectInstance +import kotlin.reflect.full.memberProperties /** * A [ServerInterceptor] that is able to suspend the interceptor without blocking the @@ -54,20 +57,19 @@ import java.util.concurrent.CompletableFuture @UnstableApi interface CoroutineServerInterceptor : AsyncServerInterceptor { - @OptIn(DelicateCoroutinesApi::class) override fun asyncInterceptCall( call: ServerCall, headers: Metadata, next: ServerCallHandler ): CompletableFuture> { - check(call is AbstractServerCall) { - throw IllegalArgumentException( - "Cannot use ${AsyncServerInterceptor::class.java.name} with a non-Armeria gRPC server" - ) - } - val executor = call.blockingExecutor() ?: call.eventLoop() - - return GlobalScope.future(executor.asCoroutineDispatcher() + ArmeriaRequestCoroutineContext(call.ctx())) { + // COROUTINE_CONTEXT_KEY.get(): + // It is necessary to propagate the CoroutineContext set by the previous CoroutineContextServerInterceptor. + // (The ArmeriaRequestCoroutineContext is also propagated by CoroutineContextServerInterceptor) + // GrpcContextElement.current(): + // In gRPC-kotlin, the Coroutine Context is propagated using the gRPC Context. + return CoroutineScope( + COROUTINE_CONTEXT_KEY.get() + GrpcContextElement.current() + ).future { suspendedInterceptCall(call, headers, next) } } @@ -87,4 +89,14 @@ interface CoroutineServerInterceptor : AsyncServerInterceptor { headers: Metadata, next: ServerCallHandler ): ServerCall.Listener + + companion object { + @Suppress("UNCHECKED_CAST") + internal val COROUTINE_CONTEXT_KEY: Context.Key = + CoroutineContextServerInterceptor::class.let { kclass -> + val companionObject = checkNotNull(kclass.companionObject) + val property = companionObject.memberProperties.single { it.name == "COROUTINE_CONTEXT_KEY" } + checkNotNull(property.getter.call(kclass.companionObjectInstance)) as Context.Key + } + } } diff --git a/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt b/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt index 08faa79fe10..0df23891f5e 100644 --- a/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt +++ b/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt @@ -16,6 +16,7 @@ package com.linecorp.armeria.server.grpc.kotlin +import com.google.common.util.concurrent.ThreadFactoryBuilder import com.google.protobuf.ByteString import com.linecorp.armeria.client.grpc.GrpcClients import com.linecorp.armeria.common.RequestContext @@ -33,14 +34,22 @@ import com.linecorp.armeria.internal.testing.AnticipatedException import com.linecorp.armeria.server.ServerBuilder import com.linecorp.armeria.server.ServiceRequestContext import com.linecorp.armeria.server.auth.Authorizer +import com.linecorp.armeria.server.grpc.AsyncServerInterceptor import com.linecorp.armeria.server.grpc.GrpcService import com.linecorp.armeria.testing.junit5.server.ServerExtension +import io.grpc.Context +import io.grpc.Contexts import io.grpc.Metadata import io.grpc.ServerCall import io.grpc.ServerCallHandler import io.grpc.Status import io.grpc.StatusException +import io.grpc.kotlin.CoroutineContextServerInterceptor +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.asContextElement +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow @@ -51,13 +60,16 @@ import kotlinx.coroutines.flow.toList import kotlinx.coroutines.future.await import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.extension.RegisterExtension import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executors import java.util.concurrent.TimeUnit +import kotlin.coroutines.CoroutineContext internal class CoroutineServerInterceptorTest { @@ -205,19 +217,30 @@ internal class CoroutineServerInterceptorTest { @RegisterExtension val server: ServerExtension = object : ServerExtension() { override fun configure(sb: ServerBuilder) { - val statusFunction = GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata -> - if (throwable is AnticipatedException && throwable.message == "Invalid access") { - return@GrpcStatusFunction Status.UNAUTHENTICATED + val statusFunction = + GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata -> + if (throwable is AnticipatedException && throwable.message == "Invalid access") { + return@GrpcStatusFunction Status.UNAUTHENTICATED + } + // Fallback to the default. + null } - // Fallback to the default. - null - } + val threadLocalInterceptor = ThreadLocalInterceptor() val authInterceptor = AuthInterceptor() + val coroutineNameInterceptor = CoroutineNameInterceptor() sb.serviceUnder( "/non-blocking", GrpcService.builder() .exceptionMapping(statusFunction) - .intercept(authInterceptor) + // applying order is "MyAsyncInterceptor -> coroutineNameInterceptor -> + // authInterceptor -> threadLocalInterceptor -> MyAsyncInterceptor" + .intercept( + MyAsyncInterceptor(), + threadLocalInterceptor, + authInterceptor, + coroutineNameInterceptor, + MyAsyncInterceptor() + ) .addService(TestService()) .build() ) @@ -226,7 +249,15 @@ internal class CoroutineServerInterceptorTest { GrpcService.builder() .addService(TestService()) .exceptionMapping(statusFunction) - .intercept(authInterceptor) + // applying order is "MyAsyncInterceptor -> coroutineNameInterceptor -> + // authInterceptor -> threadLocalInterceptor -> MyAsyncInterceptor" + .intercept( + MyAsyncInterceptor(), + threadLocalInterceptor, + authInterceptor, + coroutineNameInterceptor, + MyAsyncInterceptor() + ) .useBlockingTaskExecutor(true) .build() ) @@ -236,6 +267,10 @@ internal class CoroutineServerInterceptorTest { private const val username = "Armeria" private const val token = "token-1234" + private val executorDispatcher = Executors.newSingleThreadExecutor( + ThreadFactoryBuilder().setNameFormat("my-executor").build() + ).asCoroutineDispatcher() + private class AuthInterceptor : CoroutineServerInterceptor { private val authorizer = Authorizer { ctx: ServiceRequestContext, _: Metadata -> val future = CompletableFuture() @@ -254,21 +289,90 @@ internal class CoroutineServerInterceptorTest { headers: Metadata, next: ServerCallHandler ): ServerCall.Listener { + assertContextPropagation() + + delay(100) + assertContextPropagation() // OK even if resume from suspend. + + withContext(executorDispatcher) { + // OK even if the dispatcher is switched + assertContextPropagation() + assertThat(Thread.currentThread().name).contains("my-executor") + } + val result = authorizer.authorize(ServiceRequestContext.current(), headers).await() + if (result) { - return next.startCall(call, headers) + val ctx = Context.current().withValue(AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY, "OK") + return Contexts.interceptCall(ctx, call, headers, next) } else { throw AnticipatedException("Invalid access") } } + + private suspend fun assertContextPropagation() { + assertThat(ServiceRequestContext.currentOrNull()).isNotNull() + assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name") + } + + companion object { + val AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY: Context.Key = + Context.key("authorization-result") + } + } + + private class CoroutineNameInterceptor : CoroutineContextServerInterceptor() { + override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext { + return CoroutineName("my-coroutine-name") + } + } + + private class ThreadLocalInterceptor : CoroutineContextServerInterceptor() { + override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext { + return THREAD_LOCAL.asContextElement(value = "thread-local-value") + } + + companion object { + val THREAD_LOCAL = ThreadLocal() + } + } + + private class MyAsyncInterceptor : AsyncServerInterceptor { + override fun asyncInterceptCall( + call: ServerCall, + headers: Metadata, + next: ServerCallHandler + ): CompletableFuture> { + val context = Context.current() + return CompletableFuture.supplyAsync({ + // NB: When the current thread invoking `startCall` is different from the thread which + // started `asyncInterceptCall`, `next.startCall()` should be wrapped with `context.call()` + // to propagate the context to the next interceptor. + context.call { next.startCall(call, headers) } + }, EXECUTOR) + } + + companion object { + private val EXECUTOR = Executors.newSingleThreadExecutor() + } } private class TestService : TestServiceGrpcKt.TestServiceCoroutineImplBase() { override suspend fun unaryCall(request: SimpleRequest): SimpleResponse { + assertContextPropagation() + + delay(100) + assertContextPropagation() // OK even if resume from suspend. + + withContext(executorDispatcher) { + // OK even if the dispatcher is switched + assertContextPropagation() + assertThat(Thread.currentThread().name).contains("my-executor") + } + if (request.fillUsername) { return SimpleResponse.newBuilder().setUsername(username).build() } - return SimpleResponse.getDefaultInstance() } @@ -276,6 +380,7 @@ internal class CoroutineServerInterceptorTest { return flow { for (i in 1..5) { delay(500) + assertContextPropagation() emit(buildReply(username)) } } @@ -284,16 +389,27 @@ internal class CoroutineServerInterceptorTest { override suspend fun streamingInputCall(requests: Flow): StreamingInputCallResponse { val names = requests.map { it.payload.body.toString() }.toList() + assertContextPropagation() + return buildReply(names) } override fun fullDuplexCall(requests: Flow): Flow { return flow { requests.collect { + delay(500) + assertContextPropagation() emit(buildReply(username)) } } } + + private suspend fun assertContextPropagation() { + assertThat(ServiceRequestContext.currentOrNull()).isNotNull() + assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name") + assertThat(ThreadLocalInterceptor.THREAD_LOCAL.get()).isEqualTo("thread-local-value") + assertThat(AuthInterceptor.AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY.get()).isEqualTo("OK") + } } private fun buildReply(message: String): StreamingOutputCallResponse = diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/ArmeriaCoroutineContextInterceptor.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/ArmeriaCoroutineContextInterceptor.java index 223afad4894..94f4a5562e7 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/ArmeriaCoroutineContextInterceptor.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/ArmeriaCoroutineContextInterceptor.java @@ -16,6 +16,8 @@ package com.linecorp.armeria.server.grpc; +import static com.google.common.base.Preconditions.checkState; + import java.util.concurrent.ScheduledExecutorService; import com.linecorp.armeria.server.ServiceRequestContext; @@ -36,7 +38,9 @@ final class ArmeriaCoroutineContextInterceptor extends CoroutineContextServerInt @Override public CoroutineContext coroutineContext(ServerCall serverCall, Metadata metadata) { - final ServiceRequestContext ctx = ServiceRequestContext.current(); + final ServiceRequestContext ctx = ServerCallUtil.findRequestContext(serverCall); + checkState(ctx != null, "Failed to find the current %s from %s", + ServiceRequestContext.class.getSimpleName(), serverCall); final ArmeriaRequestCoroutineContext coroutineContext = new ArmeriaRequestCoroutineContext(ctx); final ScheduledExecutorService executor; if (useBlockingTaskExecutor) { diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptor.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptor.java index c796655ae0e..3975d8fac62 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptor.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptor.java @@ -36,10 +36,12 @@ * @Override * CompletableFuture> asyncInterceptCall( * ServerCall call, Metadata headers, ServerCallHandler next) { - * + * Context grpcContext = Context.current(); * return authorizer.authorize(headers).thenApply(result -> { * if (result) { - * return next.startCall(call, headers); + * // `next.startCall()` should be wrapped with `grpcContext.call()` if you want to propagate + * // the context to the next interceptor. + * return grpcContext.call(() -> next.startCall(call, headers)); * } else { * throw new AuthenticationException("Invalid access"); * } diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/DeferredListener.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/DeferredListener.java index 2a5147dec45..3f24a173430 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/DeferredListener.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/DeferredListener.java @@ -48,11 +48,9 @@ final class DeferredListener extends ServerCall.Listener { private boolean callClosed; DeferredListener(ServerCall serverCall, CompletableFuture> listenerFuture) { - checkState(serverCall instanceof AbstractServerCall, "Cannot use %s with a non-Armeria gRPC server", - AsyncServerInterceptor.class.getName()); - @SuppressWarnings("unchecked") - final AbstractServerCall armeriaServerCall = (AbstractServerCall) serverCall; - + final AbstractServerCall armeriaServerCall = ServerCallUtil.findArmeriaServerCall(serverCall); + checkState(armeriaServerCall != null, "Cannot use %s with a non-Armeria gRPC server. ServerCall: %s", + AsyncServerInterceptor.class.getName(), serverCall); // As per `ServerCall.Listener`'s Javadoc, the caller should call one simultaneously. `blockingExecutor` // is a sequential executor which is wrapped by `MoreExecutors.newSequentialExecutor()`. So both // `blockingExecutor` and `eventLoop` guarantees the execution order. diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/ServerCallUtil.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/ServerCallUtil.java new file mode 100644 index 00000000000..66965d7c038 --- /dev/null +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/ServerCallUtil.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.grpc; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.server.grpc.AbstractServerCall; +import com.linecorp.armeria.server.ServiceRequestContext; + +import io.grpc.ForwardingServerCall; +import io.grpc.ServerCall; + +final class ServerCallUtil { + + @Nullable + private static MethodHandle delegateMH; + + static { + try { + delegateMH = MethodHandles.lookup().findVirtual(ForwardingServerCall.class, "delegate", + MethodType.methodType(ServerCall.class)); + } catch (NoSuchMethodException | IllegalAccessException e) { + delegateMH = null; + } + } + + @Nullable + static ServiceRequestContext findRequestContext(ServerCall serverCall) { + final AbstractServerCall armeriaServerCall = findArmeriaServerCall(serverCall); + if (armeriaServerCall != null) { + return armeriaServerCall.ctx(); + } + + return ServiceRequestContext.currentOrNull(); + } + + @Nullable + static AbstractServerCall findArmeriaServerCall(ServerCall serverCall) { + if (delegateMH != null) { + while (serverCall instanceof ForwardingServerCall) { + try { + //noinspection unchecked + serverCall = (ServerCall) delegateMH.invoke(serverCall); + } catch (Throwable e) { + break; + } + } + } + if (serverCall instanceof AbstractServerCall) { + return (AbstractServerCall) serverCall; + } else { + return null; + } + } + + private ServerCallUtil() {} +} diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/DeferredListenerTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/DeferredListenerTest.java index 4a8d3d76069..4233b9c47e9 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/DeferredListenerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/DeferredListenerTest.java @@ -54,7 +54,7 @@ class DeferredListenerTest { void shouldHaveRequestContextInThread() { assertThatThrownBy(() -> new DeferredListener<>(mock(ServerCall.class), null)) .isInstanceOf(IllegalStateException.class) - .hasMessage("Cannot use %s with a non-Armeria gRPC server", + .hasMessageContaining("Cannot use %s with a non-Armeria gRPC server", AsyncServerInterceptor.class.getName()); }