Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for multiple routes #7563

Merged
merged 3 commits into from Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions mockwebserver/api/mockwebserver3.api
Expand Up @@ -86,6 +86,7 @@ public final class mockwebserver3/MockWebServer : java/io/Closeable {
public final fun getBodyLimit ()J
public final fun getDispatcher ()Lmockwebserver3/Dispatcher;
public final fun getHostName ()Ljava/lang/String;
public final fun getInetSocketAddress ()Ljava/net/InetSocketAddress;
public final fun getPort ()I
public final fun getProtocolNegotiationEnabled ()Z
public final fun getRequestCount ()I
Expand Down
14 changes: 10 additions & 4 deletions mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt
Expand Up @@ -152,10 +152,16 @@ class MockWebServer : Closeable {
val hostName: String
get() {
before()
return inetSocketAddress!!.address.canonicalHostName
return _inetSocketAddress!!.address.canonicalHostName
}

private var inetSocketAddress: InetSocketAddress? = null
private var _inetSocketAddress: InetSocketAddress? = null

val inetSocketAddress: InetSocketAddress
get() {
before()
return InetSocketAddress(hostName, portField)
}

/**
* True if ALPN is used on incoming HTTPS connections to negotiate a protocol like HTTP/1.1 or
Expand Down Expand Up @@ -206,7 +212,7 @@ class MockWebServer : Closeable {

fun toProxyAddress(): Proxy {
before()
val address = InetSocketAddress(inetSocketAddress!!.address.canonicalHostName, port)
val address = InetSocketAddress(_inetSocketAddress!!.address.canonicalHostName, port)
return Proxy(Proxy.Type.HTTP, address)
}

Expand Down Expand Up @@ -383,7 +389,7 @@ class MockWebServer : Closeable {
if (started) return
started = true

this.inetSocketAddress = inetSocketAddress
this._inetSocketAddress = inetSocketAddress

serverSocket = serverSocketFactory!!.createServerSocket()

Expand Down
@@ -0,0 +1,48 @@
/*
* Copyright (C) 2022 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3

import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.Socket
import java.net.SocketAddress

/**
* A [SocketFactory] that redirects connections to [defaultAddress] or specific overridden address via [set].
*/
class SpecificHostSocketFactory(
val defaultAddress: InetSocketAddress?
) : DelegatingSocketFactory(getDefault()) {
private val hostMapping = mutableMapOf<InetAddress, InetSocketAddress>()

/** Sets the [real] address for [requested]. */
operator fun set(
requested: InetAddress,
real: InetSocketAddress
) {
hostMapping[requested] = real
}

override fun createSocket(): Socket {
return object : Socket() {
override fun connect(endpoint: SocketAddress?, timeout: Int) {
val requested = (endpoint as InetSocketAddress)
val inetSocketAddress = hostMapping[requested.address] ?: defaultAddress ?: requested
super.connect(inetSocketAddress, timeout)
}
}
}
}
Expand Up @@ -24,18 +24,22 @@ import okhttp3.Route
* preferred.
*/
class RouteDatabase {
private val failedRoutes = mutableSetOf<Route>()
private val _failedRoutes = mutableSetOf<Route>()

val failedRoutes: Set<Route>
@Synchronized get() = _failedRoutes.toSet()


/** Records a failure connecting to [failedRoute]. */
@Synchronized fun failed(failedRoute: Route) {
failedRoutes.add(failedRoute)
_failedRoutes.add(failedRoute)
}

/** Records success connecting to [route]. */
@Synchronized fun connected(route: Route) {
failedRoutes.remove(route)
_failedRoutes.remove(route)
}

/** Returns true if [route] has failed recently and should be avoided. */
@Synchronized fun shouldPostpone(route: Route): Boolean = route in failedRoutes
@Synchronized fun shouldPostpone(route: Route): Boolean = route in _failedRoutes
}
275 changes: 275 additions & 0 deletions okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt
@@ -0,0 +1,275 @@
/*
* Copyright (C) 2022 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3
yschimke marked this conversation as resolved.
Show resolved Hide resolved

import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import mockwebserver3.MockResponse
import mockwebserver3.MockWebServer
import mockwebserver3.SocketPolicy
import mockwebserver3.junit5.internal.MockWebServerInstance
import okhttp3.internal.http2.ErrorCode
import okhttp3.testing.PlatformRule
import okhttp3.tls.internal.TlsUtil.localhost
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.RegisterExtension

class RouteFailureTest {
private lateinit var socketFactory: SpecificHostSocketFactory
private lateinit var client: OkHttpClient

@RegisterExtension
val platform = PlatformRule()

@RegisterExtension
val clientTestRule = OkHttpClientTestRule()

private lateinit var server1: MockWebServer
private lateinit var server2: MockWebServer

private var listener = RecordingEventListener()

private val handshakeCertificates = localhost()

val dns = FakeDns()

val ipv4 = InetAddress.getByName("203.0.113.1")
val ipv6 = InetAddress.getByName("2001:db8:ffff:ffff:ffff:ffff:ffff:1")

val refusedStream = MockResponse()
.setHttp2ErrorCode(ErrorCode.REFUSED_STREAM.httpCode)
.setSocketPolicy(SocketPolicy.RESET_STREAM_AT_START)
val bodyResponse = MockResponse().setBody("body")

@BeforeEach
fun setUp(
server: MockWebServer,
@MockWebServerInstance("server2") server2: MockWebServer
) {
this.server1 = server
this.server2 = server2

socketFactory = SpecificHostSocketFactory(InetSocketAddress(server.hostName, server.port))

client = clientTestRule.newClientBuilder()
.dns(dns)
.socketFactory(socketFactory)
.eventListenerFactory(clientTestRule.wrap(listener))
.build()
}

@Test
fun http2OneBadHostOneGoodNoRetryOnConnectionFailure() {
enableProtocol(Protocol.HTTP_2)

val request = Request(server1.url("/"))

server1.enqueue(refusedStream)
server2.enqueue(bodyResponse)

dns[server1.hostName] = listOf(ipv6, ipv4)
socketFactory[ipv6] = server1.inetSocketAddress
socketFactory[ipv4] = server2.inetSocketAddress

client = client.newBuilder()
.fastFallback(false)
.apply {
retryOnConnectionFailure = false
}
.build()

executeSynchronously(request)
.assertFailureMatches("stream was reset: REFUSED_STREAM")

assertThat(client.routeDatabase.failedRoutes).isEmpty()
assertThat(server1.requestCount).isEqualTo(1)
assertThat(server2.requestCount).isEqualTo(0)
}

@Test
fun http2OneBadHostOneGoodRetryOnConnectionFailure() {
enableProtocol(Protocol.HTTP_2)

val request = Request(server1.url("/"))

server1.enqueue(refusedStream)
server1.enqueue(refusedStream)
server2.enqueue(bodyResponse)

dns[server1.hostName] = listOf(ipv6, ipv4)
socketFactory[ipv6] = server1.inetSocketAddress
socketFactory[ipv4] = server2.inetSocketAddress

client = client.newBuilder()
.fastFallback(false)
.apply {
retryOnConnectionFailure = true
}
.build()

executeSynchronously(request)
.assertBody("body")

assertThat(client.routeDatabase.failedRoutes).isEmpty()
// TODO check if we expect a second request to server1, before attempting server2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think that’s the behavior we should test for, though we could create an issue and change that)

assertThat(server1.requestCount).isEqualTo(2)
assertThat(server2.requestCount).isEqualTo(1)
}

@Test
fun http2OneBadHostOneGoodNoRetryOnConnectionFailureFastFallback() {
enableProtocol(Protocol.HTTP_2)

val request = Request(server1.url("/"))

server1.enqueue(refusedStream)
server2.enqueue(bodyResponse)

dns[server1.hostName] = listOf(ipv6, ipv4)
socketFactory[ipv6] = server1.inetSocketAddress
socketFactory[ipv4] = server2.inetSocketAddress

client = client.newBuilder()
.fastFallback(true)
.apply {
retryOnConnectionFailure = false
}
.build()

executeSynchronously(request)
.assertFailureMatches("stream was reset: REFUSED_STREAM")

assertThat(client.routeDatabase.failedRoutes).isEmpty()
assertThat(server1.requestCount).isEqualTo(1)
assertThat(server2.requestCount).isEqualTo(0)
}

@Test
fun http2OneBadHostOneGoodRetryOnConnectionFailureFastFallback() {
enableProtocol(Protocol.HTTP_2)

val request = Request(server1.url("/"))

server1.enqueue(refusedStream)
server1.enqueue(refusedStream)
server2.enqueue(bodyResponse)

dns[server1.hostName] = listOf(ipv6, ipv4)
socketFactory[ipv6] = server1.inetSocketAddress
socketFactory[ipv4] = server2.inetSocketAddress

client = client.newBuilder()
.fastFallback(true)
.apply {
retryOnConnectionFailure = true
}
.build()

executeSynchronously(request)
.assertBody("body")

assertThat(client.routeDatabase.failedRoutes).isEmpty()
// TODO check if we expect a second request to server1, before attempting server2
assertThat(server1.requestCount).isEqualTo(2)
assertThat(server2.requestCount).isEqualTo(1)
}

@Test
fun http2OneBadHostRetryOnConnectionFailure() {
enableProtocol(Protocol.HTTP_2)

val request = Request(server1.url("/"))

server1.enqueue(refusedStream)
server1.enqueue(refusedStream)

dns[server1.hostName] = listOf(ipv6)
socketFactory[ipv6] = server1.inetSocketAddress

client = client.newBuilder()
.fastFallback(false)
.apply {
retryOnConnectionFailure = true
}
.build()

executeSynchronously(request)
.assertFailureMatches("stream was reset: REFUSED_STREAM")

assertThat(client.routeDatabase.failedRoutes).isEmpty()
assertThat(server1.requestCount).isEqualTo(1)
}

@Test
fun http2OneBadHostRetryOnConnectionFailureFastFallback() {
enableProtocol(Protocol.HTTP_2)

val request = Request(server1.url("/"))

server1.enqueue(refusedStream)
server1.enqueue(refusedStream)

dns[server1.hostName] = listOf(ipv6)
socketFactory[ipv6] = server1.inetSocketAddress

client = client.newBuilder()
.fastFallback(true)
.apply {
retryOnConnectionFailure = true
}
.build()

executeSynchronously(request)
.assertFailureMatches("stream was reset: REFUSED_STREAM")

assertThat(client.routeDatabase.failedRoutes).isEmpty()
assertThat(server1.requestCount).isEqualTo(1)
}

private fun enableProtocol(protocol: Protocol) {
enableTls()
client = client.newBuilder()
.protocols(listOf(protocol, Protocol.HTTP_1_1))
.build()
server1.protocols = client.protocols
server2.protocols = client.protocols
}

private fun enableTls() {
client = client.newBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager
)
.hostnameVerifier(RecordingHostnameVerifier())
.build()
server1.useHttps(handshakeCertificates.sslSocketFactory())
server2.useHttps(handshakeCertificates.sslSocketFactory())
}

private fun executeSynchronously(request: Request): RecordedResponse {
val call = client.newCall(request)
return try {
val response = call.execute()
val bodyString = response.body.string()
RecordedResponse(request, response, null, bodyString, null)
} catch (e: IOException) {
RecordedResponse(request, null, null, null, e)
}
}
}