From ae588a4db5f6165945a93eea81d2b003662f76f7 Mon Sep 17 00:00:00 2001 From: Tony Hallett Date: Thu, 22 Jul 2021 15:26:59 +0100 Subject: [PATCH] fix #1188 --- src/Moq/Protected/ProtectedMock.cs | 64 ++++++++++++++++++------- tests/Moq.Tests/ProtectedMockFixture.cs | 25 +++++++++- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/src/Moq/Protected/ProtectedMock.cs b/src/Moq/Protected/ProtectedMock.cs index 362e03583..c8d26b5b6 100644 --- a/src/Moq/Protected/ProtectedMock.cs +++ b/src/Moq/Protected/ProtectedMock.cs @@ -467,26 +467,15 @@ private static Type[] ToArgTypes(object[] args) { types[index] = ((MethodCallExpression)expr).Method.ReturnType; } + else if(ItRefAnyField(expr) is var itRefAnyField && itRefAnyField != null) + { + types[index] = itRefAnyField.FieldType.MakeByRefType(); + } else if (expr.NodeType == ExpressionType.MemberAccess) { var member = (MemberExpression)expr; if (member.Member is FieldInfo field) { - // Test for special case: `It.Ref.IsAny` - if (field.Name == nameof(It.Ref.IsAny)) - { - var fieldDeclaringType = field.DeclaringType; - if (fieldDeclaringType.IsGenericType) - { - var fieldDeclaringTypeDefinition = fieldDeclaringType.GetGenericTypeDefinition(); - if (fieldDeclaringTypeDefinition == typeof(It.Ref<>)) - { - types[index] = field.FieldType.MakeByRefType(); - continue; - } - } - } - types[index] = field.FieldType; } else if (member.Member is PropertyInfo property) @@ -510,16 +499,57 @@ private static Type[] ToArgTypes(object[] args) return types; } + private static FieldInfo ItRefAnyField(Expression expr) + { + FieldInfo itRefAnyField = null; + + if (expr.NodeType == ExpressionType.MemberAccess) + { + var member = (MemberExpression)expr; + if (member.Member is FieldInfo field) + { + if (field.Name == nameof(It.Ref.IsAny)) + { + var fieldDeclaringType = field.DeclaringType; + if (fieldDeclaringType.IsGenericType) + { + var fieldDeclaringTypeDefinition = fieldDeclaringType.GetGenericTypeDefinition(); + if (fieldDeclaringTypeDefinition == typeof(It.Ref<>)) + { + itRefAnyField = field; + } + } + } + } + } + + return itRefAnyField; + } + private static Expression ToExpressionArg(Type type, object arg) { - if (arg is LambdaExpression lambda) + if (arg is LambdaExpression lambda && !typeof(LambdaExpression).IsAssignableFrom(type)) { return lambda.Body; } if (arg is Expression expression) { - return expression; + if (!typeof(Expression).IsAssignableFrom(type)) + { + return expression; + } + + if (expression.IsMatch(out _)) + { + return expression; + } + + if (ItRefAnyField(expression) != null) + { + return expression; + } + } return Expression.Constant(arg, type); diff --git a/tests/Moq.Tests/ProtectedMockFixture.cs b/tests/Moq.Tests/ProtectedMockFixture.cs index bc3af2ec5..95fd57399 100644 --- a/tests/Moq.Tests/ProtectedMockFixture.cs +++ b/tests/Moq.Tests/ProtectedMockFixture.cs @@ -2,7 +2,7 @@ // All rights reserved. Licensed under the BSD 3-Clause License; see License.txt. using System; - +using System.Linq.Expressions; using Moq.Protected; using Xunit; @@ -855,6 +855,23 @@ public void DoesNotThrowIfVerifySetPropertyTimesReached() mock.Protected().VerifySet("ProtectedValue", Times.Exactly(2), ItExpr.IsAny()); } + [Fact] + public void SetupShouldWorkWithExpressionTypes() + { + var mock = new Mock(); + + var expression = Expression.Constant(1); + ConstantExpression setExpression = null; + mock.Protected().SetupSet("ExpressionProperty", expression).Callback(expr => setExpression = expr); + mock.Object.SetExpressionProperty(expression); + Assert.Equal(expression, setExpression); + + ConstantExpression setExpression2 = null; + mock.Protected().SetupSet("ExpressionProperty", ItExpr.Is( e => (int)e.Value == 2)).Callback(expr => setExpression2 = expr); + mock.Object.SetExpressionProperty(Expression.Constant(2)); + Assert.NotNull(setExpression2); + } + public class MethodOverloads { public void ExecuteDo(int a, int b) @@ -952,6 +969,12 @@ protected virtual FooBase Overloaded(MyDerived myBase) public class FooBase { + protected virtual ConstantExpression ExpressionProperty { get; set; } + public void SetExpressionProperty(ConstantExpression expression) + { + ExpressionProperty = expression; + } + public virtual string PublicValue { get; set; } protected internal virtual string ProtectedInternalValue { get; set; }