Hello I'm working on a project on Java Card which implies a lot of modulo-multiplication. I managed to implement an modulo-multiplication on this platform using RSA cryptosystem but it seems to work for certain numbers.
public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y,
short yOffset, short yLength, short tempOutoffset) {
//copy x value to temporary rambuffer
Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);
// copy the y value to match th size of rsa_object
Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (byte) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
// x+y
if (JBigInteger.add(x,xOffset,xLength, eempromTempBuffer,
(short)0,Configuration.LENGTH_MODULUS)) ;
if(this.isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0)
{
JBigInteger.subtract(x,xOffset,xLength, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
//(x+y)2
mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x,
xOffset); // OK
mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK
if (JBigInteger.subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
Configuration.LENGTH_MODULUS)) {
JBigInteger.add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); //OK
if (JBigInteger.subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,
Configuration.LENGTH_MODULUS)) {
JBigInteger.add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
// ((x+y)^2 - x^2 -y^2)/2
JBigInteger.modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
return x;
}
public static boolean add(byte[] x, short xOffset, short xLength, byte[] y,
short yOffset, short yLength) {
short digit_mask = 0xff;
short digit_len = 0x08;
short result = 0;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
for (; i >= xOffset; i--, j--) {
result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
}
while (result > 0 && i >= xOffset) {
result = (short) (result + (short) (x[i] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
i--;
}
return result != 0;
}
public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y,
short yOffset, short yLength) {
short digit_mask = 0xff;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
short carry = 0;
short subtraction_result = 0;
for (; i >= xOffset && j >= yOffset; i--, j--) {
subtraction_result = (short) ((x[i] & digit_mask)
- (y[j] & digit_mask) - carry);
x[i] = (byte) (subtraction_result & digit_mask);
carry = (short) (subtraction_result < 0 ? 1 : 0);
}
for (; i >= xOffset && carry > 0; i--) {
if (x[i] != 0)
carry = 0;
x[i] -= 1;
}
return carry > 0;
}
public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
{
if(xLength > yLength)
return (short)1;
if(xLength < yLength)
return (short)(-1);
short digit_mask = 0xff;
short digit_len = 0x08;
short result = 0;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
for (; i >= xOffset; i--, j--) {
result = (short) (result + (short) (x[i] & digit_mask) - (short) (y[j] & digit_mask));
if(result > 0)
return (short)1;
if(result < 0)
return (short)-1;
}
return 0;
}
The code works well for little number but fails on bigger one
Below is a very simple unit test with a (hopefully) working variant of your code:
package test.java.so;
import java.math.BigInteger;
import java.util.Random;
import javacard.framework.JCSystem;
import javacard.framework.Util;
import javacard.security.KeyBuilder;
import javacard.security.RSAPublicKey;
import javacardx.crypto.Cipher;
import org.apache.commons.lang3.ArrayUtils;
import org.bouncycastle.util.Arrays;
import org.junit.Assert;
import org.junit.Test;
import sutil.test.AbstractTest;
public class So36966764_Test extends AbstractTest {
private static final int NUM_BITS = 1024;
// Dummy
static class Configuration {
public static final short LENGTH_MODULUS = NUM_BITS/8;
public static final short LENGTH_RSAOBJECT_MODULUS = LENGTH_MODULUS;
public static final short TEMP_OFFSET_MODULUS = 0;
public static final short TEMP_OFFSET_RESULT = LENGTH_MODULUS;
}
private byte[] tempBuffer = JCSystem.makeTransientByteArray((short)(Configuration.TEMP_OFFSET_RESULT+Configuration.LENGTH_MODULUS), JCSystem.CLEAR_ON_DESELECT);
private byte[] eempromTempBuffer = new byte[Configuration.LENGTH_MODULUS]; // Why EEPROM?
private RSAPublicKey mRsaPublicKekForSquare = (RSAPublicKey)KeyBuilder.buildKey(KeyBuilder.TYPE_RSA_PUBLIC, (short)NUM_BITS, false);
private Cipher mRsaCipherForSquaring = Cipher.getInstance(Cipher.ALG_RSA_NOPAD, false);
// Assuming xLength==yLength==LENGTH_MODULUS
public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength, short tempOutoffset) {
//copy x value to temporary rambuffer
Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);
// copy the y value to match th size of rsa_object
Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
// x+y
if(add(x,xOffset,xLength, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
subtract(x,xOffset,xLength, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
while(isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0) {
subtract(x,xOffset,xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
//(x+y)2
mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x, xOffset); // OK
mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK
if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
Configuration.LENGTH_MODULUS)) {
add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
/*WRONG OFFSET mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); */
mRsaCipherForSquaring.doFinal(eempromTempBuffer, (short)0, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, (short)0); //OK
/*WRONG OFFSET if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,*/
if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
// ((x+y)^2 - x^2 -y^2)/2
modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
return x;
}
public static boolean add(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
short digit_mask = 0xff;
short digit_len = 0x08;
short result = 0;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
for (; i >= xOffset; i--, j--) {
result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
}
while (result > 0 && i >= xOffset) {
result = (short) (result + (short) (x[i] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
i--;
}
return result != 0;
}
public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
short digit_mask = 0xff;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
short carry = 0;
short subtraction_result = 0;
for (; i >= xOffset && j >= yOffset; i--, j--) {
subtraction_result = (short) ((x[i] & digit_mask)
- (y[j] & digit_mask) - carry);
x[i] = (byte) (subtraction_result & digit_mask);
carry = (short) (subtraction_result < 0 ? 1 : 0);
}
for (; i >= xOffset && carry > 0; i--) {
if (x[i] != 0)
carry = 0;
x[i] -= 1;
}
return carry > 0;
}
public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
{
// Beware: this part is not tested
while(xLength>yLength) {
if(x[xOffset++]!=0x00) {
return 1; // x is greater
}
xLength--;
}
while(yLength>xLength) {
if(y[yOffset++]!=0x00) {
return -1; // y is greater
}
yLength--;
}
// Beware: this part is not tested END
for(short i = 0; i < xLength; i++) {
if (x[xOffset] != y[yOffset]) {
short srcShort = (short)(x[xOffset]&(short)0xFF);
short dstShort = (short)(y[yOffset]&(short)0xFF);
return ( ((srcShort > dstShort) ? (byte)1 : (byte)-1));
}
xOffset++;
yOffset++;
}
return 0;
}
private void modular_division_by_2(byte[] input, short inOffset, short inLength, byte[] modulus, short modulusOffset, short modulusLength) {
short carry = 0;
short digit_mask = 0xff;
short digit_first_bit_mask = 0x80;
short lastIndex = (short) (inOffset + inLength - 1);
short i = inOffset;
if ((byte) (input[lastIndex] & 0x01) != 0) {
if (add(input, inOffset, inLength, modulus, modulusOffset,
modulusLength)) {
carry = digit_first_bit_mask;
}
}
for (; i <= lastIndex; i++) {
if ((input[i] & 0x01) == 0) {
input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
carry = 0;
} else {
input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
carry = digit_first_bit_mask;
}
}
}
@Test
public void testModMultiply() {
Random r = new Random(12345L);
for(int iiii=0;iiii<10;iiii++) {
BigInteger modulus = BigInteger.probablePrime(NUM_BITS, r);
System.out.println(" M = " + modulus);
byte[] modulusBytes = normalize(modulus.toByteArray());
Util.arrayCopyNonAtomic(modulusBytes, (short)0, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
mRsaPublicKekForSquare.setModulus(modulusBytes, (short)0, (short)modulusBytes.length);
mRsaPublicKekForSquare.setExponent(new byte[] {0x02}, (short)0, (short)1);
for(int iii=0;iii<1000;iii++) {
BigInteger x = new BigInteger(NUM_BITS, r).mod(modulus);
System.out.println(" x = " + x);
BigInteger y = new BigInteger(NUM_BITS, r).mod(modulus);
System.out.println(" y = " + y);
BigInteger accResult;
{
byte[] xBytes = normalize(x.toByteArray());
byte[] yBytes = normalize(y.toByteArray());
byte[] accResultBytes = modMultiply(xBytes, (short)0, (short)xBytes.length, yBytes, (short)0, (short)yBytes.length, Configuration.TEMP_OFFSET_RESULT);
accResult = new BigInteger(1, accResultBytes);
}
System.out.println(" Qr = " + accResult);
BigInteger realResult = x.multiply(y).mod(modulus);
System.out.println(" Rr = " + realResult);
Assert.assertEquals(realResult, accResult);
}
}
}
private byte[] normalize(byte[] xBytes) {
if(xBytes.length<Configuration.LENGTH_MODULUS) {
xBytes = ArrayUtils.addAll(new byte[Configuration.LENGTH_MODULUS-xBytes.length], xBytes);
}
if(xBytes.length>Configuration.LENGTH_MODULUS) {
Assert.assertEquals(xBytes[0], 0x00);
xBytes=Arrays.copyOfRange(xBytes, 1, xBytes.length);
}
return xBytes;
}
}
What was (IMHO) wrong:
The isGreater()
method -- although it is possible to use subtraction to compare numbers, it is much easier (and faster) to compare corresponding bytes starting from the most significant one and stop on the first mismatch. (In the subtraction case you would need to complete the subtraction and return the sign of the final result -- your original code ends on first "mismatch".)
x+y
overflow -- you should have kept the modulus subtraction for the addition overflow case (i.e. when add()
returns true) in your last edit.
Offsets into eempromTempBuffer
-- on two places you used yOffset
and should have used 0
(commented out with a "WRONG OFFSET").
Casting Configuration.LENGTH_RSAOBJECT_MODULUS-1
to byte
is not a good idea for larger values of modulus length
Some (random) comments:
the test uses already mentioned jcardsim to work
the code assumes that lengths of x
and y
are both LENGTH_MODULUS
(as well as LENGTH_RSAOBJECT_MODULUS
being equal to LENGTH_MODULUS
)
it is not a good idea to have eempromTempBuffer
in a non-volatile memory
your code is VERY similar to this code which is interesting
an interesting read regarding this topic is here (section 4.2.3).
Good luck!
Disclaimer: I am not a crypto expert nor mathematician so please do validate my thoughts
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With