Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using RSA for modulo-multiplication leads to error on Java Card

Hello I'm working on a project on Java Card which implies a lot of modulo-multiplication. I managed to implement an modulo-multiplication on this platform using RSA cryptosystem but it seems to work for certain numbers.

public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y,
        short yOffset, short yLength, short tempOutoffset) {

    //copy x value to temporary rambuffer
    Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);


    // copy the y value to match th size of rsa_object
    Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (byte) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
    Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);

    // x+y
    if (JBigInteger.add(x,xOffset,xLength, eempromTempBuffer,
            (short)0,Configuration.LENGTH_MODULUS)) ;
    if(this.isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0)
    {
        JBigInteger.subtract(x,xOffset,xLength, tempBuffer,
                Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
    }

    //(x+y)2
    mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);

    mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x,
            xOffset); // OK

    mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK


    if (JBigInteger.subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
            Configuration.LENGTH_MODULUS)) {
        JBigInteger.add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
    } 

    mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); //OK 


    if (JBigInteger.subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,
            Configuration.LENGTH_MODULUS)) {

        JBigInteger.add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);

    }
    // ((x+y)^2 - x^2 -y^2)/2
    JBigInteger.modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
    return x;
}


public static boolean add(byte[] x, short xOffset, short xLength, byte[] y,
        short yOffset, short yLength) {
    short digit_mask = 0xff;
    short digit_len = 0x08;
    short result = 0;
    short i = (short) (xLength + xOffset - 1);
    short j = (short) (yLength + yOffset - 1);

    for (; i >= xOffset; i--, j--) {
        result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));

        x[i] = (byte) (result & digit_mask);
        result = (short) ((result >> digit_len) & digit_mask);
    }
    while (result > 0 && i >= xOffset) {
        result = (short) (result + (short) (x[i] & digit_mask));
        x[i] = (byte) (result & digit_mask);
        result = (short) ((result >> digit_len) & digit_mask);
        i--;
    }

    return result != 0;
}
public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y,
        short yOffset, short yLength) {
    short digit_mask = 0xff;
    short i = (short) (xLength + xOffset - 1);
    short j = (short) (yLength + yOffset - 1);
    short carry = 0;
    short subtraction_result = 0;

    for (; i >= xOffset && j >= yOffset; i--, j--) {
        subtraction_result = (short) ((x[i] & digit_mask)
                - (y[j] & digit_mask) - carry);
        x[i] = (byte) (subtraction_result & digit_mask);
        carry = (short) (subtraction_result < 0 ? 1 : 0);
    }
    for (; i >= xOffset && carry > 0; i--) {
        if (x[i] != 0)
            carry = 0;
        x[i] -= 1;
    }

    return carry > 0;
}



 public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
    {
        if(xLength > yLength)
            return (short)1;
        if(xLength < yLength)
            return (short)(-1);
        short digit_mask = 0xff;
        short digit_len = 0x08;
        short result = 0;
        short i = (short) (xLength + xOffset - 1);
        short j = (short) (yLength + yOffset - 1);

        for (; i >= xOffset; i--, j--) {
            result = (short) (result + (short) (x[i] & digit_mask) - (short) (y[j] & digit_mask));
            if(result > 0)
                return (short)1;
            if(result < 0)
                return (short)-1;
        }
        return 0;
    }

The code works well for little number but fails on bigger one

like image 466
Alberto12 Avatar asked Nov 08 '22 15:11

Alberto12


1 Answers

Below is a very simple unit test with a (hopefully) working variant of your code:

package test.java.so;

import java.math.BigInteger;
import java.util.Random;

import javacard.framework.JCSystem;
import javacard.framework.Util;
import javacard.security.KeyBuilder;
import javacard.security.RSAPublicKey;
import javacardx.crypto.Cipher;

import org.apache.commons.lang3.ArrayUtils;
import org.bouncycastle.util.Arrays;
import org.junit.Assert;
import org.junit.Test;

import sutil.test.AbstractTest;

public class So36966764_Test extends AbstractTest {

    private static final int NUM_BITS = 1024;

    // Dummy
    static class Configuration {
        public static final short LENGTH_MODULUS = NUM_BITS/8;
        public static final short LENGTH_RSAOBJECT_MODULUS = LENGTH_MODULUS;
        public static final short TEMP_OFFSET_MODULUS = 0;
        public static final short TEMP_OFFSET_RESULT = LENGTH_MODULUS;
    }

    private byte[] tempBuffer = JCSystem.makeTransientByteArray((short)(Configuration.TEMP_OFFSET_RESULT+Configuration.LENGTH_MODULUS), JCSystem.CLEAR_ON_DESELECT);
    private byte[] eempromTempBuffer = new byte[Configuration.LENGTH_MODULUS]; // Why EEPROM?
    private RSAPublicKey mRsaPublicKekForSquare = (RSAPublicKey)KeyBuilder.buildKey(KeyBuilder.TYPE_RSA_PUBLIC, (short)NUM_BITS, false);
    private Cipher mRsaCipherForSquaring = Cipher.getInstance(Cipher.ALG_RSA_NOPAD, false);

    // Assuming xLength==yLength==LENGTH_MODULUS
    public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength, short tempOutoffset) {

        //copy x value to temporary rambuffer
        Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);

        // copy the y value to match th size of rsa_object
        Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
        Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);

        // x+y
        if(add(x,xOffset,xLength, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
            subtract(x,xOffset,xLength, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }
        while(isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0) {
            subtract(x,xOffset,xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }

        //(x+y)2
        mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
        mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x, xOffset); // OK

        mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK

        if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
                Configuration.LENGTH_MODULUS)) {
            add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                    Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }

        /*WRONG OFFSET mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); */
        mRsaCipherForSquaring.doFinal(eempromTempBuffer, (short)0, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, (short)0); //OK

        /*WRONG OFFSET if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,*/
        if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
            add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                    Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }
        // ((x+y)^2 - x^2 -y^2)/2
        modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        return x;
    }

    public static boolean add(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
        short digit_mask = 0xff;
        short digit_len = 0x08;
        short result = 0;
        short i = (short) (xLength + xOffset - 1);
        short j = (short) (yLength + yOffset - 1);

        for (; i >= xOffset; i--, j--) {
            result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));

            x[i] = (byte) (result & digit_mask);
            result = (short) ((result >> digit_len) & digit_mask);
        }
        while (result > 0 && i >= xOffset) {
            result = (short) (result + (short) (x[i] & digit_mask));
            x[i] = (byte) (result & digit_mask);
            result = (short) ((result >> digit_len) & digit_mask);
            i--;
        }

        return result != 0;
    }

    public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
        short digit_mask = 0xff;
        short i = (short) (xLength + xOffset - 1);
        short j = (short) (yLength + yOffset - 1);
        short carry = 0;
        short subtraction_result = 0;

        for (; i >= xOffset && j >= yOffset; i--, j--) {
            subtraction_result = (short) ((x[i] & digit_mask)
                    - (y[j] & digit_mask) - carry);
            x[i] = (byte) (subtraction_result & digit_mask);
            carry = (short) (subtraction_result < 0 ? 1 : 0);
        }
        for (; i >= xOffset && carry > 0; i--) {
            if (x[i] != 0)
                carry = 0;
            x[i] -= 1;
        }

        return carry > 0;
    }

    public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
    {
        // Beware: this part is not tested
        while(xLength>yLength) {
            if(x[xOffset++]!=0x00) {
                return 1; // x is greater
            }
            xLength--;
        }
        while(yLength>xLength) {
            if(y[yOffset++]!=0x00) {
                return -1; // y is greater
            }
            yLength--;
        }
        // Beware: this part is not tested END
        for(short i = 0; i < xLength; i++) {
            if (x[xOffset] != y[yOffset]) {
                short srcShort = (short)(x[xOffset]&(short)0xFF);
                short dstShort = (short)(y[yOffset]&(short)0xFF);
                return ( ((srcShort > dstShort) ? (byte)1 : (byte)-1));
            }
            xOffset++;
            yOffset++;
        }
        return 0;
    }

    private void modular_division_by_2(byte[] input, short inOffset, short inLength, byte[] modulus, short modulusOffset, short modulusLength) {
        short carry = 0;
        short digit_mask = 0xff;
        short digit_first_bit_mask = 0x80;
        short lastIndex = (short) (inOffset + inLength - 1);

        short i = inOffset;
        if ((byte) (input[lastIndex] & 0x01) != 0) {
            if (add(input, inOffset, inLength, modulus, modulusOffset,
                    modulusLength)) {
                carry = digit_first_bit_mask;
            }
        }

        for (; i <= lastIndex; i++) {
            if ((input[i] & 0x01) == 0) {
                input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
                carry = 0;
            } else {
                input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
                carry = digit_first_bit_mask;
            }
        }
    }

    @Test
    public void testModMultiply() {
        Random r = new Random(12345L);
        for(int iiii=0;iiii<10;iiii++) {
            BigInteger modulus = BigInteger.probablePrime(NUM_BITS, r);
            System.out.println(" M = " + modulus);
            byte[] modulusBytes = normalize(modulus.toByteArray());
            Util.arrayCopyNonAtomic(modulusBytes, (short)0, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);

            mRsaPublicKekForSquare.setModulus(modulusBytes, (short)0, (short)modulusBytes.length);
            mRsaPublicKekForSquare.setExponent(new byte[] {0x02}, (short)0, (short)1);

            for(int iii=0;iii<1000;iii++) {
                BigInteger x = new BigInteger(NUM_BITS, r).mod(modulus);
                System.out.println(" x = " + x);
                BigInteger y = new BigInteger(NUM_BITS, r).mod(modulus);
                System.out.println(" y = " + y);
                BigInteger accResult;
                {
                    byte[] xBytes = normalize(x.toByteArray());
                    byte[] yBytes = normalize(y.toByteArray());
                    byte[] accResultBytes = modMultiply(xBytes, (short)0, (short)xBytes.length, yBytes, (short)0, (short)yBytes.length, Configuration.TEMP_OFFSET_RESULT);
                    accResult = new BigInteger(1, accResultBytes);
                }
                System.out.println(" Qr = " + accResult);
                BigInteger realResult = x.multiply(y).mod(modulus);
                System.out.println(" Rr = " + realResult);
                Assert.assertEquals(realResult, accResult);
            }
        }
    }

    private byte[] normalize(byte[] xBytes) {
        if(xBytes.length<Configuration.LENGTH_MODULUS) {
            xBytes = ArrayUtils.addAll(new byte[Configuration.LENGTH_MODULUS-xBytes.length], xBytes);
        }
        if(xBytes.length>Configuration.LENGTH_MODULUS) {
            Assert.assertEquals(xBytes[0], 0x00);
            xBytes=Arrays.copyOfRange(xBytes, 1, xBytes.length);
        }
        return xBytes;
    }
}

What was (IMHO) wrong:

  1. The isGreater() method -- although it is possible to use subtraction to compare numbers, it is much easier (and faster) to compare corresponding bytes starting from the most significant one and stop on the first mismatch. (In the subtraction case you would need to complete the subtraction and return the sign of the final result -- your original code ends on first "mismatch".)

  2. x+y overflow -- you should have kept the modulus subtraction for the addition overflow case (i.e. when add() returns true) in your last edit.

  3. Offsets into eempromTempBuffer -- on two places you used yOffset and should have used 0 (commented out with a "WRONG OFFSET").

  4. Casting Configuration.LENGTH_RSAOBJECT_MODULUS-1 to byte is not a good idea for larger values of modulus length

Some (random) comments:

  • the test uses already mentioned jcardsim to work

  • the code assumes that lengths of x and y are both LENGTH_MODULUS (as well as LENGTH_RSAOBJECT_MODULUS being equal to LENGTH_MODULUS)

  • it is not a good idea to have eempromTempBuffer in a non-volatile memory

  • your code is VERY similar to this code which is interesting

  • an interesting read regarding this topic is here (section 4.2.3).

Good luck!

Disclaimer: I am not a crypto expert nor mathematician so please do validate my thoughts

like image 85
vlp Avatar answered Nov 14 '22 22:11

vlp