Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Asynchronous Dispatch Logic in AwsAsyncContext with Spring's DispatcherServlet #631

Merged
merged 18 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@
public class AwsAsyncContext implements AsyncContext {
private HttpServletRequest req;
private HttpServletResponse res;
private AwsLambdaServletContainerHandler handler;
private List<AsyncListenerHolder> listeners;
private long timeout;
private AtomicBoolean dispatched;
private AtomicBoolean completed;
private AtomicBoolean dispatchStarted;

private Logger log = LoggerFactory.getLogger(AwsAsyncContext.class);

public AwsAsyncContext(HttpServletRequest request, HttpServletResponse response, AwsLambdaServletContainerHandler servletHandler) {
public AwsAsyncContext(HttpServletRequest request, HttpServletResponse response) {
log.debug("Initializing async context for request: " + SecurityUtils.crlf(request.getPathInfo()) + " - " + SecurityUtils.crlf(request.getMethod()));
req = request;
res = response;
handler = servletHandler;
listeners = new ArrayList<>();
timeout = 3000;
dispatched = new AtomicBoolean(false);
completed = new AtomicBoolean(false);
dispatchStarted = new AtomicBoolean(false);
}

@Override
Expand All @@ -68,16 +68,14 @@ public boolean hasOriginalRequestAndResponse() {

@Override
public void dispatch() {
try {
log.debug("Dispatching request");
if (dispatched.get()) {
throw new IllegalStateException("Dispatching already started");
}
log.debug("Dispatching request");

if (dispatched.get()) {
throw new IllegalStateException("Dispatching already started");
}
if (dispatchStarted.getAndSet(true)) {
dispatched.set(true);
handler.doFilter(req, res, ((AwsServletContext)req.getServletContext()).getServletForPath(req.getRequestURI()));
notifyListeners(NotificationType.START_ASYNC, null);
} catch (ServletException | IOException e) {
notifyListeners(NotificationType.ERROR, e);
}
}

Expand Down Expand Up @@ -154,6 +152,10 @@ public boolean isCompleted() {
return completed.get();
}

public boolean isDispatchStarted() {
return dispatchStarted.get();
}

private void notifyListeners(NotificationType type, Throwable t) {
listeners.forEach((h) -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,15 @@ public boolean isAsyncStarted() {

@Override
public AsyncContext startAsync() throws IllegalStateException {
asyncContext = new AwsAsyncContext(this, response, containerHandler);
asyncContext = new AwsAsyncContext(this, response);
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, containerHandler);
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,25 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response

FilterChain chain = getFilterChain(request, servlet);
chain.doFilter(request, response);

if(requiresAsyncReDispatch(request)) {
chain = getFilterChain(request, servlet);
chain.doFilter(request, response);
}
// if for some reason the response wasn't flushed yet, we force it here unless it's being processed asynchronously (WebFlux)
if (!response.isCommitted() && request.getDispatcherType() != DispatcherType.ASYNC) {
response.flushBuffer();
}
}

private boolean requiresAsyncReDispatch(HttpServletRequest request) {
if (request.isAsyncStarted()) {
AsyncContext asyncContext = request.getAsyncContext();
return asyncContext instanceof AwsAsyncContext
&& ((AwsAsyncContext) asyncContext).isDispatchStarted();
}
return false;
}

@Override
public void initialize() throws ContainerInitializationException {
// we expect all servlets to be wrapped in an AwsServletRegistration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ public boolean isAsyncStarted() {
@Override
public AsyncContext startAsync()
throws IllegalStateException {
asyncContext = new AwsAsyncContext(this, response, containerHandler);
asyncContext = new AwsAsyncContext(this, response);
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
Expand All @@ -506,7 +506,7 @@ public AsyncContext startAsync()
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
throws IllegalStateException {
servletRequest.setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, containerHandler);
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.services.lambda.runtime.Context;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import jakarta.servlet.AsyncContext;
Expand All @@ -32,48 +33,20 @@ public class AwsAsyncContextTest {
private AwsServletContextTest.TestServlet srv2 = new AwsServletContextTest.TestServlet("srv2");
private AwsServletContext ctx = getCtx();

@Test
void dispatch_sendsToCorrectServlet() {
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), lambdaCtx, null);
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
req.setServletContext(ctx);
req.setContainerHandler(handler);

AsyncContext asyncCtx = req.startAsync();
handler.setDesiredStatus(201);
asyncCtx.dispatch();
assertNotNull(handler.getSelectedServlet());
assertEquals(srv1, handler.getSelectedServlet());
assertEquals(201, handler.getResponse().getStatus());

req = new AwsProxyHttpServletRequest(new AwsProxyRequestBuilder("/srv5/hello", "GET").build(), lambdaCtx, null);
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
req.setServletContext(ctx);
req.setContainerHandler(handler);
asyncCtx = req.startAsync();
handler.setDesiredStatus(202);
asyncCtx.dispatch();
assertNotNull(handler.getSelectedServlet());
assertEquals(srv2, handler.getSelectedServlet());
assertEquals(202, handler.getResponse().getStatus());
}

@Test
void dispatchNewPath_sendsToCorrectServlet() throws InvalidRequestEventException {
void dispatch_amendsPath() throws InvalidRequestEventException {
AwsProxyHttpServletRequest req = (AwsProxyHttpServletRequest)reader.readRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), null, lambdaCtx, LambdaContainerHandler.getContainerConfig());
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
req.setServletContext(ctx);
req.setContainerHandler(handler);

AsyncContext asyncCtx = req.startAsync();
handler.setDesiredStatus(301);
asyncCtx.dispatch("/srv4/hello");
assertNotNull(handler.getSelectedServlet());
assertEquals(srv2, handler.getSelectedServlet());
assertNotNull(handler.getResponse());
assertEquals(301, handler.getResponse().getStatus());
assertEquals("/srv1/hello", req.getRequestURI());
}


private AwsServletContext getCtx() {
AwsServletContext ctx = new AwsServletContext(handler);
handler.setServletContext(ctx);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.amazonaws.serverless.proxy.spring;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.spring.springapp.LambdaHandler;
import com.amazonaws.serverless.proxy.spring.springapp.MessageController;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

public class AsyncAppTest {

private static LambdaHandler handler;

@BeforeAll
public static void setUp() {
try {
handler = new LambdaHandler();
} catch (ContainerInitializationException e) {
e.printStackTrace();
fail();
}
}

@Test
void springApp_helloRequest_returnsCorrect() {
AwsProxyRequest req = new AwsProxyRequestBuilder("/hello", "GET").build();
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
assertEquals(200, resp.getStatusCode());
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
}

@Test
void springApp_asyncRequest_returnsCorrect() {
AwsProxyRequest req = new AwsProxyRequestBuilder("/async", "GET").build();
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
assertEquals(200, resp.getStatusCode());
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;

@Configuration
@Import({MessageController.class})
public class AppConfig { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.spring.SpringLambdaContainerHandler;
import com.amazonaws.serverless.proxy.spring.SpringProxyHandlerBuilder;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;

public class LambdaHandler implements RequestHandler<AwsProxyRequest, AwsProxyResponse> {
private SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler;

public LambdaHandler() throws ContainerInitializationException {
handler = new SpringProxyHandlerBuilder<AwsProxyRequest>()
.defaultProxy()
.asyncInit()
.configurationClasses(AppConfig.class)
.buildAndInitialize();
}

@Override
public AwsProxyResponse handleRequest(AwsProxyRequest awsProxyRequest, Context context) {
return handler.proxy(awsProxyRequest, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;

@RestController
@EnableWebMvc
public class MessageController {
public static final String HELLO_MESSAGE = "Hello";

@RequestMapping(path="/hello", method= RequestMethod.GET)
public String hello() {
return HELLO_MESSAGE;
}

@RequestMapping(path="/async", method= RequestMethod.GET)
public DeferredResult<String> asyncHello() {
DeferredResult<String> result = new DeferredResult<>();
result.setResult(HELLO_MESSAGE);
return result;
}
}
48 changes: 48 additions & 0 deletions aws-serverless-java-container-springboot3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,46 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
<version>3.2.1</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-tomcat</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat.embed</groupId>
<artifactId>tomcat-embed-core</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat.embed</groupId>
<artifactId>tomcat-embed-websocket</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>2.2.222</version>
<scope>test</scope>
</dependency>


</dependencies>

<build>
Expand Down Expand Up @@ -284,6 +324,14 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>10</source>
<target>10</target>
</configuration>
</plugin>
</plugins>
</build>
<repositories>
Expand Down
Loading
Loading