Commit 73e7adbe by Dave Syer

Extract base class with params and headers for proxy filters

Fixes gh-103
parent 096a8fa7
sudo: false
cache:
directories:
- $HOME/.m2
language: java
before_install:
- git config user.name "$GIT_NAME"
......
/*
* Copyright 2013-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.netflix.zuul.filters.route;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.MultivaluedMap;
import org.apache.commons.io.IOUtils;
import org.springframework.boot.actuate.trace.TraceRepository;
import org.springframework.util.StringUtils;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils;
import com.sun.jersey.core.util.MultivaluedMapImpl;
/**
* @author Dave Syer
*
*/
public abstract class BaseProxyFilter extends ZuulFilter {
public static final String CONTENT_ENCODING = "Content-Encoding";
private TraceRepository traces;
public void setTraces(TraceRepository traces) {
this.traces = traces;
}
protected MultivaluedMap<String, String> buildZuulRequestQueryParams(
HttpServletRequest request) {
Map<String, List<String>> map = HTTPRequestUtils.getInstance().getQueryParams();
MultivaluedMap<String, String> params = new MultivaluedMapImpl();
if (map == null)
return params;
for (String key : map.keySet()) {
for (String value : map.get(key)) {
params.add(key, value);
}
}
return params;
}
protected MultivaluedMap<String, String> buildZuulRequestHeaders(
HttpServletRequest request) {
RequestContext context = RequestContext.getCurrentContext();
MultivaluedMap<String, String> headers = new MultivaluedMapImpl();
Enumeration<?> headerNames = request.getHeaderNames();
if (headerNames != null) {
while (headerNames.hasMoreElements()) {
String name = (String) headerNames.nextElement();
String value = request.getHeader(name);
if (isIncludedHeader(name))
headers.putSingle(name, value);
}
}
Map<String, String> zuulRequestHeaders = context.getZuulRequestHeaders();
for (String header : zuulRequestHeaders.keySet()) {
headers.putSingle(header, zuulRequestHeaders.get(header));
}
headers.putSingle("accept-encoding", "deflate, gzip");
return headers;
}
private boolean isIncludedHeader(String headerName) {
switch (headerName.toLowerCase()) {
case "host":
case "connection":
case "content-length":
case "content-encoding":
case "server":
case "transfer-encoding":
return false;
default:
return true;
}
}
protected Map<String, Object> debug(String verb, String uri,
MultivaluedMap<String, String> headers,
MultivaluedMap<String, String> params, InputStream requestEntity)
throws IOException {
Map<String, Object> info = new LinkedHashMap<String, Object>();
if (traces != null) {
RequestContext context = RequestContext.getCurrentContext();
info.put("remote", true);
info.put("serviceId", context.get("serviceId"));
Map<String, Object> trace = new LinkedHashMap<String, Object>();
Map<String, Object> input = new LinkedHashMap<String, Object>();
trace.put("request", input);
info.put("headers", trace);
for (Entry<String, List<String>> entry : headers.entrySet()) {
Collection<String> collection = entry.getValue();
Object value = collection;
if (collection.size() < 2) {
value = collection.isEmpty() ? "" : collection.iterator().next();
}
input.put(entry.getKey(), value);
}
StringBuilder query = new StringBuilder();
for (String param : params.keySet()) {
for (String value : params.get(param)) {
query.append(param);
query.append("=");
query.append(value);
query.append("&");
}
}
info.put("method", verb);
info.put("uri", uri);
info.put("query", query.toString());
RequestContext ctx = RequestContext.getCurrentContext();
if (!ctx.isChunkedRequestBody()) {
if (requestEntity != null) {
debugRequestEntity(info, ctx.getRequest().getInputStream());
}
}
traces.add(info);
return info;
}
return info;
}
protected void appendDebug(Map<String, Object> info, int status,
Map<String, Collection<String>> headers) {
if (traces != null) {
@SuppressWarnings("unchecked")
Map<String, Object> trace = (Map<String, Object>) info.get("headers");
Map<String, Object> output = new LinkedHashMap<String, Object>();
trace.put("response", output);
info.put("status", "" + status);
for (Entry<String, Collection<String>> key : headers.entrySet()) {
Collection<String> collection = key.getValue();
Object value = collection;
if (collection.size() < 2) {
value = collection.isEmpty() ? "" : collection.iterator().next();
}
output.put(key.getKey(), value);
}
}
}
protected void appendDebug(Map<String, Object> info, int status,
MultivaluedMap<String, String> headers) {
if (traces != null) {
Map<String, Collection<String>> map = new LinkedHashMap<String, Collection<String>>();
for (Entry<String, List<String>> key : headers.entrySet()) {
Collection<String> collection = key.getValue();
map.put(key.getKey(), collection);
}
appendDebug(info, status, map);
}
}
private void debugRequestEntity(Map<String, Object> info, InputStream inputStream)
throws IOException {
String entity = IOUtils.toString(inputStream);
if (StringUtils.hasText(entity)) {
info.put("body", entity);
}
}
}
\ No newline at end of file
package org.springframework.cloud.netflix.zuul.filters.route;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import javax.ws.rs.core.MultivaluedMap;
import com.netflix.client.http.HttpRequest;
import com.netflix.client.http.HttpRequest.Builder;
import com.netflix.client.http.HttpRequest.Verb;
import com.netflix.client.http.HttpResponse;
import com.netflix.config.DynamicPropertyFactory;
import com.netflix.hystrix.HystrixCommand;
......@@ -10,15 +19,6 @@ import com.netflix.niws.client.http.RestClient;
import com.netflix.zuul.constants.ZuulConstants;
import com.netflix.zuul.context.RequestContext;
import javax.ws.rs.core.MultivaluedMap;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import static com.netflix.client.http.HttpRequest.Builder;
import static com.netflix.client.http.HttpRequest.Verb;
/**
* Hystrix wrapper around Eureka Ribbon command
*
......@@ -26,12 +26,12 @@ import static com.netflix.client.http.HttpRequest.Verb;
*/
public class RibbonCommand extends HystrixCommand<HttpResponse> {
RestClient restClient;
Verb verb;
URI uri;
MultivaluedMap<String, String> headers;
MultivaluedMap<String, String> params;
InputStream requestEntity;
private RestClient restClient;
private Verb verb;
private URI uri;
private MultivaluedMap<String, String> headers;
private MultivaluedMap<String, String> params;
private InputStream requestEntity;
public RibbonCommand(RestClient restClient,
Verb verb,
......
......@@ -3,52 +3,37 @@ package org.springframework.cloud.netflix.zuul.filters.route;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MultivaluedMap;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.actuate.trace.TraceRepository;
import org.springframework.cloud.netflix.ribbon.SpringClientFactory;
import org.springframework.util.StringUtils;
import com.netflix.client.ClientException;
import com.netflix.client.http.HttpRequest.Verb;
import com.netflix.client.http.HttpResponse;
import com.netflix.hystrix.exception.HystrixRuntimeException;
import com.netflix.niws.client.http.RestClient;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.exception.ZuulException;
import com.netflix.zuul.util.HTTPRequestUtils;
import com.sun.jersey.core.util.MultivaluedMapImpl;
public class RibbonRoutingFilter extends ZuulFilter {
public class RibbonRoutingFilter extends BaseProxyFilter {
private static final Logger LOG = LoggerFactory.getLogger(RibbonRoutingFilter.class);
public static final String CONTENT_ENCODING = "Content-Encoding";
private TraceRepository traces;
private SpringClientFactory clientFactory;
public RibbonRoutingFilter(SpringClientFactory clientFactory) {
this.clientFactory = clientFactory;
}
public void setTraces(TraceRepository traces) {
this.traces = traces;
}
@Override
public String filterType() {
return "route";
......@@ -93,94 +78,23 @@ public class RibbonRoutingFilter extends ZuulFilter {
}
catch (Exception e) {
context.set("error.status_code", HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.exception", e);
}
return null;
}
private Map<String, Object> debug(Verb verb, String uri,
MultivaluedMap<String, String> headers,
MultivaluedMap<String, String> params, InputStream requestEntity)
throws IOException {
Map<String, Object> info = new LinkedHashMap<String, Object>();
if (traces != null) {
RequestContext context = RequestContext.getCurrentContext();
info.put("remote", true);
info.put("serviceId", context.get("serviceId"));
Map<String, Object> trace = new LinkedHashMap<String, Object>();
Map<String, Object> input = new LinkedHashMap<String, Object>();
trace.put("request", input);
info.put("headers", trace);
for (Entry<String, List<String>> entry : headers.entrySet()) {
Collection<String> collection = entry.getValue();
Object value = collection;
if (collection.size() < 2) {
value = collection.isEmpty() ? "" : collection.iterator().next();
}
input.put(entry.getKey(), value);
}
StringBuilder query = new StringBuilder();
for (String param : params.keySet()) {
for (String value : params.get(param)) {
query.append(param);
query.append("=");
query.append(value);
query.append("&");
}
}
info.put("method", verb.verb());
info.put("uri", uri);
info.put("query", query.toString());
RequestContext ctx = RequestContext.getCurrentContext();
if (!ctx.isChunkedRequestBody()) {
if (requestEntity != null) {
debugRequestEntity(info, ctx.getRequest().getInputStream());
}
}
traces.add(info);
return info;
}
return info;
}
private void debugRequestEntity(Map<String, Object> info, InputStream inputStream)
throws IOException {
String entity = IOUtils.toString(inputStream);
if (StringUtils.hasText(entity)) {
info.put("body", entity);
context.set("error.exception", e);
}
return null;
}
private HttpResponse forward(RestClient restClient, Verb verb, String uri,
MultivaluedMap<String, String> headers,
MultivaluedMap<String, String> params, InputStream requestEntity)
throws Exception {
Map<String, Object> info = debug(verb, uri, headers, params, requestEntity);
Map<String, Object> info = debug(verb.verb(), uri, headers, params, requestEntity);
RibbonCommand command = new RibbonCommand(restClient, verb, uri, headers, params,
requestEntity);
try {
HttpResponse response = command.execute();
if (traces != null) {
@SuppressWarnings("unchecked")
Map<String, Object> trace = (Map<String, Object>) info.get("headers");
Map<String, Object> output = new LinkedHashMap<String, Object>();
trace.put("response", output);
info.put("status", ""+response.getStatus());
for (Entry<String, Collection<String>> key : response.getHeaders()
.entrySet()) {
Collection<String> collection = key.getValue();
Object value = collection;
if (collection.size() < 2) {
value = collection.isEmpty() ? "" : collection.iterator().next();
}
output.put(key.getKey(), value);
}
}
appendDebug(info, response.getStatus(), response.getHeaders());
return response;
}
catch (HystrixRuntimeException e) {
......@@ -201,10 +115,10 @@ public class RibbonRoutingFilter extends ZuulFilter {
private InputStream getRequestBody(HttpServletRequest request) {
InputStream requestEntity = null;
//ApacheHttpClient4Handler does not support body in delete requests
if (request.getMethod().equals("DELETE")) {
return null;
}
// ApacheHttpClient4Handler does not support body in delete requests
if (request.getMethod().equals("DELETE")) {
return null;
}
try {
requestEntity = (InputStream) RequestContext.getCurrentContext().get(
"requestEntity");
......@@ -219,56 +133,6 @@ public class RibbonRoutingFilter extends ZuulFilter {
return requestEntity;
}
private MultivaluedMap<String, String> buildZuulRequestQueryParams(
HttpServletRequest request) {
Map<String, List<String>> map = HTTPRequestUtils.getInstance().getQueryParams();
MultivaluedMap<String, String> params = new MultivaluedMapImpl();
if (map == null)
return params;
for (String key : map.keySet()) {
for (String value : map.get(key)) {
params.add(key, value);
}
}
return params;
}
private MultivaluedMap<String, String> buildZuulRequestHeaders(
HttpServletRequest request) {
RequestContext context = RequestContext.getCurrentContext();
MultivaluedMap<String, String> headers = new MultivaluedMapImpl();
Enumeration<?> headerNames = request.getHeaderNames();
if (headerNames != null) {
while (headerNames.hasMoreElements()) {
String name = (String) headerNames.nextElement();
String value = request.getHeader(name);
if (!name.toLowerCase().contains("content-length"))
headers.putSingle(name, value);
}
}
Map<String, String> zuulRequestHeaders = context.getZuulRequestHeaders();
for (String header : zuulRequestHeaders.keySet()) {
headers.putSingle(header, zuulRequestHeaders.get(header));
}
headers.putSingle("accept-encoding", "deflate, gzip");
if (headers.containsKey("transfer-encoding"))
headers.remove("transfer-encoding");
if (headers.containsKey("host"))
headers.remove("host");
return headers;
}
Verb getVerb(HttpServletRequest request) {
String sMethod = request.getMethod();
return getVerb(sMethod);
......
......@@ -14,19 +14,19 @@ import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.atomic.AtomicReference;
import java.util.zip.GZIPInputStream;
import javax.annotation.Nullable;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MultivaluedMap;
import org.apache.commons.io.IOUtils;
import org.apache.http.Header;
......@@ -55,18 +55,14 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.Iterables;
import com.netflix.config.DynamicIntProperty;
import com.netflix.config.DynamicPropertyFactory;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.constants.ZuulConstants;
import com.netflix.zuul.context.Debug;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils;
public class SimpleHostRoutingFilter extends ZuulFilter {
public class SimpleHostRoutingFilter extends BaseProxyFilter {
public static final String CONTENT_ENCODING = "Content-Encoding";
......@@ -205,7 +201,8 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
public Object run() {
RequestContext context = RequestContext.getCurrentContext();
HttpServletRequest request = context.getRequest();
Header[] headers = buildZuulRequestHeaders(request);
MultivaluedMap<String, String> headers = buildZuulRequestHeaders(request);
MultivaluedMap<String, String> params = buildZuulRequestQueryParams(request);
String verb = getVerb(request);
InputStream requestEntity = getRequestBody(request);
HttpClient httpclient = CLIENT.get();
......@@ -217,20 +214,23 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
try {
HttpResponse response = forward(httpclient, verb, uri, request, headers,
requestEntity);
params, requestEntity);
setResponse(response);
}
catch (Exception e) {
context.set("error.status_code", HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
context.set("error.exception", e);
context.set("error.exception", e);
}
return null;
}
private HttpResponse forward(HttpClient httpclient, String verb, String uri,
HttpServletRequest request, Header[] headers, InputStream requestEntity)
HttpServletRequest request, MultivaluedMap<String, String> headers,
MultivaluedMap<String, String> params, InputStream requestEntity)
throws Exception {
Map<String, Object> info = debug(verb, uri, headers, params, requestEntity);
URL host = RequestContext.getCurrentContext().getRouteHost();
HttpHost httpHost = getHttpHost(host);
uri = StringUtils.cleanPath(host.getPath() + uri);
......@@ -256,7 +256,7 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
}
try {
httpRequest.setHeaders(headers);
httpRequest.setHeaders(convertHeaders(headers));
LOG.debug(httpHost.getHostName() + " " + httpHost.getPort() + " "
+ httpHost.getSchemeName());
HttpResponse zuulResponse = forwardRequest(httpclient, httpHost, httpRequest);
......@@ -271,6 +271,16 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
}
private Header[] convertHeaders(MultivaluedMap<String, String> headers) {
List<Header> list = new ArrayList<>();
for (String name : headers.keySet()) {
for (String value : headers.get(name)) {
list.add(new BasicHeader(name, value));
}
}
return list.toArray(new BasicHeader[0]);
}
private HttpResponse forwardRequest(HttpClient httpclient, HttpHost httpHost,
HttpRequest httpRequest) throws IOException {
return httpclient.execute(httpHost, httpRequest);
......@@ -309,62 +319,11 @@ public class SimpleHostRoutingFilter extends ZuulFilter {
return true;
}
private Header[] buildZuulRequestHeaders(HttpServletRequest request) {
ArrayList<Header> headers = new ArrayList<>();
Enumeration headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String name = (String) headerNames.nextElement();
String value = request.getHeader(name);
if (isIncludedHeader(name))
headers.add(new BasicHeader(name, value));
}
Map<String, String> zuulRequestHeaders = RequestContext.getCurrentContext()
.getZuulRequestHeaders();
for (String it : zuulRequestHeaders.keySet()) {
final String name = it.toLowerCase();
Optional<Header> h = Iterables.tryFind(headers, new Predicate<Header>() {
@Override
public boolean apply(@Nullable Header input) {
return input.getName().equals(name);
}
});
if (h.isPresent()) {
headers.remove(h);
}
headers.add(new BasicHeader(it, zuulRequestHeaders.get(it)));
}
if (RequestContext.getCurrentContext().getResponseGZipped()) {
headers.add(new BasicHeader("accept-encoding", "deflate, gzip"));
}
return headers.toArray(new Header[0]);
}
private String getVerb(HttpServletRequest request) {
String sMethod = request.getMethod();
return sMethod.toUpperCase();
}
private String getVerb(String sMethod) {
if (sMethod == null)
return "GET";
sMethod = sMethod.toLowerCase();
if (sMethod.equalsIgnoreCase("post"))
return "POST";
if (sMethod.equalsIgnoreCase("put"))
return "PUT";
if (sMethod.equalsIgnoreCase("delete"))
return "DELETE";
if (sMethod.equalsIgnoreCase("options"))
return "OPTIONS";
if (sMethod.equalsIgnoreCase("head"))
return "HEAD";
return "GET";
}
private void setResponse(HttpResponse response) throws IOException {
RequestContext context = RequestContext.getCurrentContext();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment