Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

MethodHandles or LambdaMetafactory?

At my job we have a DSL for specfying mathematical formulas, that we later apply to a lot of points (in the millions).

As of today, we build an AST of the formula, and visit each node to produce what we call an "Evaluator". We then pass that evaluator the arguments of the formula, and for each point it does the computing.

For instance, we have that formula: x * (3 + y)

           ┌────┐
     ┌─────┤mult├─────┐
     │     └────┘     │
     │                │
  ┌──v──┐          ┌──v──┐
  │  x  │      ┌───┤ add ├──┐
  └─────┘      │   └─────┘  │
               │            │
            ┌──v──┐      ┌──v──┐
            │  3  │      │  y  │
            └─────┘      └─────┘

Our evaluator will emit "Evaluate" objects for each step.

This method is easy to program, but not very efficient.

So I started looking into method handles to build up a "composed" method handle to speed things up lately.

Something along this: I have my "Arithmetic" class with :

public class Arithmetics {

  public static double add(double a, double b){
      return a+b;
  }

  public static double mult(double a, double b){
      return a*b;
  }

}

And when building my AST I use MethodHandles.lookup() to directly get a handle on those and compose them. Something along these lines, but in a tree:

Method add = ArithmeticOperator.class.getDeclaredMethod("add", double.class, double.class);
Method mult = ArithmeticOperator.class.getDeclaredMethod("mult", double.class, double.class);
MethodHandle mh_add = lookup.unreflect(add);
MethodHandle mh_mult = lookup.unreflect(mult);
MethodHandle mh_add_3 = MethodHandles.insertArguments(mh_add, 3, plus_arg);
MethodHandle formula = MethodHandles.collectArguments(mh_mult, 1, mh_add_3); // formula is f(x,y) = x * (3 + y)

Sadly, I'm quite disapointed by the results. For instance, the actual construction of the method handle is very long (due to calls to MethodHandles::insertArguments and other such compositions functions), and the added speedup for evaluation only starts to make a difference after over 600k iterations.

At 10M iterations, the Method handle starts to really shine, but millions of iterations is not (yet?) a typical use case. We are more around 10k-1M, where the result is mixed.

Also, the actual computation is sped up, but by not so much (~2-10 times). I was expecting the thing to run a bit faster..

So anyway, I started scouring StackOverflow again, and saw the LambdaMetafactory threads like these: https://stackoverflow.com/a/19563000/389405

And I'm itching to start trying this. But before that, I'd like your input on some questions:

  • I need to be able to compose all those lambdas. MethodHandles provides lots of (slowish, admitedly) ways to do it, but I feel like lambdas have a stricter "interface", and I can't yet wrap my head on how to do that. Do you know how?

  • lambdas and method handles are quite interconnected, and I'm not sure that I will get a significant speedup. I see these results for simple lambdas: direct: 0,02s, lambda: 0,02s, mh: 0,35s, reflection: 0,40 but what about composed lambdas?

Thanks guys!

like image 358
Gui13 Avatar asked Jun 08 '16 10:06

Gui13


People also ask

What is LambdaMetafactory?

java.lang.Object. java.lang.invoke.LambdaMetafactory. public final class LambdaMetafactory extends Object. Methods to facilitate the creation of simple "function objects" that implement one or more interfaces by delegation to a provided MethodHandle , possibly after type adaptation and partial evaluation of arguments.

How lambda expression works in Java?

Lambda Expressions were added in Java 8. A lambda expression is a short block of code which takes in parameters and returns a value. Lambda expressions are similar to methods, but they do not need a name and they can be implemented right in the body of a method.


1 Answers

I think, for most practical cases, an immutable evaluation tree consisting of nodes fulfilling a particular interface or inheriting from a common evaluator base class, is unbeatable. HotSpot is capable of performing (aggressive) inlining, at least for subtrees, but has the freedom to decide how many nodes it will inline.

In contrast, generating explicit code for an entire tree, imposes the risk of exceeding the JVM’s thresholds, then, you have code that has no dispatch overhead, for sure, but might run interpreted all the time.

A tree of adapted MethodHandles starts like any other tree, but with a higher overhead. Whether its own optimization is capable of beating HotSpots own inlining strategy, is debatable. And as you noticed, it takes a long number of invocations, before that self-tuning kicks in. It seems, the thresholds accumulate in an unfortunate way for composed method handles.

To name one prominent example of the evaluation tree pattern, when you use Pattern.compile to prepare a regex matching operation, no bytecode nor native code will be generated, despite the method’s name may mislead into thinking into that direction. The internal representation is just an immutable tree of nodes, representing the combinations of the different kind of operations. It’s up to the JVMs optimizer to generate flattened code for it where it is considered beneficial.

Lambda expressions don’t change the game. They allow you to generate (small) classes fulfilling an interface and invoking a target method. You can use them to build an immutable evaluation tree and while this is unlikely to have a different performance than explicitly programmed evaluation node classes, it allows much simpler code:

public class Arithmetics {
    public static void main(String[] args) {
        // x * (3 + y)
        DoubleBinaryOperator func=op(MUL, X, op(ADD, constant(3), Y));
        System.out.println(func.applyAsDouble(5, 4));
        PREDEFINED_UNARY_FUNCTIONS.forEach((name, f) ->
            System.out.println(name+"(0.42) = "+f.applyAsDouble(0.42)));
        PREDEFINED_BINARY_FUNCTIONS.forEach((name, f) ->
            System.out.println(name+"(0.42,0.815) = "+f.applyAsDouble(0.42,0.815)));
        // sin(x)+cos(y)
        func=op(ADD,
            op(PREDEFINED_UNARY_FUNCTIONS.get("sin"), X),
            op(PREDEFINED_UNARY_FUNCTIONS.get("cos"), Y));
        System.out.println("sin(0.6)+cos(y) = "+func.applyAsDouble(0.6, 0.5));
    }
    public static DoubleBinaryOperator ADD = Double::sum;
    public static DoubleBinaryOperator SUB = (a,b) -> a-b;
    public static DoubleBinaryOperator MUL = (a,b) -> a*b;
    public static DoubleBinaryOperator DIV = (a,b) -> a/b;
    public static DoubleBinaryOperator REM = (a,b) -> a%b;

    public static <T> DoubleBinaryOperator op(
        DoubleUnaryOperator op, DoubleBinaryOperator arg1) {
        return (x,y) -> op.applyAsDouble(arg1.applyAsDouble(x,y));
    }
    public static DoubleBinaryOperator op(
        DoubleBinaryOperator op, DoubleBinaryOperator arg1, DoubleBinaryOperator arg2) {
        return (x,y)->op.applyAsDouble(arg1.applyAsDouble(x,y),arg2.applyAsDouble(x,y));
    }
    public static DoubleBinaryOperator X = (x,y) -> x, Y = (x,y) -> y;
    public static DoubleBinaryOperator constant(double value) {
        return (x,y) -> value;
    }

    public static final Map<String,DoubleUnaryOperator> PREDEFINED_UNARY_FUNCTIONS
        = getPredefinedFunctions(DoubleUnaryOperator.class,
            MethodType.methodType(double.class, double.class));
    public static final Map<String,DoubleBinaryOperator> PREDEFINED_BINARY_FUNCTIONS
        = getPredefinedFunctions(DoubleBinaryOperator.class,
            MethodType.methodType(double.class, double.class, double.class));

    private static <T> Map<String,T> getPredefinedFunctions(Class<T> t, MethodType mt) {
        Map<String,T> result=new HashMap<>();
        MethodHandles.Lookup l=MethodHandles.lookup();
        for(Method m:Math.class.getMethods()) try {
            MethodHandle mh=l.unreflect(m);
            if(!mh.type().equals(mt)) continue;
            result.put(m.getName(), t.cast(LambdaMetafactory.metafactory(
            MethodHandles.lookup(), "applyAsDouble", MethodType.methodType(t),
            mt, mh, mt) .getTarget().invoke()));
        }
        catch(RuntimeException|Error ex) { throw ex; }
        catch(Throwable ex) { throw new AssertionError(ex); }
        return Collections.unmodifiableMap(result);
    }
}

This is everything you need to compose evaluators for expressions made of basic arithmetic operators and functions found in java.lang.Math, the latter collected dynamically, to address that aspect of your question.

Note that technically,

public static DoubleBinaryOperator MUL = (a,b) -> a*b;

is just a short-hand for

public static DoubleBinaryOperator MUL = Arithmetics::mul;
public static double mul(double a, double b){
    return a*b;
}

I added a main method containing some examples. Keep in mind that these function behave like compiled code, right in the first invocation, as in fact, they consist of compiled code only, but composed of multiple functions.

like image 85
Holger Avatar answered Sep 20 '22 02:09

Holger