Commit c883495f by Stephen Oakey Committed by Spencer Gibb

Detect https from Ribbon Client Config

RibbonLoadBalancer and RibbonLoadBalancer use IClientConfig IsSecure property for service ID to determine if the service is secure. Replaces scheme of URI of request with https if the request is secure. fixes gh-337
parent ddb09c49
......@@ -21,6 +21,8 @@ import java.net.URI;
import java.util.Collection;
import java.util.Map;
import org.springframework.web.util.UriComponentsBuilder;
import com.netflix.client.AbstractLoadBalancerAwareClient;
import com.netflix.client.ClientException;
import com.netflix.client.ClientRequest;
......@@ -48,11 +50,14 @@ public class RibbonLoadBalancer
private final IClientConfig clientConfig;
private final boolean secure;
public RibbonLoadBalancer(Client delegate, ILoadBalancer lb,
IClientConfig clientConfig) {
super(lb, clientConfig);
this.setRetryHandler(RetryHandler.DEFAULT);
this.clientConfig = clientConfig;
this.secure = clientConfig.get(CommonClientConfigKey.IsSecure);
this.delegate = delegate;
this.connectTimeout = clientConfig.get(CommonClientConfigKey.ConnectTimeout);
this.readTimeout = clientConfig.get(CommonClientConfigKey.ReadTimeout);
......@@ -71,10 +76,19 @@ public class RibbonLoadBalancer
else {
options = new Request.Options(this.connectTimeout, this.readTimeout);
}
if (isSecure(configOverride)) {
URI secureUri = UriComponentsBuilder.fromUri(request.getUri())
.scheme("https").build().toUri();
request = new RibbonRequest(request.toRequest(), secureUri);
}
Response response = this.delegate.execute(request.toRequest(), options);
return new RibbonResponse(request.getUri(), response);
}
private boolean isSecure(IClientConfig config) {
return (config != null) ? config.get(CommonClientConfigKey.IsSecure) : secure;
}
@Override
public RequestSpecificRetryHandler getRequestSpecificRetryHandler(
RibbonRequest request, IClientConfig requestConfig) {
......
......@@ -25,7 +25,10 @@ import org.springframework.cloud.client.loadbalancer.LoadBalancerClient;
import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.util.UriComponentsBuilder;
import com.netflix.client.config.CommonClientConfigKey;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.ILoadBalancer;
import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerStats;
......@@ -50,7 +53,12 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
RibbonLoadBalancerContext context = this.clientFactory
.getLoadBalancerContext(serviceId);
Server server = new Server(instance.getHost(), instance.getPort());
return context.reconstructURIWithServer(server, original);
boolean secure = isSecure(this.clientFactory, serviceId);
URI uri = original;
if(secure) {
uri = UriComponentsBuilder.fromUri(uri).scheme("https").build().toUri();
}
return context.reconstructURIWithServer(server, uri);
}
@Override
......@@ -59,7 +67,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
if (server == null) {
return null;
}
return new RibbonServer(serviceId, server);
return new RibbonServer(serviceId, server, isSecure(this.clientFactory, serviceId));
}
@Override
......@@ -68,7 +76,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
RibbonLoadBalancerContext context = this.clientFactory
.getLoadBalancerContext(serviceId);
Server server = getServer(loadBalancer);
RibbonServer ribbonServer = new RibbonServer(serviceId, server);
RibbonServer ribbonServer = new RibbonServer(serviceId, server, isSecure(clientFactory, serviceId));
ServerStats serverStats = context.getServerStats(server);
context.noteOpenConnection(serverStats);
......@@ -85,6 +93,14 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
}
return null;
}
private boolean isSecure(SpringClientFactory clientFactory, String serviceId) {
IClientConfig config = clientFactory.getClientConfig(serviceId);
if(config != null) {
return config.get(CommonClientConfigKey.IsSecure, false);
}
return false;
}
private void recordStats(RibbonLoadBalancerContext context, Stopwatch tracer,
ServerStats serverStats, Object entity, Throwable exception) {
......@@ -111,8 +127,13 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
protected static class RibbonServer implements ServiceInstance {
private String serviceId;
private Server server;
private boolean secure;
protected RibbonServer(String serviceId, Server server) {
this(serviceId, server, false);
}
protected RibbonServer(String serviceId, Server server, boolean secure) {
this.serviceId = serviceId;
this.server = server;
}
......@@ -134,7 +155,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient {
@Override
public boolean isSecure() {
return false; //TODO: howto determine https from ribbon Server
return this.secure;
}
@Override
......
package org.springframework.cloud.netflix.feign.ribbon;
import static com.netflix.client.config.CommonClientConfigKey.ConnectTimeout;
import static com.netflix.client.config.CommonClientConfigKey.IsSecure;
import static com.netflix.client.config.CommonClientConfigKey.MaxAutoRetries;
import static com.netflix.client.config.CommonClientConfigKey.MaxAutoRetriesNextServer;
import static com.netflix.client.config.CommonClientConfigKey.OkToRetryOnAllOperations;
import static com.netflix.client.config.CommonClientConfigKey.ReadTimeout;
import static com.netflix.client.config.DefaultClientConfigImpl.DEFAULT_MAX_AUTO_RETRIES;
import static com.netflix.client.config.DefaultClientConfigImpl.DEFAULT_MAX_AUTO_RETRIES_NEXT_SERVER;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import lombok.SneakyThrows;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.cloud.netflix.feign.ribbon.RibbonLoadBalancer.RibbonRequest;
import org.springframework.cloud.netflix.feign.ribbon.RibbonLoadBalancer.RibbonResponse;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.ILoadBalancer;
import feign.Client;
import feign.Request;
import feign.Request.Options;
import feign.RequestTemplate;
import feign.Response;
public class RibbonLoadBalancerTests {
@Mock
private Client delegate;
@Mock
private ILoadBalancer lb;
@Mock
private IClientConfig config;
private RibbonLoadBalancer ribbonLoadBalancer;
private Integer defaultConnectTimeout = 10000;
private Integer defaultReadTimeout = 10000;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
when(config.get(MaxAutoRetries, DEFAULT_MAX_AUTO_RETRIES)).thenReturn(1);
when(config.get(MaxAutoRetriesNextServer, DEFAULT_MAX_AUTO_RETRIES_NEXT_SERVER))
.thenReturn(1);
when(config.get(OkToRetryOnAllOperations, eq(anyBoolean()))).thenReturn(true);
when(config.get(ConnectTimeout)).thenReturn(defaultConnectTimeout);
when(config.get(ReadTimeout)).thenReturn(defaultReadTimeout);
}
@Test
@SneakyThrows
public void testUriInsecure() {
when(config.get(IsSecure)).thenReturn(false);
ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config);
Request request = new RequestTemplate().method("GET").append("http://foo/")
.request();
RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url()));
Response response = Response.create(200, "Test",
Collections.<String, Collection<String>> emptyMap(), new byte[0]);
when(delegate.execute(any(Request.class), any(Options.class))).thenReturn(
response);
RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, null);
assertThat(resp.getRequestedURI(), is(new URI("http://foo/")));
}
@Test
@SneakyThrows
public void testSecureUriFromClientConfig() {
when(config.get(IsSecure)).thenReturn(true);
ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config);
Request request = new RequestTemplate().method("GET").append("http://foo/")
.request();
RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url()));
Response response = Response.create(200, "Test",
Collections.<String, Collection<String>> emptyMap(), new byte[0]);
when(delegate.execute(any(Request.class), any(Options.class))).thenReturn(
response);
RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, null);
assertThat(resp.getRequestedURI(), is(new URI("https://foo/")));
}
@Test
@SneakyThrows
public void testSecureUriFromClientConfigOverride() {
when(config.get(IsSecure)).thenReturn(true);
ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config);
Request request = new RequestTemplate().method("GET").append("http://foo/")
.request();
RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url()));
Response response = Response.create(200, "Test",
Collections.<String, Collection<String>> emptyMap(), new byte[0]);
when(delegate.execute(any(Request.class), any(Options.class))).thenReturn(
response);
IClientConfig override = mock(IClientConfig.class);
when(override.get(ConnectTimeout, defaultConnectTimeout)).thenReturn(5000);
when(override.get(ReadTimeout, defaultConnectTimeout)).thenReturn(5000);
/*
* Override secure value.
*/
when(override.get(IsSecure)).thenReturn(false);
RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, override);
assertThat(resp.getRequestedURI(), is(new URI("http://foo/")));
}
}
......@@ -16,10 +16,23 @@
package org.springframework.cloud.netflix.ribbon;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyDouble;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.net.URI;
import java.net.URL;
import lombok.SneakyThrows;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
......@@ -28,20 +41,13 @@ import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest;
import org.springframework.cloud.netflix.ribbon.RibbonLoadBalancerClient.RibbonServer;
import com.netflix.client.config.CommonClientConfigKey;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.BaseLoadBalancer;
import com.netflix.loadbalancer.LoadBalancerStats;
import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerStats;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import static org.junit.Assert.assertNull;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyDouble;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.verify;
/**
* @author Spencer Gibb
*/
......@@ -81,10 +87,27 @@ public class RibbonLoadBalancerClientTests {
RibbonServer server = getRibbonServer();
RibbonLoadBalancerClient client = getRibbonLoadBalancerClient(server);
ServiceInstance serviceInstance = client.choose(server.getServiceId());
URI uri = client.reconstructURI(serviceInstance, new URL(scheme +"://"
+ server.getServiceId()).toURI());
URI uri = client.reconstructURI(serviceInstance,
new URL(scheme + "://" + server.getServiceId()).toURI());
assertEquals(server.getHost(), uri.getHost());
assertEquals(server.getPort(), uri.getPort());
}
@Test
@SneakyThrows
public void testReconstructUriWithSecureClientConfig() {
RibbonServer server = getRibbonServer();
IClientConfig config = mock(IClientConfig.class);
when(config.get(CommonClientConfigKey.IsSecure, false)).thenReturn(true);
when(clientFactory.getClientConfig(server.getServiceId())).thenReturn(config);
RibbonLoadBalancerClient client = getRibbonLoadBalancerClient(server);
ServiceInstance serviceInstance = client.choose(server.getServiceId());
URI uri = client.reconstructURI(serviceInstance,
new URL("http://" + server.getServiceId()).toURI());
assertEquals(server.getHost(), uri.getHost());
assertEquals(server.getPort(), uri.getPort());
assertEquals("https", uri.getScheme());
}
@Test
......@@ -166,8 +189,8 @@ public class RibbonLoadBalancerClientTests {
protected RibbonLoadBalancerClient getRibbonLoadBalancerClient(
RibbonServer ribbonServer) {
given(this.loadBalancer.getName()).willReturn(ribbonServer.getServiceId());
given(this.loadBalancer.chooseServer(anyString()))
.willReturn(ribbonServer.getServer());
given(this.loadBalancer.chooseServer(anyObject())).willReturn(
ribbonServer.getServer());
given(this.loadBalancer.getLoadBalancerStats())
.willReturn(this.loadBalancerStats);
given(this.loadBalancerStats.getSingleServerStat(ribbonServer.getServer()))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment