Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Look if a method is called inside a method using reflection

I'm working with reflection and currently have a MethodBody. How do I check if a specific method is called inside the MethodBody?

Assembly assembly = Assembly.Load("Module1");
Type type = assembly.GetType("Module1.ModuleInit");
MethodInfo mi = type.GetMethod("Initialize");
MethodBody mb = mi.GetMethodBody();
like image 208
Peter Avatar asked Apr 21 '11 08:04

Peter


1 Answers

Use Mono.Cecil. It is a single standalone assembly that will work on Microsoft .NET as well as Mono. (I think I used version 0.6 or thereabouts back when I wrote the code below)

Say you have a number of assemblies

IEnumerable<AssemblyDefinition> assemblies;

Get these using AssemblyFactory (load one?)

The following snippet would enumerate all usages of methods in all types of these assemblies

methodUsages = assemblies
            .SelectMany(assembly => assembly.MainModule.Types.Cast<TypeDefinition>())
            .SelectMany(type => type.Methods.Cast<MethodDefinition>())
            .Where(method => null != method.Body) // allow abstracts and generics
            .SelectMany(method => method.Body.Instructions.Cast<Instruction>())
            .Select(instr => instr.Operand)
            .OfType<MethodReference>();

This will return all references to methods (so including use in reflection, or to construct expressions which may or may not be executed). As such, this is probably not very useful, except to show you what can be done with the Cecil API without too much of an effort :)

Note that this sample assumes a somewhat older version of Cecil (the one in mainstream mono versions). Newer versions are

  • more succinct (by using strong typed generic collections)
  • faster

Of course in your case you could have a single method reference as starting point. Say you want to detect when 'mytargetmethod' can actually be called directly inside 'startingpoint':

MethodReference startingpoint; // get it somewhere using Cecil
MethodReference mytargetmethod; // what you are looking for

bool isCalled = startingpoint    
    .GetOriginalMethod() // jump to original (for generics e.g.)
    .Resolve()           // get the definition from the IL image
    .Body.Instructions.Cast<Instruction>()
    .Any(i => i.OpCode == OpCodes.Callvirt && i.Operand == (mytargetmethod));

Call Tree Search

Here is a working snippet that allows you to recursively search to (selected) methods that call each other (indirectly).

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Mono.Cecil;
using Mono.Cecil.Cil;

namespace StackOverflow
{
    /*
     * breadth-first lazy search across a subset of the call tree rooting in startingPoint
     * 
     * methodSelect selects the methods to recurse into
     * resultGen generates the result objects to be returned by the enumerator
     * 
     */
    class CallTreeSearch<T> : BaseCodeVisitor, IEnumerable<T> where T : class
    {
        private readonly Func<MethodReference, bool> _methodSelect;
        private readonly Func<Instruction, Stack<MethodReference>, T> _transform;

        private readonly IEnumerable<MethodDefinition> _startingPoints;
        private readonly IDictionary<MethodDefinition, Stack<MethodReference>> _chain = new Dictionary<MethodDefinition, Stack<MethodReference>>();
        private readonly ICollection<MethodDefinition> _seen = new HashSet<MethodDefinition>(new CompareMembers<MethodDefinition>());
        private readonly ICollection<T> _results = new HashSet<T>();
        private Stack<MethodReference> _currentStack;

        private const int InfiniteRecursion = -1;
        private readonly int _maxrecursiondepth;
        private bool _busy;

        public CallTreeSearch(IEnumerable<MethodDefinition> startingPoints,
                              Func<MethodReference, bool> methodSelect,
                              Func<Instruction, Stack<MethodReference>, T> resultGen)
            : this(startingPoints, methodSelect, resultGen, InfiniteRecursion)
        {

        }

        public CallTreeSearch(IEnumerable<MethodDefinition> startingPoints,
                              Func<MethodReference, bool> methodSelect,
                              Func<Instruction, Stack<MethodReference>, T> resultGen,
                              int maxrecursiondepth)
        {
            _startingPoints = startingPoints.ToList();

            _methodSelect = methodSelect;
            _maxrecursiondepth = maxrecursiondepth;
            _transform = resultGen;
        }

        public override void VisitMethodBody(MethodBody body)
        {
            _seen.Add(body.Method); // avoid infinite recursion
            base.VisitMethodBody(body);
        }

        public override void VisitInstructionCollection(InstructionCollection instructions)
        {
            foreach (Instruction instr in instructions)
                VisitInstruction(instr);

            base.VisitInstructionCollection(instructions);
        }

        public override void VisitInstruction(Instruction instr)
        {
            T result = _transform(instr, _currentStack);
            if (result != null)
                _results.Add(result);

            var methodRef = instr.Operand as MethodReference; // TODO select calls only?
            if (methodRef != null && _methodSelect(methodRef))
            {
                var resolve = methodRef.Resolve();
                if (null != resolve && !(_chain.ContainsKey(resolve) || _seen.Contains(resolve)))
                    _chain.Add(resolve, new Stack<MethodReference>(_currentStack.Reverse()));
            }

            base.VisitInstruction(instr);
        }

        public IEnumerator<T> GetEnumerator()
        {
            lock (this) // not multithread safe
            {
                if (_busy)
                    throw new InvalidOperationException("CallTreeSearch enumerator is not reentrant");
                _busy = true;

                try
                {
                    int recursionLevel = 0;
                    ResetToStartingPoints();

                    while (_chain.Count > 0 &&
                           ((InfiniteRecursion == _maxrecursiondepth) || recursionLevel++ <= _maxrecursiondepth))
                    {

                        // swapout the collection because Visitor will modify
                        var clone = new Dictionary<MethodDefinition, Stack<MethodReference>>(_chain);
                        _chain.Clear();

                        foreach (var call in clone.Where(call => HasBody(call.Key)))
                        {
//                          Console.Error.Write("\rCallTreeSearch: level #{0}, scanning {1,-20}\r", recursionLevel, call.Key.Name + new string(' ',21));
                            _currentStack = call.Value;
                            _currentStack.Push(call.Key);
                            try
                            {
                                _results.Clear();
                                call.Key.Body.Accept(this); // grows _chain and _results
                            }
                            finally
                            {
                                _currentStack.Pop();
                            }
                            _currentStack = null;

                            foreach (var result in _results)
                                yield return result;
                        }
                    }
                }
                finally
                {
                    _busy = false;
                }
            }
        }

        private void ResetToStartingPoints()
        {
            _chain.Clear();
            _seen.Clear();
            foreach (var startingPoint in _startingPoints)
            {
                _chain.Add(startingPoint, new Stack<MethodReference>());
                _seen.Add(startingPoint);
            }
        }

        private static bool HasBody(MethodDefinition methodDefinition)
        {
            return !(methodDefinition.IsAbstract || methodDefinition.Body == null);
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }
    }

    internal class CompareMembers<T> : IComparer<T>, IEqualityComparer<T>
        where T: class, IMemberReference
    {
        public int Compare(T x, T y)
        { return StringComparer.InvariantCultureIgnoreCase.Compare(KeyFor(x), KeyFor(y)); }

        public bool Equals(T x, T y)
        { return KeyFor(x).Equals(KeyFor(y)); }

        private static string KeyFor(T mr)
        { return null == mr ? "" : String.Format("{0}::{1}", mr.DeclaringType.FullName, mr.Name); }

        public int GetHashCode(T obj)
        { return KeyFor(obj).GetHashCode(); }
    }
}

Notes

  • do some error handling a Resolve() (I have an extension method TryResolve() for the purpose)
  • optionally select usages of MethodReferences in a call operation (call, calli, callvirt ...) only (see //TODO)

Typical usage:

public static IEnumerable<T> SearchCallTree<T>(this TypeDefinition startingClass,
                                               Func<MethodReference, bool> methodSelect,
                                               Func<Instruction, Stack<MethodReference>, T> resultFunc,
                                               int maxdepth)
    where T : class
{
    return new CallTreeSearch<T>(startingClass.Methods.Cast<MethodDefinition>(), methodSelect, resultFunc, maxdepth);
}

public static IEnumerable<T> SearchCallTree<T>(this MethodDefinition startingMethod,
                                               Func<MethodReference, bool> methodSelect,
                                               Func<Instruction, Stack<MethodReference>, T> resultFunc,
                                               int maxdepth)
    where T : class
{
    return new CallTreeSearch<T>(new[] { startingMethod }, methodSelect, resultFunc, maxdepth); 
}

// Actual usage:
private static IEnumerable<TypeUsage> SearchMessages(TypeDefinition uiType, bool onlyConstructions)
{
    return uiType.SearchCallTree(IsBusinessCall,
           (instruction, stack) => DetectRequestUsage(instruction, stack, onlyConstructions));
}

Note the completiion of a function like DetectRequestUsage to suite your needs is completely and entirely up to you (edit: but see here). You can do whatever you want, and don't forget: you'll have the complete statically analyzed call stack at your disposal, so you actually can do pretty neat things with all that information!

like image 106
sehe Avatar answered Sep 20 '22 00:09

sehe