From 6ee58fa9739062fc894dc8011a2275a485f864c2 Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Fri, 23 Dec 2022 19:02:51 +1000 Subject: [PATCH 1/3] Add tests for multiple routes --- .../kotlin/mockwebserver3/MockWebServer.kt | 14 +- .../internal/connection/RouteDatabase.kt | 12 +- .../jvmTest/java/okhttp3/RouteFailureTest.kt | 240 ++++++++++++++++++ 3 files changed, 258 insertions(+), 8 deletions(-) create mode 100644 okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt diff --git a/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt b/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt index 4f5a65ad60a5..4c2732385e6f 100644 --- a/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt +++ b/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt @@ -152,10 +152,16 @@ class MockWebServer : Closeable { val hostName: String get() { before() - return inetSocketAddress!!.address.canonicalHostName + return _inetSocketAddress!!.address.canonicalHostName } - private var inetSocketAddress: InetSocketAddress? = null + private var _inetSocketAddress: InetSocketAddress? = null + + val inetSocketAddress: InetSocketAddress + get() { + before() + return InetSocketAddress(hostName, portField) + } /** * True if ALPN is used on incoming HTTPS connections to negotiate a protocol like HTTP/1.1 or @@ -206,7 +212,7 @@ class MockWebServer : Closeable { fun toProxyAddress(): Proxy { before() - val address = InetSocketAddress(inetSocketAddress!!.address.canonicalHostName, port) + val address = InetSocketAddress(_inetSocketAddress!!.address.canonicalHostName, port) return Proxy(Proxy.Type.HTTP, address) } @@ -383,7 +389,7 @@ class MockWebServer : Closeable { if (started) return started = true - this.inetSocketAddress = inetSocketAddress + this._inetSocketAddress = inetSocketAddress serverSocket = serverSocketFactory!!.createServerSocket() diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt index c7c7422fc39e..4685de667166 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt @@ -24,18 +24,22 @@ import okhttp3.Route * preferred. */ class RouteDatabase { - private val failedRoutes = mutableSetOf() + private val _failedRoutes = mutableSetOf() + + val failedRoutes: Set + @Synchronized get() = _failedRoutes.toSet() + /** Records a failure connecting to [failedRoute]. */ @Synchronized fun failed(failedRoute: Route) { - failedRoutes.add(failedRoute) + _failedRoutes.add(failedRoute) } /** Records success connecting to [route]. */ @Synchronized fun connected(route: Route) { - failedRoutes.remove(route) + _failedRoutes.remove(route) } /** Returns true if [route] has failed recently and should be avoided. */ - @Synchronized fun shouldPostpone(route: Route): Boolean = route in failedRoutes + @Synchronized fun shouldPostpone(route: Route): Boolean = route in _failedRoutes } diff --git a/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt b/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt new file mode 100644 index 000000000000..a4664038ef8d --- /dev/null +++ b/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt @@ -0,0 +1,240 @@ +package okhttp3 + + +import java.io.IOException +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Socket +import java.net.SocketAddress +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.SocketPolicy +import mockwebserver3.junit5.internal.MockWebServerInstance +import okhttp3.internal.http2.ErrorCode +import okhttp3.testing.Flaky +import okhttp3.testing.PlatformRule +import okhttp3.tls.internal.TlsUtil.localhost +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Flaky +class RouteFailureTest { + private lateinit var socketFactory: SpecificHostSocketFactory + private lateinit var client: OkHttpClient + + @RegisterExtension + val platform = PlatformRule() + + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + + private lateinit var server1: MockWebServer + private lateinit var server2: MockWebServer + + private var listener = RecordingEventListener() + + private val handshakeCertificates = localhost() + + val dns = FakeDns() + + val ipv4 = InetAddress.getByName("192.168.1.1") + val ipv6 = InetAddress.getByName("2606:2800:220:1:248:1893:25c8:1946") + + val refusedStream = MockResponse() + .setHttp2ErrorCode(ErrorCode.REFUSED_STREAM.httpCode) + .setSocketPolicy(SocketPolicy.RESET_STREAM_AT_START) + val bodyResponse = MockResponse().setBody("body") + + @BeforeEach + fun setUp( + server: MockWebServer, + @MockWebServerInstance("server2") server2: MockWebServer + ) { + this.server1 = server + this.server2 = server2 + + socketFactory = SpecificHostSocketFactory(InetSocketAddress(server.hostName, server.port)) + + client = clientTestRule.newClientBuilder() + .dns(dns) + .socketFactory(socketFactory) + .eventListenerFactory(clientTestRule.wrap(listener)) + .build() + } + + @Test + fun http2OneBadHostOneGoodNoRetryOnConnectionFailure() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(false) + .apply { + retryOnConnectionFailure = false + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + assertThat(server2.requestCount).isEqualTo(0) + } + + @Test + fun http2OneBadHostOneGoodRetryOnConnectionFailure() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(false) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertBody("body") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + // TODO check if we expect a second request to server1, before attempting server2 + assertThat(server1.requestCount).isEqualTo(2) + assertThat(server2.requestCount).isEqualTo(1) + } + + @Test + fun http2OneBadHostOneGoodNoRetryOnConnectionFailureFastFallback() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(true) + .apply { + retryOnConnectionFailure = false + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + assertThat(server2.requestCount).isEqualTo(0) + } + + @Test + fun http2OneBadHostOneGoodRetryOnConnectionFailureFastFallback() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(true) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertBody("body") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + // TODO check if we expect a second request to server1, before attempting server2 + assertThat(server1.requestCount).isEqualTo(2) + assertThat(server2.requestCount).isEqualTo(1) + } + + /** + * Tests that use this will fail unless boot classpath is set. Ex. `-Xbootclasspath/p:/tmp/alpn-boot-8.0.0.v20140317` + */ + private fun enableProtocol(protocol: Protocol) { + enableTls() + client = client.newBuilder() + .protocols(listOf(protocol, Protocol.HTTP_1_1)) + .build() + server1.protocols = client.protocols + server2.protocols = client.protocols + } + + private fun enableTls() { + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .build() + server1.useHttps(handshakeCertificates.sslSocketFactory()) + server2.useHttps(handshakeCertificates.sslSocketFactory()) + } + + private fun executeSynchronously(request: Request): RecordedResponse { + val call = client.newCall(request) + return try { + val response = call.execute() + val bodyString = response.body.string() + RecordedResponse(request, response, null, bodyString, null) + } catch (e: IOException) { + RecordedResponse(request, null, null, null, e) + } + } +} + +class SpecificHostSocketFactory( + val defaultAddress: InetSocketAddress? +) : DelegatingSocketFactory(getDefault()) { + private val hostMapping = mutableMapOf() + + /** Sets the results for `hostname`. */ + operator fun set( + requested: InetAddress, + real: InetSocketAddress + ) { + hostMapping[requested] = real + } + + override fun createSocket(): Socket { + return object : Socket() { + override fun connect(endpoint: SocketAddress?, timeout: Int) { + val requested = (endpoint as InetSocketAddress) + val inetSocketAddress = hostMapping[requested.address] ?: defaultAddress ?: requested + super.connect(inetSocketAddress, timeout) + } + } + } +} From 3e5a2507a9160574b0d138120cc6eb3c0104b7a2 Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Fri, 23 Dec 2022 19:11:22 +1000 Subject: [PATCH 2/3] Add tests for multiple routes --- mockwebserver/api/mockwebserver3.api | 1 + 1 file changed, 1 insertion(+) diff --git a/mockwebserver/api/mockwebserver3.api b/mockwebserver/api/mockwebserver3.api index 6c0b263acb26..2e32e3c41a82 100644 --- a/mockwebserver/api/mockwebserver3.api +++ b/mockwebserver/api/mockwebserver3.api @@ -86,6 +86,7 @@ public final class mockwebserver3/MockWebServer : java/io/Closeable { public final fun getBodyLimit ()J public final fun getDispatcher ()Lmockwebserver3/Dispatcher; public final fun getHostName ()Ljava/lang/String; + public final fun getInetSocketAddress ()Ljava/net/InetSocketAddress; public final fun getPort ()I public final fun getProtocolNegotiationEnabled ()Z public final fun getRequestCount ()I From 6bb66607d9be77b1e06c6facb54d5d9c4556ea9f Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Sat, 24 Dec 2022 10:12:07 +1000 Subject: [PATCH 3/3] Review comments --- .../okhttp3/SpecificHostSocketFactory.kt | 48 ++++++++ .../jvmTest/java/okhttp3/RouteFailureTest.kt | 103 ++++++++++++------ 2 files changed, 117 insertions(+), 34 deletions(-) create mode 100644 okhttp-testing-support/src/main/kotlin/okhttp3/SpecificHostSocketFactory.kt diff --git a/okhttp-testing-support/src/main/kotlin/okhttp3/SpecificHostSocketFactory.kt b/okhttp-testing-support/src/main/kotlin/okhttp3/SpecificHostSocketFactory.kt new file mode 100644 index 000000000000..d58d08ee50f1 --- /dev/null +++ b/okhttp-testing-support/src/main/kotlin/okhttp3/SpecificHostSocketFactory.kt @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2022 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Socket +import java.net.SocketAddress + +/** + * A [SocketFactory] that redirects connections to [defaultAddress] or specific overridden address via [set]. + */ +class SpecificHostSocketFactory( + val defaultAddress: InetSocketAddress? +) : DelegatingSocketFactory(getDefault()) { + private val hostMapping = mutableMapOf() + + /** Sets the [real] address for [requested]. */ + operator fun set( + requested: InetAddress, + real: InetSocketAddress + ) { + hostMapping[requested] = real + } + + override fun createSocket(): Socket { + return object : Socket() { + override fun connect(endpoint: SocketAddress?, timeout: Int) { + val requested = (endpoint as InetSocketAddress) + val inetSocketAddress = hostMapping[requested.address] ?: defaultAddress ?: requested + super.connect(inetSocketAddress, timeout) + } + } + } +} diff --git a/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt b/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt index a4664038ef8d..21789543772f 100644 --- a/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt +++ b/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt @@ -1,17 +1,28 @@ +/* + * Copyright (C) 2022 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package okhttp3 - import java.io.IOException import java.net.InetAddress import java.net.InetSocketAddress -import java.net.Socket -import java.net.SocketAddress import mockwebserver3.MockResponse import mockwebserver3.MockWebServer import mockwebserver3.SocketPolicy import mockwebserver3.junit5.internal.MockWebServerInstance import okhttp3.internal.http2.ErrorCode -import okhttp3.testing.Flaky import okhttp3.testing.PlatformRule import okhttp3.tls.internal.TlsUtil.localhost import org.assertj.core.api.Assertions.assertThat @@ -19,7 +30,6 @@ import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.RegisterExtension -@Flaky class RouteFailureTest { private lateinit var socketFactory: SpecificHostSocketFactory private lateinit var client: OkHttpClient @@ -39,8 +49,8 @@ class RouteFailureTest { val dns = FakeDns() - val ipv4 = InetAddress.getByName("192.168.1.1") - val ipv6 = InetAddress.getByName("2606:2800:220:1:248:1893:25c8:1946") + val ipv4 = InetAddress.getByName("203.0.113.1") + val ipv6 = InetAddress.getByName("2001:db8:ffff:ffff:ffff:ffff:ffff:1") val refusedStream = MockResponse() .setHttp2ErrorCode(ErrorCode.REFUSED_STREAM.httpCode) @@ -180,9 +190,58 @@ class RouteFailureTest { assertThat(server2.requestCount).isEqualTo(1) } - /** - * Tests that use this will fail unless boot classpath is set. Ex. `-Xbootclasspath/p:/tmp/alpn-boot-8.0.0.v20140317` - */ + @Test + fun http2OneBadHostRetryOnConnectionFailure() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + + dns[server1.hostName] = listOf(ipv6) + socketFactory[ipv6] = server1.inetSocketAddress + + client = client.newBuilder() + .fastFallback(false) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + } + + @Test + fun http2OneBadHostRetryOnConnectionFailureFastFallback() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + + dns[server1.hostName] = listOf(ipv6) + socketFactory[ipv6] = server1.inetSocketAddress + + client = client.newBuilder() + .fastFallback(true) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + } + private fun enableProtocol(protocol: Protocol) { enableTls() client = client.newBuilder() @@ -214,27 +273,3 @@ class RouteFailureTest { } } } - -class SpecificHostSocketFactory( - val defaultAddress: InetSocketAddress? -) : DelegatingSocketFactory(getDefault()) { - private val hostMapping = mutableMapOf() - - /** Sets the results for `hostname`. */ - operator fun set( - requested: InetAddress, - real: InetSocketAddress - ) { - hostMapping[requested] = real - } - - override fun createSocket(): Socket { - return object : Socket() { - override fun connect(endpoint: SocketAddress?, timeout: Int) { - val requested = (endpoint as InetSocketAddress) - val inetSocketAddress = hostMapping[requested.address] ?: defaultAddress ?: requested - super.connect(inetSocketAddress, timeout) - } - } - } -}