Commit 2a8db73d by Dave Syer

Extract header manipulation into a helper and add test

Ribbon and simple host reverse proxy now share quite a lot of code (everything to do with headers). Fixes gh-103
parent 73e7adbe
......@@ -15,6 +15,7 @@ import org.springframework.cloud.netflix.zuul.filters.post.SendResponseFilter;
import org.springframework.cloud.netflix.zuul.filters.pre.DebugFilter;
import org.springframework.cloud.netflix.zuul.filters.pre.PreDecorationFilter;
import org.springframework.cloud.netflix.zuul.filters.pre.Servlet30WrapperFilter;
import org.springframework.cloud.netflix.zuul.filters.route.ProxyRequestHelper;
import org.springframework.cloud.netflix.zuul.filters.route.RibbonRoutingFilter;
import org.springframework.cloud.netflix.zuul.filters.route.SimpleHostRoutingFilter;
import org.springframework.context.annotation.Bean;
......@@ -97,16 +98,21 @@ public class ZuulConfiguration {
// route filters
@Bean
public RibbonRoutingFilter ribbonRoutingFilter() {
RibbonRoutingFilter filter = new RibbonRoutingFilter(clientFactory);
ProxyRequestHelper helper = new ProxyRequestHelper();
if (traces != null) {
filter.setTraces(traces);
helper.setTraces(traces);
}
RibbonRoutingFilter filter = new RibbonRoutingFilter(helper , clientFactory);
return filter;
}
@Bean
public SimpleHostRoutingFilter simpleHostRoutingFilter() {
return new SimpleHostRoutingFilter();
ProxyRequestHelper helper = new ProxyRequestHelper();
if (traces != null) {
helper.setTraces(traces);
}
return new SimpleHostRoutingFilter(helper);
}
// post filters
......
......@@ -31,7 +31,6 @@ import org.apache.commons.io.IOUtils;
import org.springframework.boot.actuate.trace.TraceRepository;
import org.springframework.util.StringUtils;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils;
import com.sun.jersey.core.util.MultivaluedMapImpl;
......@@ -40,7 +39,7 @@ import com.sun.jersey.core.util.MultivaluedMapImpl;
* @author Dave Syer
*
*/
public abstract class BaseProxyFilter extends ZuulFilter {
public class ProxyRequestHelper {
public static final String CONTENT_ENCODING = "Content-Encoding";
......@@ -50,7 +49,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
this.traces = traces;
}
protected MultivaluedMap<String, String> buildZuulRequestQueryParams(
public MultivaluedMap<String, String> buildZuulRequestQueryParams(
HttpServletRequest request) {
Map<String, List<String>> map = HTTPRequestUtils.getInstance().getQueryParams();
......@@ -68,7 +67,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
return params;
}
protected MultivaluedMap<String, String> buildZuulRequestHeaders(
public MultivaluedMap<String, String> buildZuulRequestHeaders(
HttpServletRequest request) {
RequestContext context = RequestContext.getCurrentContext();
......@@ -94,7 +93,46 @@ public abstract class BaseProxyFilter extends ZuulFilter {
return headers;
}
private boolean isIncludedHeader(String headerName) {
public void setResponse(int status, InputStream entity,
Map<String, Collection<String>> headers) throws IOException {
RequestContext context = RequestContext.getCurrentContext();
RequestContext.getCurrentContext().setResponseStatusCode(status);
if (entity != null) {
RequestContext.getCurrentContext().setResponseDataStream(entity);
}
boolean isOriginResponseGzipped = false;
if (headers.containsKey(CONTENT_ENCODING)) {
Collection<String> collection = headers.get(CONTENT_ENCODING);
for (String header : collection) {
if (HTTPRequestUtils.getInstance().isGzipped(header)) {
isOriginResponseGzipped = true;
break;
}
}
}
context.setResponseGZipped(isOriginResponseGzipped);
for (Entry<String, Collection<String>> header : headers.entrySet()) {
RequestContext ctx = RequestContext.getCurrentContext();
String name = header.getKey();
for (String value : header.getValue()) {
ctx.addOriginResponseHeader(name, value);
if (name.equalsIgnoreCase("content-length"))
ctx.setOriginContentLength(value);
if (isIncludedHeader(name)) {
ctx.addZuulResponseHeader(name, value);
}
}
}
}
public boolean isIncludedHeader(String headerName) {
switch (headerName.toLowerCase()) {
case "host":
case "connection":
......@@ -108,7 +146,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
}
}
protected Map<String, Object> debug(String verb, String uri,
public Map<String, Object> debug(String verb, String uri,
MultivaluedMap<String, String> headers,
MultivaluedMap<String, String> params, InputStream requestEntity)
throws IOException {
......@@ -156,7 +194,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
return info;
}
protected void appendDebug(Map<String, Object> info, int status,
public void appendDebug(Map<String, Object> info, int status,
Map<String, Collection<String>> headers) {
if (traces != null) {
@SuppressWarnings("unchecked")
......@@ -175,7 +213,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
}
}
protected void appendDebug(Map<String, Object> info, int status,
public void appendDebug(Map<String, Object> info, int status,
MultivaluedMap<String, String> headers) {
if (traces != null) {
Map<String, Collection<String>> map = new LinkedHashMap<String, Collection<String>>();
......
......@@ -18,11 +18,12 @@ import com.netflix.client.http.HttpRequest.Verb;
import com.netflix.client.http.HttpResponse;
import com.netflix.hystrix.exception.HystrixRuntimeException;
import com.netflix.niws.client.http.RestClient;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.exception.ZuulException;
import com.netflix.zuul.util.HTTPRequestUtils;
public class RibbonRoutingFilter extends BaseProxyFilter {
public class RibbonRoutingFilter extends ZuulFilter {
private static final Logger LOG = LoggerFactory.getLogger(RibbonRoutingFilter.class);
......@@ -30,10 +31,18 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
private SpringClientFactory clientFactory;
public RibbonRoutingFilter(SpringClientFactory clientFactory) {
private ProxyRequestHelper helper;
public RibbonRoutingFilter(ProxyRequestHelper helper,
SpringClientFactory clientFactory) {
this.helper = helper;
this.clientFactory = clientFactory;
}
public RibbonRoutingFilter(SpringClientFactory clientFactory) {
this(new ProxyRequestHelper(), clientFactory);
}
@Override
public String filterType() {
return "route";
......@@ -54,8 +63,9 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
RequestContext context = RequestContext.getCurrentContext();
HttpServletRequest request = context.getRequest();
MultivaluedMap<String, String> headers = buildZuulRequestHeaders(request);
MultivaluedMap<String, String> params = buildZuulRequestQueryParams(request);
MultivaluedMap<String, String> headers = helper.buildZuulRequestHeaders(request);
MultivaluedMap<String, String> params = helper
.buildZuulRequestQueryParams(request);
Verb verb = getVerb(request);
InputStream requestEntity = getRequestBody(request);
......@@ -88,13 +98,14 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
MultivaluedMap<String, String> params, InputStream requestEntity)
throws Exception {
Map<String, Object> info = debug(verb.verb(), uri, headers, params, requestEntity);
Map<String, Object> info = helper.debug(verb.verb(), uri, headers, params,
requestEntity);
RibbonCommand command = new RibbonCommand(restClient, verb, uri, headers, params,
requestEntity);
try {
HttpResponse response = command.execute();
appendDebug(info, response.getStatus(), response.getHeaders());
helper.appendDebug(info, response.getStatus(), response.getHeaders());
return response;
}
catch (HystrixRuntimeException e) {
......@@ -155,57 +166,9 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
return Verb.GET;
}
void setResponse(HttpResponse resp) throws ClientException, IOException {
RequestContext context = RequestContext.getCurrentContext();
context.setResponseStatusCode(resp.getStatus());
if (resp.hasEntity()) {
context.setResponseDataStream(resp.getInputStream());
}
String contentEncoding = null;
Collection<String> contentEncodingHeader = resp.getHeaders()
.get(CONTENT_ENCODING);
if (contentEncodingHeader != null && !contentEncodingHeader.isEmpty()) {
contentEncoding = contentEncodingHeader.iterator().next();
}
if (contentEncoding != null
&& HTTPRequestUtils.getInstance().isGzipped(contentEncoding)) {
context.setResponseGZipped(true);
}
else {
context.setResponseGZipped(false);
}
for (String key : resp.getHeaders().keySet()) {
boolean isValidHeader = isIncludedHeader(key);
Collection<java.lang.String> list = resp.getHeaders().get(key);
for (String header : list) {
context.addOriginResponseHeader(key, header);
if (key.equalsIgnoreCase("content-length"))
context.setOriginContentLength(header);
if (isValidHeader) {
context.addZuulResponseHeader(key, header);
}
}
}
}
private boolean isIncludedHeader(String headerName) {
switch (headerName.toLowerCase()) {
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
}
private void setResponse(HttpResponse resp) throws ClientException, IOException {
helper.setResponse(resp.getStatus(),
!resp.hasEntity() ? null : resp.getInputStream(), resp.getHeaders());
}
}
......@@ -14,6 +14,8 @@ import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Timer;
......@@ -57,17 +59,19 @@ import org.springframework.util.StringUtils;
import com.netflix.config.DynamicIntProperty;
import com.netflix.config.DynamicPropertyFactory;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.constants.ZuulConstants;
import com.netflix.zuul.context.Debug;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils;
public class SimpleHostRoutingFilter extends BaseProxyFilter {
public class SimpleHostRoutingFilter extends ZuulFilter {
public static final String CONTENT_ENCODING = "Content-Encoding";
private static final Logger LOG = LoggerFactory
.getLogger(SimpleHostRoutingFilter.class);
private static final Runnable CLIENTLOADER = new Runnable() {
@Override
public void run() {
......@@ -128,6 +132,16 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
return cm;
}
private ProxyRequestHelper helper;
public SimpleHostRoutingFilter() {
this(new ProxyRequestHelper());
}
public SimpleHostRoutingFilter(ProxyRequestHelper helper) {
this.helper = helper;
}
@Override
public String filterType() {
return "route";
......@@ -201,8 +215,9 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
public Object run() {
RequestContext context = RequestContext.getCurrentContext();
HttpServletRequest request = context.getRequest();
MultivaluedMap<String, String> headers = buildZuulRequestHeaders(request);
MultivaluedMap<String, String> params = buildZuulRequestQueryParams(request);
MultivaluedMap<String, String> headers = helper.buildZuulRequestHeaders(request);
MultivaluedMap<String, String> params = helper
.buildZuulRequestQueryParams(request);
String verb = getVerb(request);
InputStream requestEntity = getRequestBody(request);
HttpClient httpclient = CLIENT.get();
......@@ -229,7 +244,8 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
MultivaluedMap<String, String> params, InputStream requestEntity)
throws Exception {
Map<String, Object> info = debug(verb, uri, headers, params, requestEntity);
Map<String, Object> info = helper
.debug(verb, uri, headers, params, requestEntity);
URL host = RequestContext.getCurrentContext().getRouteHost();
HttpHost httpHost = getHttpHost(host);
......@@ -237,7 +253,7 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
HttpRequest httpRequest;
switch (verb) {
switch (verb.toUpperCase()) {
case "POST":
HttpPost httpPost = new HttpPost(uri + getQueryString());
httpRequest = httpPost;
......@@ -260,6 +276,8 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
LOG.debug(httpHost.getHostName() + " " + httpHost.getPort() + " "
+ httpHost.getSchemeName());
HttpResponse zuulResponse = forwardRequest(httpclient, httpHost, httpRequest);
helper.appendDebug(info, zuulResponse.getStatusLine().getStatusCode(),
revertHeaders(zuulResponse.getAllHeaders()));
return zuulResponse;
}
finally {
......@@ -271,6 +289,18 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
}
private Map<String, Collection<String>> revertHeaders(Header[] headers) {
Map<String, Collection<String>> map = new LinkedHashMap<String, Collection<String>>();
for (Header header : headers) {
String name = header.getName();
if (!map.containsKey(name)) {
map.put(name, new ArrayList<String>());
}
map.get(name).add(header.getValue());
}
return map;
}
private Header[] convertHeaders(MultivaluedMap<String, String> headers) {
List<Header> list = new ArrayList<>();
for (String name : headers.keySet()) {
......@@ -286,13 +316,13 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
return httpclient.execute(httpHost, httpRequest);
}
String getQueryString() {
private String getQueryString() {
HttpServletRequest request = RequestContext.getCurrentContext().getRequest();
String query = request.getQueryString();
return (query != null) ? "?" + query : "";
}
HttpHost getHttpHost(URL host) {
private HttpHost getHttpHost(URL host) {
HttpHost httpHost = new HttpHost(host.getHost(), host.getPort(),
host.getProtocol());
return httpHost;
......@@ -309,91 +339,15 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
return requestEntity;
}
private boolean isIncludedHeader(String name) {
if (name.toLowerCase().contains("content-length"))
return false;
if (!RequestContext.getCurrentContext().getResponseGZipped()) {
if (name.toLowerCase().contains("accept-encoding"))
return false;
}
return true;
}
private String getVerb(HttpServletRequest request) {
String sMethod = request.getMethod();
return sMethod.toUpperCase();
}
private void setResponse(HttpResponse response) throws IOException {
RequestContext context = RequestContext.getCurrentContext();
RequestContext.getCurrentContext().set("hostZuulResponse", response);
RequestContext.getCurrentContext().setResponseStatusCode(
response.getStatusLine().getStatusCode());
if (response.getEntity() != null) {
RequestContext.getCurrentContext().setResponseDataStream(
response.getEntity().getContent());
}
boolean isOriginResponseGzipped = false;
for (Header h : response.getHeaders(CONTENT_ENCODING)) {
if (HTTPRequestUtils.getInstance().isGzipped(h.getValue())) {
isOriginResponseGzipped = true;
break;
}
}
context.setResponseGZipped(isOriginResponseGzipped);
if (Debug.debugRequest()) {
for (Header header : response.getAllHeaders()) {
if (isValidHeader(header)) {
RequestContext.getCurrentContext().addZuulResponseHeader(
header.getName(), header.getValue());
Debug.addRequestDebug("ORIGIN_RESPONSE:: < " + header.getName() + ","
+ header.getValue());
}
}
if (context.getResponseDataStream() != null) {
byte[] origBytes = IOUtils.toByteArray(context.getResponseDataStream());
ByteArrayInputStream byteStream = new ByteArrayInputStream(origBytes);
InputStream inputStream = byteStream;
if (RequestContext.getCurrentContext().getResponseGZipped()) {
inputStream = new GZIPInputStream(byteStream);
}
context.setResponseDataStream(new ByteArrayInputStream(origBytes));
}
}
else {
for (Header header : response.getAllHeaders()) {
RequestContext ctx = RequestContext.getCurrentContext();
ctx.addOriginResponseHeader(header.getName(), header.getValue());
if (header.getName().equalsIgnoreCase("content-length"))
ctx.setOriginContentLength(header.getValue());
if (isValidHeader(header)) {
ctx.addZuulResponseHeader(header.getName(), header.getValue());
}
}
}
}
boolean isValidHeader(Header header) {
switch (header.getName().toLowerCase()) {
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
}
helper.setResponse(response.getStatusLine().getStatusCode(),
response.getEntity() == null ? null : response.getEntity().getContent(),
revertHeaders(response.getAllHeaders()));
}
public static class MySSLSocketFactory extends SSLSocketFactory {
......
package org.springframework.cloud.netflix.zuul;
import java.util.Arrays;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.cloud.client.discovery.EnableDiscoveryClient;
import org.springframework.cloud.netflix.ribbon.RibbonClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import com.netflix.appinfo.EurekaInstanceConfig;
import com.netflix.loadbalancer.BaseLoadBalancer;
import com.netflix.loadbalancer.ILoadBalancer;
import com.netflix.loadbalancer.Server;
import com.netflix.zuul.ZuulFilter;
@SpringBootApplication
// Don't use @SpringBootApplication because we don't want to component scan
@Configuration
@EnableAutoConfiguration
@RestController
@EnableZuulProxy
@EnableDiscoveryClient
@RibbonClient(name = "simple", configuration = SimpleRibbonClientConfiguration.class)
public class SampleZuulProxyApplication {
@RequestMapping("/testing123")
public String testing123() {
throw new RuntimeException("myerror");
}
@RequestMapping("/testing123")
public String testing123() {
throw new RuntimeException("myerror");
}
@RequestMapping("/local")
public String local() {
return "Hello local";
}
@RequestMapping("/local")
public String local() {
return "Hello local";
}
@RequestMapping(value="/local/{id}", method=RequestMethod.DELETE)
public String delete() {
return "Deleted!";
}
@RequestMapping(value = "/local/{id}", method = RequestMethod.DELETE)
public String delete() {
return "Deleted!";
}
@RequestMapping(value="/local/{id}", method=RequestMethod.GET)
public String get() {
return "Gotten!";
}
@RequestMapping(value = "/local/{id}", method = RequestMethod.GET)
public String get() {
return "Gotten!";
}
@RequestMapping("/")
public String home() {
return "Hello world";
}
@Bean
public ZuulFilter sampleFilter() {
return new ZuulFilter() {
......@@ -48,23 +59,38 @@ public class SampleZuulProxyApplication {
public String filterType() {
return "pre";
}
@Override
public boolean shouldFilter() {
return true;
}
@Override
public Object run() {
return null;
}
@Override
public int filterOrder() {
return 0;
}
};
}
public static void main(String[] args) {
SpringApplication.run(SampleZuulProxyApplication.class, args);
}
}
// Load balancer with fixed server list for "simple" pointing to localhost
@Configuration
class SimpleRibbonClientConfiguration {
@Bean
public ILoadBalancer ribbonLoadBalancer(EurekaInstanceConfig instance) {
BaseLoadBalancer balancer = new BaseLoadBalancer();
balancer.setServersList(Arrays.asList(new Server("localhost", instance
.getNonSecurePort())));
return balancer;
}
}
\ No newline at end of file
......@@ -46,6 +46,16 @@ public class SampleZuulProxyApplicationTests {
}
@Test
public void getOnSelfViaRibbonRoutingFilter() {
mapping.reset();
ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + port + "/simple/local/1", HttpMethod.GET,
new HttpEntity<Void>((Void) null), String.class);
assertEquals(HttpStatus.OK, result.getStatusCode());
assertEquals("Gotten!", result.getBody());
}
@Test
public void deleteOnSelfViaSimpleHostRoutingFilter() {
routes.addRoute("/self/**", "http://localhost:" + port + "/local");
mapping.reset();
......
server:
port: 9999
logging:
level:
org.springframework.web: DEBUG
spring:
application:
name: testclient
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment