I have a simple classifier:
struct Clf {
x: f64,
}
The classifier returns 0 if the observed value is smaller than x
and 1 if bigger than x
.
I want to implement the call operator for this classifier. However, the function should be able to take either a float or a vector as arguments. In case of a vector, the output is a vector of 0 or 1 which has the same size as the input vector:
let c = Clf { x: 0 };
let v = vec![-1, 0.5, 1];
println!("{}", c(0.5)); // prints 1
println!("{}", c(v)); // prints [0, 1, 1]
How can I write implementation of Fn
in this case?
impl Fn for Clf {
extern "rust-call" fn call(/*...*/) {
// ...
}
}
Implementing a trait in Rust To implement a trait, declare an impl block for the type you want to implement the trait for. The syntax is impl <trait> for <type> . You'll need to implement all the methods that don't have default implementations.
A trait in Rust is a group of methods that are defined for a particular type. Traits are an abstract definition of shared behavior amongst different types. So, in a way, traits are to Rust what interfaces are to Java or abstract classes are to C++. A trait method is able to access other methods within that trait.
The short answer is: You can't. At least it won't work the way you want. I think the best way to show that is to walk through and see what happens, but the general idea is that Rust doesn't support function overloading.
For this example, we will be implementing FnOnce
, because Fn
requires FnMut
which requires FnOnce
. So, if we were to get this all sorted, we could do it for the other function traits.
First, this is unstable, so we need some feature flags
#![feature(unboxed_closures, fn_traits)]
Then, let's do the impl
for taking an f64
:
impl FnOnce<(f64,)> for Clf {
type Output = i32;
extern "rust-call" fn call_once(self, args: (f64,)) -> i32 {
if args.0 > self.x {
1
} else {
0
}
}
}
The arguments to the Fn
family of traits are supplied via a tuple, so that's the (f64,)
syntax; it's a tuple with just one element.
This is all well and good, and we can now do c(0.5)
, although it will consume c
until we implement the other traits.
Now let's do the same thing for Vec
s:
impl FnOnce<(Vec<f64>,)> for Clf {
type Output = Vec<i32>;
extern "rust-call" fn call_once(self, args: (Vec<f64>,)) -> Vec<i32> {
args.0
.iter()
.map(|&f| if f > self.x { 1 } else { 0 })
.collect()
}
}
Before Rust 1.33 nightly, you cannot directly call c(v)
or even c(0.5)
(which worked before); we'd get an error about the type of the function not being known. Basically, these versions of Rust didn't support function overloading. But we can still call the functions using fully qualified syntax, where c(0.5)
becomes FnOnce::call_once(c, (0.5,))
.
Not knowing your bigger picture, I would want to solve this simply by giving Clf
two functions like so:
impl Clf {
fn classify(&self, val: f64) -> u32 {
if val > self.x {
1
} else {
0
}
}
fn classify_vec(&self, vals: Vec<f64>) -> Vec<u32> {
vals.into_iter().map(|v| self.classify(v)).collect()
}
}
Then your usage example becomes
let c = Clf { x: 0 };
let v = vec![-1, 0.5, 1];
println!("{}", c.classify(0.5)); // prints 1
println!("{}", c.classify_vec(v)); // prints [0, 1, 1]
I would actually want to make the second function classify_slice
and take &[f64]
to be a bit more general, then you could still use it with Vec
s by referencing them: c.classify_slice(&v)
.
This is indeed possible, but you need a new trait and a ton of mess.
If you start with the abstraction
enum VecOrScalar<T> {
Scalar(T),
Vector(Vec<T>),
}
use VecOrScalar::*;
You want a way to use the type transformations
T (hidden) -> VecOrScalar<T> -> T (known)
Vec<T> (hidden) -> VecOrScalar<T> -> Vec<T> (known)
because then you can take a "hidden" type T
, wrap it in a VecOrScalar
and extract the real type T
with a match
.
You also want
T (known) -> bool = T::Output
Vec<T> (known) -> Vec<bool> = Vec<T>::Output
but without higher-kinded-types, this is a bit tricky. Instead, you can do
T (known) -> VecOrScalar<T> -> T::Output
Vec<T> (known) -> VecOrScalar<T> -> Vec<T>::Output
if you allow for a branch that can panic.
The trait will thus be
trait FromVecOrScalar<T> {
type Output;
fn put(self) -> VecOrScalar<T>;
fn get(out: VecOrScalar<bool>) -> Self::Output;
}
with implementations
impl<T> FromVecOrScalar<T> for T {
type Output = bool;
fn put(self) -> VecOrScalar<T> {
Scalar(self)
}
fn get(out: VecOrScalar<bool>) -> Self::Output {
match out {
Scalar(val) => val,
Vector(_) => panic!("Wrong output type!"),
}
}
}
impl<T> FromVecOrScalar<T> for Vec<T> {
type Output = Vec<bool>;
fn put(self) -> VecOrScalar<T> {
Vector(self)
}
fn get(out: VecOrScalar<bool>) -> Self::Output {
match out {
Vector(val) => val,
Scalar(_) => panic!("Wrong output type!"),
}
}
}
Your type
#[derive(Copy, Clone)]
struct Clf {
x: f64,
}
will first implement the two branches:
impl Clf {
fn calc_scalar(self, f: f64) -> bool {
f > self.x
}
fn calc_vector(self, v: Vec<f64>) -> Vec<bool> {
v.into_iter().map(|x| self.calc_scalar(x)).collect()
}
}
Then it will dispatch by implementing FnOnce
for T: FromVecOrScalar<f64>
impl<T> FnOnce<(T,)> for Clf
where
T: FromVecOrScalar<f64>,
{
with types
type Output = T::Output;
extern "rust-call" fn call_once(self, (arg,): (T,)) -> T::Output {
The dispatch first boxes the private type up, so you can extract it with the enum
, and then T::get
s the result, to hide it again.
match arg.put() {
Scalar(scalar) => T::get(Scalar(self.calc_scalar(scalar))),
Vector(vector) => T::get(Vector(self.calc_vector(vector))),
}
}
}
Then, success:
fn main() {
let c = Clf { x: 0.0 };
let v = vec![-1.0, 0.5, 1.0];
println!("{}", c(0.5f64));
println!("{:?}", c(v));
}
Since the compiler can see through all of this malarky, it actually compiles away to basically the same assembly as a direct call to the calc_
methods.
That's not to say it's nice to write. Overloading like this is a pain, fragile and most certainly A Bad Idea™. Don't do it, though it's fine to know that you can.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With