283 lines
13 KiB
C#
283 lines
13 KiB
C#
// ----------------------------------------------------------------------------
|
|
//
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
//
|
|
// ----------------------------------------------------------------------------
|
|
|
|
using System;
|
|
using System.Diagnostics;
|
|
|
|
namespace Microsoft.Singularity.Crypto.PublicKey
|
|
{
|
|
public class Rsa {
|
|
public class Key {
|
|
int _eBitN;
|
|
byte[] _eBytes;
|
|
Digits _eDigits;
|
|
int _eDigitN {
|
|
get { return (_eBitN + (Digit.BitN - 1)) / Digit.BitN; }
|
|
}
|
|
int _modBitN;
|
|
byte[] _modBytes;
|
|
int _modByteN { get { return (_modBitN + 8 - 1) / 8; } }
|
|
Digits _modDigits;
|
|
int _modDigitN {
|
|
get { return (_modBitN + (Digit.BitN - 1)) / Digit.BitN; }
|
|
}
|
|
int[] _primeBitN;
|
|
byte[][] _primeBytes, _dBytes;
|
|
Modulus _modulus;
|
|
Modulus[] _privateModulus;
|
|
Digits[] _dDigits, _chineseDigits;
|
|
public Key(int bitN, Random generator) {
|
|
_modBitN = bitN;
|
|
// The public exponent is 2^16 + 1
|
|
_eBitN = 17;
|
|
_eBytes = new byte[] { 1, 0, 1 };
|
|
int p1BitN = (bitN + 1) / 2
|
|
, p1ByteN = (p1BitN + 7) / 8
|
|
, p1DigitN = (p1BitN + (Digit.BitN - 1)) / Digit.BitN
|
|
, p2BitN = bitN / 2
|
|
, p2ByteN = (p2BitN + 7) / 8
|
|
, p2DigitN = (p2BitN + (Digit.BitN - 1)) / Digit.BitN
|
|
, longerDigitN = p1DigitN > p2DigitN ? p1DigitN : p2DigitN;
|
|
Digits d1 = new Digits(p1DigitN), d2 = new Digits(p2DigitN);
|
|
_modDigits = new Digits(_modDigitN);
|
|
_eDigits = new Digits(1);
|
|
Digits gcd = new Digits(longerDigitN)
|
|
, temp = new Digits(longerDigitN);
|
|
Digits.BytesToDigits(_eBytes, 0, _eDigits, _eBitN);
|
|
int[] pBitN = new int[] { p1BitN, p2BitN };
|
|
int nPrimeFound = 0;
|
|
Digits p1 = null, p2 = null;
|
|
while (nPrimeFound != 2) {
|
|
int pNowBitN = pBitN[nPrimeFound]
|
|
, pNowDigitN
|
|
= (pNowBitN + (Digit.BitN - 1)) / Digit.BitN;
|
|
Digits pNow = Prime.NewPrime(pNowBitN, generator);
|
|
if (nPrimeFound == 0) {
|
|
p1 = pNow; } else { p2 = pNow;
|
|
}
|
|
Digits.Sub(pNow, 1, temp, pNowDigitN);
|
|
int lgcd;
|
|
Digits.ExtendedGcd(_eDigits
|
|
, _eDigitN
|
|
, temp
|
|
, pNowDigitN
|
|
, nPrimeFound == 0 ? d1 : d2
|
|
, null
|
|
, gcd
|
|
, out lgcd);
|
|
if (Digits.Compare(gcd, 1, lgcd) != 0) {
|
|
Debug.Assert(false, "untested code");
|
|
continue;
|
|
}
|
|
if (
|
|
nPrimeFound == 1
|
|
&& Digits.Compare(p1, p1DigitN, p2, p2DigitN) == 0
|
|
) {
|
|
Debug.Assert(false, "untested code");
|
|
continue;
|
|
}
|
|
nPrimeFound++;
|
|
}
|
|
Digits.Mul(p1, p1DigitN, p2, p2DigitN, _modDigits);
|
|
int modBitN = Digits.SigBitN(_modDigits, _modDigitN);
|
|
Debug.Assert(modBitN == p1BitN + p2BitN && modBitN == _modBitN
|
|
, "internal error");
|
|
_primeBitN = new int[2] { p1BitN, p2BitN };
|
|
_modBytes = new byte[_modByteN];
|
|
_primeBytes
|
|
= new byte[2][] { new byte[p1ByteN], new byte[p2ByteN] };
|
|
_dBytes
|
|
= new byte[2][] { new byte[p1ByteN], new byte[p2ByteN] };
|
|
Digits.DigitsToBytes(_modDigits, _modBytes, 0, _modBitN);
|
|
for (int ip = 0; ip != 2; ip++) {
|
|
Digits.DigitsToBytes(
|
|
ip == 0 ? p1 : p2, _primeBytes[ip], 0, pBitN[ip]);
|
|
Digits.DigitsToBytes(
|
|
ip == 0 ? d1 : d2, _dBytes[ip], 0, pBitN[ip]);
|
|
}
|
|
int moduliCreated = 0;
|
|
Digits.BytesToDigits(_eBytes, 0, _eDigits, _eBitN);
|
|
_modulus = new Modulus(_modDigits, _modDigitN, true);
|
|
_privateModulus = new Modulus[2];
|
|
_dDigits = new Digits[2];
|
|
_chineseDigits = new Digits[2];
|
|
for (int ip = 0; ip != 2; ip++) {
|
|
Digits temp2 = new Digits(_modDigitN);
|
|
_dDigits[ip] = new Digits(p1DigitN);
|
|
_chineseDigits[ip] = new Digits(p1DigitN);
|
|
Digits.BytesToDigits(
|
|
_primeBytes[ip], 0, temp2, _primeBitN[ip]);
|
|
_privateModulus[ip]
|
|
= new
|
|
Modulus(temp2
|
|
, (_primeBitN[ip] + (Digit.BitN - 1)) / Digit.BitN
|
|
, true);
|
|
moduliCreated++;
|
|
Digits.BytesToDigits(
|
|
_dBytes[ip], 0, _dDigits[ip], _primeBitN[ip]);
|
|
}
|
|
int lgcd2 = 0;
|
|
Digits gcd2 = new Digits(_modDigitN);
|
|
Digits.ExtendedGcd(_privateModulus[0]._mod
|
|
, p1DigitN
|
|
, _privateModulus[1]._mod
|
|
, p2DigitN
|
|
, _chineseDigits[1]
|
|
, _chineseDigits[0]
|
|
, gcd2
|
|
, out lgcd2);
|
|
if (Digits.Compare(gcd2, 1, lgcd2) != 0) {
|
|
throw new ArgumentException();
|
|
}
|
|
}
|
|
class Block
|
|
{
|
|
readonly byte[] _bytes;
|
|
readonly int _i;
|
|
readonly int _n;
|
|
internal Block(byte[] bytes, int i, int n) {
|
|
_bytes = bytes;
|
|
_i = i;
|
|
_n = n;
|
|
}
|
|
internal byte[] _Bytes { get { return _bytes; } }
|
|
internal int _ByteI { get { return _i; } }
|
|
internal int _ByteN { get { return _n; } }
|
|
internal int _BitN { get { return 8 * _n; } }
|
|
internal int _DigitN {
|
|
get { return (_BitN + (Digit.BitN - 1)) / Digit.BitN; }
|
|
}
|
|
}
|
|
enum Mode { Encrypt, Decrypt };
|
|
public byte[] _Encrypt(byte[] inputBytes, int inputByteN) {
|
|
return _DoBlocks(inputBytes
|
|
, inputByteN
|
|
, _modByteN - 1
|
|
, _modByteN
|
|
, Mode.Encrypt);
|
|
}
|
|
public byte[] _Decrypt(byte[] inputBytes, int inputByteN) {
|
|
return _DoBlocks(inputBytes
|
|
, inputByteN
|
|
, _modByteN
|
|
, _modByteN - 1
|
|
, Mode.Decrypt);
|
|
}
|
|
public byte[] _Sign(byte[] inputBytes, int inputByteN) {
|
|
return _DoBlocks(inputBytes
|
|
, inputByteN
|
|
, _modByteN - 1
|
|
, _modByteN
|
|
, Mode.Decrypt);
|
|
}
|
|
public byte[] _Check(byte[] inputBytes, int inputByteN) {
|
|
return _DoBlocks(inputBytes
|
|
, inputByteN
|
|
, _modByteN
|
|
, _modByteN - 1
|
|
, Mode.Encrypt);
|
|
}
|
|
byte[] _DoBlocks(
|
|
byte[] inputBytes
|
|
, int inputByteN
|
|
, int inputBlockByteN
|
|
, int outputBlockByteN
|
|
, Mode mode
|
|
) {
|
|
int blockN = (inputByteN + inputBlockByteN - 1)
|
|
/ inputBlockByteN;
|
|
byte[] outputBytes = new byte[blockN * outputBlockByteN];
|
|
int i;
|
|
for (i = 0; i < blockN - 1; i++) {
|
|
Block inputBlock
|
|
= new Block(
|
|
inputBytes, i * inputBlockByteN, inputBlockByteN)
|
|
, outputBlock = new Block(outputBytes
|
|
, i * outputBlockByteN
|
|
, outputBlockByteN);
|
|
if (mode == Mode.Encrypt) {
|
|
_EncryptBlock(inputBlock, outputBlock);
|
|
}
|
|
else {
|
|
_DecryptBlock(inputBlock, outputBlock);
|
|
}
|
|
}
|
|
inputByteN -= i * inputBlockByteN;
|
|
if (inputByteN > 0) {
|
|
byte[] b = new byte[inputBlockByteN];
|
|
for (int j = 0; j < inputByteN; j++) {
|
|
b[j] = inputBytes[i * inputBlockByteN + j];
|
|
}
|
|
Block inputBlock = new Block(b, 0, inputBlockByteN)
|
|
, outputBlock = new Block(outputBytes
|
|
, i * outputBlockByteN
|
|
, outputBlockByteN);
|
|
if (mode == Mode.Encrypt) {
|
|
_EncryptBlock(inputBlock, outputBlock);
|
|
}
|
|
else {
|
|
_DecryptBlock(inputBlock, outputBlock);
|
|
}
|
|
}
|
|
return outputBytes;
|
|
}
|
|
void _EncryptBlock(Block input, Block output) {
|
|
if (input == null || output == null) {
|
|
throw new ArgumentNullException();
|
|
}
|
|
Digits digits = new Digits(_modDigitN);
|
|
Digits.BytesToDigits(
|
|
input._Bytes, input._ByteI, digits, input._BitN);
|
|
Modular.ValidateData(digits, _modulus._mod, _modulus._digitN);
|
|
_modulus._ToModular(digits, input._DigitN, digits);
|
|
_modulus._Exp(digits, _eDigits, _eDigitN, digits);
|
|
_modulus._FromModular(digits, digits);
|
|
Digits.DigitsToBytes(
|
|
digits, output._Bytes, output._ByteI, output._BitN);
|
|
}
|
|
void _DecryptBlock(Block input, Block output) {
|
|
if (input == null || output == null) {
|
|
throw new ArgumentNullException();
|
|
}
|
|
if (
|
|
_privateModulus[0]._digitN
|
|
+ _privateModulus[1]._digitN
|
|
- _modDigitN
|
|
> 1
|
|
) {
|
|
throw new ArgumentException();
|
|
}
|
|
Digits[] dmsg = new Digits[] {
|
|
new Digits(_modDigitN + 1)
|
|
, new Digits(_modDigitN + 1)
|
|
};
|
|
int longerDigitN
|
|
= _privateModulus[0]._digitN > _privateModulus[1]._digitN
|
|
? _privateModulus[0]._digitN
|
|
: _privateModulus[1]._digitN;
|
|
Digits res = new Digits(longerDigitN);
|
|
Digits.BytesToDigits(
|
|
input._Bytes, input._ByteI, dmsg[1], input._BitN);
|
|
Modular.ValidateData(dmsg[1], _modDigits, _modDigitN);
|
|
for (int ip = 0; ip != 2; ip++) {
|
|
_privateModulus[ip]._ToModular(dmsg[1], _modDigitN, res);
|
|
_privateModulus[ip]
|
|
._Exp(res, _dDigits[ip], _privateModulus[ip]._digitN, res);
|
|
_privateModulus[ip]._Mul(res, _chineseDigits[ip], res);
|
|
Digits.Mul(res
|
|
, _privateModulus[ip]._digitN
|
|
, _privateModulus[1 - ip]._mod
|
|
, _privateModulus[1 - ip]._digitN
|
|
, dmsg[ip]);
|
|
}
|
|
Modular.Add(dmsg[0], dmsg[1], dmsg[0], _modDigits, _modDigitN);
|
|
Digits.DigitsToBytes(
|
|
dmsg[0], output._Bytes, output._ByteI, output._BitN);
|
|
}
|
|
}
|
|
}
|
|
}
|