Commit 053296de by Dave Syer

Ensure request body is not buffered when direct to ZuulServlet

A change in Zuul in NetflixOSS means that it buffers all requests unless you switch it off at the level of the servlet. And even when you do that you need to ensure that Spring Cloud doesn't re-introduce a buffer through the Servlet30RequestWrapper. This change restores the feature in Spring Cloud Angel that multipart files posted to /zuul/* go straight through the servlet and are not buffered in memory. Fixes gh-773
parent 249f524a
......@@ -95,9 +95,14 @@ public class ZuulConfiguration {
}
@Bean
@ConditionalOnMissingBean(name = "zuulServlet")
public ServletRegistrationBean zuulServlet() {
return new ServletRegistrationBean(new ZuulServlet(),
ServletRegistrationBean servlet = new ServletRegistrationBean(new ZuulServlet(),
this.zuulProperties.getServletPattern());
// The whole point of exposing this servlet is to provide a route that doesn't
// buffer requests.
servlet.addInitParameter("buffer-requests", "false");
return servlet;
}
// pre filters
......
......@@ -41,11 +41,11 @@ import org.springframework.web.util.WebUtils;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils;
import lombok.extern.apachecommons.CommonsLog;
import static org.springframework.http.HttpHeaders.CONTENT_ENCODING;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
import lombok.extern.apachecommons.CommonsLog;
/**
* @author Dave Syer
*/
......@@ -123,7 +123,7 @@ public class ProxyRequestHelper {
}
public void setResponse(int status, InputStream entity,
MultiValueMap<String, String> headers) throws IOException {
MultiValueMap<String, String> headers) throws IOException {
RequestContext context = RequestContext.getCurrentContext();
context.setResponseStatusCode(status);
if (entity != null) {
......@@ -185,21 +185,21 @@ public class ProxyRequestHelper {
}
}
switch (name) {
case "host":
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
case "host":
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
}
}
public Map<String, Object> debug(String verb, String uri,
MultiValueMap<String, String> headers, MultiValueMap<String, String> params,
InputStream requestEntity) throws IOException {
MultiValueMap<String, String> headers, MultiValueMap<String, String> params,
InputStream requestEntity) throws IOException {
Map<String, Object> info = new LinkedHashMap<String, Object>();
if (this.traces != null) {
RequestContext context = RequestContext.getCurrentContext();
......@@ -230,7 +230,8 @@ public class ProxyRequestHelper {
input.put(entry.getKey(), value);
}
RequestContext ctx = RequestContext.getCurrentContext();
if (!ctx.isChunkedRequestBody()) {
if (shouldDebugBody(ctx)) {
// Prevent input stream from being read if it needs to go downstream
if (requestEntity != null) {
debugRequestEntity(info, ctx.getRequest().getInputStream());
}
......@@ -241,8 +242,19 @@ public class ProxyRequestHelper {
return info;
}
private boolean shouldDebugBody(RequestContext ctx) {
HttpServletRequest request = ctx.getRequest();
if (ctx.isChunkedRequestBody()) {
return false;
}
if (request == null || request.getContentType() == null) {
return true;
}
return !request.getContentType().toLowerCase().contains("multipart");
}
public void appendDebug(Map<String, Object> info, int status,
MultiValueMap<String, String> headers) {
MultiValueMap<String, String> headers) {
if (this.traces != null) {
@SuppressWarnings("unchecked")
Map<String, Object> trace = (Map<String, Object>) info.get("headers");
......@@ -276,4 +288,3 @@ public class ProxyRequestHelper {
}
}
......@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.DispatcherServlet;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
......@@ -37,7 +38,8 @@ public class Servlet30WrapperFilter extends ZuulFilter {
public Servlet30WrapperFilter() {
this.requestField = ReflectionUtils.findField(HttpServletRequestWrapper.class,
"req", HttpServletRequest.class);
Assert.notNull(this.requestField, "HttpServletRequestWrapper.req field not found");
Assert.notNull(this.requestField,
"HttpServletRequestWrapper.req field not found");
this.requestField.setAccessible(true);
}
......@@ -67,9 +69,18 @@ public class Servlet30WrapperFilter extends ZuulFilter {
if (request instanceof HttpServletRequestWrapper) {
request = (HttpServletRequest) ReflectionUtils.getField(this.requestField,
request);
ctx.setRequest(new Servlet30RequestWrapper(request));
}
else if (isDispatcherServletRequest(request)) {
// If it's going through the dispatcher we need to buffer the body
ctx.setRequest(new Servlet30RequestWrapper(request));
}
ctx.setRequest(new Servlet30RequestWrapper(request));
return null;
}
private boolean isDispatcherServletRequest(HttpServletRequest request) {
return request.getAttribute(
DispatcherServlet.WEB_APPLICATION_CONTEXT_ATTRIBUTE) != null;
}
}
......@@ -23,8 +23,6 @@ import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.apachecommons.CommonsLog;
import org.springframework.cloud.netflix.zuul.filters.ProxyRequestHelper;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.MultiValueMap;
......@@ -35,6 +33,8 @@ import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.exception.ZuulException;
import lombok.extern.apachecommons.CommonsLog;
@CommonsLog
public class RibbonRoutingFilter extends ZuulFilter {
......@@ -64,8 +64,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
@Override
public boolean shouldFilter() {
RequestContext ctx = RequestContext.getCurrentContext();
return (ctx.getRouteHost() == null && ctx.get("serviceId") != null && ctx
.sendZuulResponse());
return (ctx.getRouteHost() == null && ctx.get("serviceId") != null
&& ctx.sendZuulResponse());
}
@Override
......@@ -78,7 +78,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
return response;
}
catch (Exception ex) {
context.set("error.status_code", HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.status_code",
HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.exception", ex);
}
return null;
......@@ -93,6 +94,9 @@ public class RibbonRoutingFilter extends ZuulFilter {
.buildZuulRequestQueryParams(request);
String verb = getVerb(request);
InputStream requestEntity = getRequestBody(request);
if (request.getContentLength() < 0) {
context.setChunkedRequestBody();
}
String serviceId = (String) context.get("serviceId");
Boolean retryable = (Boolean) context.get("retryable");
......@@ -102,15 +106,15 @@ public class RibbonRoutingFilter extends ZuulFilter {
// remove double slashes
uri = uri.replace("//", "/");
return new RibbonCommandContext(serviceId, verb, uri, retryable,
headers, params, requestEntity);
return new RibbonCommandContext(serviceId, verb, uri, retryable, headers, params,
requestEntity);
}
private ClientHttpResponse forward(RibbonCommandContext context) throws Exception {
Map<String, Object> info = this.helper.debug(context.getVerb(), context.getUri(),
context.getHeaders(), context.getParams(), context.getRequestEntity());
RibbonCommand command = ribbonCommandFactory.create(context);
RibbonCommand command = this.ribbonCommandFactory.create(context);
try {
ClientHttpResponse response = command.execute();
this.helper.appendDebug(info, response.getStatusCode().value(),
......@@ -124,11 +128,11 @@ public class RibbonRoutingFilter extends ZuulFilter {
&& ex.getFallbackException().getCause() instanceof ClientException) {
ClientException cause = (ClientException) ex.getFallbackException()
.getCause();
throw new ZuulException(cause, "Forwarding error", 500, cause
.getErrorType().toString());
throw new ZuulException(cause, "Forwarding error", 500,
cause.getErrorType().toString());
}
throw new ZuulException(ex, "Forwarding error", 500, ex.getFailureType()
.toString());
throw new ZuulException(ex, "Forwarding error", 500,
ex.getFailureType().toString());
}
}
......@@ -140,8 +144,8 @@ public class RibbonRoutingFilter extends ZuulFilter {
return null;
}
try {
requestEntity = (InputStream) RequestContext.getCurrentContext().get(
"requestEntity");
requestEntity = (InputStream) RequestContext.getCurrentContext()
.get("requestEntity");
if (requestEntity == null) {
requestEntity = request.getInputStream();
}
......@@ -154,12 +158,14 @@ public class RibbonRoutingFilter extends ZuulFilter {
private String getVerb(HttpServletRequest request) {
String method = request.getMethod();
if (method == null)
if (method == null) {
return "GET";
}
return method;
}
private void setResponse(ClientHttpResponse resp) throws ClientException, IOException {
private void setResponse(ClientHttpResponse resp)
throws ClientException, IOException {
this.helper.setResponse(resp.getStatusCode().value(),
resp.getBody() == null ? null : resp.getBody(), resp.getHeaders());
}
......
......@@ -29,6 +29,7 @@ import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.net.ssl.SSLContext;
......@@ -55,6 +56,7 @@ import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.socket.PlainConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.InputStreamEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.DefaultHttpRequestRetryHandler;
......@@ -80,12 +82,12 @@ import lombok.extern.apachecommons.CommonsLog;
public class SimpleHostRoutingFilter extends ZuulFilter {
private static final DynamicIntProperty SOCKET_TIMEOUT = DynamicPropertyFactory
.getInstance().getIntProperty(ZuulConstants.ZUUL_HOST_SOCKET_TIMEOUT_MILLIS,
10000);
.getInstance()
.getIntProperty(ZuulConstants.ZUUL_HOST_SOCKET_TIMEOUT_MILLIS, 10000);
private static final DynamicIntProperty CONNECTION_TIMEOUT = DynamicPropertyFactory
.getInstance().getIntProperty(ZuulConstants.ZUUL_HOST_CONNECT_TIMEOUT_MILLIS,
2000);
.getInstance()
.getIntProperty(ZuulConstants.ZUUL_HOST_CONNECT_TIMEOUT_MILLIS, 2000);
private final Timer connectionManagerTimer = new Timer(
"SimpleHostRoutingFilter.connectionManagerTimer", true);
......@@ -98,11 +100,12 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
@Override
public void run() {
try {
httpClient.close();
} catch (IOException ex) {
SimpleHostRoutingFilter.this.httpClient.close();
}
catch (IOException ex) {
log.error("error closing client", ex);
}
httpClient = newClient();
SimpleHostRoutingFilter.this.httpClient = newClient();
}
};
......@@ -117,22 +120,22 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
@PostConstruct
private void initialize() {
this.httpClient = newClient();
SOCKET_TIMEOUT.addCallback(clientloader);
CONNECTION_TIMEOUT.addCallback(clientloader);
connectionManagerTimer.schedule(new TimerTask() {
SOCKET_TIMEOUT.addCallback(this.clientloader);
CONNECTION_TIMEOUT.addCallback(this.clientloader);
this.connectionManagerTimer.schedule(new TimerTask() {
@Override
public void run() {
if (connectionManager == null) {
if (SimpleHostRoutingFilter.this.connectionManager == null) {
return;
}
connectionManager.closeExpiredConnections();
SimpleHostRoutingFilter.this.connectionManager.closeExpiredConnections();
}
}, 30000, 5000);
}
@PreDestroy
public void stop() {
connectionManagerTimer.cancel();
this.connectionManagerTimer.cancel();
}
@Override
......@@ -161,16 +164,20 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
.buildZuulRequestQueryParams(request);
String verb = getVerb(request);
InputStream requestEntity = getRequestBody(request);
if (request.getContentLength() < 0) {
context.setChunkedRequestBody();
}
String uri = this.helper.buildZuulRequestURI(request);
try {
HttpResponse response = forward(httpClient, verb, uri, request, headers,
HttpResponse response = forward(this.httpClient, verb, uri, request, headers,
params, requestEntity);
setResponse(response);
}
catch (Exception ex) {
context.set("error.status_code", HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.status_code",
HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.exception", ex);
}
return null;
......@@ -179,31 +186,37 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
protected PoolingHttpClientConnectionManager newConnectionManager() {
try {
final SSLContext sslContext = SSLContext.getInstance("SSL");
sslContext.init(null, new TrustManager[]{new X509TrustManager() {
sslContext.init(null, new TrustManager[] { new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
public void checkClientTrusted(X509Certificate[] x509Certificates,
String s) throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
public void checkServerTrusted(X509Certificate[] x509Certificates,
String s) throws CertificateException {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return null;
}
}}, new SecureRandom());
} }, new SecureRandom());
final Registry<ConnectionSocketFactory> registry = RegistryBuilder.<ConnectionSocketFactory>create()
final Registry<ConnectionSocketFactory> registry = RegistryBuilder
.<ConnectionSocketFactory> create()
.register("http", PlainConnectionSocketFactory.INSTANCE)
.register("https", new SSLConnectionSocketFactory(sslContext))
.build();
connectionManager = new PoolingHttpClientConnectionManager(registry);
connectionManager.setMaxTotal(Integer.parseInt(System.getProperty("zuul.max.host.connections", "200")));
connectionManager.setDefaultMaxPerRoute(Integer.parseInt(System.getProperty("zuul.max.host.connections", "20")));
return connectionManager;
} catch (Exception ex) {
this.connectionManager = new PoolingHttpClientConnectionManager(registry);
this.connectionManager.setMaxTotal(Integer
.parseInt(System.getProperty("zuul.max.host.connections", "200")));
this.connectionManager.setDefaultMaxPerRoute(Integer
.parseInt(System.getProperty("zuul.max.host.connections", "20")));
return this.connectionManager;
}
catch (Exception ex) {
throw new RuntimeException(ex);
}
}
......@@ -212,55 +225,56 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
final RequestConfig requestConfig = RequestConfig.custom()
.setSocketTimeout(SOCKET_TIMEOUT.get())
.setConnectTimeout(CONNECTION_TIMEOUT.get())
.setCookieSpec(CookieSpecs.IGNORE_COOKIES)
.build();
.setCookieSpec(CookieSpecs.IGNORE_COOKIES).build();
return HttpClients.custom()
.setConnectionManager(newConnectionManager())
return HttpClients.custom().setConnectionManager(newConnectionManager())
.setDefaultRequestConfig(requestConfig)
.setRetryHandler(new DefaultHttpRequestRetryHandler(0, false))
.setRedirectStrategy(new RedirectStrategy() {
@Override
public boolean isRedirected(HttpRequest request, HttpResponse response, HttpContext context) throws ProtocolException {
public boolean isRedirected(HttpRequest request,
HttpResponse response, HttpContext context)
throws ProtocolException {
return false;
}
@Override
public HttpUriRequest getRedirect(HttpRequest request, HttpResponse response, HttpContext context) throws ProtocolException {
public HttpUriRequest getRedirect(HttpRequest request,
HttpResponse response, HttpContext context)
throws ProtocolException {
return null;
}
})
.build();
}).build();
}
private HttpResponse forward(HttpClient httpclient, String verb, String uri,
HttpServletRequest request, MultiValueMap<String, String> headers,
MultiValueMap<String, String> params, InputStream requestEntity)
throws Exception {
HttpServletRequest request, MultiValueMap<String, String> headers,
MultiValueMap<String, String> params, InputStream requestEntity)
throws Exception {
Map<String, Object> info = this.helper.debug(verb, uri, headers, params,
requestEntity);
URL host = RequestContext.getCurrentContext().getRouteHost();
HttpHost httpHost = getHttpHost(host);
uri = StringUtils.cleanPath((host.getPath() + uri).replaceAll("/{2,}", "/"));
HttpRequest httpRequest;
int contentLength = request.getContentLength();
InputStreamEntity entity = new InputStreamEntity(requestEntity, contentLength,
ContentType.create(request.getContentType()));
switch (verb.toUpperCase()) {
case "POST":
HttpPost httpPost = new HttpPost(uri + getQueryString());
httpRequest = httpPost;
httpPost.setEntity(new InputStreamEntity(requestEntity, request
.getContentLength()));
httpPost.setEntity(entity);
break;
case "PUT":
HttpPut httpPut = new HttpPut(uri + getQueryString());
httpRequest = httpPut;
httpPut.setEntity(new InputStreamEntity(requestEntity, request
.getContentLength()));
httpPut.setEntity(entity);
break;
case "PATCH":
HttpPatch httpPatch = new HttpPatch(uri + getQueryString());
httpRequest = httpPatch;
httpPatch.setEntity(new InputStreamEntity(requestEntity, request
.getContentLength()));
httpPatch.setEntity(entity);
break;
default:
httpRequest = new BasicHttpRequest(verb, uri + getQueryString());
......@@ -312,10 +326,11 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
private String getQueryString() throws UnsupportedEncodingException {
HttpServletRequest request = RequestContext.getCurrentContext().getRequest();
MultiValueMap<String, String> params=helper.buildZuulRequestQueryParams(request);
StringBuilder query=new StringBuilder();
MultiValueMap<String, String> params = this.helper
.buildZuulRequestQueryParams(request);
StringBuilder query = new StringBuilder();
for (Map.Entry<String, List<String>> entry : params.entrySet()) {
String key=URLEncoder.encode(entry.getKey(), "UTF-8");
String key = URLEncoder.encode(entry.getKey(), "UTF-8");
for (String value : entry.getValue()) {
query.append("&");
query.append(key);
......@@ -323,7 +338,7 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
query.append(URLEncoder.encode(value, "UTF-8"));
}
}
return (query.length()>0) ? "?" + query.substring(1) : "";
return (query.length() > 0) ? "?" + query.substring(1) : "";
}
private HttpHost getHttpHost(URL host) {
......@@ -359,7 +374,7 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
* @param names
*/
protected void addIgnoredHeaders(String... names) {
helper.addIgnoredHeaders(names);
this.helper.addIgnoredHeaders(names);
}
}
......@@ -16,14 +16,10 @@
package org.springframework.cloud.netflix.zuul;
import static org.junit.Assert.assertEquals;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Value;
......@@ -60,6 +56,10 @@ import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerList;
import com.netflix.zuul.ZuulFilter;
import static org.junit.Assert.assertEquals;
import lombok.extern.slf4j.Slf4j;
@RunWith(SpringJUnit4ClassRunner.class)
@SpringApplicationConfiguration(classes = FormZuulServletProxyApplication.class)
@WebAppConfiguration
......@@ -107,6 +107,8 @@ public class FormZuulServletProxyApplicationTests {
form.set("foo", new HttpEntity<byte[]>("bar".getBytes(), part));
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
headers.set("Transfer-Encoding", "chunked");
headers.setContentLength(-1);
ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + this.port + "/zuul/simple/file", HttpMethod.POST,
new HttpEntity<MultiValueMap<String, Object>>(form, headers),
......@@ -120,8 +122,8 @@ public class FormZuulServletProxyApplicationTests {
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
form.set("foo", "bar");
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType
.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + "; charset=UTF-8"));
headers.setContentType(MediaType.valueOf(
MediaType.APPLICATION_FORM_URLENCODED_VALUE + "; charset=UTF-8"));
ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + this.port + "/zuul/simple/form", HttpMethod.POST,
new HttpEntity<MultiValueMap<String, String>>(form, headers),
......@@ -209,12 +211,12 @@ class FormZuulServletProxyApplication {
}
public static void main(String[] args) {
new SpringApplicationBuilder(FormZuulProxyApplication.class).properties(
"zuul.routes.simple:/zuul/simple/**",
"zuul.routes.direct.url:http://localhost:9999",
"zuul.routes.direct.path:/zuul/direct/**",
"multipart.maxFileSize:4096MB", "multipart.maxRequestSize:4096MB").run(
args);
new SpringApplicationBuilder(FormZuulProxyApplication.class)
.properties("zuul.routes.simple:/zuul/simple/**",
"zuul.routes.direct.url:http://localhost:9999",
"zuul.routes.direct.path:/zuul/direct/**",
"multipart.maxFileSize:4096MB", "multipart.maxRequestSize:4096MB")
.run(args);
}
}
......@@ -228,7 +230,7 @@ class ServletFormRibbonClientConfiguration {
@Bean
public ServerList<Server> ribbonServerList() {
return new StaticServerList<>(new Server("localhost", port));
return new StaticServerList<>(new Server("localhost", this.port));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment