Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I improve the performance of element-wise multiplication in Rust?

I will be doing element-wise multiplication on multiple vectors with 10^6+ elements. This is being flagged in profiling as one of the slowest parts of my code, so how can I improve it?

/// element-wise multiplication for vecs
pub fn vec_mul<T>(v1: &Vec<T>, v2: &Vec<T>) -> Vec<T>
where
    T: std::ops::Mul<Output = T> + Copy,
{
    if v1.len() != v2.len() {
        panic!("Cannot multiply vectors of different lengths!")
    }
    let mut out: Vec<T> = Vec::with_capacity(v1.len());
    for i in 0..(v1.len()) {
        out.push(v1[i] * v2[i]);
    }
    out
}
like image 281
ryn1x Avatar asked Feb 09 '19 04:02

ryn1x


1 Answers

When you use the indexer operator on a Vec or a slice, the compiler has to check whether the index is in bounds or out of bounds.

However, when you use iterators, these bounds checks are omitted, because the iterators have been carefully written to ensure that they never read out of bounds. Furthermore, due to how borrowing works in Rust, a data structure cannot be mutated while an iterator exists over that data structure (except via that iterator itself), so it's impossible for the valid bounds to change during iteration.

Since you are iterating over two different data structures concurrently, you'll want to use the zip iterator adapter. zip stops as soon as one iterator is exhausted, so it's still relevant to validate that both vectors have the same length. zip produces an iterator of tuples, where each tuple contains the items at the same position in the two original iterators. Then you can use map to transform each tuple into the product of the two values. Finally, you'll want to collect the new iterator produced by map into a Vec which you can then return from your function. collect uses size_hint to preallocate memory for the vector using Vec::with_capacity.

/// element-wise multiplication for vecs
pub fn vec_mul<T>(v1: &[T], v2: &[T]) -> Vec<T>
where
    T: std::ops::Mul<Output = T> + Copy,
{
    if v1.len() != v2.len() {
        panic!("Cannot multiply vectors of different lengths!")
    }

    v1.iter().zip(v2).map(|(&i1, &i2)| i1 * i2).collect()
}

Note: I've changed the signature to take slices instead of references to vectors. See Why is it discouraged to accept a reference to a String (&String), Vec (&Vec), or Box (&Box) as a function argument? for more information.

like image 74
Francis Gagné Avatar answered Sep 28 '22 09:09

Francis Gagné