diff --git a/java/src/org/openqa/selenium/support/decorators/DefaultDecorated.java b/java/src/org/openqa/selenium/support/decorators/DefaultDecorated.java index 0eace50a2e200..cc5735a776f1f 100644 --- a/java/src/org/openqa/selenium/support/decorators/DefaultDecorated.java +++ b/java/src/org/openqa/selenium/support/decorators/DefaultDecorated.java @@ -23,9 +23,9 @@ public class DefaultDecorated implements Decorated { private final T original; - private final WebDriverDecorator decorator; + private final WebDriverDecorator decorator; - public DefaultDecorated(final T original, final WebDriverDecorator decorator) { + public DefaultDecorated(final T original, final WebDriverDecorator decorator) { this.original = original; this.decorator = decorator; } @@ -34,7 +34,7 @@ public final T getOriginal() { return original; } - public final WebDriverDecorator getDecorator() { + public final WebDriverDecorator getDecorator() { return decorator; } diff --git a/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java b/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java index 1a832ce37300b..6320288f10e4f 100644 --- a/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java +++ b/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java @@ -34,12 +34,15 @@ import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * This class helps to create decorators for instances of {@link WebDriver} and @@ -179,22 +182,35 @@ * */ @Beta -public class WebDriverDecorator { +public class WebDriverDecorator { - private Decorated decorated; + private final Class targetWebDriverClass; - public final WebDriver decorate(WebDriver original) { + private Decorated decorated; + + @SuppressWarnings("unchecked") + public WebDriverDecorator() + { + this((Class) WebDriver.class); + } + + public WebDriverDecorator(Class targetClass) + { + this.targetWebDriverClass = targetClass; + } + + public final T decorate(T original) { Require.nonNull("WebDriver", original); decorated = createDecorated(original); - return createProxy(decorated, WebDriver.class); + return createProxy(decorated, targetWebDriverClass); } - public Decorated getDecoratedDriver() { + public Decorated getDecoratedDriver() { return decorated; } - public Decorated createDecorated(WebDriver driver) { + public Decorated createDecorated(T driver) { return new DefaultDecorated<>(driver, this); } @@ -248,7 +264,7 @@ public Object onError( private Object decorateResult(Object toDecorate) { if (toDecorate instanceof WebDriver) { - return createProxy(getDecoratedDriver(), WebDriver.class); + return createProxy(getDecoratedDriver(), targetWebDriverClass); } if (toDecorate instanceof WebElement) { return createProxy(createDecorated((WebElement) toDecorate), WebElement.class); @@ -293,7 +309,8 @@ protected final Z createProxy(final Decorated decorated, Class clazz) || decoratedInterfaces.contains(method.getDeclaringClass())) { return method.invoke(decorated, args); } - if (originalInterfaces.contains(method.getDeclaringClass())) { + if (originalInterfaces.contains(method.getDeclaringClass()) || findInterfaceByMethod(originalInterfaces, + method).isPresent()) { decorated.beforeCall(method, args); Object result = decorated.call(method, args); decorated.afterCall(method, result, args); @@ -316,7 +333,7 @@ protected final Z createProxy(final Decorated decorated, Class clazz) Class[] allInterfacesArray = allInterfaces.toArray(new Class[0]); Class proxy = new ByteBuddy() - .subclass(Object.class) + .subclass(clazz.isInterface() ? Object.class : clazz) .implement(allInterfacesArray) .method(ElementMatchers.any()) .intercept(InvocationHandlerAdapter.of(handler)) @@ -332,6 +349,18 @@ protected final Z createProxy(final Decorated decorated, Class clazz) } } + private Optional> findInterfaceByMethod(Set> interfaces, Method method) { + String methodName = method.getName(); + Class[] methodParameterTypes = method.getParameterTypes(); + return interfaces.stream().filter(i -> doesMethodBelongToClass(i, methodName, methodParameterTypes)).findFirst(); + } + + private boolean doesMethodBelongToClass(Class clazz, String methodName, Class[] methodParameterTypes) { + return Stream.of(clazz.getMethods()).anyMatch( + m -> m.getName().equals(methodName) && Arrays.equals(m.getParameterTypes(), methodParameterTypes) + ); + } + static Set> extractInterfaces(final Object object) { return extractInterfaces(object.getClass()); } diff --git a/java/src/org/openqa/selenium/support/events/EventFiringDecorator.java b/java/src/org/openqa/selenium/support/events/EventFiringDecorator.java index f0f2cd63ecc62..df1f16348b90d 100644 --- a/java/src/org/openqa/selenium/support/events/EventFiringDecorator.java +++ b/java/src/org/openqa/selenium/support/events/EventFiringDecorator.java @@ -155,7 +155,7 @@ * extending {@link WebDriverDecorator}, not by creating sophisticated listeners. */ @Beta -public class EventFiringDecorator extends WebDriverDecorator { +public class EventFiringDecorator extends WebDriverDecorator { private static final Logger logger = Logger.getLogger(EventFiringDecorator.class.getName()); diff --git a/java/test/org/openqa/selenium/remote/AugmenterTest.java b/java/test/org/openqa/selenium/remote/AugmenterTest.java index 24fea5ae019b3..28ff61f9bd0e7 100644 --- a/java/test/org/openqa/selenium/remote/AugmenterTest.java +++ b/java/test/org/openqa/selenium/remote/AugmenterTest.java @@ -429,7 +429,7 @@ public Capabilities getCapabilities() { } } - private static class ModifyTitleWebDriverDecorator extends WebDriverDecorator { + private static class ModifyTitleWebDriverDecorator extends WebDriverDecorator { @Override public Object call(Decorated target, Method method, Object[] args) throws Throwable { diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedAlertTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedAlertTest.java index ea98a1c2f3797..e63db49a9ec3b 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedAlertTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedAlertTest.java @@ -49,7 +49,7 @@ public Fixture() { originalDriver = mock(WebDriver.class); when(originalSwitch.alert()).thenReturn(original); when(originalDriver.switchTo()).thenReturn(originalSwitch); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = decoratedDriver.switchTo().alert(); } } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedNavigationTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedNavigationTest.java index 7ce29e3a11853..77c9f22a83ee2 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedNavigationTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedNavigationTest.java @@ -45,7 +45,7 @@ public Fixture() { original = mock(WebDriver.Navigation.class); originalDriver = mock(WebDriver.class); when(originalDriver.navigate()).thenReturn(original); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = decoratedDriver.navigate(); } } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedOptionsTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedOptionsTest.java index 9f3e1d09a6eea..f7dcf69f9932d 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedOptionsTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedOptionsTest.java @@ -50,7 +50,7 @@ public Fixture() { original = mock(WebDriver.Options.class); originalDriver = mock(WebDriver.class); when(originalDriver.manage()).thenReturn(original); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = decoratedDriver.manage(); } } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedRemoteWebDriverTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedRemoteWebDriverTest.java index fc1af450dc397..aa7816d651dd2 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedRemoteWebDriverTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedRemoteWebDriverTest.java @@ -49,10 +49,11 @@ public void shouldImplementWrapsDriverToProvideAccessToUnderlyingDriver() { RemoteWebDriver originalDriver = mock(RemoteWebDriver.class); when(originalDriver.getSessionId()).thenReturn(sessionId); - WebDriver decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + RemoteWebDriver decoratedDriver = new WebDriverDecorator<>(RemoteWebDriver.class).decorate(originalDriver); - RemoteWebDriver underlying = (RemoteWebDriver) ((WrapsDriver) decoratedDriver).getWrappedDriver(); + assertThat(decoratedDriver.getSessionId()).isEqualTo(sessionId); + RemoteWebDriver underlying = (RemoteWebDriver) ((WrapsDriver) decoratedDriver).getWrappedDriver(); assertThat(underlying.getSessionId()).isEqualTo(sessionId); } @@ -60,7 +61,7 @@ public void shouldImplementWrapsDriverToProvideAccessToUnderlyingDriver() { public void cannotConvertDecoratedToRemoteWebDriver() { RemoteWebDriver originalDriver = mock(RemoteWebDriver.class); - WebDriver decorated = new WebDriverDecorator().decorate(originalDriver); + WebDriver decorated = new WebDriverDecorator<>().decorate(originalDriver); assertThat(decorated).isNotInstanceOf(RemoteWebDriver.class); } @@ -69,7 +70,7 @@ public void cannotConvertDecoratedToRemoteWebDriver() { public void decoratedDriversShouldImplementWrapsDriver() { RemoteWebDriver originalDriver = mock(RemoteWebDriver.class); - WebDriver decorated = new WebDriverDecorator().decorate(originalDriver); + WebDriver decorated = new WebDriverDecorator<>().decorate(originalDriver); assertThat(decorated).isInstanceOf(WrapsDriver.class); } @@ -84,7 +85,7 @@ public void decoratedElementsShouldImplementWrapsElement() { when(originalDriver.findElement(any())).thenReturn(originalElement); - WebDriver decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + WebDriver decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); WebElement element = decoratedDriver.findElement(By.id("test")); assertThat(element).isInstanceOf(WrapsElement.class); @@ -100,7 +101,7 @@ public void canConvertDecoratedRemoteWebElementToJson() { when(originalDriver.findElement(any())).thenReturn(originalElement); - WebDriver decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + WebDriver decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); WebElement element = decoratedDriver.findElement(By.id("test")); diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedSwitchToTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedSwitchToTest.java index 682209b41b919..8595e37599583 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedSwitchToTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedSwitchToTest.java @@ -48,7 +48,7 @@ public Fixture() { original = mock(WebDriver.TargetLocator.class); originalDriver = mock(WebDriver.class); when(originalDriver.switchTo()).thenReturn(original); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = decoratedDriver.switchTo(); } } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedTimeoutsTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedTimeoutsTest.java index 4ed0fc630b3af..d26d7e177dc6e 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedTimeoutsTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedTimeoutsTest.java @@ -48,7 +48,7 @@ public Fixture() { originalDriver = mock(WebDriver.class); when(originalOptions.timeouts()).thenReturn(original); when(originalDriver.manage()).thenReturn(originalOptions); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = decoratedDriver.manage().timeouts(); } } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedVirtualAuthenticatorTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedVirtualAuthenticatorTest.java index 1b87f047f9db5..e4641d3f26a26 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedVirtualAuthenticatorTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedVirtualAuthenticatorTest.java @@ -53,7 +53,7 @@ public Fixture() { originalDriver = mock( WebDriver.class, withSettings().extraInterfaces(HasVirtualAuthenticator.class)); when(((HasVirtualAuthenticator) originalDriver).addVirtualAuthenticator(any())).thenReturn(original); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = ((HasVirtualAuthenticator) decoratedDriver) .addVirtualAuthenticator(new VirtualAuthenticatorOptions()); } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java index cfc7e5cc10229..2b7dc0aafd0a4 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java @@ -62,7 +62,7 @@ public Fixture() { .extraInterfaces(JavascriptExecutor.class, TakesScreenshot.class, Interactive.class, HasVirtualAuthenticator.class)); originalAuth = mock(VirtualAuthenticator.class); - decorated = new WebDriverDecorator().decorate(original); + decorated = new WebDriverDecorator<>().decorate(original); when(((HasVirtualAuthenticator) original).addVirtualAuthenticator(any())).thenReturn(originalAuth); } } @@ -85,9 +85,9 @@ public void canCompareDecorated() { WebDriver original1 = mock(WebDriver.class); WebDriver original2 = mock(WebDriver.class); - WebDriver decorated1 = new WebDriverDecorator().decorate(original1); - WebDriver decorated2 = new WebDriverDecorator().decorate(original1); - WebDriver decorated3 = new WebDriverDecorator().decorate(original2); + WebDriver decorated1 = new WebDriverDecorator<>().decorate(original1); + WebDriver decorated2 = new WebDriverDecorator<>().decorate(original1); + WebDriver decorated3 = new WebDriverDecorator<>().decorate(original2); assertThat(decorated1).isEqualTo(decorated2); assertThat(decorated1).isNotEqualTo(decorated3); @@ -100,7 +100,7 @@ public void canCompareDecorated() { @Test public void testHashCode() { WebDriver original = mock(WebDriver.class); - WebDriver decorated = new WebDriverDecorator().decorate(original); + WebDriver decorated = new WebDriverDecorator<>().decorate(original); assertThat(decorated.hashCode()).isEqualTo(original.hashCode()); } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedWebElementTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedWebElementTest.java index fcee8732420d1..f8a28cbb5c5d0 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedWebElementTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedWebElementTest.java @@ -56,7 +56,7 @@ public Fixture() { original = mock(WebElement.class); originalDriver = mock(WebDriver.class); when(originalDriver.findElement(any())).thenReturn(original); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = decoratedDriver.findElement(By.id("test")); } } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedWindowTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedWindowTest.java index 3901feac4e9a5..611f5fd07517d 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedWindowTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedWindowTest.java @@ -50,7 +50,7 @@ public Fixture() { originalDriver = mock(WebDriver.class); when(originalOptions.window()).thenReturn(original); when(originalDriver.manage()).thenReturn(originalOptions); - decoratedDriver = new WebDriverDecorator().decorate(originalDriver); + decoratedDriver = new WebDriverDecorator<>().decorate(originalDriver); decorated = decoratedDriver.manage().window(); } } diff --git a/java/test/org/openqa/selenium/support/decorators/IntegrationTest.java b/java/test/org/openqa/selenium/support/decorators/IntegrationTest.java index c77d0b863b9bc..79a86e276874e 100644 --- a/java/test/org/openqa/selenium/support/decorators/IntegrationTest.java +++ b/java/test/org/openqa/selenium/support/decorators/IntegrationTest.java @@ -35,7 +35,7 @@ @Category(UnitTests.class) public class IntegrationTest { - static class CountCalls extends WebDriverDecorator { + static class CountCalls extends WebDriverDecorator { int counterBefore = 0; int counterAfter = 0; diff --git a/java/test/org/openqa/selenium/support/decorators/InterfacesTest.java b/java/test/org/openqa/selenium/support/decorators/InterfacesTest.java index 91a448b3d34fb..4e03aeac977ca 100644 --- a/java/test/org/openqa/selenium/support/decorators/InterfacesTest.java +++ b/java/test/org/openqa/selenium/support/decorators/InterfacesTest.java @@ -37,7 +37,7 @@ public void shouldNotAddInterfacesNotAvailableInTheOriginalDriver() { WebDriver driver = mock(WebDriver.class); assertThat(driver).isNotInstanceOf(SomeOtherInterface.class); - WebDriver decorated = new WebDriverDecorator().decorate(driver); + WebDriver decorated = new WebDriverDecorator<>().decorate(driver); assertThat(decorated).isNotInstanceOf(SomeOtherInterface.class); } @@ -46,7 +46,7 @@ public void shouldRespectInterfacesAvailableInTheOriginalDriver() { WebDriver driver = mock(ExtendedDriver.class); assertThat(driver).isInstanceOf(SomeOtherInterface.class); - WebDriver decorated = new WebDriverDecorator().decorate(driver); + WebDriver decorated = new WebDriverDecorator<>().decorate(driver); assertThat(decorated).isInstanceOf(SomeOtherInterface.class); } }