I've written a DSL and a compiler that generates a .NET expression tree from it. All expressions within the tree are side-effect-free and the expression is guaranteed to be a "non-statement" expression (no locals, loops, blocks etc.). (Edit: The tree may include literals, property accesses, standard operators and function calls - which may be doing fancy things like memoization inside, but are externally side-effect free).
Now I would like to perform the "Common sub-expression elimination" optimization on it.
For example, given a tree corresponding to the C# lambda:
foo => (foo.Bar * 5 + foo.Baz * 2 > 7)
|| (foo.Bar * 5 + foo.Baz * 2 < 3)
|| (foo.Bar * 5 + 3 == foo.Xyz)
...I would like to generate the tree-equivalent of (ignore the fact that some of the short-circuiting semantics are being ignored):
foo =>
{
var local1 = foo.Bar * 5;
// Notice that this local depends on the first one.
var local2 = local1 + foo.Baz * 2;
// Notice that no unnecessary locals have been generated.
return local2 > 7 || local2 < 3 || (local1 + 3 == foo.Xyz);
}
I'm familiar with writing expression-visitors, but the algorithm for this optimization isn't immediately obvious to me - I could of course find "duplicates" within a tree, but there's obviously some trick to analyzing the dependencies within and between sub-trees to eliminate sub-expressions efficiently and correctly.
I looked for algorithms on Google but they seem quite complicated to implement quickly. Also, they seem very "general" and don't necessarily take the simplicity of the trees I have into account.
Common Subexpression Elimination is an optimization that searches for instances of identical expressions, and replaces them with a single variable holding the computed value. For instance, consider the following code: a <- 1 / (8 + 8 + 1 + 9 * 1 ^ 8) b <- (8 + 8 + 1 + 9 * 1 ^ 8) * 2.
The Directed Acyclic Graph (DAG) facilitates the transformation of basic blocks. DAG is an efficient method for identifying common sub-expressions. It demonstrates how the statement's computed value is used in subsequent statements.
Common Sub-expression elimination. Dead code elimination. Code movement. Strength reduction.
You are doing unnecessary work, common sub-expression elimination is the job of the jitter optimizer. Let's take your example and look at the generated code. I wrote it like this:
static void Main(string[] args) {
var lambda = new Func<Foo, bool>(foo =>
(foo.Bar * 5 + foo.Baz * 2 > 7)
|| (foo.Bar * 5 + foo.Baz * 2 < 3)
|| (foo.Bar * 5 + 3 == foo.Xyz));
var obj = new Foo() { Bar = 1, Baz = 2, Xyz = 3 };
var result = lambda(obj);
Console.WriteLine(result);
}
}
class Foo {
public int Bar { get; internal set; }
public int Baz { get; internal set; }
public int Xyz { get; internal set; }
}
The x86 jitter generated this machine code for the lambda expression:
006526B8 push ebp ; prologue
006526B9 mov ebp,esp
006526BB push esi
006526BC mov esi,dword ptr [ecx+4] ; esi = foo.Bar
006526BF lea esi,[esi+esi*4] ; esi = 5 * foo.Bar
006526C2 mov edx,dword ptr [ecx+8] ; edx = foo.Baz
006526C5 add edx,edx ; edx = 2 * foo.Baz
006526C7 lea eax,[esi+edx] ; eax = 5 * foo.Bar + 2 * foo.Baz
006526CA cmp eax,7 ; > 7 test
006526CD jg 006526E7 ; > 7 then return true
006526CF add edx,esi ; HERE!!
006526D1 cmp edx,3 ; < 3 test
006526D4 jl 006526E7 ; < 3 then return true
006526D6 add esi,3 ; HERE!!
006526D9 mov eax,esi
006526DB cmp eax,dword ptr [ecx+0Ch] ; == foo.Xyz test
006526DE sete al ; convert to bool
006526E1 movzx eax,al
006526E4 pop esi ; epilogue
006526E5 pop ebp
006526E6 ret
006526E7 mov eax,1
006526EC pop esi
006526ED pop ebp
006526EE ret
I marked the places in the code where the foo.Bar * 5
sub-expression was eliminated with HERE. Notable is how it did not eliminate the foo.Bar * 5 + foo.Baz * 2
sub-expression, the addition was performed again at address 006526CF. There is a good reason for that, the x86 jitter doesn't have enough registers available to store the intermediary result. If you look at the machine code generated by the x64 jitter then you do see it eliminated, the r9 register stores it.
This ought to give enough reasons to reconsider your intend. You are doing work that doesn't need to be done. And not only that, you are liable to generate worse code than the jitter will generate since you don't have the luxury to estimate the CPU register budget.
Don't do this.
You're correct in noting this is not a trivial problem.
The classical way that compilers handle it is a Directed Acyclic Graph (DAG) representation of the expression. The DAG is built in the same manner as the abstract syntax tree (and can be built by traversing the AST - perhaps a job for the expression visitor; I don't know much of C# libraries), except that a dictionary of previously emitted subgraphs is maintained. Before generating any given node type with given children, the dictionary is consulted to see if one already exists. Only if this check fails is a new one created, then added to the dictionary.
Since now a node may descend from multiple parents, the result is a DAG.
Then the DAG is traversed depth first to generate code. Since common sub-expressions are now represented by a single node, the value is only computed once and stored in a temp for other expressions emitted later in the code generation to use. If the original code contains assignments, this phase gets complicated. Since your trees are side-effect free, the DAG ought to be the most straightforward way to solve your problem.
As I recall, the coverage of DAGs in the Dragon book is particularly nice.
As others have noted, if your trees will ultimately be compiled by an existing compiler, it's kind of futile to redo what's already there.
Addition
I had some Java code laying around from a student project (I teach) so hacked up a little example of how this works. It's too long to post, but see the Gist here.
Running it on your input prints the DAG below. The numbers in parens are (unique id, DAG parent count). The parent count is needed to decide when to compute the local temp variables and when to just use the expression for a node.
Binary OR (27,1)
lhs:
Binary OR (19,1)
lhs:
Binary GREATER (9,1)
lhs:
Binary ADD (7,2)
lhs:
Binary MULTIPLY (3,2)
lhs:
Id 'Bar' (1,1)
rhs:
Number 5 (2,1)
rhs:
Binary MULTIPLY (6,1)
lhs:
Id 'Baz' (4,1)
rhs:
Number 2 (5,1)
rhs:
Number 7 (8,1)
rhs:
Binary LESS (18,1)
lhs:
ref to Binary ADD (7,2)
rhs:
Number 3 (17,2)
rhs:
Binary EQUALS (26,1)
lhs:
Binary ADD (24,1)
lhs:
ref to Binary MULTIPLY (3,2)
rhs:
ref to Number 3 (17,2)
rhs:
Id 'Xyz' (25,1)
Then it generates this code:
t3 = (Bar) * (5);
t7 = (t3) + ((Baz) * (2));
return (((t7) > (7)) || ((t7) < (3))) || (((t3) + (3)) == (Xyz));
You can see that the temp var numbers correspond to DAG nodes. You could make the code generator more complex to get rid of the unnecessary parentheses, but I'll leave that for others.
Make a SortedDictionary<Expression, object>
that can compare arbitrary Expression
s.
(You can define your own arbitrary comparison function here -- for example, you can lexicographically compare the types of the expressions, and if they compare equal then you can compare the children one by one.)
Go through all the leaves and add them to the dictionary; if they already exist, then they're duplicates, so merge them.
(This is also a good time to emit code -- such as creating a new variable -- for this leaf if it's the first instance of it; you can then store the emitted code inside the object
value in the dictionary.)
Then go through the parents of all the previous leaves and add them to the dictionary; if they already exist, then they're duplicates, so merge them.
Keep on going up level by level until you reach the root.
Now you know what all the duplicates are, and where they occur, and you've generated code for all of them.
Disclaimer: I have never tackled a problem like this, I'm just throwing out an idea that seems reasonably efficient:
For every node in the tree have some sort of signature. A hash should do, collisions can be dealt with. The signature must map all Foo.Bar entries to the same value.
Traverse the tree (O(n)) building a list of signatures of INTERNAL nodes (ignore leaves), sort on a combined key of expression size and then signature (O(n log n)). Take the most common item of the smallest expression in the list (O(n)) and go through replacing the expression with a local variable. (Check that they are truly matches at this time just in case we had a hash collision. B)
Repeat this until you accomplish nothing. This can't possibly run more than n/2 times, thus bounding the whole operation to O(n^2 log n).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With