Commit 053296de by Dave Syer

Ensure request body is not buffered when direct to ZuulServlet

A change in Zuul in NetflixOSS means that it buffers all requests unless you switch it off at the level of the servlet. And even when you do that you need to ensure that Spring Cloud doesn't re-introduce a buffer through the Servlet30RequestWrapper. This change restores the feature in Spring Cloud Angel that multipart files posted to /zuul/* go straight through the servlet and are not buffered in memory. Fixes gh-773
parent 249f524a
...@@ -95,9 +95,14 @@ public class ZuulConfiguration { ...@@ -95,9 +95,14 @@ public class ZuulConfiguration {
} }
@Bean @Bean
@ConditionalOnMissingBean(name = "zuulServlet")
public ServletRegistrationBean zuulServlet() { public ServletRegistrationBean zuulServlet() {
return new ServletRegistrationBean(new ZuulServlet(), ServletRegistrationBean servlet = new ServletRegistrationBean(new ZuulServlet(),
this.zuulProperties.getServletPattern()); this.zuulProperties.getServletPattern());
// The whole point of exposing this servlet is to provide a route that doesn't
// buffer requests.
servlet.addInitParameter("buffer-requests", "false");
return servlet;
} }
// pre filters // pre filters
......
...@@ -41,11 +41,11 @@ import org.springframework.web.util.WebUtils; ...@@ -41,11 +41,11 @@ import org.springframework.web.util.WebUtils;
import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils; import com.netflix.zuul.util.HTTPRequestUtils;
import lombok.extern.apachecommons.CommonsLog;
import static org.springframework.http.HttpHeaders.CONTENT_ENCODING; import static org.springframework.http.HttpHeaders.CONTENT_ENCODING;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH; import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
import lombok.extern.apachecommons.CommonsLog;
/** /**
* @author Dave Syer * @author Dave Syer
*/ */
...@@ -123,7 +123,7 @@ public class ProxyRequestHelper { ...@@ -123,7 +123,7 @@ public class ProxyRequestHelper {
} }
public void setResponse(int status, InputStream entity, public void setResponse(int status, InputStream entity,
MultiValueMap<String, String> headers) throws IOException { MultiValueMap<String, String> headers) throws IOException {
RequestContext context = RequestContext.getCurrentContext(); RequestContext context = RequestContext.getCurrentContext();
context.setResponseStatusCode(status); context.setResponseStatusCode(status);
if (entity != null) { if (entity != null) {
...@@ -185,21 +185,21 @@ public class ProxyRequestHelper { ...@@ -185,21 +185,21 @@ public class ProxyRequestHelper {
} }
} }
switch (name) { switch (name) {
case "host": case "host":
case "connection": case "connection":
case "content-length": case "content-length":
case "content-encoding": case "content-encoding":
case "server": case "server":
case "transfer-encoding": case "transfer-encoding":
return false; return false;
default: default:
return true; return true;
} }
} }
public Map<String, Object> debug(String verb, String uri, public Map<String, Object> debug(String verb, String uri,
MultiValueMap<String, String> headers, MultiValueMap<String, String> params, MultiValueMap<String, String> headers, MultiValueMap<String, String> params,
InputStream requestEntity) throws IOException { InputStream requestEntity) throws IOException {
Map<String, Object> info = new LinkedHashMap<String, Object>(); Map<String, Object> info = new LinkedHashMap<String, Object>();
if (this.traces != null) { if (this.traces != null) {
RequestContext context = RequestContext.getCurrentContext(); RequestContext context = RequestContext.getCurrentContext();
...@@ -230,7 +230,8 @@ public class ProxyRequestHelper { ...@@ -230,7 +230,8 @@ public class ProxyRequestHelper {
input.put(entry.getKey(), value); input.put(entry.getKey(), value);
} }
RequestContext ctx = RequestContext.getCurrentContext(); RequestContext ctx = RequestContext.getCurrentContext();
if (!ctx.isChunkedRequestBody()) { if (shouldDebugBody(ctx)) {
// Prevent input stream from being read if it needs to go downstream
if (requestEntity != null) { if (requestEntity != null) {
debugRequestEntity(info, ctx.getRequest().getInputStream()); debugRequestEntity(info, ctx.getRequest().getInputStream());
} }
...@@ -241,8 +242,19 @@ public class ProxyRequestHelper { ...@@ -241,8 +242,19 @@ public class ProxyRequestHelper {
return info; return info;
} }
private boolean shouldDebugBody(RequestContext ctx) {
HttpServletRequest request = ctx.getRequest();
if (ctx.isChunkedRequestBody()) {
return false;
}
if (request == null || request.getContentType() == null) {
return true;
}
return !request.getContentType().toLowerCase().contains("multipart");
}
public void appendDebug(Map<String, Object> info, int status, public void appendDebug(Map<String, Object> info, int status,
MultiValueMap<String, String> headers) { MultiValueMap<String, String> headers) {
if (this.traces != null) { if (this.traces != null) {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Map<String, Object> trace = (Map<String, Object>) info.get("headers"); Map<String, Object> trace = (Map<String, Object>) info.get("headers");
...@@ -276,4 +288,3 @@ public class ProxyRequestHelper { ...@@ -276,4 +288,3 @@ public class ProxyRequestHelper {
} }
} }
...@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest; ...@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.DispatcherServlet;
import com.netflix.zuul.ZuulFilter; import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.context.RequestContext;
...@@ -37,7 +38,8 @@ public class Servlet30WrapperFilter extends ZuulFilter { ...@@ -37,7 +38,8 @@ public class Servlet30WrapperFilter extends ZuulFilter {
public Servlet30WrapperFilter() { public Servlet30WrapperFilter() {
this.requestField = ReflectionUtils.findField(HttpServletRequestWrapper.class, this.requestField = ReflectionUtils.findField(HttpServletRequestWrapper.class,
"req", HttpServletRequest.class); "req", HttpServletRequest.class);
Assert.notNull(this.requestField, "HttpServletRequestWrapper.req field not found"); Assert.notNull(this.requestField,
"HttpServletRequestWrapper.req field not found");
this.requestField.setAccessible(true); this.requestField.setAccessible(true);
} }
...@@ -67,9 +69,18 @@ public class Servlet30WrapperFilter extends ZuulFilter { ...@@ -67,9 +69,18 @@ public class Servlet30WrapperFilter extends ZuulFilter {
if (request instanceof HttpServletRequestWrapper) { if (request instanceof HttpServletRequestWrapper) {
request = (HttpServletRequest) ReflectionUtils.getField(this.requestField, request = (HttpServletRequest) ReflectionUtils.getField(this.requestField,
request); request);
ctx.setRequest(new Servlet30RequestWrapper(request));
}
else if (isDispatcherServletRequest(request)) {
// If it's going through the dispatcher we need to buffer the body
ctx.setRequest(new Servlet30RequestWrapper(request));
} }
ctx.setRequest(new Servlet30RequestWrapper(request));
return null; return null;
} }
private boolean isDispatcherServletRequest(HttpServletRequest request) {
return request.getAttribute(
DispatcherServlet.WEB_APPLICATION_CONTEXT_ATTRIBUTE) != null;
}
} }
...@@ -23,8 +23,6 @@ import java.util.Map; ...@@ -23,8 +23,6 @@ import java.util.Map;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import lombok.extern.apachecommons.CommonsLog;
import org.springframework.cloud.netflix.zuul.filters.ProxyRequestHelper; import org.springframework.cloud.netflix.zuul.filters.ProxyRequestHelper;
import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
...@@ -35,6 +33,8 @@ import com.netflix.zuul.ZuulFilter; ...@@ -35,6 +33,8 @@ import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.exception.ZuulException; import com.netflix.zuul.exception.ZuulException;
import lombok.extern.apachecommons.CommonsLog;
@CommonsLog @CommonsLog
public class RibbonRoutingFilter extends ZuulFilter { public class RibbonRoutingFilter extends ZuulFilter {
...@@ -64,8 +64,8 @@ public class RibbonRoutingFilter extends ZuulFilter { ...@@ -64,8 +64,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
@Override @Override
public boolean shouldFilter() { public boolean shouldFilter() {
RequestContext ctx = RequestContext.getCurrentContext(); RequestContext ctx = RequestContext.getCurrentContext();
return (ctx.getRouteHost() == null && ctx.get("serviceId") != null && ctx return (ctx.getRouteHost() == null && ctx.get("serviceId") != null
.sendZuulResponse()); && ctx.sendZuulResponse());
} }
@Override @Override
...@@ -78,7 +78,8 @@ public class RibbonRoutingFilter extends ZuulFilter { ...@@ -78,7 +78,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
return response; return response;
} }
catch (Exception ex) { catch (Exception ex) {
context.set("error.status_code", HttpServletResponse.SC_INTERNAL_SERVER_ERROR); context.set("error.status_code",
HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.exception", ex); context.set("error.exception", ex);
} }
return null; return null;
...@@ -93,6 +94,9 @@ public class RibbonRoutingFilter extends ZuulFilter { ...@@ -93,6 +94,9 @@ public class RibbonRoutingFilter extends ZuulFilter {
.buildZuulRequestQueryParams(request); .buildZuulRequestQueryParams(request);
String verb = getVerb(request); String verb = getVerb(request);
InputStream requestEntity = getRequestBody(request); InputStream requestEntity = getRequestBody(request);
if (request.getContentLength() < 0) {
context.setChunkedRequestBody();
}
String serviceId = (String) context.get("serviceId"); String serviceId = (String) context.get("serviceId");
Boolean retryable = (Boolean) context.get("retryable"); Boolean retryable = (Boolean) context.get("retryable");
...@@ -102,15 +106,15 @@ public class RibbonRoutingFilter extends ZuulFilter { ...@@ -102,15 +106,15 @@ public class RibbonRoutingFilter extends ZuulFilter {
// remove double slashes // remove double slashes
uri = uri.replace("//", "/"); uri = uri.replace("//", "/");
return new RibbonCommandContext(serviceId, verb, uri, retryable, return new RibbonCommandContext(serviceId, verb, uri, retryable, headers, params,
headers, params, requestEntity); requestEntity);
} }
private ClientHttpResponse forward(RibbonCommandContext context) throws Exception { private ClientHttpResponse forward(RibbonCommandContext context) throws Exception {
Map<String, Object> info = this.helper.debug(context.getVerb(), context.getUri(), Map<String, Object> info = this.helper.debug(context.getVerb(), context.getUri(),
context.getHeaders(), context.getParams(), context.getRequestEntity()); context.getHeaders(), context.getParams(), context.getRequestEntity());
RibbonCommand command = ribbonCommandFactory.create(context); RibbonCommand command = this.ribbonCommandFactory.create(context);
try { try {
ClientHttpResponse response = command.execute(); ClientHttpResponse response = command.execute();
this.helper.appendDebug(info, response.getStatusCode().value(), this.helper.appendDebug(info, response.getStatusCode().value(),
...@@ -124,11 +128,11 @@ public class RibbonRoutingFilter extends ZuulFilter { ...@@ -124,11 +128,11 @@ public class RibbonRoutingFilter extends ZuulFilter {
&& ex.getFallbackException().getCause() instanceof ClientException) { && ex.getFallbackException().getCause() instanceof ClientException) {
ClientException cause = (ClientException) ex.getFallbackException() ClientException cause = (ClientException) ex.getFallbackException()
.getCause(); .getCause();
throw new ZuulException(cause, "Forwarding error", 500, cause throw new ZuulException(cause, "Forwarding error", 500,
.getErrorType().toString()); cause.getErrorType().toString());
} }
throw new ZuulException(ex, "Forwarding error", 500, ex.getFailureType() throw new ZuulException(ex, "Forwarding error", 500,
.toString()); ex.getFailureType().toString());
} }
} }
...@@ -140,8 +144,8 @@ public class RibbonRoutingFilter extends ZuulFilter { ...@@ -140,8 +144,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
return null; return null;
} }
try { try {
requestEntity = (InputStream) RequestContext.getCurrentContext().get( requestEntity = (InputStream) RequestContext.getCurrentContext()
"requestEntity"); .get("requestEntity");
if (requestEntity == null) { if (requestEntity == null) {
requestEntity = request.getInputStream(); requestEntity = request.getInputStream();
} }
...@@ -154,12 +158,14 @@ public class RibbonRoutingFilter extends ZuulFilter { ...@@ -154,12 +158,14 @@ public class RibbonRoutingFilter extends ZuulFilter {
private String getVerb(HttpServletRequest request) { private String getVerb(HttpServletRequest request) {
String method = request.getMethod(); String method = request.getMethod();
if (method == null) if (method == null) {
return "GET"; return "GET";
}
return method; return method;
} }
private void setResponse(ClientHttpResponse resp) throws ClientException, IOException { private void setResponse(ClientHttpResponse resp)
throws ClientException, IOException {
this.helper.setResponse(resp.getStatusCode().value(), this.helper.setResponse(resp.getStatusCode().value(),
resp.getBody() == null ? null : resp.getBody(), resp.getHeaders()); resp.getBody() == null ? null : resp.getBody(), resp.getHeaders());
} }
......
...@@ -16,14 +16,10 @@ ...@@ -16,14 +16,10 @@
package org.springframework.cloud.netflix.zuul; package org.springframework.cloud.netflix.zuul;
import static org.junit.Assert.assertEquals;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Map; import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
...@@ -60,6 +56,10 @@ import com.netflix.loadbalancer.Server; ...@@ -60,6 +56,10 @@ import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerList; import com.netflix.loadbalancer.ServerList;
import com.netflix.zuul.ZuulFilter; import com.netflix.zuul.ZuulFilter;
import static org.junit.Assert.assertEquals;
import lombok.extern.slf4j.Slf4j;
@RunWith(SpringJUnit4ClassRunner.class) @RunWith(SpringJUnit4ClassRunner.class)
@SpringApplicationConfiguration(classes = FormZuulServletProxyApplication.class) @SpringApplicationConfiguration(classes = FormZuulServletProxyApplication.class)
@WebAppConfiguration @WebAppConfiguration
...@@ -107,6 +107,8 @@ public class FormZuulServletProxyApplicationTests { ...@@ -107,6 +107,8 @@ public class FormZuulServletProxyApplicationTests {
form.set("foo", new HttpEntity<byte[]>("bar".getBytes(), part)); form.set("foo", new HttpEntity<byte[]>("bar".getBytes(), part));
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA); headers.setContentType(MediaType.MULTIPART_FORM_DATA);
headers.set("Transfer-Encoding", "chunked");
headers.setContentLength(-1);
ResponseEntity<String> result = new TestRestTemplate().exchange( ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + this.port + "/zuul/simple/file", HttpMethod.POST, "http://localhost:" + this.port + "/zuul/simple/file", HttpMethod.POST,
new HttpEntity<MultiValueMap<String, Object>>(form, headers), new HttpEntity<MultiValueMap<String, Object>>(form, headers),
...@@ -120,8 +122,8 @@ public class FormZuulServletProxyApplicationTests { ...@@ -120,8 +122,8 @@ public class FormZuulServletProxyApplicationTests {
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>(); MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
form.set("foo", "bar"); form.set("foo", "bar");
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType headers.setContentType(MediaType.valueOf(
.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + "; charset=UTF-8")); MediaType.APPLICATION_FORM_URLENCODED_VALUE + "; charset=UTF-8"));
ResponseEntity<String> result = new TestRestTemplate().exchange( ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + this.port + "/zuul/simple/form", HttpMethod.POST, "http://localhost:" + this.port + "/zuul/simple/form", HttpMethod.POST,
new HttpEntity<MultiValueMap<String, String>>(form, headers), new HttpEntity<MultiValueMap<String, String>>(form, headers),
...@@ -209,12 +211,12 @@ class FormZuulServletProxyApplication { ...@@ -209,12 +211,12 @@ class FormZuulServletProxyApplication {
} }
public static void main(String[] args) { public static void main(String[] args) {
new SpringApplicationBuilder(FormZuulProxyApplication.class).properties( new SpringApplicationBuilder(FormZuulProxyApplication.class)
"zuul.routes.simple:/zuul/simple/**", .properties("zuul.routes.simple:/zuul/simple/**",
"zuul.routes.direct.url:http://localhost:9999", "zuul.routes.direct.url:http://localhost:9999",
"zuul.routes.direct.path:/zuul/direct/**", "zuul.routes.direct.path:/zuul/direct/**",
"multipart.maxFileSize:4096MB", "multipart.maxRequestSize:4096MB").run( "multipart.maxFileSize:4096MB", "multipart.maxRequestSize:4096MB")
args); .run(args);
} }
} }
...@@ -228,7 +230,7 @@ class ServletFormRibbonClientConfiguration { ...@@ -228,7 +230,7 @@ class ServletFormRibbonClientConfiguration {
@Bean @Bean
public ServerList<Server> ribbonServerList() { public ServerList<Server> ribbonServerList() {
return new StaticServerList<>(new Server("localhost", port)); return new StaticServerList<>(new Server("localhost", this.port));
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment