Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Replace parameter to point to nested parameter in lambda expression

I can successfully replace simple parameter types in a lambda expression thanks to some answers on a previous question but I cannot figure out how to replace parameters from an incoming lambda to a nested parameter.

Consider the following objects:

public class DtoColour {

    public DtoColour(string name)
    {
        Name = name;
    }

    public string Name { get; set; }

    public ICollection<DtoFavouriteColour> FavouriteColours { get; set; }
}

public class DtoPerson
{
    public DtoPerson(string firstName, string lastName)
    {
        FirstName = firstName;
        LastName = lastName;
        FavouriteColours = new Collection<DtoFavouriteColour>();
    }

    public string FirstName { get; private set; }

    public string LastName { get; private set; }

    public ICollection<DtoFavouriteColour> FavouriteColours { get; set; }
}

public class DtoFavouriteColour
{
    public DtoColour Colour { get; set; }

    public DtoPerson Person { get; set; }
}

public class DomainColour {

    public DomainColour(string name)
    {
        Name = name;
    }

    public string Name { get; set; }

    public ICollection<DomainPerson> People { get; set; }
}

public class DomainPerson {

    public DomainPerson(string firstName, string lastName)
    {
        FirstName = firstName;
        LastName = lastName;
        Colours = new Collection<DomainColour>();
    }

    public string FirstName { get; private set; }

    public string LastName { get; private set; }

    public ICollection<DomainColour> Colours { get; set; }
}

and a Repository:

public class ColourRepository {

    private IList<DtoColour> Colours { get; set; } 

    public ColourRepository()
    {
        var favColours = new Collection<DtoFavouriteColour>
        {
            new DtoFavouriteColour() { Person = new DtoPerson("Peter", "Parker") },
            new DtoFavouriteColour() { Person = new DtoPerson("John", "Smith") },
            new DtoFavouriteColour() { Person = new DtoPerson("Joe", "Blogs") }
        };
        Colours = new List<DtoColour>
        {
            new DtoColour("Red") { FavouriteColours = favColours },
            new DtoColour("Blue"),
            new DtoColour("Yellow")
        };
    }

    public IEnumerable<DomainColour> GetWhere(Expression<Func<DomainColour, bool>> predicate)
    {
        var coonvertedPred = MyExpressionVisitor.Convert(predicate);
        return Colours.Where(coonvertedPred).Select(c => new DomainColour(c.Name)).ToList();
    }
}

and finally an expression visitor which should convert the predicate into the correct one for the Dto Models

public class MyExpressionVisitor : ExpressionVisitor
{
    private ReadOnlyCollection<ParameterExpression> _parameters;

    public static Func<DtoColour, bool> Convert<T>(Expression<T> root)
    {
        var visitor = new MyExpressionVisitor();
        var expression = (Expression<Func<DtoColour, bool>>)visitor.Visit(root);
        return expression.Compile();
    }

    protected override Expression VisitParameter(ParameterExpression node)
    {
        var param = _parameters?.FirstOrDefault(p => p.Name == node.Name);

        if (param != null)
        {
            return param;
        }

        if(node.Type == typeof(DomainColour))
        {
            return Expression.Parameter(typeof(DtoColour), node.Name);
        }

        if (node.Type == typeof(DomainPerson))
        {
            return Expression.Parameter(typeof(DtoFavouriteColour), node.Name);
        }

        return node;
    }

    protected override Expression VisitLambda<T>(Expression<T> node)
    {
        _parameters = VisitAndConvert<ParameterExpression>(node.Parameters, "VisitLambda");
        return Expression.Lambda(Visit(node.Body), _parameters);
    }

    protected override Expression VisitMember(MemberExpression node)
    {
        var exp = Visit(node.Expression);

        if (node.Member.DeclaringType == typeof(DomainColour))
        {
            if (node.Type == typeof(ICollection<DomainPerson>))
            {
                return Expression.MakeMemberAccess(exp, typeof(DtoColour).GetProperty("FavouriteColours"));
            }

            return Expression.MakeMemberAccess(exp, typeof(DtoColour).GetProperty(node.Member.Name));
        }

        if (node.Member.DeclaringType == typeof(DomainPerson))
        {
            var nested = Expression.MakeMemberAccess(exp, typeof(DtoFavouriteColour).GetProperty("Person"));
            return Expression.MakeMemberAccess(nested, typeof(DtoPerson).GetProperty(node.Member.Name));
        }

        return base.VisitMember(node);
    }
}

Currently I get the following Exception

[System.ArgumentException: Expression of type 'System.Collections.Generic.ICollection1[ExpressionVisitorTests.DtoFavouriteColour]' cannot be used for parameter of type 'System.Collections.Generic.IEnumerable1[ExpressionVisitorTests.DomainPerson]' of method 'Boolean Any[DomainPerson](System.Collections.Generic.IEnumerable1[ExpressionVisitorTests.DomainPerson], System.Func2[ExpressionVisitorTests.DomainPerson,System.Boolean])']

Here is a dotnetfiddle of it not working.

Thank in advance for any help.

like image 498
Jake Aitchison Avatar asked Jul 21 '16 08:07

Jake Aitchison


2 Answers

After some more searching I came across this answer by John Skeet which has led to me coming up with a working solution which involves adding an override for VisitMethodCall method on the ExpressionVisitor to replace the original MethodInfo with a new one for the correct type of collection.

protected override Expression VisitMethodCall(MethodCallExpression node)
{
    if (node.Method.DeclaringType == typeof(Enumerable) && node.Arguments[0].Type == typeof(ICollection<DomainPerson>))
    {
        Expression obj = Visit(node.Object);
        IEnumerable<Expression> args = Visit(node.Arguments);
        if (obj != node.Object || args != node.Arguments)
        {
            var generic = typeof(Enumerable).GetMethods()
                            .Where(m => m.Name == node.Method.Name)
                            .Where(m => m.GetParameters().Length == node.Arguments.Count)
                            .Single();
            var constructed = generic.MakeGenericMethod(typeof(DtoFavouriteColour));
            return Expression.Call(obj, constructed, args);
        }
    }
    return node;
}

I also needed to make sure my reference to the _parameters collection wasn't replaced by nested calls to VisitLambda<T> which might happen whilst visiting node.Body.

protected override Expression VisitLambda<T>(Expression<T> node)
{
    var parameters = VisitAndConvert(node.Parameters, "VisitLambda");

    // ensure parameters set but dont let original reference 
    // be overidden by nested calls
    _parameters = parameters;

    return Expression.Lambda(Visit(node.Body), parameters);
}

See dotnetfiddle for fully working solution.

If anyone has a better/more elegant solution please add an answer for me to mark.

like image 75
Jake Aitchison Avatar answered Oct 16 '22 08:10

Jake Aitchison


You already solved the concrete issue, so I can't say if what I'm going to propose you is better/more elegant, but for sure is a bit more generic (removed the concrete types/properties/assumptions), hence can be reused for translating similar expressions from different model types.

Here is the code:

public class ExpressionMap
{
    private Dictionary<Type, Type> typeMap = new Dictionary<Type, Type>();
    private Dictionary<MemberInfo, Expression> memberMap = new Dictionary<MemberInfo, Expression>();
    public ExpressionMap Add<TFrom, TTo>()
    {
        typeMap.Add(typeof(TFrom), typeof(TTo));
        return this;
    }
    public ExpressionMap Add<TFrom, TFromMember, TTo, TToMember>(Expression<Func<TFrom, TFromMember>> from, Expression<Func<TTo, TToMember>> to)
    {
        memberMap.Add(((MemberExpression)from.Body).Member, to.Body);
        return this;
    }
    public Expression Map(Expression source) => new MapVisitor { map = this }.Visit(source);

    private class MapVisitor : ExpressionVisitor
    {
        public ExpressionMap map;
        private Dictionary<Type, ParameterExpression> parameterMap = new Dictionary<Type, ParameterExpression>();
        protected override Expression VisitLambda<T>(Expression<T> node)
        {
            return Expression.Lambda(Visit(node.Body), node.Parameters.Select(Map));
        }
        protected override Expression VisitParameter(ParameterExpression node) => Map(node);
        protected override Expression VisitMember(MemberExpression node)
        {
            var expression = Visit(node.Expression);
            if (expression == node.Expression)
                return node;
            Expression mappedMember;
            if (map.memberMap.TryGetValue(node.Member, out mappedMember))
                return Visit(mappedMember);
            return Expression.PropertyOrField(expression, node.Member.Name);
        }
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            if (node.Object == null && node.Method.IsGenericMethod)
            {
                // Static generic method
                var arguments = Visit(node.Arguments);
                var genericArguments = node.Method.GetGenericArguments().Select(Map).ToArray();
                var method = node.Method.GetGenericMethodDefinition().MakeGenericMethod(genericArguments);
                return Expression.Call(method, arguments);
            }
            return base.VisitMethodCall(node);
        }
        private Type Map(Type type)
        {
            Type mappedType;
            return map.typeMap.TryGetValue(type, out mappedType) ? mappedType : type;
        }
        private ParameterExpression Map(ParameterExpression parameter)
        {
            var mappedType = Map(parameter.Type);
            ParameterExpression mappedParameter;
            if (!parameterMap.TryGetValue(mappedType, out mappedParameter))
                parameterMap.Add(mappedType, mappedParameter = Expression.Parameter(mappedType, parameter.Name));
            return mappedParameter;
        }
    }
}

and the usage for your concrete example:

public IEnumerable<DomainColour> GetWhere(Expression<Func<DomainColour, bool>> predicate)
{
    var map = new ExpressionMap()
        .Add<DomainColour, DtoColour>()
        .Add((DomainColour c) => c.People, (DtoColour c) => c.FavouriteColours.Select(fc => fc.Person))
        .Add<DomainPerson, DtoPerson>();
    var mappedPredicate = ((Expression<Func<DtoColour, bool>>)map.Map(predicate));
    return Colours.Where(mappedPredicate.Compile()).Select(c => new DomainColour(c.Name)).ToList();
}

As you can see, it allows you to define a simple mapping from one type to another, and optionally from member of one type to member/expression of another type (as soon as they are compatible) using "fluent" syntax with lambda expressions. The members that have no specified mapping are mapped by name as in the original code.

Once the mappings are defined, the actual processing of course is done by a custom ExpressionVisitor, similar to yours. The difference is that it maps and consolidates ParameterExpressions by type, and also translates every static generic method, thus should work also with Queryable and similar.

like image 40
Ivan Stoev Avatar answered Oct 16 '22 07:10

Ivan Stoev