diff --git a/Src/Idioms/GuardClauseAssertion.cs b/Src/Idioms/GuardClauseAssertion.cs index e638bb372..31f68c4a8 100644 --- a/Src/Idioms/GuardClauseAssertion.cs +++ b/Src/Idioms/GuardClauseAssertion.cs @@ -1,4 +1,5 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Globalization; @@ -99,7 +100,7 @@ public override void Verify(ConstructorInfo constructorInfo) constructorInfo = this.ResolveUnclosedGenericType(constructorInfo); var method = new ConstructorMethod(constructorInfo); - this.DoVerify(method, false, false); + this.DoVerify(method, isReturnValueDeferable: false, isReturnValueTask: false); } /// @@ -128,52 +129,25 @@ public override void Verify(MethodInfo methodInfo) methodInfo = this.ResolveUnclosedGenericMethod(methodInfo); var method = this.CreateMethod(methodInfo); + var returnType = methodInfo.ReturnType; - var isReturnValueIterator = - typeof(System.Collections.IEnumerable).IsAssignableFrom(methodInfo.ReturnType) || - typeof(System.Collections.IEnumerator).IsAssignableFrom(methodInfo.ReturnType); + // According to MSDN method with yield could have only 4 possible types: + // https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/keywords/yield + var isReturnTypePossibleDefferable = returnType == typeof(IEnumerable) + || returnType == typeof(IEnumerator) + || (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + || (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IEnumerator<>)); - var isReturnValueNonDeferred = IsNonDeferredEnumerable(methodInfo.ReturnType); - var isReturnValueDeferable = isReturnValueIterator && !isReturnValueNonDeferred; + var containsByRefArgs = methodInfo.GetParameters().Select(p => p.ParameterType).Any(t => t.IsByRef); + + var isReturnValueDeferable = isReturnTypePossibleDefferable && !containsByRefArgs; var isReturnValueTask = - typeof(System.Threading.Tasks.Task).IsAssignableFrom(methodInfo.ReturnType); + typeof(System.Threading.Tasks.Task).IsAssignableFrom(returnType); this.DoVerify(method, isReturnValueDeferable, isReturnValueTask); } - private static bool IsNonDeferredEnumerable(Type t) - { - var nonGenericCollectionTypes = new[] - { - typeof(System.Collections.ICollection), - typeof(System.Collections.IList), - typeof(System.Collections.IDictionary) - }; - - var genericCollectionTypeGtds = new[] - { - typeof(IList<>), - typeof(ICollection<>), - typeof(IDictionary<,>) - }; - - var isGeneric = t.IsGenericType; - - var gtdInterfaces = (isGeneric && !t.IsInterface) - ? t.GetInterfaces() - .Where(i => i.IsGenericType) - .Select(i => i.GetGenericTypeDefinition()) - .ToArray() - : (isGeneric && t.IsInterface) - ? new[] { t.GetGenericTypeDefinition() } - : null; - - return t.IsArray || - nonGenericCollectionTypes.Any(gt => gt.IsAssignableFrom(t)) || - (isGeneric && genericCollectionTypeGtds.Any(gtd => gtdInterfaces.Contains(gtd))); - } - /// /// Verifies that a property has appropriate Guard Clauses in place. /// diff --git a/Src/Idioms/MethodInfoExtensions.cs b/Src/Idioms/MethodInfoExtensions.cs index 7e7bf0d01..1ddf9014e 100644 --- a/Src/Idioms/MethodInfoExtensions.cs +++ b/Src/Idioms/MethodInfoExtensions.cs @@ -7,28 +7,28 @@ internal static class MethodInfoExtensions { internal static bool IsEqualsMethod(this MethodInfo method) { - return string.Equals(method.Name, "Equals", StringComparison.Ordinal) + return string.Equals(method.Name, nameof(Equals), StringComparison.Ordinal) && method.GetParameters().Length == 1 && method.ReturnType == typeof(bool); } internal static bool IsGetHashCodeMethod(this MethodInfo method) { - return string.Equals(method.Name, "GetHashCode", StringComparison.Ordinal) + return string.Equals(method.Name, nameof(GetHashCode), StringComparison.Ordinal) && method.GetParameters().Length == 0 && method.ReturnType == typeof(int); } internal static bool IsToString(this MethodInfo method) { - return string.Equals(method.Name, "ToString", StringComparison.Ordinal) + return string.Equals(method.Name, nameof(ToString), StringComparison.Ordinal) && method.GetParameters().Length == 0 && method.ReturnType == typeof(string); } internal static bool IsGetType(this MethodInfo method) { - return string.Equals(method.Name, "GetType", StringComparison.Ordinal) + return string.Equals(method.Name, nameof(GetType), StringComparison.Ordinal) && method.GetParameters().Length == 0 && method.ReturnType == typeof(Type); } diff --git a/Src/IdiomsUnitTest/GuardClauseAssertionTest.cs b/Src/IdiomsUnitTest/GuardClauseAssertionTest.cs index f849f0248..2c3ff768e 100644 --- a/Src/IdiomsUnitTest/GuardClauseAssertionTest.cs +++ b/Src/IdiomsUnitTest/GuardClauseAssertionTest.cs @@ -487,6 +487,7 @@ public IEnumerator GetValues(Guid someGuid) [InlineData(typeof(ClassWithEnumerableNonDeferredArrayListMissingGuard))] [InlineData(typeof(ClassWithEnumerableNonDeferredStackMissingGuard))] [InlineData(typeof(ClassWithEnumerableNonDeferredReadOnlyCollectionBaseMissingGuard))] + [InlineData(typeof(ClassWithEnumerableNonDeferredStringMissingGuard))] public void VerifyMethodWithNonDeferredMissingGuardThrowsExceptionWithoutDeferredMessage( Type type) { @@ -632,6 +633,14 @@ private class ClassWithEnumerableNonDeferredGenericDictionaryMissingGuard } } + private class ClassWithEnumerableNonDeferredStringMissingGuard + { + public string GetValues(string someString) + { + return someString; + } + } + private interface IHaveNoImplementers { } @@ -1104,8 +1113,10 @@ public void VerifyNonProperlyGuardedPropertyThrowsException() [Theory] [InlineData(nameof(NonProperlyGuardedClass.Method), "Guard Clause prevented it, however")] - [InlineData(nameof(NonProperlyGuardedClass.DeferredMethod), "deferred")] - [InlineData(nameof(NonProperlyGuardedClass.AnotherDeferredMethod), "deferred")] + [InlineData(nameof(NonProperlyGuardedClass.DeferredMethodReturningGenericEnumerable), "deferred")] + [InlineData(nameof(NonProperlyGuardedClass.DeferredMethodReturningGenericEnumerator), "deferred")] + [InlineData(nameof(NonProperlyGuardedClass.DeferredMethodReturningNonGenericEnumerable), "deferred")] + [InlineData(nameof(NonProperlyGuardedClass.DeferredMethodReturningNonGenericEnumerator), "deferred")] public void VerifyNonProperlyGuardedMethodThrowsException(string methodName, string expectedMessage) { var sut = new GuardClauseAssertion(new Fixture()); @@ -1626,11 +1637,9 @@ public NestedGenericParameterTestType(Func>, T1[ private class NonProperlyGuardedClass { - private const string InvalidParamName = "invalidParamName"; - public NonProperlyGuardedClass(object argument) { - if (argument == null) throw new ArgumentNullException(InvalidParamName); + if (argument == null) throw new ArgumentNullException("invalid parameter name"); } public object Property @@ -1638,25 +1647,39 @@ public object Property get => null; set { - if (value == null) throw new ArgumentNullException(InvalidParamName); + if (value == null) throw new ArgumentNullException("invalid parameter name"); } } public void Method(object argument) { - if (argument == null) throw new ArgumentNullException(InvalidParamName); + if (argument == null) throw new ArgumentNullException("invalid parameter name"); + } + + public IEnumerable DeferredMethodReturningGenericEnumerable(object argument) + { + if (argument == null) throw new ArgumentNullException(nameof(argument)); + + yield return argument; + } + + public IEnumerator DeferredMethodReturningGenericEnumerator(object argument) + { + if (argument == null) throw new ArgumentNullException(nameof(argument)); + + yield return argument; } - public IEnumerable DeferredMethod(object argument) + public IEnumerable DeferredMethodReturningNonGenericEnumerable(object argument) { - if (argument == null) throw new ArgumentNullException(InvalidParamName); + if (argument == null) throw new ArgumentNullException(nameof(argument)); yield return argument; } - public IEnumerator AnotherDeferredMethod(object argument) + public IEnumerator DeferredMethodReturningNonGenericEnumerator(object argument) { - if (argument == null) throw new ArgumentNullException(InvalidParamName); + if (argument == null) throw new ArgumentNullException(nameof(argument)); yield return argument; }