Commit a43d9d86 by Dave Syer

Add null check in annotation processing for Feign

Allows POST with a @RequestBody Fixes gh-689
parent 57206c63
......@@ -33,12 +33,12 @@ import org.springframework.cloud.netflix.feign.annotation.RequestParamParameterP
import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.RequestMapping;
import feign.Contract;
import feign.MethodMetadata;
import static feign.Util.checkState;
import static feign.Util.emptyToNull;
import feign.Contract;
import feign.MethodMetadata;
/**
* @author Spencer Gibb
*/
......@@ -51,16 +51,19 @@ public class SpringMvcContract extends Contract.BaseContract {
private final Map<Class<? extends Annotation>, AnnotatedParameterProcessor> annotatedArgumentProcessors;
public SpringMvcContract() {
this(Collections.<AnnotatedParameterProcessor>emptyList());
this(Collections.<AnnotatedParameterProcessor> emptyList());
}
public SpringMvcContract(List<AnnotatedParameterProcessor> annotatedParameterProcessors) {
Assert.notNull(annotatedParameterProcessors, "Parameter processors can not be null.");
public SpringMvcContract(
List<AnnotatedParameterProcessor> annotatedParameterProcessors) {
Assert.notNull(annotatedParameterProcessors,
"Parameter processors can not be null.");
List<AnnotatedParameterProcessor> processors;
if(!annotatedParameterProcessors.isEmpty()) {
if (!annotatedParameterProcessors.isEmpty()) {
processors = new ArrayList<>(annotatedParameterProcessors);
} else {
}
else {
processors = getDefaultAnnotatedArgumentsProcessors();
}
this.annotatedArgumentProcessors = toAnnotatedArgumentProcessorMap(processors);
......@@ -75,7 +78,8 @@ public class SpringMvcContract extends Contract.BaseContract {
// Prepend path from class annotation if specified
if (classAnnotation.value().length > 0) {
String pathValue = emptyToNull(classAnnotation.value()[0]);
checkState(pathValue != null, "RequestMapping.value() was empty on type %s",
checkState(pathValue != null,
"RequestMapping.value() was empty on type %s",
method.getDeclaringClass().getName());
if (!pathValue.startsWith("/")) {
pathValue = "/" + pathValue;
......@@ -84,16 +88,17 @@ public class SpringMvcContract extends Contract.BaseContract {
}
// produces - use from class annotation only if method has not specified this
if(!md.template().headers().containsKey(ACCEPT)) {
if (!md.template().headers().containsKey(ACCEPT)) {
parseProduces(md, method, classAnnotation);
}
// consumes -- use from class annotation only if method has not specified this
if(!md.template().headers().containsKey(CONTENT_TYPE)) {
if (!md.template().headers().containsKey(CONTENT_TYPE)) {
parseConsumes(md, method, classAnnotation);
}
// headers -- class annotation is inherited to methods, always write these if present
// headers -- class annotation is inherited to methods, always write these if
// present
parseHeaders(md, method, classAnnotation);
}
return md;
......@@ -113,11 +118,12 @@ public class SpringMvcContract extends Contract.BaseContract {
// path
checkAtMostOne(method, methodMapping.value(), "value");
if(methodMapping.value().length > 0) {
if (methodMapping.value().length > 0) {
String pathValue = emptyToNull(methodMapping.value()[0]);
if (pathValue != null) {
// Append path from @RequestMapping if value is present on method
if (!pathValue.startsWith("/") && !data.template().toString().endsWith("/")) {
if (!pathValue.startsWith("/")
&& !data.template().toString().endsWith("/")) {
pathValue = "/" + pathValue;
}
data.template().append(pathValue);
......@@ -134,7 +140,6 @@ public class SpringMvcContract extends Contract.BaseContract {
parseHeaders(data, method, methodMapping);
}
private void checkAtMostOne(Method method, Object[] values, String fieldName) {
checkState(values != null && (values.length == 0 || values.length == 1),
"Method %s can only contain at most 1 %s field. Found: %s",
......@@ -149,20 +154,25 @@ public class SpringMvcContract extends Contract.BaseContract {
}
@Override
protected boolean processAnnotationsOnParameter(MethodMetadata data, Annotation[] annotations, int paramIndex) {
protected boolean processAnnotationsOnParameter(MethodMetadata data,
Annotation[] annotations, int paramIndex) {
boolean isHttpAnnotation = false;
AnnotatedParameterProcessor.AnnotatedParameterContext context =
new SimpleAnnotatedParameterContext(data, paramIndex);
AnnotatedParameterProcessor.AnnotatedParameterContext context = new SimpleAnnotatedParameterContext(
data, paramIndex);
for (Annotation parameterAnnotation : annotations) {
AnnotatedParameterProcessor processor =
annotatedArgumentProcessors.get(parameterAnnotation.annotationType());
isHttpAnnotation |= processor.processArgument(context, parameterAnnotation);
AnnotatedParameterProcessor processor = this.annotatedArgumentProcessors
.get(parameterAnnotation.annotationType());
if (processor != null) {
isHttpAnnotation |= processor.processArgument(context,
parameterAnnotation);
}
}
return isHttpAnnotation;
}
private void parseProduces(MethodMetadata md, Method method, RequestMapping annotation) {
private void parseProduces(MethodMetadata md, Method method,
RequestMapping annotation) {
checkAtMostOne(method, annotation.produces(), "produces");
String[] serverProduces = annotation.produces();
String clientAccepts = serverProduces.length == 0 ? null
......@@ -172,7 +182,8 @@ public class SpringMvcContract extends Contract.BaseContract {
}
}
private void parseConsumes(MethodMetadata md, Method method, RequestMapping annotation) {
private void parseConsumes(MethodMetadata md, Method method,
RequestMapping annotation) {
checkAtMostOne(method, annotation.consumes(), "consumes");
String[] serverConsumes = annotation.consumes();
String clientProduces = serverConsumes.length == 0 ? null
......@@ -182,7 +193,8 @@ public class SpringMvcContract extends Contract.BaseContract {
}
}
private void parseHeaders(MethodMetadata md, Method method, RequestMapping annotation) {
private void parseHeaders(MethodMetadata md, Method method,
RequestMapping annotation) {
// TODO: only supports one header value per key
if (annotation.headers() != null && annotation.headers().length > 0) {
for (String header : annotation.headers()) {
......@@ -193,9 +205,10 @@ public class SpringMvcContract extends Contract.BaseContract {
}
}
private Map<Class<? extends Annotation>, AnnotatedParameterProcessor> toAnnotatedArgumentProcessorMap(List<AnnotatedParameterProcessor> processors) {
private Map<Class<? extends Annotation>, AnnotatedParameterProcessor> toAnnotatedArgumentProcessorMap(
List<AnnotatedParameterProcessor> processors) {
Map<Class<? extends Annotation>, AnnotatedParameterProcessor> result = new HashMap<>();
for(AnnotatedParameterProcessor processor : processors) {
for (AnnotatedParameterProcessor processor : processors) {
result.put(processor.getAnnotationType(), processor);
}
return result;
......@@ -212,34 +225,37 @@ public class SpringMvcContract extends Contract.BaseContract {
return annotatedArgumentResolvers;
}
private class SimpleAnnotatedParameterContext implements AnnotatedParameterProcessor.AnnotatedParameterContext {
private class SimpleAnnotatedParameterContext
implements AnnotatedParameterProcessor.AnnotatedParameterContext {
private final MethodMetadata methodMetadata;
private final int parameterIndex;
public SimpleAnnotatedParameterContext(MethodMetadata methodMetadata, int parameterIndex) {
public SimpleAnnotatedParameterContext(MethodMetadata methodMetadata,
int parameterIndex) {
this.methodMetadata = methodMetadata;
this.parameterIndex = parameterIndex;
}
@Override
public MethodMetadata getMethodMetadata() {
return methodMetadata;
return this.methodMetadata;
}
@Override
public int getParameterIndex() {
return parameterIndex;
return this.parameterIndex;
}
@Override
public void setParameterName(String name) {
nameParam(methodMetadata, name, parameterIndex);
nameParam(this.methodMetadata, name, this.parameterIndex);
}
@Override
public Collection<String> setTemplateParameter(String name, Collection<String> rest) {
public Collection<String> setTemplateParameter(String name,
Collection<String> rest) {
return addTemplatedParam(rest, name);
}
}
......
......@@ -8,6 +8,7 @@ import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
......@@ -15,13 +16,13 @@ import org.springframework.web.bind.annotation.RequestParam;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import static org.junit.Assert.assertEquals;
import feign.MethodMetadata;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import lombok.ToString;
import static org.junit.Assert.assertEquals;
/**
* @author chadjaros
*/
......@@ -31,84 +32,118 @@ public class SpringMvcContractTest {
@Before
public void setup() {
contract = new SpringMvcContract();
this.contract = new SpringMvcContract();
}
@Test
public void testProcessAnnotationOnMethod_Simple() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest", String.class);
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest",
String.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/test/{id}", data.template().url());
assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
}
@Test
public void testProcessAnnotations_Simple() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest", String.class);
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest",
String.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/test/{id}", data.template().url());
assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
assertEquals("id", data.indexToName().get(0).iterator().next());
}
@Test
public void testProcessAnnotations_SimplePost() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("postTest",
TestObject.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("", data.template().url());
assertEquals("POST", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
}
@Test
public void testProcessAnnotationsOnMethod_Advanced() throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest", String.class, String.class, Integer.class);
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest",
String.class, String.class, Integer.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/advanced/test/{id}", data.template().url());
assertEquals("PUT", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
}
@Test
public void testProcessAnnotationsOnMethod_Advanced_UnknownAnnotation() throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest", String.class, String.class, Integer.class);
contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
public void testProcessAnnotationsOnMethod_Advanced_UnknownAnnotation()
throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest",
String.class, String.class, Integer.class);
this.contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
// Don't throw an exception and this passes
}
@Test
public void testProcessAnnotations_Advanced() throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest", String.class, String.class, Integer.class);
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest",
String.class, String.class, Integer.class);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/advanced/test/{id}", data.template().url());
assertEquals("PUT", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
assertEquals("Authorization", data.indexToName().get(0).iterator().next());
assertEquals("id", data.indexToName().get(1).iterator().next());
assertEquals("amount", data.indexToName().get(2).iterator().next());
assertEquals("{Authorization}", data.template().headers().get("Authorization").iterator().next());
assertEquals("{amount}", data.template().queries().get("amount").iterator().next());
assertEquals("{Authorization}",
data.template().headers().get("Authorization").iterator().next());
assertEquals("{amount}",
data.template().queries().get("amount").iterator().next());
}
@Test
public void testProcessAnnotations_Advanced2() throws Exception {
Method method = TestTemplate_Advanced.class.getDeclaredMethod("getTest");
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("/advanced", data.template().url());
assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
}
@Test
public void testProcessAnnotations_Advanced3() throws Exception {
Method method = TestTemplate_Simple.class.getDeclaredMethod("getTest");
MethodMetadata data = contract.parseAndValidateMetadata(method.getDeclaringClass(), method);
MethodMetadata data = this.contract
.parseAndValidateMetadata(method.getDeclaringClass(), method);
assertEquals("", data.template().url());
assertEquals("GET", data.template().method());
assertEquals(MediaType.APPLICATION_JSON_VALUE, data.template().headers().get("Accept").iterator().next());
assertEquals(MediaType.APPLICATION_JSON_VALUE,
data.template().headers().get("Accept").iterator().next());
}
public interface TestTemplate_Simple {
......@@ -117,6 +152,9 @@ public class SpringMvcContractTest {
@RequestMapping(method = RequestMethod.GET, produces = MediaType.APPLICATION_JSON_VALUE)
TestObject getTest();
@RequestMapping(method = RequestMethod.POST, produces = MediaType.APPLICATION_JSON_VALUE)
TestObject postTest(@RequestBody TestObject object);
}
@JsonAutoDetect
......@@ -125,7 +163,8 @@ public class SpringMvcContractTest {
@ExceptionHandler
@RequestMapping(value = "/test/{id}", method = RequestMethod.PUT, produces = MediaType.APPLICATION_JSON_VALUE)
ResponseEntity<TestObject> getTest(@RequestHeader("Authorization") String auth, @PathVariable("id") String id, @RequestParam("amount") Integer amount );
ResponseEntity<TestObject> getTest(@RequestHeader("Authorization") String auth,
@PathVariable("id") String id, @RequestParam("amount") Integer amount);
@RequestMapping(method = RequestMethod.GET, produces = MediaType.APPLICATION_JSON_VALUE)
TestObject getTest();
......@@ -142,21 +181,31 @@ public class SpringMvcContractTest {
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TestObject that = (TestObject) o;
if (number != null ? !number.equals(that.number) : that.number != null) return false;
if (something != null ? !something.equals(that.something) : that.something != null) return false;
if (this.number != null ? !this.number.equals(that.number)
: that.number != null) {
return false;
}
if (this.something != null ? !this.something.equals(that.something)
: that.something != null) {
return false;
}
return true;
}
@Override
public int hashCode() {
int result = (something != null ? something.hashCode() : 0);
result = 31 * result + (number != null ? number.hashCode() : 0);
int result = (this.something != null ? this.something.hashCode() : 0);
result = 31 * result + (this.number != null ? this.number.hashCode() : 0);
return result;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment