Commit 2a8db73d by Dave Syer

Extract header manipulation into a helper and add test

Ribbon and simple host reverse proxy now share quite a lot of code (everything to do with headers). Fixes gh-103
parent 73e7adbe
...@@ -15,6 +15,7 @@ import org.springframework.cloud.netflix.zuul.filters.post.SendResponseFilter; ...@@ -15,6 +15,7 @@ import org.springframework.cloud.netflix.zuul.filters.post.SendResponseFilter;
import org.springframework.cloud.netflix.zuul.filters.pre.DebugFilter; import org.springframework.cloud.netflix.zuul.filters.pre.DebugFilter;
import org.springframework.cloud.netflix.zuul.filters.pre.PreDecorationFilter; import org.springframework.cloud.netflix.zuul.filters.pre.PreDecorationFilter;
import org.springframework.cloud.netflix.zuul.filters.pre.Servlet30WrapperFilter; import org.springframework.cloud.netflix.zuul.filters.pre.Servlet30WrapperFilter;
import org.springframework.cloud.netflix.zuul.filters.route.ProxyRequestHelper;
import org.springframework.cloud.netflix.zuul.filters.route.RibbonRoutingFilter; import org.springframework.cloud.netflix.zuul.filters.route.RibbonRoutingFilter;
import org.springframework.cloud.netflix.zuul.filters.route.SimpleHostRoutingFilter; import org.springframework.cloud.netflix.zuul.filters.route.SimpleHostRoutingFilter;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
...@@ -97,16 +98,21 @@ public class ZuulConfiguration { ...@@ -97,16 +98,21 @@ public class ZuulConfiguration {
// route filters // route filters
@Bean @Bean
public RibbonRoutingFilter ribbonRoutingFilter() { public RibbonRoutingFilter ribbonRoutingFilter() {
RibbonRoutingFilter filter = new RibbonRoutingFilter(clientFactory); ProxyRequestHelper helper = new ProxyRequestHelper();
if (traces != null) { if (traces != null) {
filter.setTraces(traces); helper.setTraces(traces);
} }
RibbonRoutingFilter filter = new RibbonRoutingFilter(helper , clientFactory);
return filter; return filter;
} }
@Bean @Bean
public SimpleHostRoutingFilter simpleHostRoutingFilter() { public SimpleHostRoutingFilter simpleHostRoutingFilter() {
return new SimpleHostRoutingFilter(); ProxyRequestHelper helper = new ProxyRequestHelper();
if (traces != null) {
helper.setTraces(traces);
}
return new SimpleHostRoutingFilter(helper);
} }
// post filters // post filters
......
...@@ -31,7 +31,6 @@ import org.apache.commons.io.IOUtils; ...@@ -31,7 +31,6 @@ import org.apache.commons.io.IOUtils;
import org.springframework.boot.actuate.trace.TraceRepository; import org.springframework.boot.actuate.trace.TraceRepository;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils; import com.netflix.zuul.util.HTTPRequestUtils;
import com.sun.jersey.core.util.MultivaluedMapImpl; import com.sun.jersey.core.util.MultivaluedMapImpl;
...@@ -40,7 +39,7 @@ import com.sun.jersey.core.util.MultivaluedMapImpl; ...@@ -40,7 +39,7 @@ import com.sun.jersey.core.util.MultivaluedMapImpl;
* @author Dave Syer * @author Dave Syer
* *
*/ */
public abstract class BaseProxyFilter extends ZuulFilter { public class ProxyRequestHelper {
public static final String CONTENT_ENCODING = "Content-Encoding"; public static final String CONTENT_ENCODING = "Content-Encoding";
...@@ -50,7 +49,7 @@ public abstract class BaseProxyFilter extends ZuulFilter { ...@@ -50,7 +49,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
this.traces = traces; this.traces = traces;
} }
protected MultivaluedMap<String, String> buildZuulRequestQueryParams( public MultivaluedMap<String, String> buildZuulRequestQueryParams(
HttpServletRequest request) { HttpServletRequest request) {
Map<String, List<String>> map = HTTPRequestUtils.getInstance().getQueryParams(); Map<String, List<String>> map = HTTPRequestUtils.getInstance().getQueryParams();
...@@ -68,7 +67,7 @@ public abstract class BaseProxyFilter extends ZuulFilter { ...@@ -68,7 +67,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
return params; return params;
} }
protected MultivaluedMap<String, String> buildZuulRequestHeaders( public MultivaluedMap<String, String> buildZuulRequestHeaders(
HttpServletRequest request) { HttpServletRequest request) {
RequestContext context = RequestContext.getCurrentContext(); RequestContext context = RequestContext.getCurrentContext();
...@@ -94,7 +93,46 @@ public abstract class BaseProxyFilter extends ZuulFilter { ...@@ -94,7 +93,46 @@ public abstract class BaseProxyFilter extends ZuulFilter {
return headers; return headers;
} }
private boolean isIncludedHeader(String headerName) { public void setResponse(int status, InputStream entity,
Map<String, Collection<String>> headers) throws IOException {
RequestContext context = RequestContext.getCurrentContext();
RequestContext.getCurrentContext().setResponseStatusCode(status);
if (entity != null) {
RequestContext.getCurrentContext().setResponseDataStream(entity);
}
boolean isOriginResponseGzipped = false;
if (headers.containsKey(CONTENT_ENCODING)) {
Collection<String> collection = headers.get(CONTENT_ENCODING);
for (String header : collection) {
if (HTTPRequestUtils.getInstance().isGzipped(header)) {
isOriginResponseGzipped = true;
break;
}
}
}
context.setResponseGZipped(isOriginResponseGzipped);
for (Entry<String, Collection<String>> header : headers.entrySet()) {
RequestContext ctx = RequestContext.getCurrentContext();
String name = header.getKey();
for (String value : header.getValue()) {
ctx.addOriginResponseHeader(name, value);
if (name.equalsIgnoreCase("content-length"))
ctx.setOriginContentLength(value);
if (isIncludedHeader(name)) {
ctx.addZuulResponseHeader(name, value);
}
}
}
}
public boolean isIncludedHeader(String headerName) {
switch (headerName.toLowerCase()) { switch (headerName.toLowerCase()) {
case "host": case "host":
case "connection": case "connection":
...@@ -108,7 +146,7 @@ public abstract class BaseProxyFilter extends ZuulFilter { ...@@ -108,7 +146,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
} }
} }
protected Map<String, Object> debug(String verb, String uri, public Map<String, Object> debug(String verb, String uri,
MultivaluedMap<String, String> headers, MultivaluedMap<String, String> headers,
MultivaluedMap<String, String> params, InputStream requestEntity) MultivaluedMap<String, String> params, InputStream requestEntity)
throws IOException { throws IOException {
...@@ -156,7 +194,7 @@ public abstract class BaseProxyFilter extends ZuulFilter { ...@@ -156,7 +194,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
return info; return info;
} }
protected void appendDebug(Map<String, Object> info, int status, public void appendDebug(Map<String, Object> info, int status,
Map<String, Collection<String>> headers) { Map<String, Collection<String>> headers) {
if (traces != null) { if (traces != null) {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
...@@ -175,7 +213,7 @@ public abstract class BaseProxyFilter extends ZuulFilter { ...@@ -175,7 +213,7 @@ public abstract class BaseProxyFilter extends ZuulFilter {
} }
} }
protected void appendDebug(Map<String, Object> info, int status, public void appendDebug(Map<String, Object> info, int status,
MultivaluedMap<String, String> headers) { MultivaluedMap<String, String> headers) {
if (traces != null) { if (traces != null) {
Map<String, Collection<String>> map = new LinkedHashMap<String, Collection<String>>(); Map<String, Collection<String>> map = new LinkedHashMap<String, Collection<String>>();
......
...@@ -18,11 +18,12 @@ import com.netflix.client.http.HttpRequest.Verb; ...@@ -18,11 +18,12 @@ import com.netflix.client.http.HttpRequest.Verb;
import com.netflix.client.http.HttpResponse; import com.netflix.client.http.HttpResponse;
import com.netflix.hystrix.exception.HystrixRuntimeException; import com.netflix.hystrix.exception.HystrixRuntimeException;
import com.netflix.niws.client.http.RestClient; import com.netflix.niws.client.http.RestClient;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.exception.ZuulException; import com.netflix.zuul.exception.ZuulException;
import com.netflix.zuul.util.HTTPRequestUtils; import com.netflix.zuul.util.HTTPRequestUtils;
public class RibbonRoutingFilter extends BaseProxyFilter { public class RibbonRoutingFilter extends ZuulFilter {
private static final Logger LOG = LoggerFactory.getLogger(RibbonRoutingFilter.class); private static final Logger LOG = LoggerFactory.getLogger(RibbonRoutingFilter.class);
...@@ -30,10 +31,18 @@ public class RibbonRoutingFilter extends BaseProxyFilter { ...@@ -30,10 +31,18 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
private SpringClientFactory clientFactory; private SpringClientFactory clientFactory;
public RibbonRoutingFilter(SpringClientFactory clientFactory) { private ProxyRequestHelper helper;
public RibbonRoutingFilter(ProxyRequestHelper helper,
SpringClientFactory clientFactory) {
this.helper = helper;
this.clientFactory = clientFactory; this.clientFactory = clientFactory;
} }
public RibbonRoutingFilter(SpringClientFactory clientFactory) {
this(new ProxyRequestHelper(), clientFactory);
}
@Override @Override
public String filterType() { public String filterType() {
return "route"; return "route";
...@@ -54,8 +63,9 @@ public class RibbonRoutingFilter extends BaseProxyFilter { ...@@ -54,8 +63,9 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
RequestContext context = RequestContext.getCurrentContext(); RequestContext context = RequestContext.getCurrentContext();
HttpServletRequest request = context.getRequest(); HttpServletRequest request = context.getRequest();
MultivaluedMap<String, String> headers = buildZuulRequestHeaders(request); MultivaluedMap<String, String> headers = helper.buildZuulRequestHeaders(request);
MultivaluedMap<String, String> params = buildZuulRequestQueryParams(request); MultivaluedMap<String, String> params = helper
.buildZuulRequestQueryParams(request);
Verb verb = getVerb(request); Verb verb = getVerb(request);
InputStream requestEntity = getRequestBody(request); InputStream requestEntity = getRequestBody(request);
...@@ -88,13 +98,14 @@ public class RibbonRoutingFilter extends BaseProxyFilter { ...@@ -88,13 +98,14 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
MultivaluedMap<String, String> params, InputStream requestEntity) MultivaluedMap<String, String> params, InputStream requestEntity)
throws Exception { throws Exception {
Map<String, Object> info = debug(verb.verb(), uri, headers, params, requestEntity); Map<String, Object> info = helper.debug(verb.verb(), uri, headers, params,
requestEntity);
RibbonCommand command = new RibbonCommand(restClient, verb, uri, headers, params, RibbonCommand command = new RibbonCommand(restClient, verb, uri, headers, params,
requestEntity); requestEntity);
try { try {
HttpResponse response = command.execute(); HttpResponse response = command.execute();
appendDebug(info, response.getStatus(), response.getHeaders()); helper.appendDebug(info, response.getStatus(), response.getHeaders());
return response; return response;
} }
catch (HystrixRuntimeException e) { catch (HystrixRuntimeException e) {
...@@ -155,57 +166,9 @@ public class RibbonRoutingFilter extends BaseProxyFilter { ...@@ -155,57 +166,9 @@ public class RibbonRoutingFilter extends BaseProxyFilter {
return Verb.GET; return Verb.GET;
} }
void setResponse(HttpResponse resp) throws ClientException, IOException { private void setResponse(HttpResponse resp) throws ClientException, IOException {
RequestContext context = RequestContext.getCurrentContext(); helper.setResponse(resp.getStatus(),
!resp.hasEntity() ? null : resp.getInputStream(), resp.getHeaders());
context.setResponseStatusCode(resp.getStatus());
if (resp.hasEntity()) {
context.setResponseDataStream(resp.getInputStream());
}
String contentEncoding = null;
Collection<String> contentEncodingHeader = resp.getHeaders()
.get(CONTENT_ENCODING);
if (contentEncodingHeader != null && !contentEncodingHeader.isEmpty()) {
contentEncoding = contentEncodingHeader.iterator().next();
}
if (contentEncoding != null
&& HTTPRequestUtils.getInstance().isGzipped(contentEncoding)) {
context.setResponseGZipped(true);
}
else {
context.setResponseGZipped(false);
}
for (String key : resp.getHeaders().keySet()) {
boolean isValidHeader = isIncludedHeader(key);
Collection<java.lang.String> list = resp.getHeaders().get(key);
for (String header : list) {
context.addOriginResponseHeader(key, header);
if (key.equalsIgnoreCase("content-length"))
context.setOriginContentLength(header);
if (isValidHeader) {
context.addZuulResponseHeader(key, header);
}
}
}
}
private boolean isIncludedHeader(String headerName) {
switch (headerName.toLowerCase()) {
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
}
} }
} }
...@@ -14,6 +14,8 @@ import java.security.UnrecoverableKeyException; ...@@ -14,6 +14,8 @@ import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Timer; import java.util.Timer;
...@@ -57,17 +59,19 @@ import org.springframework.util.StringUtils; ...@@ -57,17 +59,19 @@ import org.springframework.util.StringUtils;
import com.netflix.config.DynamicIntProperty; import com.netflix.config.DynamicIntProperty;
import com.netflix.config.DynamicPropertyFactory; import com.netflix.config.DynamicPropertyFactory;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.constants.ZuulConstants; import com.netflix.zuul.constants.ZuulConstants;
import com.netflix.zuul.context.Debug; import com.netflix.zuul.context.Debug;
import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils; import com.netflix.zuul.util.HTTPRequestUtils;
public class SimpleHostRoutingFilter extends BaseProxyFilter { public class SimpleHostRoutingFilter extends ZuulFilter {
public static final String CONTENT_ENCODING = "Content-Encoding"; public static final String CONTENT_ENCODING = "Content-Encoding";
private static final Logger LOG = LoggerFactory private static final Logger LOG = LoggerFactory
.getLogger(SimpleHostRoutingFilter.class); .getLogger(SimpleHostRoutingFilter.class);
private static final Runnable CLIENTLOADER = new Runnable() { private static final Runnable CLIENTLOADER = new Runnable() {
@Override @Override
public void run() { public void run() {
...@@ -128,6 +132,16 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -128,6 +132,16 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
return cm; return cm;
} }
private ProxyRequestHelper helper;
public SimpleHostRoutingFilter() {
this(new ProxyRequestHelper());
}
public SimpleHostRoutingFilter(ProxyRequestHelper helper) {
this.helper = helper;
}
@Override @Override
public String filterType() { public String filterType() {
return "route"; return "route";
...@@ -201,8 +215,9 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -201,8 +215,9 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
public Object run() { public Object run() {
RequestContext context = RequestContext.getCurrentContext(); RequestContext context = RequestContext.getCurrentContext();
HttpServletRequest request = context.getRequest(); HttpServletRequest request = context.getRequest();
MultivaluedMap<String, String> headers = buildZuulRequestHeaders(request); MultivaluedMap<String, String> headers = helper.buildZuulRequestHeaders(request);
MultivaluedMap<String, String> params = buildZuulRequestQueryParams(request); MultivaluedMap<String, String> params = helper
.buildZuulRequestQueryParams(request);
String verb = getVerb(request); String verb = getVerb(request);
InputStream requestEntity = getRequestBody(request); InputStream requestEntity = getRequestBody(request);
HttpClient httpclient = CLIENT.get(); HttpClient httpclient = CLIENT.get();
...@@ -229,7 +244,8 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -229,7 +244,8 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
MultivaluedMap<String, String> params, InputStream requestEntity) MultivaluedMap<String, String> params, InputStream requestEntity)
throws Exception { throws Exception {
Map<String, Object> info = debug(verb, uri, headers, params, requestEntity); Map<String, Object> info = helper
.debug(verb, uri, headers, params, requestEntity);
URL host = RequestContext.getCurrentContext().getRouteHost(); URL host = RequestContext.getCurrentContext().getRouteHost();
HttpHost httpHost = getHttpHost(host); HttpHost httpHost = getHttpHost(host);
...@@ -237,7 +253,7 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -237,7 +253,7 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
HttpRequest httpRequest; HttpRequest httpRequest;
switch (verb) { switch (verb.toUpperCase()) {
case "POST": case "POST":
HttpPost httpPost = new HttpPost(uri + getQueryString()); HttpPost httpPost = new HttpPost(uri + getQueryString());
httpRequest = httpPost; httpRequest = httpPost;
...@@ -260,6 +276,8 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -260,6 +276,8 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
LOG.debug(httpHost.getHostName() + " " + httpHost.getPort() + " " LOG.debug(httpHost.getHostName() + " " + httpHost.getPort() + " "
+ httpHost.getSchemeName()); + httpHost.getSchemeName());
HttpResponse zuulResponse = forwardRequest(httpclient, httpHost, httpRequest); HttpResponse zuulResponse = forwardRequest(httpclient, httpHost, httpRequest);
helper.appendDebug(info, zuulResponse.getStatusLine().getStatusCode(),
revertHeaders(zuulResponse.getAllHeaders()));
return zuulResponse; return zuulResponse;
} }
finally { finally {
...@@ -271,6 +289,18 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -271,6 +289,18 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
} }
private Map<String, Collection<String>> revertHeaders(Header[] headers) {
Map<String, Collection<String>> map = new LinkedHashMap<String, Collection<String>>();
for (Header header : headers) {
String name = header.getName();
if (!map.containsKey(name)) {
map.put(name, new ArrayList<String>());
}
map.get(name).add(header.getValue());
}
return map;
}
private Header[] convertHeaders(MultivaluedMap<String, String> headers) { private Header[] convertHeaders(MultivaluedMap<String, String> headers) {
List<Header> list = new ArrayList<>(); List<Header> list = new ArrayList<>();
for (String name : headers.keySet()) { for (String name : headers.keySet()) {
...@@ -286,13 +316,13 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -286,13 +316,13 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
return httpclient.execute(httpHost, httpRequest); return httpclient.execute(httpHost, httpRequest);
} }
String getQueryString() { private String getQueryString() {
HttpServletRequest request = RequestContext.getCurrentContext().getRequest(); HttpServletRequest request = RequestContext.getCurrentContext().getRequest();
String query = request.getQueryString(); String query = request.getQueryString();
return (query != null) ? "?" + query : ""; return (query != null) ? "?" + query : "";
} }
HttpHost getHttpHost(URL host) { private HttpHost getHttpHost(URL host) {
HttpHost httpHost = new HttpHost(host.getHost(), host.getPort(), HttpHost httpHost = new HttpHost(host.getHost(), host.getPort(),
host.getProtocol()); host.getProtocol());
return httpHost; return httpHost;
...@@ -309,91 +339,15 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter { ...@@ -309,91 +339,15 @@ public class SimpleHostRoutingFilter extends BaseProxyFilter {
return requestEntity; return requestEntity;
} }
private boolean isIncludedHeader(String name) {
if (name.toLowerCase().contains("content-length"))
return false;
if (!RequestContext.getCurrentContext().getResponseGZipped()) {
if (name.toLowerCase().contains("accept-encoding"))
return false;
}
return true;
}
private String getVerb(HttpServletRequest request) { private String getVerb(HttpServletRequest request) {
String sMethod = request.getMethod(); String sMethod = request.getMethod();
return sMethod.toUpperCase(); return sMethod.toUpperCase();
} }
private void setResponse(HttpResponse response) throws IOException { private void setResponse(HttpResponse response) throws IOException {
RequestContext context = RequestContext.getCurrentContext(); helper.setResponse(response.getStatusLine().getStatusCode(),
response.getEntity() == null ? null : response.getEntity().getContent(),
RequestContext.getCurrentContext().set("hostZuulResponse", response); revertHeaders(response.getAllHeaders()));
RequestContext.getCurrentContext().setResponseStatusCode(
response.getStatusLine().getStatusCode());
if (response.getEntity() != null) {
RequestContext.getCurrentContext().setResponseDataStream(
response.getEntity().getContent());
}
boolean isOriginResponseGzipped = false;
for (Header h : response.getHeaders(CONTENT_ENCODING)) {
if (HTTPRequestUtils.getInstance().isGzipped(h.getValue())) {
isOriginResponseGzipped = true;
break;
}
}
context.setResponseGZipped(isOriginResponseGzipped);
if (Debug.debugRequest()) {
for (Header header : response.getAllHeaders()) {
if (isValidHeader(header)) {
RequestContext.getCurrentContext().addZuulResponseHeader(
header.getName(), header.getValue());
Debug.addRequestDebug("ORIGIN_RESPONSE:: < " + header.getName() + ","
+ header.getValue());
}
}
if (context.getResponseDataStream() != null) {
byte[] origBytes = IOUtils.toByteArray(context.getResponseDataStream());
ByteArrayInputStream byteStream = new ByteArrayInputStream(origBytes);
InputStream inputStream = byteStream;
if (RequestContext.getCurrentContext().getResponseGZipped()) {
inputStream = new GZIPInputStream(byteStream);
}
context.setResponseDataStream(new ByteArrayInputStream(origBytes));
}
}
else {
for (Header header : response.getAllHeaders()) {
RequestContext ctx = RequestContext.getCurrentContext();
ctx.addOriginResponseHeader(header.getName(), header.getValue());
if (header.getName().equalsIgnoreCase("content-length"))
ctx.setOriginContentLength(header.getValue());
if (isValidHeader(header)) {
ctx.addZuulResponseHeader(header.getName(), header.getValue());
}
}
}
}
boolean isValidHeader(Header header) {
switch (header.getName().toLowerCase()) {
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
}
} }
public static class MySSLSocketFactory extends SSLSocketFactory { public static class MySSLSocketFactory extends SSLSocketFactory {
......
package org.springframework.cloud.netflix.zuul; package org.springframework.cloud.netflix.zuul;
import java.util.Arrays;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.cloud.client.discovery.EnableDiscoveryClient; import org.springframework.cloud.client.discovery.EnableDiscoveryClient;
import org.springframework.cloud.netflix.ribbon.RibbonClient;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import com.netflix.appinfo.EurekaInstanceConfig;
import com.netflix.loadbalancer.BaseLoadBalancer;
import com.netflix.loadbalancer.ILoadBalancer;
import com.netflix.loadbalancer.Server;
import com.netflix.zuul.ZuulFilter; import com.netflix.zuul.ZuulFilter;
@SpringBootApplication // Don't use @SpringBootApplication because we don't want to component scan
@Configuration
@EnableAutoConfiguration
@RestController @RestController
@EnableZuulProxy @EnableZuulProxy
@EnableDiscoveryClient @EnableDiscoveryClient
@RibbonClient(name = "simple", configuration = SimpleRibbonClientConfiguration.class)
public class SampleZuulProxyApplication { public class SampleZuulProxyApplication {
@RequestMapping("/testing123") @RequestMapping("/testing123")
public String testing123() { public String testing123() {
throw new RuntimeException("myerror"); throw new RuntimeException("myerror");
} }
@RequestMapping("/local") @RequestMapping("/local")
public String local() { public String local() {
return "Hello local"; return "Hello local";
} }
@RequestMapping(value="/local/{id}", method=RequestMethod.DELETE) @RequestMapping(value = "/local/{id}", method = RequestMethod.DELETE)
public String delete() { public String delete() {
return "Deleted!"; return "Deleted!";
} }
@RequestMapping(value="/local/{id}", method=RequestMethod.GET) @RequestMapping(value = "/local/{id}", method = RequestMethod.GET)
public String get() { public String get() {
return "Gotten!"; return "Gotten!";
} }
@RequestMapping("/") @RequestMapping("/")
public String home() { public String home() {
return "Hello world"; return "Hello world";
} }
@Bean @Bean
public ZuulFilter sampleFilter() { public ZuulFilter sampleFilter() {
return new ZuulFilter() { return new ZuulFilter() {
...@@ -48,23 +59,38 @@ public class SampleZuulProxyApplication { ...@@ -48,23 +59,38 @@ public class SampleZuulProxyApplication {
public String filterType() { public String filterType() {
return "pre"; return "pre";
} }
@Override @Override
public boolean shouldFilter() { public boolean shouldFilter() {
return true; return true;
} }
@Override @Override
public Object run() { public Object run() {
return null; return null;
} }
@Override @Override
public int filterOrder() { public int filterOrder() {
return 0; return 0;
} }
}; };
} }
public static void main(String[] args) { public static void main(String[] args) {
SpringApplication.run(SampleZuulProxyApplication.class, args); SpringApplication.run(SampleZuulProxyApplication.class, args);
} }
} }
// Load balancer with fixed server list for "simple" pointing to localhost
@Configuration
class SimpleRibbonClientConfiguration {
@Bean
public ILoadBalancer ribbonLoadBalancer(EurekaInstanceConfig instance) {
BaseLoadBalancer balancer = new BaseLoadBalancer();
balancer.setServersList(Arrays.asList(new Server("localhost", instance
.getNonSecurePort())));
return balancer;
}
}
\ No newline at end of file
...@@ -46,6 +46,16 @@ public class SampleZuulProxyApplicationTests { ...@@ -46,6 +46,16 @@ public class SampleZuulProxyApplicationTests {
} }
@Test @Test
public void getOnSelfViaRibbonRoutingFilter() {
mapping.reset();
ResponseEntity<String> result = new TestRestTemplate().exchange(
"http://localhost:" + port + "/simple/local/1", HttpMethod.GET,
new HttpEntity<Void>((Void) null), String.class);
assertEquals(HttpStatus.OK, result.getStatusCode());
assertEquals("Gotten!", result.getBody());
}
@Test
public void deleteOnSelfViaSimpleHostRoutingFilter() { public void deleteOnSelfViaSimpleHostRoutingFilter() {
routes.addRoute("/self/**", "http://localhost:" + port + "/local"); routes.addRoute("/self/**", "http://localhost:" + port + "/local");
mapping.reset(); mapping.reset();
......
server: server:
port: 9999 port: 9999
logging:
level:
org.springframework.web: DEBUG
spring: spring:
application: application:
name: testclient name: testclient
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment