Commit c883495f by Stephen Oakey Committed by Spencer Gibb

Detect https from Ribbon Client Config

RibbonLoadBalancer and RibbonLoadBalancer use IClientConfig IsSecure property for service ID to determine if the service is secure. Replaces scheme of URI of request with https if the request is secure. fixes gh-337
parent ddb09c49
...@@ -21,6 +21,8 @@ import java.net.URI; ...@@ -21,6 +21,8 @@ import java.net.URI;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
import org.springframework.web.util.UriComponentsBuilder;
import com.netflix.client.AbstractLoadBalancerAwareClient; import com.netflix.client.AbstractLoadBalancerAwareClient;
import com.netflix.client.ClientException; import com.netflix.client.ClientException;
import com.netflix.client.ClientRequest; import com.netflix.client.ClientRequest;
...@@ -48,11 +50,14 @@ public class RibbonLoadBalancer ...@@ -48,11 +50,14 @@ public class RibbonLoadBalancer
private final IClientConfig clientConfig; private final IClientConfig clientConfig;
private final boolean secure;
public RibbonLoadBalancer(Client delegate, ILoadBalancer lb, public RibbonLoadBalancer(Client delegate, ILoadBalancer lb,
IClientConfig clientConfig) { IClientConfig clientConfig) {
super(lb, clientConfig); super(lb, clientConfig);
this.setRetryHandler(RetryHandler.DEFAULT); this.setRetryHandler(RetryHandler.DEFAULT);
this.clientConfig = clientConfig; this.clientConfig = clientConfig;
this.secure = clientConfig.get(CommonClientConfigKey.IsSecure);
this.delegate = delegate; this.delegate = delegate;
this.connectTimeout = clientConfig.get(CommonClientConfigKey.ConnectTimeout); this.connectTimeout = clientConfig.get(CommonClientConfigKey.ConnectTimeout);
this.readTimeout = clientConfig.get(CommonClientConfigKey.ReadTimeout); this.readTimeout = clientConfig.get(CommonClientConfigKey.ReadTimeout);
...@@ -71,10 +76,19 @@ public class RibbonLoadBalancer ...@@ -71,10 +76,19 @@ public class RibbonLoadBalancer
else { else {
options = new Request.Options(this.connectTimeout, this.readTimeout); options = new Request.Options(this.connectTimeout, this.readTimeout);
} }
if (isSecure(configOverride)) {
URI secureUri = UriComponentsBuilder.fromUri(request.getUri())
.scheme("https").build().toUri();
request = new RibbonRequest(request.toRequest(), secureUri);
}
Response response = this.delegate.execute(request.toRequest(), options); Response response = this.delegate.execute(request.toRequest(), options);
return new RibbonResponse(request.getUri(), response); return new RibbonResponse(request.getUri(), response);
} }
private boolean isSecure(IClientConfig config) {
return (config != null) ? config.get(CommonClientConfigKey.IsSecure) : secure;
}
@Override @Override
public RequestSpecificRetryHandler getRequestSpecificRetryHandler( public RequestSpecificRetryHandler getRequestSpecificRetryHandler(
RibbonRequest request, IClientConfig requestConfig) { RibbonRequest request, IClientConfig requestConfig) {
......
...@@ -25,7 +25,10 @@ import org.springframework.cloud.client.loadbalancer.LoadBalancerClient; ...@@ -25,7 +25,10 @@ import org.springframework.cloud.client.loadbalancer.LoadBalancerClient;
import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest; import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.web.util.UriComponentsBuilder;
import com.netflix.client.config.CommonClientConfigKey;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.ILoadBalancer; import com.netflix.loadbalancer.ILoadBalancer;
import com.netflix.loadbalancer.Server; import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerStats; import com.netflix.loadbalancer.ServerStats;
...@@ -50,7 +53,12 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { ...@@ -50,7 +53,12 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
RibbonLoadBalancerContext context = this.clientFactory RibbonLoadBalancerContext context = this.clientFactory
.getLoadBalancerContext(serviceId); .getLoadBalancerContext(serviceId);
Server server = new Server(instance.getHost(), instance.getPort()); Server server = new Server(instance.getHost(), instance.getPort());
return context.reconstructURIWithServer(server, original); boolean secure = isSecure(this.clientFactory, serviceId);
URI uri = original;
if(secure) {
uri = UriComponentsBuilder.fromUri(uri).scheme("https").build().toUri();
}
return context.reconstructURIWithServer(server, uri);
} }
@Override @Override
...@@ -59,7 +67,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { ...@@ -59,7 +67,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
if (server == null) { if (server == null) {
return null; return null;
} }
return new RibbonServer(serviceId, server); return new RibbonServer(serviceId, server, isSecure(this.clientFactory, serviceId));
} }
@Override @Override
...@@ -68,7 +76,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { ...@@ -68,7 +76,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
RibbonLoadBalancerContext context = this.clientFactory RibbonLoadBalancerContext context = this.clientFactory
.getLoadBalancerContext(serviceId); .getLoadBalancerContext(serviceId);
Server server = getServer(loadBalancer); Server server = getServer(loadBalancer);
RibbonServer ribbonServer = new RibbonServer(serviceId, server); RibbonServer ribbonServer = new RibbonServer(serviceId, server, isSecure(clientFactory, serviceId));
ServerStats serverStats = context.getServerStats(server); ServerStats serverStats = context.getServerStats(server);
context.noteOpenConnection(serverStats); context.noteOpenConnection(serverStats);
...@@ -86,6 +94,14 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { ...@@ -86,6 +94,14 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
return null; return null;
} }
private boolean isSecure(SpringClientFactory clientFactory, String serviceId) {
IClientConfig config = clientFactory.getClientConfig(serviceId);
if(config != null) {
return config.get(CommonClientConfigKey.IsSecure, false);
}
return false;
}
private void recordStats(RibbonLoadBalancerContext context, Stopwatch tracer, private void recordStats(RibbonLoadBalancerContext context, Stopwatch tracer,
ServerStats serverStats, Object entity, Throwable exception) { ServerStats serverStats, Object entity, Throwable exception) {
tracer.stop(); tracer.stop();
...@@ -111,8 +127,13 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { ...@@ -111,8 +127,13 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
protected static class RibbonServer implements ServiceInstance { protected static class RibbonServer implements ServiceInstance {
private String serviceId; private String serviceId;
private Server server; private Server server;
private boolean secure;
protected RibbonServer(String serviceId, Server server) { protected RibbonServer(String serviceId, Server server) {
this(serviceId, server, false);
}
protected RibbonServer(String serviceId, Server server, boolean secure) {
this.serviceId = serviceId; this.serviceId = serviceId;
this.server = server; this.server = server;
} }
...@@ -134,7 +155,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { ...@@ -134,7 +155,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
@Override @Override
public boolean isSecure() { public boolean isSecure() {
return false; //TODO: howto determine https from ribbon Server return this.secure;
} }
@Override @Override
......
package org.springframework.cloud.netflix.feign.ribbon;
import static com.netflix.client.config.CommonClientConfigKey.ConnectTimeout;
import static com.netflix.client.config.CommonClientConfigKey.IsSecure;
import static com.netflix.client.config.CommonClientConfigKey.MaxAutoRetries;
import static com.netflix.client.config.CommonClientConfigKey.MaxAutoRetriesNextServer;
import static com.netflix.client.config.CommonClientConfigKey.OkToRetryOnAllOperations;
import static com.netflix.client.config.CommonClientConfigKey.ReadTimeout;
import static com.netflix.client.config.DefaultClientConfigImpl.DEFAULT_MAX_AUTO_RETRIES;
import static com.netflix.client.config.DefaultClientConfigImpl.DEFAULT_MAX_AUTO_RETRIES_NEXT_SERVER;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import lombok.SneakyThrows;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.cloud.netflix.feign.ribbon.RibbonLoadBalancer.RibbonRequest;
import org.springframework.cloud.netflix.feign.ribbon.RibbonLoadBalancer.RibbonResponse;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.ILoadBalancer;
import feign.Client;
import feign.Request;
import feign.Request.Options;
import feign.RequestTemplate;
import feign.Response;
public class RibbonLoadBalancerTests {
@Mock
private Client delegate;
@Mock
private ILoadBalancer lb;
@Mock
private IClientConfig config;
private RibbonLoadBalancer ribbonLoadBalancer;
private Integer defaultConnectTimeout = 10000;
private Integer defaultReadTimeout = 10000;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
when(config.get(MaxAutoRetries, DEFAULT_MAX_AUTO_RETRIES)).thenReturn(1);
when(config.get(MaxAutoRetriesNextServer, DEFAULT_MAX_AUTO_RETRIES_NEXT_SERVER))
.thenReturn(1);
when(config.get(OkToRetryOnAllOperations, eq(anyBoolean()))).thenReturn(true);
when(config.get(ConnectTimeout)).thenReturn(defaultConnectTimeout);
when(config.get(ReadTimeout)).thenReturn(defaultReadTimeout);
}
@Test
@SneakyThrows
public void testUriInsecure() {
when(config.get(IsSecure)).thenReturn(false);
ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config);
Request request = new RequestTemplate().method("GET").append("http://foo/")
.request();
RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url()));
Response response = Response.create(200, "Test",
Collections.<String, Collection<String>> emptyMap(), new byte[0]);
when(delegate.execute(any(Request.class), any(Options.class))).thenReturn(
response);
RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, null);
assertThat(resp.getRequestedURI(), is(new URI("http://foo/")));
}
@Test
@SneakyThrows
public void testSecureUriFromClientConfig() {
when(config.get(IsSecure)).thenReturn(true);
ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config);
Request request = new RequestTemplate().method("GET").append("http://foo/")
.request();
RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url()));
Response response = Response.create(200, "Test",
Collections.<String, Collection<String>> emptyMap(), new byte[0]);
when(delegate.execute(any(Request.class), any(Options.class))).thenReturn(
response);
RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, null);
assertThat(resp.getRequestedURI(), is(new URI("https://foo/")));
}
@Test
@SneakyThrows
public void testSecureUriFromClientConfigOverride() {
when(config.get(IsSecure)).thenReturn(true);
ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config);
Request request = new RequestTemplate().method("GET").append("http://foo/")
.request();
RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url()));
Response response = Response.create(200, "Test",
Collections.<String, Collection<String>> emptyMap(), new byte[0]);
when(delegate.execute(any(Request.class), any(Options.class))).thenReturn(
response);
IClientConfig override = mock(IClientConfig.class);
when(override.get(ConnectTimeout, defaultConnectTimeout)).thenReturn(5000);
when(override.get(ReadTimeout, defaultConnectTimeout)).thenReturn(5000);
/*
* Override secure value.
*/
when(override.get(IsSecure)).thenReturn(false);
RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, override);
assertThat(resp.getRequestedURI(), is(new URI("http://foo/")));
}
}
...@@ -16,10 +16,23 @@ ...@@ -16,10 +16,23 @@
package org.springframework.cloud.netflix.ribbon; package org.springframework.cloud.netflix.ribbon;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyDouble;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.net.URI; import java.net.URI;
import java.net.URL; import java.net.URL;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
...@@ -28,20 +41,13 @@ import org.springframework.cloud.client.ServiceInstance; ...@@ -28,20 +41,13 @@ import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest; import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest;
import org.springframework.cloud.netflix.ribbon.RibbonLoadBalancerClient.RibbonServer; import org.springframework.cloud.netflix.ribbon.RibbonLoadBalancerClient.RibbonServer;
import com.netflix.client.config.CommonClientConfigKey;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.BaseLoadBalancer; import com.netflix.loadbalancer.BaseLoadBalancer;
import com.netflix.loadbalancer.LoadBalancerStats; import com.netflix.loadbalancer.LoadBalancerStats;
import com.netflix.loadbalancer.Server; import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerStats; import com.netflix.loadbalancer.ServerStats;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import static org.junit.Assert.assertNull;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyDouble;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.verify;
/** /**
* @author Spencer Gibb * @author Spencer Gibb
*/ */
...@@ -81,10 +87,27 @@ public class RibbonLoadBalancerClientTests { ...@@ -81,10 +87,27 @@ public class RibbonLoadBalancerClientTests {
RibbonServer server = getRibbonServer(); RibbonServer server = getRibbonServer();
RibbonLoadBalancerClient client = getRibbonLoadBalancerClient(server); RibbonLoadBalancerClient client = getRibbonLoadBalancerClient(server);
ServiceInstance serviceInstance = client.choose(server.getServiceId()); ServiceInstance serviceInstance = client.choose(server.getServiceId());
URI uri = client.reconstructURI(serviceInstance, new URL(scheme +"://" URI uri = client.reconstructURI(serviceInstance,
+ server.getServiceId()).toURI()); new URL(scheme + "://" + server.getServiceId()).toURI());
assertEquals(server.getHost(), uri.getHost());
assertEquals(server.getPort(), uri.getPort());
}
@Test
@SneakyThrows
public void testReconstructUriWithSecureClientConfig() {
RibbonServer server = getRibbonServer();
IClientConfig config = mock(IClientConfig.class);
when(config.get(CommonClientConfigKey.IsSecure, false)).thenReturn(true);
when(clientFactory.getClientConfig(server.getServiceId())).thenReturn(config);
RibbonLoadBalancerClient client = getRibbonLoadBalancerClient(server);
ServiceInstance serviceInstance = client.choose(server.getServiceId());
URI uri = client.reconstructURI(serviceInstance,
new URL("http://" + server.getServiceId()).toURI());
assertEquals(server.getHost(), uri.getHost()); assertEquals(server.getHost(), uri.getHost());
assertEquals(server.getPort(), uri.getPort()); assertEquals(server.getPort(), uri.getPort());
assertEquals("https", uri.getScheme());
} }
@Test @Test
...@@ -166,8 +189,8 @@ public class RibbonLoadBalancerClientTests { ...@@ -166,8 +189,8 @@ public class RibbonLoadBalancerClientTests {
protected RibbonLoadBalancerClient getRibbonLoadBalancerClient( protected RibbonLoadBalancerClient getRibbonLoadBalancerClient(
RibbonServer ribbonServer) { RibbonServer ribbonServer) {
given(this.loadBalancer.getName()).willReturn(ribbonServer.getServiceId()); given(this.loadBalancer.getName()).willReturn(ribbonServer.getServiceId());
given(this.loadBalancer.chooseServer(anyString())) given(this.loadBalancer.chooseServer(anyObject())).willReturn(
.willReturn(ribbonServer.getServer()); ribbonServer.getServer());
given(this.loadBalancer.getLoadBalancerStats()) given(this.loadBalancer.getLoadBalancerStats())
.willReturn(this.loadBalancerStats); .willReturn(this.loadBalancerStats);
given(this.loadBalancerStats.getSingleServerStat(ribbonServer.getServer())) given(this.loadBalancerStats.getSingleServerStat(ribbonServer.getServer()))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment