Commit 053296de by Dave Syer

Ensure request body is not buffered when direct to ZuulServlet

A change in Zuul in NetflixOSS means that it buffers all requests unless you switch it off at the level of the servlet. And even when you do that you need to ensure that Spring Cloud doesn't re-introduce a buffer through the Servlet30RequestWrapper. This change restores the feature in Spring Cloud Angel that multipart files posted to /zuul/* go straight through the servlet and are not buffered in memory. Fixes gh-773
parent 249f524a
......@@ -95,9 +95,14 @@ public class ZuulConfiguration {
}
@Bean
@ConditionalOnMissingBean(name = "zuulServlet")
public ServletRegistrationBean zuulServlet() {
return new ServletRegistrationBean(new ZuulServlet(),
ServletRegistrationBean servlet = new ServletRegistrationBean(new ZuulServlet(),
this.zuulProperties.getServletPattern());
// The whole point of exposing this servlet is to provide a route that doesn't
// buffer requests.
servlet.addInitParameter("buffer-requests", "false");
return servlet;
}
// pre filters
......
......@@ -41,11 +41,11 @@ import org.springframework.web.util.WebUtils;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils;
import lombok.extern.apachecommons.CommonsLog;
import static org.springframework.http.HttpHeaders.CONTENT_ENCODING;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
import lombok.extern.apachecommons.CommonsLog;
/**
* @author Dave Syer
*/
......@@ -123,7 +123,7 @@ public class ProxyRequestHelper {
}
public void setResponse(int status, InputStream entity,
MultiValueMap<String, String> headers) throws IOException {
MultiValueMap<String, String> headers) throws IOException {
RequestContext context = RequestContext.getCurrentContext();
context.setResponseStatusCode(status);
if (entity != null) {
......@@ -185,21 +185,21 @@ public class ProxyRequestHelper {
}
}
switch (name) {
case "host":
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
case "host":
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
}
}
public Map<String, Object> debug(String verb, String uri,
MultiValueMap<String, String> headers, MultiValueMap<String, String> params,
InputStream requestEntity) throws IOException {
MultiValueMap<String, String> headers, MultiValueMap<String, String> params,
InputStream requestEntity) throws IOException {
Map<String, Object> info = new LinkedHashMap<String, Object>();
if (this.traces != null) {
RequestContext context = RequestContext.getCurrentContext();
......@@ -230,7 +230,8 @@ public class ProxyRequestHelper {
input.put(entry.getKey(), value);
}
RequestContext ctx = RequestContext.getCurrentContext();
if (!ctx.isChunkedRequestBody()) {
if (shouldDebugBody(ctx)) {
// Prevent input stream from being read if it needs to go downstream
if (requestEntity != null) {
debugRequestEntity(info, ctx.getRequest().getInputStream());
}
......@@ -241,8 +242,19 @@ public class ProxyRequestHelper {
return info;
}
private boolean shouldDebugBody(RequestContext ctx) {
HttpServletRequest request = ctx.getRequest();
if (ctx.isChunkedRequestBody()) {
return false;
}
if (request == null || request.getContentType() == null) {
return true;
}
return !request.getContentType().toLowerCase().contains("multipart");
}
public void appendDebug(Map<String, Object> info, int status,
MultiValueMap<String, String> headers) {
MultiValueMap<String, String> headers) {
if (this.traces != null) {
@SuppressWarnings("unchecked")
Map<String, Object> trace = (Map<String, Object>) info.get("headers");
......@@ -276,4 +288,3 @@ public class ProxyRequestHelper {
}
}
......@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.DispatcherServlet;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
......@@ -37,7 +38,8 @@ public class Servlet30WrapperFilter extends ZuulFilter {
public Servlet30WrapperFilter() {
this.requestField = ReflectionUtils.findField(HttpServletRequestWrapper.class,
"req", HttpServletRequest.class);
Assert.notNull(this.requestField, "HttpServletRequestWrapper.req field not found");
Assert.notNull(this.requestField,
"HttpServletRequestWrapper.req field not found");
this.requestField.setAccessible(true);
}
......@@ -67,9 +69,18 @@ public class Servlet30WrapperFilter extends ZuulFilter {
if (request instanceof HttpServletRequestWrapper) {
request = (HttpServletRequest) ReflectionUtils.getField(this.requestField,
request);
ctx.setRequest(new Servlet30RequestWrapper(request));
}
else if (isDispatcherServletRequest(request)) {
// If it's going through the dispatcher we need to buffer the body
ctx.setRequest(new Servlet30RequestWrapper(request));
}
ctx.setRequest(new Servlet30RequestWrapper(request));
return null;
}
private boolean isDispatcherServletRequest(HttpServletRequest request) {
return request.getAttribute(
DispatcherServlet.WEB_APPLICATION_CONTEXT_ATTRIBUTE) != null;
}
}
......@@ -23,8 +23,6 @@ import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.apachecommons.CommonsLog;
import org.springframework.cloud.netflix.zuul.filters.ProxyRequestHelper;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.MultiValueMap;
......@@ -35,6 +33,8 @@ import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.exception.ZuulException;
import lombok.extern.apachecommons.CommonsLog;
@CommonsLog
public class RibbonRoutingFilter extends ZuulFilter {
......@@ -64,8 +64,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
@Override
public boolean shouldFilter() {
RequestContext ctx = RequestContext.getCurrentContext();
return (ctx.getRouteHost() == null && ctx.get("serviceId") != null && ctx
.sendZuulResponse());
return (ctx.getRouteHost() == null && ctx.get("serviceId") != null
&& ctx.sendZuulResponse());
}
@Override
......@@ -78,7 +78,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
return response;
}
catch (Exception ex) {
context.set("error.status_code", HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.status_code",
HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.exception", ex);
}
return null;
......@@ -93,6 +94,9 @@ public class RibbonRoutingFilter extends ZuulFilter {
.buildZuulRequestQueryParams(request);
String verb = getVerb(request);
InputStream requestEntity = getRequestBody(request);
if (request.getContentLength() < 0) {
context.setChunkedRequestBody();
}
String serviceId = (String) context.get("serviceId");
Boolean retryable = (Boolean) context.get("retryable");
......@@ -102,15 +106,15 @@ public class RibbonRoutingFilter extends ZuulFilter {
// remove double slashes
uri = uri.replace("//", "/");
return new RibbonCommandContext(serviceId, verb, uri, retryable,
headers, params, requestEntity);
return new RibbonCommandContext(serviceId, verb, uri, retryable, headers, params,
requestEntity);
}
private ClientHttpResponse forward(RibbonCommandContext context) throws Exception {
Map<String, Object> info = this.helper.debug(context.getVerb(), context.getUri(),
context.getHeaders(), context.getParams(), context.getRequestEntity());
RibbonCommand command = ribbonCommandFactory.create(context);
RibbonCommand command = this.ribbonCommandFactory.create(context);
try {
ClientHttpResponse response = command.execute();
this.helper.appendDebug(info, response.getStatusCode().value(),
......@@ -124,11 +128,11 @@ public class RibbonRoutingFilter extends ZuulFilter {
&& ex.getFallbackException().getCause() instanceof ClientException) {
ClientException cause = (ClientException) ex.getFallbackException()
.getCause();
throw new ZuulException(cause, "Forwarding error", 500, cause
.getErrorType().toString());
throw new ZuulException(cause, "Forwarding error", 500,
cause.getErrorType().toString());
}
throw new ZuulException(ex, "Forwarding error", 500, ex.getFailureType()
.toString());
throw new ZuulException(ex, "Forwarding error", 500,
ex.getFailureType().toString());
}
}
......@@ -140,8 +144,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
return null;
}
try {
requestEntity = (InputStream) RequestContext.getCurrentContext().get(
"requestEntity");
requestEntity = (InputStream) RequestContext.getCurrentContext()
.get("requestEntity");
if (requestEntity == null) {
requestEntity = request.getInputStream();
}
......@@ -154,12 +158,14 @@ public class RibbonRoutingFilter extends ZuulFilter {
private String getVerb(HttpServletRequest request) {
String method = request.getMethod();
if (method == null)
if (method == null) {
return "GET";
}
return method;
}
private void setResponse(ClientHttpResponse resp) throws ClientException, IOException {
private void setResponse(ClientHttpResponse resp)
throws ClientException, IOException {
this.helper.setResponse(resp.getStatusCode().value(),
resp.getBody() == null ? null : resp.getBody(), resp.getHeaders());
}
......
......@@ -16,14 +16,10 @@
package org.springframework.cloud.netflix.zuul;
import static org.junit.Assert.assertEquals;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Value;
......@@ -60,6 +56,10 @@ import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerList;
import com.netflix.zuul.ZuulFilter;
import static org.junit.Assert.assertEquals;
import lombok.extern.slf4j.Slf4j;
@RunWith(SpringJUnit4ClassRunner.class)
@SpringApplicationConfiguration(classes = FormZuulServletProxyApplication.class)
@WebAppConfiguration
......@@ -107,6 +107,8 @@ public class FormZuulServletProxyApplicationTests {
form.set("foo", new HttpEntity<byte[]>("bar".getBytes(), part));
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
headers.set("Transfer-Encoding", "chunked");
headers.setContentLength(-1);
ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + this.port + "/zuul/simple/file", HttpMethod.POST,
new HttpEntity<MultiValueMap<String, Object>>(form, headers),
......@@ -120,8 +122,8 @@ public class FormZuulServletProxyApplicationTests {
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
form.set("foo", "bar");
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType
.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + "; charset=UTF-8"));
headers.setContentType(MediaType.valueOf(
MediaType.APPLICATION_FORM_URLENCODED_VALUE + "; charset=UTF-8"));
ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + this.port + "/zuul/simple/form", HttpMethod.POST,
new HttpEntity<MultiValueMap<String, String>>(form, headers),
......@@ -209,12 +211,12 @@ class FormZuulServletProxyApplication {
}
public static void main(String[] args) {
new SpringApplicationBuilder(FormZuulProxyApplication.class).properties(
"zuul.routes.simple:/zuul/simple/**",
"zuul.routes.direct.url:http://localhost:9999",
"zuul.routes.direct.path:/zuul/direct/**",
"multipart.maxFileSize:4096MB", "multipart.maxRequestSize:4096MB").run(
args);
new SpringApplicationBuilder(FormZuulProxyApplication.class)
.properties("zuul.routes.simple:/zuul/simple/**",
"zuul.routes.direct.url:http://localhost:9999",
"zuul.routes.direct.path:/zuul/direct/**",
"multipart.maxFileSize:4096MB", "multipart.maxRequestSize:4096MB")
.run(args);
}
}
......@@ -228,7 +230,7 @@ class ServletFormRibbonClientConfiguration {
@Bean
public ServerList<Server> ribbonServerList() {
return new StaticServerList<>(new Server("localhost", port));
return new StaticServerList<>(new Server("localhost", this.port));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment