Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to simplify mathematical formulas with rust macros?

I must admit I'm a bit lost with macros. I want to build a macro that does the following task and I'm not sure how to do it. I want to perform a scalar product of two arrays, say x and y, which have the same length N. The result I want to compute is of the form:

z = sum_{i=0}^{N-1} x[i] * y[i].

x is const which elements are 0, 1, or -1 which are known at compile time, while y's elements are determined at runtime. Because of the structure of x, many computations are useless (terms multiplied by 0 can be removed from the sum, and multiplications of the form 1 * y[i], -1 * y[i] can be transformed into y[i], -y[i] respectively).

As an example if x = [-1, 1, 0], the scalar product above would be

z=-1 * y[0] + 1 * y[1] + 0 * y[2]

To speed up my computation I can unroll the loop by hand and rewrite the whole thing without x[i], and I could hard code the above formula as

z = -y[0] + y[1]

But this procedure is not elegant, error prone and very tedious when N becomes large.

I'm pretty sure I can do that with a macro, but I don't know where to start (the different books I read are not going too deep into macros and I'm stuck)...

Would anyone of you have any idea how to (if it is possible) this problem using macros?

Thank you in advance for your help!

Edit: As pointed out in many of the answers, the compiler is smart enough to remove optimize the loop in the case of integers. I am not only using integers but also floats (the x array is i32s, but in general y is f64s), so the compiler is not smart enough (and rightfully so) to optimize the loop. The following piece of code gives the following asm.

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
}
playground::dot_x:
    xorpd   %xmm0, %xmm0
    movsd   (%rdi), %xmm1
    mulsd   %xmm0, %xmm1
    addsd   %xmm0, %xmm1
    addsd   8(%rdi), %xmm1
    subsd   16(%rdi), %xmm1
    movupd  24(%rdi), %xmm2
    xorpd   %xmm3, %xmm3
    mulpd   %xmm2, %xmm3
    addsd   %xmm3, %xmm1
    unpckhpd    %xmm3, %xmm3
    addsd   %xmm1, %xmm3
    addsd   40(%rdi), %xmm3
    mulsd   48(%rdi), %xmm0
    addsd   %xmm3, %xmm0
    subsd   56(%rdi), %xmm0
    retq
like image 734
Jean-Paul Dax Avatar asked Apr 05 '19 20:04

Jean-Paul Dax


People also ask

Can you use macros on Rust?

The most widely used form of macros in Rust is the declarative macro. These are also sometimes referred to as “macros by example,” “ macro_rules! macros,” or just plain “macros.” At their core, declarative macros allow you to write something similar to a Rust match expression.

How do macros work in Rust?

Rust has excellent support for macros. Macros enable you to write code that writes other code, which is known as metaprogramming. Macros provide functionality similar to functions but without the runtime cost. There is some compile-time cost, however, since macros are expanded during compile time.

When to use macros vs functions Rust?

Another important difference between macros and functions is that you must define macros or bring them into scope before you call them in a file, as opposed to functions you can define anywhere and call anywhere. Save this answer.


1 Answers

First of all, a (proc) macro can simply not look inside your array x. All it gets are the tokens you pass it, without any context. If you want it to know about the values (0, 1, -1), you need to pass those directly to your macro:

let result = your_macro!(y, -1, 0, 1, -1);

But you don't really need a macro for this. The compiler optimizes a lot, as also shown in the other answers. However, it will not, as you already mention in your edit, optimize away 0.0 * x[i], as the result of that is not always 0.0. (It could be -0.0 or NaN for example.) What we can do here, is simply help the optimizer a bit by using a match or if, to make sure it does nothing for the 0.0 * y case:

const X: [i32; 8] = [0, -1, 0, 0, 0, 0, 1, 0];

fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = 0.0;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            sum += x as f64 * y;
        }
    }
    sum
}

In release mode, the loop is unrolled and the values of X inlined, resulting in most iterations being thrown away as they don't do anything. The only thing left in the resulting binary (on x86_64), is:

foobar:
 xorpd   xmm0, xmm0
 subsd   xmm0, qword, ptr, [rdi, +, 8]
 addsd   xmm0, qword, ptr, [rdi, +, 48]
 ret

(As suggested by @lu-zero, this can also be done using filter_map. That will look like this: X.iter().zip(&y).filter_map(|(&x, &y)| match x { 0 => None, _ => Some(x as f64 * y) }).sum(), and gives the exact same generated assembly. Or even without a match, by using filter and map separately: .filter(|(&x, _)| x != 0).map(|(&x, &y)| x as f64 * y).sum().)

Pretty good! However, this function calculates 0.0 - y[1] + y[6], since sum started at 0.0 and we only subtract and add things to it. The optimizer is again not willing to optimize away a 0.0. We can help it a bit more by not starting at 0.0, but starting with None:

fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = None;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            let p = x as f64 * y;
            sum = Some(sum.map_or(p, |s| s + p));
        }
    }
    sum.unwrap_or(0.0)
}

This results in:

foobar:
 movsd   xmm0, qword, ptr, [rdi, +, 48]
 subsd   xmm0, qword, ptr, [rdi, +, 8]
 ret

Which simply does y[6] - y[1]. Bingo!

like image 187
Mara Bos Avatar answered Dec 26 '22 05:12

Mara Bos