Skip to content

Commit

Permalink
Merge pull request #947 from Netflix/dgs-improvements
Browse files Browse the repository at this point in the history
Support @RequestHeader with HttpHeader/Map/MultiValueMap types.
  • Loading branch information
srinivasankavitha committed Mar 29, 2022
2 parents 12cc193 + cecd707 commit 0823675
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import kotlinx.coroutines.future.asCompletableFuture
import org.slf4j.LoggerFactory
import org.springframework.core.DefaultParameterNameDiscoverer
import org.springframework.core.annotation.AnnotationUtils
import org.springframework.http.HttpHeaders
import org.springframework.util.MultiValueMap
import org.springframework.util.ReflectionUtils
import org.springframework.web.bind.annotation.CookieValue
import org.springframework.web.bind.annotation.RequestHeader
Expand Down Expand Up @@ -174,6 +176,13 @@ class DataFetcherInvoker(
val annotation = AnnotationUtils.getAnnotation(parameter, RequestHeader::class.java)!!
val name: String = AnnotationUtils.getAnnotationAttributes(annotation)["name"] as String
val parameterName = name.ifBlank { parameterNames[idx] }

if (parameter.type.isAssignableFrom(Map::class.java)) {
return getValueAsOptional(requestData?.headers?.toSingleValueMap(), parameter)
} else if (parameter.type.isAssignableFrom(HttpHeaders::class.java) || parameter.type.isAssignableFrom(MultiValueMap::class.java)) {
return getValueAsOptional(requestData?.headers, parameter)
}

val value = requestData?.headers?.get(parameterName)?.let {
if (parameter.type.isAssignableFrom(List::class.java)) {
it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import org.springframework.context.ApplicationContext
import org.springframework.http.HttpHeaders
import org.springframework.http.MediaType
import org.springframework.mock.web.MockMultipartFile
import org.springframework.util.MultiValueMap
import org.springframework.web.bind.annotation.RequestHeader
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.context.request.WebRequest
Expand Down Expand Up @@ -1068,6 +1069,105 @@ internal class InputArgumentTest {
verify { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) }
}

@Test
fun `A @RequestHeader argument with map should be supported`() {
val fetcher = object : Any() {
@DgsData(parentType = "Query", field = "hello")
fun someFetcher(@RequestHeader headers: Map<String, String>): String {
val header = headers.get("Referer")
return "From, $header"
}
}

every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(
Pair(
"helloFetcher",
fetcher
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()

val provider = DgsSchemaProvider(applicationContextMock, Optional.empty(), Optional.empty(), Optional.empty())
val schema = provider.schema()

val build = GraphQL.newGraphQL(schema).build()
val httpHeaders = HttpHeaders()
httpHeaders.add("Referer", "localhost")
val executionResult = build.execute(ExecutionInput.newExecutionInput("""{hello}""").context(DgsContext(null, DgsWebMvcRequestData(emptyMap(), httpHeaders))))
Assertions.assertTrue(executionResult.isDataPresent)
val data = executionResult.getData<Map<String, *>>()
Assertions.assertEquals("From, localhost", data["hello"])

verify { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) }
}

@Test
fun `A @RequestHeader argument with multi-value map should be supported`() {
val fetcher = object : Any() {
@DgsData(parentType = "Query", field = "hello")
fun someFetcher(@RequestHeader headers: MultiValueMap<String, String>): String {
val header = headers.getFirst("Referer")
return "From, $header"
}
}

every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(
Pair(
"helloFetcher",
fetcher
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()

val provider = DgsSchemaProvider(applicationContextMock, Optional.empty(), Optional.empty(), Optional.empty())
val schema = provider.schema()

val build = GraphQL.newGraphQL(schema).build()
val httpHeaders = HttpHeaders()
httpHeaders.add("Referer", "localhost")
val executionResult = build.execute(ExecutionInput.newExecutionInput("""{hello}""").context(DgsContext(null, DgsWebMvcRequestData(emptyMap(), httpHeaders))))
Assertions.assertTrue(executionResult.isDataPresent)
val data = executionResult.getData<Map<String, *>>()
Assertions.assertEquals("From, localhost", data["hello"])

verify { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) }
}

@Test
fun `A @RequestHeader argument with HttpHeaders should be supported`() {
val fetcher = object : Any() {
@DgsData(parentType = "Query", field = "hello")
fun someFetcher(@RequestHeader headers: HttpHeaders): String {
val header = headers.getFirst("Referer")
return "From, $header"
}
}

every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(
Pair(
"helloFetcher",
fetcher
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()

val provider = DgsSchemaProvider(applicationContextMock, Optional.empty(), Optional.empty(), Optional.empty())
val schema = provider.schema()

val build = GraphQL.newGraphQL(schema).build()
val httpHeaders = HttpHeaders()
httpHeaders.add("Referer", "localhost")
val executionResult = build.execute(ExecutionInput.newExecutionInput("""{hello}""").context(DgsContext(null, DgsWebMvcRequestData(emptyMap(), httpHeaders))))
Assertions.assertTrue(executionResult.isDataPresent)
val data = executionResult.getData<Map<String, *>>()
Assertions.assertEquals("From, localhost", data["hello"])

verify { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) }
}

@Test
fun `A @RequestHeader argument with name should be supported`() {
val fetcher = object : Any() {
Expand Down

0 comments on commit 0823675

Please sign in to comment.