Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Struggling with closures and lifetimes in Rust

Tags:

rust

I'm trying to port a little benchmark from F# to Rust. The F# code looks like this:

let inline iterNeighbors f (i, j) =
  f (i-1, j)
  f (i+1, j)
  f (i, j-1)
  f (i, j+1)

let rec nthLoop n (s1: HashSet<_>) (s2: HashSet<_>) =
  match n with
  | 0 -> s1
  | n ->
      let s0 = HashSet(HashIdentity.Structural)
      let add p =
        if not(s1.Contains p || s2.Contains p) then
          ignore(s0.Add p)
      Seq.iter (fun p -> iterNeighbors add p) s1
      nthLoop (n-1) s0 s1

let nth n p =
  nthLoop n (HashSet([p], HashIdentity.Structural)) (HashSet(HashIdentity.Structural))

(nth 2000 (0, 0)).Count

It computes the nth-nearest neighbor shells from an initial vertex in a potentially infinite graph. I used something similar during my PhD to study amorphous materials.

I've spent many hours trying and failing to port this to Rust. I have managed to get one version working but only by manually inlining the closure and converting the recursion into a loop with local mutables (yuk!).

I tried writing the iterNeighbors function like this:

use std::collections::HashSet;

fn iterNeighbors<F>(f: &F, (i, j): (i32, i32)) -> ()
where
    F: Fn((i32, i32)) -> (),
{
    f((i - 1, j));
    f((i + 1, j));
    f((i, j - 1));
    f((i, j + 1));
}

I think that is a function that accepts a closure (that itself accepts a pair and returns unit) and a pair and returns unit. I seem to have to double bracket things: is that correct?

I tried writing a recursive version like this:

fn nthLoop(n: i32, s1: HashSet<(i32, i32)>, s2: HashSet<(i32, i32)>) -> HashSet<(i32, i32)> {
    if n == 0 {
        return &s1;
    } else {
        let mut s0 = HashSet::new();
        for &p in s1 {
            if !(s1.contains(&p) || s2.contains(&p)) {
                s0.insert(p);
            }
        }
        return &nthLoop(n - 1, s0, s1);
    }
}

Note that I haven't even bothered with the call to iterNeighbors yet.

I think I'm struggling to get the lifetimes of the arguments correct because they are rotated in the recursive call. How should I annotate the lifetimes if I want s2 to be deallocated just before the returns and I want s1 to survive either when returned or into the recursive call?

The caller would look something like this:

fn nth<'a>(n: i32, p: (i32, i32)) -> &'a HashSet<(i32, i32)> {
    let s0 = HashSet::new();
    let mut s1 = HashSet::new();
    s1.insert(p);
    return &nthLoop(n, &s1, s0);
}

I gave up on that and wrote it as a while loop with mutable locals instead:

fn nth<'a>(n: i32, p: (i32, i32)) -> HashSet<(i32, i32)> {
    let mut n = n;
    let mut s0 = HashSet::new();
    let mut s1 = HashSet::new();
    let mut s2 = HashSet::new();
    s1.insert(p);
    while n > 0 {
        for &p in &s1 {
            let add = &|p| {
                if !(s1.contains(&p) || s2.contains(&p)) {
                    s0.insert(p);
                }
            };
            iterNeighbors(&add, p);
        }
        std::mem::swap(&mut s0, &mut s1);
        std::mem::swap(&mut s0, &mut s2);
        s0.clear();
        n -= 1;
    }
    return s1;
}

This works if I inline the closure by hand, but I cannot figure out how to invoke the closure. Ideally, I'd like static dispatch here.

The main function is then:

fn main() {
    let s = nth(2000, (0, 0));
    println!("{}", s.len());
}

So... what am I doing wrong? :-)

Also, I only used HashSet in the F# because I assume Rust doesn't provide a purely functional Set with efficient set-theoretic operations (union, intersection and difference). Am I correct in assuming that?

like image 244
J D Avatar asked Mar 23 '16 01:03

J D


2 Answers

I think that is a function that accepts a closure (that itself accepts a pair and returns unit) and a pair and returns unit. I seem to have to double bracket things: is that correct?

You need the double brackets because you're passing a 2-tuple to the closure, which matches your original F# code.

I think I'm struggling to get the lifetimes of the arguments correct because they are rotated in the recursive call. How should I annotate the lifetimes if I want s2 to be deallocated just before the returns and I want s1 to survive either when returned or into the recursive call?

The problem is that you're using references to HashSets when you should just use HashSets directly. Your signature for nthLoop is already correct; you just need to remove a few occurrences of &.

To deallocate s2, you can write drop(s2). Note that Rust doesn't have guaranteed tail calls, so each recursive call will still take a bit of stack space (you can see how much with the mem::size_of function), but the drop call will purge the data on the heap.

The caller would look something like this:

Again, you just need to remove the &'s here.

Note that I haven't even bothered with the call to iterNeighbors yet.


This works if I inline the closure by hand but I cannot figure out how to invoke the closure. Ideally, I'd like static dispatch here.

There are three types of closures in Rust: Fn, FnMut and FnOnce. They differ by the type of their self argument. The distinction is important because it puts restrictions on what the closure is allowed to do and on how the caller can use the closure. The Rust book has a chapter on closures that already explains this well.

Your closure needs to mutate s0. However, iterNeighbors is defined as expecting an Fn closure. Your closure cannot implement Fn because Fn receives &self, but to mutate s0, you need &mut self. iterNeighbors cannot use FnOnce, since it needs to call the closure more than once. Therefore, you need to use FnMut.

Also, it's not necessary to pass the closure by reference to iterNeighbors. You can just pass it by value; each call to the closure will only borrow the closure, not consume it.

Also, I only used HashSet in the F# because I assume Rust doesn't provide a purely functional Set with efficient set-theoretic operations (union, intersection and difference). Am I correct in assuming that?

There's no purely functional set implementation in the standard library (maybe there's one on crates.io?). While Rust embraces functional programming, it also takes advantage of its ownership and borrowing system to make imperative programming safer. A functional set would probably impose using some form of reference counting or garbage collection in order to share items across sets.

However, HashSet does implement set-theoretic operations. There are two ways to use them: iterators (difference, symmetric_difference, intersection, union), which generate the sequence lazily, or operators (|, &, ^, -, as listed in the trait implementations for HashSet), which produce new sets containing clones of the values from the source sets.


Here's the working code:

use std::collections::HashSet;

fn iterNeighbors<F>(mut f: F, (i, j): (i32, i32)) -> ()
where
    F: FnMut((i32, i32)) -> (),
{
    f((i - 1, j));
    f((i + 1, j));
    f((i, j - 1));
    f((i, j + 1));
}

fn nthLoop(n: i32, s1: HashSet<(i32, i32)>, s2: HashSet<(i32, i32)>) -> HashSet<(i32, i32)> {
    if n == 0 {
        return s1;
    } else {
        let mut s0 = HashSet::new();
        for &p in &s1 {
            let add = |p| {
                if !(s1.contains(&p) || s2.contains(&p)) {
                    s0.insert(p);
                }
            };
            iterNeighbors(add, p);
        }
        drop(s2);
        return nthLoop(n - 1, s0, s1);
    }
}

fn nth(n: i32, p: (i32, i32)) -> HashSet<(i32, i32)> {
    let mut s1 = HashSet::new();
    s1.insert(p);
    let s2 = HashSet::new();
    return nthLoop(n, s1, s2);
}

fn main() {
    let s = nth(2000, (0, 0));
    println!("{}", s.len());
}
like image 163
Francis Gagné Avatar answered Oct 19 '22 20:10

Francis Gagné


I seem to have to double bracket things: is that correct?

No: the double bracketes are because you've chosen to use tuples and calling a function that takes a tuple requires creating the tuple first, but one can have closures that take multiple arguments, like F: Fn(i32, i32). That is, one could write that function as:

fn iterNeighbors<F>(i: i32, j: i32, f: F)
where
    F: Fn(i32, i32),
{
    f(i - 1, j);
    f(i + 1, j);
    f(i, j - 1);
    f(i, j + 1);
}

However, it seems that retaining the tuples makes sense for this case.

I think I'm struggling to get the lifetimes of the arguments correct because they are rotated in the recursive call. How should I annotate the lifetimes if I want s2 to be deallocated just before the returns and I want s1 to survive either when returned or into the recursive call?

No need for references (and hence no need for lifetimes), just pass the data through directly:

fn nthLoop(n: i32, s1: HashSet<(i32, i32)>, s2: HashSet<(i32, i32)>) -> HashSet<(i32, i32)> {
    if n == 0 {
        return s1;
    } else {
        let mut s0 = HashSet::new();
        for &p in &s1 {
            iterNeighbors(p, |p| {
                if !(s1.contains(&p) || s2.contains(&p)) {
                    s0.insert(p);
                }
            })
        }
        drop(s2); // guarantees timely deallocation
        return nthLoop(n - 1, s0, s1);
    }
}

The key here is you can do everything by value, and things passed around by value will of course keep their values around.

However, this fails to compile:

error[E0387]: cannot borrow data mutably in a captured outer variable in an `Fn` closure
  --> src/main.rs:21:21
   |
21 |                     s0.insert(p);
   |                     ^^
   |
help: consider changing this closure to take self by mutable reference
  --> src/main.rs:19:30
   |
19 |               iterNeighbors(p, |p| {
   |  ______________________________^
20 | |                 if !(s1.contains(&p) || s2.contains(&p)) {
21 | |                     s0.insert(p);
22 | |                 }
23 | |             })
   | |_____________^

That is to say, the closure is trying to mutate values it captures (s0), but the Fn closure trait doesn't allow this. That trait can be called in a more flexible manner (when shared), but this imposes more restrictions on what the closure can do internally. (If you're interested, I've written more about this)

Fortunately there's an easy fix: using the FnMut trait, which requires that the closure can only be called when one has unique access to it, but allows the internals to mutate things.

fn iterNeighbors<F>((i, j): (i32, i32), mut f: F)
where
    F: FnMut((i32, i32)),
{
    f((i - 1, j));
    f((i + 1, j));
    f((i, j - 1));
    f((i, j + 1));
}

The caller would look something like this:

Values work here too: returning a reference in that case would be returning a pointer to s0, which lives the stack frame that is being destroyed as the function returns. That is, the reference is pointing to dead data.

The fix is just not using references:

fn nth(n: i32, p: (i32, i32)) -> HashSet<(i32, i32)> {
    let s0 = HashSet::new();
    let mut s1 = HashSet::new();
    s1.insert(p);
    return nthLoop(n, s1, s0);
}

This works if I inline the closure by hand but I cannot figure out how to invoke the closure. Ideally, I'd like static dispatch here.

(I don't understand what this means, including the compiler error messages you're having trouble with helps us help you.)

Also, I only used HashSet in the F# because I assume Rust doesn't provide a purely functional Set with efficient set-theoretic operations (union, intersection and difference). Am I correct in assuming that?

Depending on exactly what you want, no, e.g. both HashSet and BTreeSet provide various set-theoretic operations as methods which return iterators.


Some small points:

  • explicit/named lifetimes allow the compiler to reason about the static validity of data, they don't control it (i.e. they allow the compiler to point out when you do something wrong, but language still has the same sort of static resource usage/life-cycle guarantees as C++)
  • the version with a loop is likely to be more efficient as written, as it reuses memory directly (swapping the sets, plus the s0.clear(), however, the same benefit can be realised with a recursive version by passing s2 down for reuse instead of dropping it.
  • the while loop could be for _ in 0..n
  • there's no need to pass closures by reference, but with or without the reference, there's still static dispatch (the closure is a type parameter, not a trait object).
  • conventionally, closure arguments are last, and not taken by reference, because it makes defining & passing them inline easier to read (e.g. foo(x, |y| bar(y + 1)) instead of foo(&|y| bar(y + 1), x))
  • the return keyword isn't necessary for trailing returns (if the ; is omitted):

    fn nth(n: i32, p: (i32, i32)) -> HashSet<(i32, i32)> {
        let s0 = HashSet::new();
        let mut s1 = HashSet::new();
        s1.insert(p);
        nthLoop(n, s1, s0)
    }
    
like image 37
huon Avatar answered Oct 19 '22 20:10

huon