Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I replace a type parameter in an expression tree?

I'd like to be able to write a generic expression that a user can use to describe how he wants to do a conversion accross a family of types.

The expression might look something like:

Expression<Func<PlaceHolder,object>> sample = 
x=> (object)EqualityComparer<PlaceHolder>.GetHashCode(x)

I would like to convert it into::

Expression<Func<Foo,object>> sample = 
x=> (object)EqualityComparer<Foo>.GetHashCode(x)

I can just visit the expression, and replace PlaceHolder parameter with x, but then i can't resolve the generic type call.

The expression is user provided, and you can't assign a generic method to an expression.

The end result is always to return an object and the expression will always be from T=>object. I will compile a new expression for any object that the default rule will replace.

Here is my existing code which works, but it seems very complicated.

// ReSharper disable once InconsistentNaming
// By design this is supposed to look like a generic parameter.
public enum TEnum : long
{
}

internal sealed class EnumReplacer : ExpressionVisitor
{
    private Type ReplacePlaceHolder(Type type)
    {
        if (type.IsByRef)
        {
            return ReplacePlaceHolder(type.GetElementType()).MakeByRefType();
        }

        if (type.IsArray)
        {
            // expressionTrees can only deal with 1d arrays.
            return ReplacePlaceHolder(type.GetElementType()).MakeArrayType();
        }

        if (type.IsGenericType)
        {
            var typeDef = type.GetGenericTypeDefinition();
            var args = Array.ConvertAll(type.GetGenericArguments(), t => ReplacePlaceHolder(t));
            return typeDef.MakeGenericType(args);
        }

        if (type == typeof(TEnum))
        {
            return _enumParam.Type;
        }

        return type;
    }

    private MethodBase ReplacePlaceHolder(MethodBase method)
    {
        var newCandidate = method;
        var currentParams = method.IsGenericMethod ? ((MethodInfo)method).GetGenericMethodDefinition().GetParameters() : method.GetParameters();
        // ReSharper disable once PossibleNullReferenceException
        if (method.DeclaringType.IsGenericType)
        {
            var newType = ReplacePlaceHolder(method.DeclaringType);
            var methodCandidates = newType.GetMembers()
                .OfType<MethodBase>()
                .Where(x => x.Name == method.Name
                            && x.IsStatic == method.IsStatic
                            && x.IsGenericMethod == method.IsGenericMethod).ToArray();

            // grab the first method that wins. Not 100% correct, but close enough. 
            // yes an evil person could define a class like this::
            // class C<T>{
            //     public object Foo<T>(T b){return null;}
            //     public object Foo(PlaceHolderEnum b){return new object();}
            // }
            // my code would prefer the former, where as C#6 likes the later.
            newCandidate = methodCandidates.First(m => TestParameters(m, currentParams));
        }

        if (method.IsGenericMethod)
        {
            var genericArgs = method.GetGenericArguments();
            genericArgs = Array.ConvertAll(genericArgs, temp => ReplacePlaceHolder(temp));
            newCandidate = ((MethodInfo)newCandidate).GetGenericMethodDefinition().MakeGenericMethod(genericArgs);
        }
        return newCandidate;
    }
    private Expression ReplacePlaceHolder(MethodBase method, Expression target, ReadOnlyCollection<Expression> arguments)
    {
        // no point in not doing this.
        var newArgs = Visit(arguments);

        if (target != null)
        {
            target = Visit(target);
        }

        var newCandidate = ReplacePlaceHolder(method);

        MethodInfo info = newCandidate as MethodInfo;
        if (info != null)
        {
            return Expression.Call(target, info, newArgs);
        }
        return Expression.New((ConstructorInfo)newCandidate, newArgs);
    }

    private bool TestParameters(MethodBase candidate, ParameterInfo[] currentParams)
    {
        var candidateParams = candidate.GetParameters();
        if (candidateParams.Length != currentParams.Length) return false;
        for (int i = 0; i < currentParams.Length; i++)
        {
            // the names should match.
            if (currentParams[i].Name != candidateParams[i].Name) return false;

            var curType = currentParams[i].ParameterType;
            var candidateType = candidateParams[i].ParameterType;

            // Either they are the same generic type arg, or they are the same type after replacements.
            if (!((curType.IsGenericParameter &&
                  curType.GenericParameterPosition == candidateType.GenericParameterPosition)
                  || ReplacePlaceHolder(curType) == candidateType))
            {
                return false;
            }
        }
        return true;
    }

    private readonly ParameterExpression _enumParam;

    public EnumReplacer(ParameterExpression enumParam)
    {
        _enumParam = enumParam;
    }

    protected override Expression VisitParameter(ParameterExpression node)
    {
        if (node.Type == typeof(TEnum))
        {
            return _enumParam;
        }

        if (node.Type == typeof(TypeCode))
        {
            return Expression.Constant(Type.GetTypeCode(_enumParam.Type));
        }

        return base.VisitParameter(node);
    }

    protected override Expression VisitUnary(UnaryExpression node)
    {
        if (node.NodeType == ExpressionType.Convert || node.NodeType == ExpressionType.ConvertChecked)
        {
            var t = ReplacePlaceHolder(node.Type);
            // this isn't perfect. The compiler loves inserting random casts. To be protective and offer the most range, TEnum should be a long.
            var method = node.Method == null ? null : ReplacePlaceHolder(node.Method);
            return node.NodeType == ExpressionType.ConvertChecked
                ? Expression.ConvertChecked(Visit(node.Operand), t, (MethodInfo) method)
                : Expression.Convert(Visit(node.Operand), t, (MethodInfo) method);
        }
        if (node.Operand.Type == typeof(TEnum))
        {
            var operand = Visit(node.Operand);

            return node.Update(operand);
        }

        return base.VisitUnary(node);
    }

    private MemberInfo ReplacePlaceHolder(MemberInfo member)
    {
        if (member.MemberType == MemberTypes.Method || member.MemberType == MemberTypes.Constructor)
        {
            return ReplacePlaceHolder((MethodBase) member);
        }
        var newType = ReplacePlaceHolder(member.DeclaringType);
        var newMember = newType.GetMembers().First(x => x.Name == member.Name);
        return newMember;
    }

    protected override Expression VisitNewArray(NewArrayExpression node)
    {
        var children = Visit(node.Expressions);
        // Despite returning T[], it expects T.
        var type = ReplacePlaceHolder(node.Type.GetElementType());
        return Expression.NewArrayInit(type, children);
    }

    protected override MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding node)
    {
        var newMember = ReplacePlaceHolder(node.Member);
        var bindings = node.Bindings.Select(x => VisitMemberBinding(x));
        return Expression.MemberBind(newMember, bindings);
    }

    protected override MemberListBinding VisitMemberListBinding(MemberListBinding node)
    {
        var prop = ReplacePlaceHolder(node.Member);
        var inits = node.Initializers.Select(x => VisitElementInit(x));
        return Expression.ListBind(prop, inits);
    }

    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        return ReplacePlaceHolder(node.Method, node.Object, node.Arguments);
    }

    protected override MemberAssignment VisitMemberAssignment(MemberAssignment node)
    {
        var expr = Visit(node.Expression);
        var prop = ReplacePlaceHolder(node.Member);
        return Expression.Bind(prop, expr);
    }

    protected override ElementInit VisitElementInit(ElementInit node)
    {
        var method = ReplacePlaceHolder(node.AddMethod);
        var args = Visit(node.Arguments);
        return Expression.ElementInit((MethodInfo)method, args);
    }
    protected override Expression VisitNew(NewExpression node)
    {
        return ReplacePlaceHolder(node.Constructor, null, node.Arguments);
    }

    protected override Expression VisitConstant(ConstantExpression node)
    {
        // replace typeof expression
        if (node.Type == typeof(Type) && (Type)node.Value == typeof(TEnum))
        {
            return Expression.Constant(_enumParam.Type);
        }
        // explicit usage of default(TEnum) or (TEnum)456
        if (node.Type == typeof(TEnum))
        {
            return Expression.Constant(Enum.ToObject(_enumParam.Type, node.Value));
        }

        return base.VisitConstant(node);
    }
}

Usage is like so::

class Program
{
    public class Holder
    {
        public int Foo { get; set; }
    }
    public class Foo<T1,T2> : IEnumerable
    {
        public object GenericMethod<TM, TM2>(TM2 blarg) => blarg.ToString();

        public IList<Foo<T1, T2>> T { get; set; } = new List<Foo<T1, T2>>();

        public T1 Prop { get; set; }
        public void Add(int i) { }
        public Holder Holder { get; set; } = new Holder {};

        public IEnumerator GetEnumerator()
        {
            throw new NotImplementedException();
        }
    }

    public enum LongEnum:ulong
    {
    }

    static void Main(string[] args)
    {
        Expression<Func<TEnum, TypeCode, object>> evilTest = (x,t) =>
                TypeCode.UInt64 == t
                    ? (object)new Dictionary<TEnum, TypeCode>().TryGetValue(checked((x - 407)), out t)
                    : new Foo<string, TEnum> { Holder = {Foo =6}, T = new []
                    {
                        new Foo<string, TEnum>
                        {
                            T = {
                                new Foo<string, TEnum>{1,2,3,4,5,6,7,8,9,10,11,12}
                            }
                        },
                        new Foo<string, TEnum>
                        {
                            Prop = $"What up hello? {args}"
                        }
                    }}.GenericMethod<string, TEnum>(x);
        Console.WriteLine(evilTest);
        var p = Expression.Parameter(typeof(LongEnum), "long");
        var expressionBody = new EnumReplacer(p).Visit(evilTest.Body);

        var q = Expression.Lambda<Func<LongEnum, object>>(expressionBody, p);
        var func =q.Compile();
        var res = func.Invoke((LongEnum)1234567890123Ul);
like image 609
Michael B Avatar asked Nov 08 '22 09:11

Michael B


1 Answers

Modifying an existing expression tree just to change a type used within it seems like a fools errand. You're going to run into more problems with that particularly if you ever try to do operations on the objects of that type.

But why are you getting all caught up in changing an existing tree? You're trying to parameterize a type which calls for generics. Just create a generic method (where the type is your parameter) that returns the desired expression.

Expression<Func<T, object>> CreateConverter<T>() =>
    (T x) => EqualityComparer<T>.Default.GetHashCode(x);

No need for creating fake placeholder types, the generic type parameter is your placeholder.

If you need this to be pluggable, place the method in an interface and users would provide implementations that do the conversion.

like image 171
Jeff Mercado Avatar answered Nov 14 '22 21:11

Jeff Mercado