Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Modular Exponentiation for high numbers in C++

So I've been working recently on an implementation of the Miller-Rabin primality test. I am limiting it to a scope of all 32-bit numbers, because this is a just-for-fun project that I am doing to familiarize myself with c++, and I don't want to have to work with anything 64-bits for awhile. An added bonus is that the algorithm is deterministic for all 32-bit numbers, so I can significantly increase efficiency because I know exactly what witnesses to test for.

So for low numbers, the algorithm works exceptionally well. However, part of the process relies upon modular exponentiation, that is (num ^ pow) % mod. so, for example,

3 ^ 2 % 5 = 
9 % 5 = 
4

here is the code I have been using for this modular exponentiation:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned test;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = (test * num) % mod;
        num = (num * num) % mod;
    }

    return test;

}

As you might have already guessed, problems arise when the arguments are all exceptionally large numbers. For example, if I want to test the number 673109 for primality, I will at one point have to find:

(2 ^ 168277) % 673109

now 2 ^ 168277 is an exceptionally large number, and somewhere in the process it overflows test, which results in an incorrect evaluation.

on the reverse side, arguments such as

4000111222 ^ 3 % 1608

also evaluate incorrectly, for much the same reason.

Does anyone have suggestions for modular exponentiation in a way that can prevent this overflow and/or manipulate it to produce the correct result? (the way I see it, overflow is just another form of modulo, that is num % (UINT_MAX+1))

like image 379
Axel Magnuson Avatar asked Feb 05 '10 12:02

Axel Magnuson


People also ask

How do you calculate the modulo of a high raised number?

For example, in C-derived languages, the % operator is the modulus operator. Thus, int p = 625 % 221 would assign 183 to p . You can achieve the same functionality by dividing 625 by 221 as integer division and getting the answer 2 . Then you take 625 - 2 * 221 to get the remainder.

Which algorithm is used for fast modular exponentiation calculation?

Modular exponentiation can be performed with a negative exponent e by finding the multiplicative inverse d of b modulo m using the extended Euclidean algorithm.


3 Answers

Exponentiation by squaring still "works" for modulo exponentiation. Your problem isn't that 2 ^ 168277 is an exceptionally large number, it's that one of your intermediate results is a fairly large number (bigger than 2^32), because 673109 is bigger than 2^16.

So I think the following will do. It's possible I've missed a detail, but the basic idea works, and this is how "real" crypto code might do large mod-exponentiation (although not with 32 and 64 bit numbers, rather with bignums that never have to get bigger than 2 * log (modulus)):

  • Start with exponentiation by squaring, as you have.
  • Perform the actual squaring in a 64-bit unsigned integer.
  • Reduce modulo 673109 at each step to get back within the 32-bit range, as you do.

Obviously that's a bit awkward if your C++ implementation doesn't have a 64 bit integer, although you can always fake one.

There's an example on slide 22 here: http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf, although it uses very small numbers (less than 2^16), so it may not illustrate anything you don't already know.

Your other example, 4000111222 ^ 3 % 1608 would work in your current code if you just reduce 4000111222 modulo 1608 before you start. 1608 is small enough that you can safely multiply any two mod-1608 numbers in a 32 bit int.

like image 179
Steve Jessop Avatar answered Oct 16 '22 09:10

Steve Jessop


I wrote something for this recently for RSA in C++, bit messy though.

#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>

BigInteger::BigInteger() {
    digits.push_back(0);
    negative = false;
}

BigInteger::~BigInteger() {
}

void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sum_n_carry = 0;
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size()) {
        n = b.digits.size();
    }
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size()) {
            a_digit = a.digits[i];
        }
        if (i < (int)b.digits.size()) {
            b_digit = b.digits[i];
        }
        sum_n_carry += a_digit + b_digit;
        c.digits[i] = (sum_n_carry & 0xFFFF);
        sum_n_carry >>= 16;
    }
    if (sum_n_carry != 0) {
        putCarryInfront(c, sum_n_carry);
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
}

void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sub_n_borrow = 0;
    int n = a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        sub_n_borrow += a_digit - b_digit;
        if (sub_n_borrow >= 0) {
            c.digits[i] = sub_n_borrow;
            sub_n_borrow = 0;
        } else {
            c.digits[i] = 0x10000 + sub_n_borrow;
            sub_n_borrow = -1;
        }
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
}

int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
    for (int i = n-1; i >= 0; --i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        if (a_digit < b_digit) {
            //std::cout << "-1" << std::endl;
            return -1;
        } else if (a_digit > b_digit) {
            //std::cout << "+1" << std::endl;
            return +1;
        }
    }
    //std::cout << "0" << std::endl;
    return 0;
}

void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
    unsigned int mult_n_carry = 0;
    c.digits.clear();
    c.digits.resize(a.digits.size());
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = b;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        mult_n_carry += a_digit * b_digit;
        c.digits[i] = (mult_n_carry & 0xFFFF);
        mult_n_carry >>= 16;
    }
    if (mult_n_carry != 0) {
        putCarryInfront(c, mult_n_carry);
    }
    //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
}

void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
    b.digits.resize(a.digits.size() + times);
    for (int i = 0; i < times; ++i) {
        b.digits[i] = 0;
    }
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i + times] = a.digits[i];
    }
}

void BigInteger::shiftRight(BigInteger& a) {
    //std::cout << "shr " << a.toString() << " == ";
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        a.digits[i] >>= 1;
        if (i+1 < (int)a.digits.size()) {
            if ((a.digits[i+1] & 0x1) != 0) {
                a.digits[i] |= 0x8000;
            }
        }
    }
    //std::cout << a.toString() << std::endl;
}

void BigInteger::shiftLeft(BigInteger& a) {
    bool lastBit = false;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        bool bit = (a.digits[i] & 0x8000) != 0;
        a.digits[i] <<= 1;
        if (lastBit)
            a.digits[i] |= 1;
        lastBit = bit;
    }
    if (lastBit) {
        a.digits.push_back(1);
    }
}

void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
    BigInteger b;
    b.negative = a.negative;
    b.digits.resize(a.digits.size() + 1);
    b.digits[a.digits.size()] = carry;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i] = a.digits[i];
    }
    a.digits.swap(b.digits);
}

void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
    c.digits.clear();
    c.digits.push_back(0);
    BigInteger two("2");
    BigInteger e = b;
    BigInteger f("1");
    BigInteger g = a;
    BigInteger one("1");
    while (cmpWithoutSign(g, e) >= 0) {
        shiftLeft(e);
        shiftLeft(f);
    }
    shiftRight(e);
    shiftRight(f);
    while (cmpWithoutSign(g, b) >= 0) {
        g -= e;
        c += f;
        while (cmpWithoutSign(g, e) < 0) {
            shiftRight(e);
            shiftRight(f);
        }
    }
    e = c;
    e *= b;
    f = a;
    f -= e;
    d = f;
}

BigInteger::BigInteger(const BigInteger& other) {
    digits = other.digits;
    negative = other.negative;
}

BigInteger::BigInteger(const char* other) {
    digits.push_back(0);
    negative = false;
    BigInteger ten;
    ten.digits[0] = 10;
    const char* c = other;
    bool make_negative = false;
    if (*c == '-') {
        make_negative = true;
        ++c;
    }
    while (*c != 0) {
        BigInteger digit;
        digit.digits[0] = *c - '0';
        *this *= ten;
        *this += digit;
        ++c;
    }
    negative = make_negative;
}

bool BigInteger::isOdd() const {
    return (digits[0] & 0x1) != 0;
}

BigInteger& BigInteger::operator=(const BigInteger& other) {
    if (this == &other) // handle self assignment
        return *this;
    digits = other.digits;
    negative = other.negative;
    return *this;
}

BigInteger& BigInteger::operator+=(const BigInteger& other) {
    BigInteger result;
    if (negative) {
        if (other.negative) {
            result.negative = true;
            addWithoutSign(result, *this, other);
        } else {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = false;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = true;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        }
    } else {
        if (other.negative) {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = true;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = false;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        } else {
            result.negative = false;
            addWithoutSign(result, *this, other);
        }
    }
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator-=(const BigInteger& other) {
    BigInteger neg_other = other;
    neg_other.negative = !neg_other.negative;
    return *this += neg_other;
}

BigInteger& BigInteger::operator*=(const BigInteger& other) {
    BigInteger result;
    for (int i = 0; i < (int)digits.size(); ++i) {
        BigInteger mult;
        multByDigitWithoutSign(mult, other, digits[i]);
        BigInteger shift;
        shiftLeftByBase(shift, mult, i);
        BigInteger add;
        addWithoutSign(add, result, shift);
        result = add;
    }
    if (negative != other.negative) {
        result.negative = true;
    } else {
        result.negative = false;
    }
    //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator/=(const BigInteger& other) {
    BigInteger result, tmp;
    divideWithoutSign(result, tmp, *this, other);
    result.negative = (negative != other.negative);
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator%=(const BigInteger& other) {
    BigInteger c, d;
    divideWithoutSign(c, d, *this, other);
    *this = d;
    return *this;
}

bool BigInteger::operator>(const BigInteger& other) const {
    if (negative) {
        if (other.negative) {
            return cmpWithoutSign(*this, other) < 0;
        } else {
            return false;
        }
    } else {
        if (other.negative) {
            return true;
        } else {
            return cmpWithoutSign(*this, other) > 0;
        }
    }
}

BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
    BigInteger zero("0");
    BigInteger one("1");
    BigInteger e = exponent;
    BigInteger base = *this;
    *this = one;
    while (cmpWithoutSign(e, zero) != 0) {
        //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
        if (e.isOdd()) {
            *this *= base;
            *this %= modulus;
        }
        shiftRight(e);
        base *= BigInteger(base);
        base %= modulus;
    }
    return *this;
}

std::string BigInteger::toString() const {
    std::ostringstream os;
    if (negative)
        os << "-";
    BigInteger tmp = *this;
    BigInteger zero("0");
    BigInteger ten("10");
    tmp.negative = false;
    std::stack<char> s;
    while (cmpWithoutSign(tmp, zero) != 0) {
        BigInteger tmp2, tmp3;
        divideWithoutSign(tmp2, tmp3, tmp, ten);
        s.push((char)(tmp3.digits[0] + '0'));
        tmp = tmp2;
    }
    while (!s.empty()) {
        os << s.top();
        s.pop();
    }
    /*
    for (int i = digits.size()-1; i >= 0; --i) {
        os << digits[i];
        if (i != 0) {
            os << ",";
        }
    }
    */
    return os.str();

And an example usage.

BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);

Its fast too, and has unlimited number of digits.

like image 38
clinux Avatar answered Oct 16 '22 09:10

clinux


Two things:

  • Are you using the appropriate data type? In other words, does UINT_MAX allow you to have 673109 as an argument?

No, it does not, since at one point you have Your code does not work because at one point you have num = 2^16 and the num = ... causes overflow. Use a bigger data type to hold this intermediate value.

  • How about taking modulo at every possible overflow oppertunity such as:

    test = ((test % mod) * (num % mod)) % mod;

Edit:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned long long test;
    unsigned long long n = num;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = ((test % mod) * (n % mod)) % mod;
        n = ((n % mod) * (n % mod)) % mod;
    }

    return test; /* note this is potentially lossy */
}

int main(int argc, char* argv[])
{

    /* (2 ^ 168277) % 673109 */
    printf("%u\n", mod_pow(2, 168277, 673109));
    return 0;
}
like image 37
dirkgently Avatar answered Oct 16 '22 07:10

dirkgently