I'm trying to implement a fast primality test for Rust's u32
and u64
datatypes. As part of it, I need to compute (n*n)%d
where n
and d
are u32
(or u64
, respectively).
While the result can easily fit in the datatype, I'm at a loss for how to compute this. As far as I know there is no processor primitive for this.
For u32
we can fake it -- cast up to u64
, so that the product won't overflow, then take the modulus, then cast back down to u32
, knowing this won't overflow. However since I don't have a u128
datatype (as far as I know) this trick won't work for u64
.
So for u64
, the most obvious way I can think of to accomplish this is to somehow compute x*y
to get a pair (carry, product)
of u64
, so we capture the amount of overflow instead of just losing it (or panicking, or whatever).
Is there a way to do this? Or another standard way to solve the problem?
We can multiply recursively to overcome the difficulty of overflow. To multiply a*b, first calculate a*b/2 then add it twice. For calculating a*b/2 calculate a*b/4 and so on (similar to log n exponentiation algorithm).
Overflow can occur during a modulo operation when the dividend is equal to the minimum (negative) value for the signed integer type and the divisor is equal to -1.
Richard Rast pointed out that Wikipedia version works only with 63-bit integers. I extended the code provided by Boiethios to work with full range of 64-bit unsigned integers.
fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 {
let msb = 0x8000_0000_0000_0000;
let mut d = 0;
let mp2 = m >> 1;
x %= m;
y %= m;
if m & msb == 0 {
for _ in 0..64 {
d = if d > mp2 {
(d << 1) - m
} else {
d << 1
};
if x & msb != 0 {
d += y;
}
if d >= m {
d -= m;
}
x <<= 1;
}
d
} else {
for _ in 0..64 {
d = if d > mp2 {
d.wrapping_shl(1).wrapping_sub(m)
} else {
// the case d == m && x == 0 is taken care of
// after the end of the loop
d << 1
};
if x & msb != 0 {
let (mut d1, overflow) = d.overflowing_add(y);
if overflow {
d1 = d1.wrapping_sub(m);
}
d = if d1 >= m { d1 - m } else { d1 };
}
x <<= 1;
}
if d >= m { d - m } else { d }
}
}
#[test]
fn test_mul_mod64() {
let half = 1 << 16;
let max = std::u64::MAX;
assert_eq!(mul_mod64(0, 0, 2), 0);
assert_eq!(mul_mod64(1, 0, 2), 0);
assert_eq!(mul_mod64(0, 1, 2), 0);
assert_eq!(mul_mod64(1, 1, 2), 1);
assert_eq!(mul_mod64(42, 1, 2), 0);
assert_eq!(mul_mod64(1, 42, 2), 0);
assert_eq!(mul_mod64(42, 42, 2), 0);
assert_eq!(mul_mod64(42, 42, 42), 0);
assert_eq!(mul_mod64(42, 42, 41), 1);
assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320);
assert_eq!(mul_mod64(1239876, 2948635, half), 18476);
assert_eq!(mul_mod64(half, half, half), 0);
assert_eq!(mul_mod64(half+1, half+1, half), 1);
assert_eq!(mul_mod64(max, max, max), 0);
assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260);
assert_eq!(mul_mod64(1239876, max, max), 0);
assert_eq!(mul_mod64(1239876, max-1, max), max-1239876);
assert_eq!(mul_mod64(max, 2948635, max), 0);
assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635);
assert_eq!(mul_mod64(max-1, max-1, max), 1);
assert_eq!(mul_mod64(2, max/2, max-1), 0);
}
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With