Commit a43d9d86 by Dave Syer

Add null check in annotation processing for Feign

Allows POST with a @RequestBody Fixes gh-689
parent 57206c63
...@@ -33,12 +33,12 @@ import org.springframework.cloud.netflix.feign.annotation.RequestParamParameterP ...@@ -33,12 +33,12 @@ import org.springframework.cloud.netflix.feign.annotation.RequestParamParameterP
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import feign.Contract;
import feign.MethodMetadata;
import static feign.Util.checkState; import static feign.Util.checkState;
import static feign.Util.emptyToNull; import static feign.Util.emptyToNull;
import feign.Contract;
import feign.MethodMetadata;
/** /**
* @author Spencer Gibb * @author Spencer Gibb
*/ */
...@@ -51,16 +51,19 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -51,16 +51,19 @@ public class SpringMvcContract extends Contract.BaseContract {
private final Map<Class<? extends Annotation>, AnnotatedParameterProcessor> annotatedArgumentProcessors; private final Map<Class<? extends Annotation>, AnnotatedParameterProcessor> annotatedArgumentProcessors;
public SpringMvcContract() { public SpringMvcContract() {
this(Collections.<AnnotatedParameterProcessor>emptyList()); this(Collections.<AnnotatedParameterProcessor> emptyList());
} }
public SpringMvcContract(List<AnnotatedParameterProcessor> annotatedParameterProcessors) { public SpringMvcContract(
Assert.notNull(annotatedParameterProcessors, "Parameter processors can not be null."); List<AnnotatedParameterProcessor> annotatedParameterProcessors) {
Assert.notNull(annotatedParameterProcessors,
"Parameter processors can not be null.");
List<AnnotatedParameterProcessor> processors; List<AnnotatedParameterProcessor> processors;
if(!annotatedParameterProcessors.isEmpty()) { if (!annotatedParameterProcessors.isEmpty()) {
processors = new ArrayList<>(annotatedParameterProcessors); processors = new ArrayList<>(annotatedParameterProcessors);
} else { }
else {
processors = getDefaultAnnotatedArgumentsProcessors(); processors = getDefaultAnnotatedArgumentsProcessors();
} }
this.annotatedArgumentProcessors = toAnnotatedArgumentProcessorMap(processors); this.annotatedArgumentProcessors = toAnnotatedArgumentProcessorMap(processors);
...@@ -75,7 +78,8 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -75,7 +78,8 @@ public class SpringMvcContract extends Contract.BaseContract {
// Prepend path from class annotation if specified // Prepend path from class annotation if specified
if (classAnnotation.value().length > 0) { if (classAnnotation.value().length > 0) {
String pathValue = emptyToNull(classAnnotation.value()[0]); String pathValue = emptyToNull(classAnnotation.value()[0]);
checkState(pathValue != null, "RequestMapping.value() was empty on type %s", checkState(pathValue != null,
"RequestMapping.value() was empty on type %s",
method.getDeclaringClass().getName()); method.getDeclaringClass().getName());
if (!pathValue.startsWith("/")) { if (!pathValue.startsWith("/")) {
pathValue = "/" + pathValue; pathValue = "/" + pathValue;
...@@ -84,16 +88,17 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -84,16 +88,17 @@ public class SpringMvcContract extends Contract.BaseContract {
} }
// produces - use from class annotation only if method has not specified this // produces - use from class annotation only if method has not specified this
if(!md.template().headers().containsKey(ACCEPT)) { if (!md.template().headers().containsKey(ACCEPT)) {
parseProduces(md, method, classAnnotation); parseProduces(md, method, classAnnotation);
} }
// consumes -- use from class annotation only if method has not specified this // consumes -- use from class annotation only if method has not specified this
if(!md.template().headers().containsKey(CONTENT_TYPE)) { if (!md.template().headers().containsKey(CONTENT_TYPE)) {
parseConsumes(md, method, classAnnotation); parseConsumes(md, method, classAnnotation);
} }
// headers -- class annotation is inherited to methods, always write these if present // headers -- class annotation is inherited to methods, always write these if
// present
parseHeaders(md, method, classAnnotation); parseHeaders(md, method, classAnnotation);
} }
return md; return md;
...@@ -113,11 +118,12 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -113,11 +118,12 @@ public class SpringMvcContract extends Contract.BaseContract {
// path // path
checkAtMostOne(method, methodMapping.value(), "value"); checkAtMostOne(method, methodMapping.value(), "value");
if(methodMapping.value().length > 0) { if (methodMapping.value().length > 0) {
String pathValue = emptyToNull(methodMapping.value()[0]); String pathValue = emptyToNull(methodMapping.value()[0]);
if (pathValue != null) { if (pathValue != null) {
// Append path from @RequestMapping if value is present on method // Append path from @RequestMapping if value is present on method
if (!pathValue.startsWith("/") && !data.template().toString().endsWith("/")) { if (!pathValue.startsWith("/")
&& !data.template().toString().endsWith("/")) {
pathValue = "/" + pathValue; pathValue = "/" + pathValue;
} }
data.template().append(pathValue); data.template().append(pathValue);
...@@ -134,7 +140,6 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -134,7 +140,6 @@ public class SpringMvcContract extends Contract.BaseContract {
parseHeaders(data, method, methodMapping); parseHeaders(data, method, methodMapping);
} }
private void checkAtMostOne(Method method, Object[] values, String fieldName) { private void checkAtMostOne(Method method, Object[] values, String fieldName) {
checkState(values != null && (values.length == 0 || values.length == 1), checkState(values != null && (values.length == 0 || values.length == 1),
"Method %s can only contain at most 1 %s field. Found: %s", "Method %s can only contain at most 1 %s field. Found: %s",
...@@ -149,20 +154,25 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -149,20 +154,25 @@ public class SpringMvcContract extends Contract.BaseContract {
} }
@Override @Override
protected boolean processAnnotationsOnParameter(MethodMetadata data, Annotation[] annotations, int paramIndex) { protected boolean processAnnotationsOnParameter(MethodMetadata data,
Annotation[] annotations, int paramIndex) {
boolean isHttpAnnotation = false; boolean isHttpAnnotation = false;
AnnotatedParameterProcessor.AnnotatedParameterContext context = AnnotatedParameterProcessor.AnnotatedParameterContext context = new SimpleAnnotatedParameterContext(
new SimpleAnnotatedParameterContext(data, paramIndex); data, paramIndex);
for (Annotation parameterAnnotation : annotations) { for (Annotation parameterAnnotation : annotations) {
AnnotatedParameterProcessor processor = AnnotatedParameterProcessor processor = this.annotatedArgumentProcessors
annotatedArgumentProcessors.get(parameterAnnotation.annotationType()); .get(parameterAnnotation.annotationType());
isHttpAnnotation |= processor.processArgument(context, parameterAnnotation); if (processor != null) {
isHttpAnnotation |= processor.processArgument(context,
parameterAnnotation);
}
} }
return isHttpAnnotation; return isHttpAnnotation;
} }
private void parseProduces(MethodMetadata md, Method method, RequestMapping annotation) { private void parseProduces(MethodMetadata md, Method method,
RequestMapping annotation) {
checkAtMostOne(method, annotation.produces(), "produces"); checkAtMostOne(method, annotation.produces(), "produces");
String[] serverProduces = annotation.produces(); String[] serverProduces = annotation.produces();
String clientAccepts = serverProduces.length == 0 ? null String clientAccepts = serverProduces.length == 0 ? null
...@@ -172,7 +182,8 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -172,7 +182,8 @@ public class SpringMvcContract extends Contract.BaseContract {
} }
} }
private void parseConsumes(MethodMetadata md, Method method, RequestMapping annotation) { private void parseConsumes(MethodMetadata md, Method method,
RequestMapping annotation) {
checkAtMostOne(method, annotation.consumes(), "consumes"); checkAtMostOne(method, annotation.consumes(), "consumes");
String[] serverConsumes = annotation.consumes(); String[] serverConsumes = annotation.consumes();
String clientProduces = serverConsumes.length == 0 ? null String clientProduces = serverConsumes.length == 0 ? null
...@@ -182,7 +193,8 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -182,7 +193,8 @@ public class SpringMvcContract extends Contract.BaseContract {
} }
} }
private void parseHeaders(MethodMetadata md, Method method, RequestMapping annotation) { private void parseHeaders(MethodMetadata md, Method method,
RequestMapping annotation) {
// TODO: only supports one header value per key // TODO: only supports one header value per key
if (annotation.headers() != null && annotation.headers().length > 0) { if (annotation.headers() != null && annotation.headers().length > 0) {
for (String header : annotation.headers()) { for (String header : annotation.headers()) {
...@@ -193,9 +205,10 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -193,9 +205,10 @@ public class SpringMvcContract extends Contract.BaseContract {
} }
} }
private Map<Class<? extends Annotation>, AnnotatedParameterProcessor> toAnnotatedArgumentProcessorMap(List<AnnotatedParameterProcessor> processors) { private Map<Class<? extends Annotation>, AnnotatedParameterProcessor> toAnnotatedArgumentProcessorMap(
List<AnnotatedParameterProcessor> processors) {
Map<Class<? extends Annotation>, AnnotatedParameterProcessor> result = new HashMap<>(); Map<Class<? extends Annotation>, AnnotatedParameterProcessor> result = new HashMap<>();
for(AnnotatedParameterProcessor processor : processors) { for (AnnotatedParameterProcessor processor : processors) {
result.put(processor.getAnnotationType(), processor); result.put(processor.getAnnotationType(), processor);
} }
return result; return result;
...@@ -212,34 +225,37 @@ public class SpringMvcContract extends Contract.BaseContract { ...@@ -212,34 +225,37 @@ public class SpringMvcContract extends Contract.BaseContract {
return annotatedArgumentResolvers; return annotatedArgumentResolvers;
} }
private class SimpleAnnotatedParameterContext implements AnnotatedParameterProcessor.AnnotatedParameterContext { private class SimpleAnnotatedParameterContext
implements AnnotatedParameterProcessor.AnnotatedParameterContext {
private final MethodMetadata methodMetadata; private final MethodMetadata methodMetadata;
private final int parameterIndex; private final int parameterIndex;
public SimpleAnnotatedParameterContext(MethodMetadata methodMetadata, int parameterIndex) { public SimpleAnnotatedParameterContext(MethodMetadata methodMetadata,
int parameterIndex) {
this.methodMetadata = methodMetadata; this.methodMetadata = methodMetadata;
this.parameterIndex = parameterIndex; this.parameterIndex = parameterIndex;
} }
@Override @Override
public MethodMetadata getMethodMetadata() { public MethodMetadata getMethodMetadata() {
return methodMetadata; return this.methodMetadata;
} }
@Override @Override
public int getParameterIndex() { public int getParameterIndex() {
return parameterIndex; return this.parameterIndex;
} }
@Override @Override
public void setParameterName(String name) { public void setParameterName(String name) {
nameParam(methodMetadata, name, parameterIndex); nameParam(this.methodMetadata, name, this.parameterIndex);
} }
@Override @Override
public Collection<String> setTemplateParameter(String name, Collection<String> rest) { public Collection<String> setTemplateParameter(String name,
Collection<String> rest) {
return addTemplatedParam(rest, name); return addTemplatedParam(rest, name);
} }
} }
......
...@@ -8,6 +8,7 @@ import org.springframework.http.MediaType; ...@@ -8,6 +8,7 @@ import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RequestMethod;
...@@ -15,13 +16,13 @@ import org.springframework.web.bind.annotation.RequestParam; ...@@ -15,13 +16,13 @@ import org.springframework.web.bind.annotation.RequestParam;
import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonAutoDetect;
import static org.junit.Assert.assertEquals;
import feign.MethodMetadata; import feign.MethodMetadata;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.ToString; import lombok.ToString;
import static org.junit.Assert.assertEquals;
/** /**
* @author chadjaros * @author chadjaros
*/ */
...@@ -31,84 +32,118 @@ public class SpringMvcContractTest { ...@@ -31,84 +32,118 @@ public class SpringMvcContractTest {
@Before @Before
public void setup() { public void setup() {
contract = new SpringMvcContract(); this.contract = new SpringMvcContract();
} }
@Test @Test
public void testProcessAnnotationOnMethod_Simple() throws Exception { public void testProcessAnnotationOnMethod_Simple() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest", String.class); Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest",
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method); String.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/test/{id}", data.template().url()); assertEquals("/test/{id}", data.template().url());
assertEquals("GET", data.template().method()); assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next()); assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
} }
@Test @Test
public void testProcessAnnotations_Simple() throws Exception { public void testProcessAnnotations_Simple() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest", String.class); Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest",
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method); String.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/test/{id}", data.template().url()); assertEquals("/test/{id}", data.template().url());
assertEquals("GET", data.template().method()); assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next()); assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
assertEquals("id", data.indexToName().get(0).iterator().next()); assertEquals("id", data.indexToName().get(0).iterator().next());
} }
@Test @Test
public void testProcessAnnotations_SimplePost() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("postTest",
TestObject.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("", data.template().url());
assertEquals("POST", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
}
@Test
public void testProcessAnnotationsOnMethod_Advanced() throws Exception { public void testProcessAnnotationsOnMethod_Advanced() throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest", String.class, String.class, Integer.class); Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest",
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method); String.class, String.class, Integer.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/advanced/test/{id}", data.template().url()); assertEquals("/advanced/test/{id}", data.template().url());
assertEquals("PUT", data.template().method()); assertEquals("PUT", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next()); assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
} }
@Test @Test
public void testProcessAnnotationsOnMethod_Advanced_UnknownAnnotation() throws Exception { public void testProcessAnnotationsOnMethod_Advanced_UnknownAnnotation()
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest", String.class, String.class, Integer.class); throws Exception {
contract.parseAndValidateMetadata(method.getDeclaringClass(), method); Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest",
String.class, String.class, Integer.class);
this.contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
// Don't throw an exception and this passes // Don't throw an exception and this passes
} }
@Test @Test
public void testProcessAnnotations_Advanced() throws Exception { public void testProcessAnnotations_Advanced() throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest", String.class, String.class, Integer.class); Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest",
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method); String.class, String.class, Integer.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/advanced/test/{id}", data.template().url()); assertEquals("/advanced/test/{id}", data.template().url());
assertEquals("PUT", data.template().method()); assertEquals("PUT", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next()); assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
assertEquals("Authorization", data.indexToName().get(0).iterator().next()); assertEquals("Authorization", data.indexToName().get(0).iterator().next());
assertEquals("id", data.indexToName().get(1).iterator().next()); assertEquals("id", data.indexToName().get(1).iterator().next());
assertEquals("amount", data.indexToName().get(2).iterator().next()); assertEquals("amount", data.indexToName().get(2).iterator().next());
assertEquals("{Authorization}", data.template().headers().get("Authorization").iterator().next()); assertEquals("{Authorization}",
assertEquals("{amount}", data.template().queries().get("amount").iterator().next()); data.template().headers().get("Authorization").iterator().next());
assertEquals("{amount}",
data.template().queries().get("amount").iterator().next());
} }
@Test @Test
public void testProcessAnnotations_Advanced2() throws Exception { public void testProcessAnnotations_Advanced2() throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest"); Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest");
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method); MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/advanced", data.template().url()); assertEquals("/advanced", data.template().url());
assertEquals("GET", data.template().method()); assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next()); assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
} }
@Test @Test
public void testProcessAnnotations_Advanced3() throws Exception { public void testProcessAnnotations_Advanced3() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest"); Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest");
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method); MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("", data.template().url()); assertEquals("", data.template().url());
assertEquals("GET", data.template().method()); assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next()); assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
} }
public interface TestTemplate_Simple { public interface TestTemplate_Simple {
...@@ -117,6 +152,9 @@ public class SpringMvcContractTest { ...@@ -117,6 +152,9 @@ public class SpringMvcContractTest {
@RequestMapping(method = RequestMethod.GET, produces = MediaType.APPLICATION_JSON_VALUE) @RequestMapping(method = RequestMethod.GET, produces = MediaType.APPLICATION_JSON_VALUE)
TestObject getTest(); TestObject getTest();
@RequestMapping(method = RequestMethod.POST, produces = MediaType.APPLICATION_JSON_VALUE)
TestObject postTest(@RequestBody TestObject object);
} }
@JsonAutoDetect @JsonAutoDetect
...@@ -125,7 +163,8 @@ public class SpringMvcContractTest { ...@@ -125,7 +163,8 @@ public class SpringMvcContractTest {
@ExceptionHandler @ExceptionHandler
@RequestMapping(value = "/test/{id}", method = RequestMethod.PUT, produces = MediaType.APPLICATION_JSON_VALUE) @RequestMapping(value = "/test/{id}", method = RequestMethod.PUT, produces = MediaType.APPLICATION_JSON_VALUE)
ResponseEntity<TestObject> getTest(@RequestHeader("Authorization") String auth, @PathVariable("id") String id, @RequestParam("amount") Integer amount ); ResponseEntity<TestObject> getTest(@RequestHeader("Authorization") String auth,
@PathVariable("id") String id, @RequestParam("amount") Integer amount);
@RequestMapping(method = RequestMethod.GET, produces = MediaType.APPLICATION_JSON_VALUE) @RequestMapping(method = RequestMethod.GET, produces = MediaType.APPLICATION_JSON_VALUE)
TestObject getTest(); TestObject getTest();
...@@ -142,21 +181,31 @@ public class SpringMvcContractTest { ...@@ -142,21 +181,31 @@ public class SpringMvcContractTest {
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) {
if (o == null || getClass() != o.getClass()) return false; return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TestObject that = (TestObject) o; TestObject that = (TestObject) o;
if (number != null ? !number.equals(that.number) : that.number != null) return false; if (this.number != null ? !this.number.equals(that.number)
if (something != null ? !something.equals(that.something) : that.something != null) return false; : that.number != null) {
return false;
}
if (this.something != null ? !this.something.equals(that.something)
: that.something != null) {
return false;
}
return true; return true;
} }
@Override @Override
public int hashCode() { public int hashCode() {
int result = (something != null ? something.hashCode() : 0); int result = (this.something != null ? this.something.hashCode() : 0);
result = 31 * result + (number != null ? number.hashCode() : 0); result = 31 * result + (this.number != null ? this.number.hashCode() : 0);
return result; return result;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment