singrdk/base/Libraries/Crypto/Rsa/Rsa.cs

283 lines
13 KiB
C#

// ----------------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
// ----------------------------------------------------------------------------
using System;
using System.Diagnostics;
namespace Microsoft.Singularity.Crypto.PublicKey
{
public class Rsa {
public class Key {
int _eBitN;
byte[] _eBytes;
Digits _eDigits;
int _eDigitN {
get { return (_eBitN + (Digit.BitN - 1)) / Digit.BitN; }
}
int _modBitN;
byte[] _modBytes;
int _modByteN { get { return (_modBitN + 8 - 1) / 8; } }
Digits _modDigits;
int _modDigitN {
get { return (_modBitN + (Digit.BitN - 1)) / Digit.BitN; }
}
int[] _primeBitN;
byte[][] _primeBytes, _dBytes;
Modulus _modulus;
Modulus[] _privateModulus;
Digits[] _dDigits, _chineseDigits;
public Key(int bitN, Random generator) {
_modBitN = bitN;
// The public exponent is 2^16 + 1
_eBitN = 17;
_eBytes = new byte[] { 1, 0, 1 };
int p1BitN = (bitN + 1) / 2
, p1ByteN = (p1BitN + 7) / 8
, p1DigitN = (p1BitN + (Digit.BitN - 1)) / Digit.BitN
, p2BitN = bitN / 2
, p2ByteN = (p2BitN + 7) / 8
, p2DigitN = (p2BitN + (Digit.BitN - 1)) / Digit.BitN
, longerDigitN = p1DigitN > p2DigitN ? p1DigitN : p2DigitN;
Digits d1 = new Digits(p1DigitN), d2 = new Digits(p2DigitN);
_modDigits = new Digits(_modDigitN);
_eDigits = new Digits(1);
Digits gcd = new Digits(longerDigitN)
, temp = new Digits(longerDigitN);
Digits.BytesToDigits(_eBytes, 0, _eDigits, _eBitN);
int[] pBitN = new int[] { p1BitN, p2BitN };
int nPrimeFound = 0;
Digits p1 = null, p2 = null;
while (nPrimeFound != 2) {
int pNowBitN = pBitN[nPrimeFound]
, pNowDigitN
= (pNowBitN + (Digit.BitN - 1)) / Digit.BitN;
Digits pNow = Prime.NewPrime(pNowBitN, generator);
if (nPrimeFound == 0) {
p1 = pNow; } else { p2 = pNow;
}
Digits.Sub(pNow, 1, temp, pNowDigitN);
int lgcd;
Digits.ExtendedGcd(_eDigits
, _eDigitN
, temp
, pNowDigitN
, nPrimeFound == 0 ? d1 : d2
, null
, gcd
, out lgcd);
if (Digits.Compare(gcd, 1, lgcd) != 0) {
Debug.Assert(false, "untested code");
continue;
}
if (
nPrimeFound == 1
&& Digits.Compare(p1, p1DigitN, p2, p2DigitN) == 0
) {
Debug.Assert(false, "untested code");
continue;
}
nPrimeFound++;
}
Digits.Mul(p1, p1DigitN, p2, p2DigitN, _modDigits);
int modBitN = Digits.SigBitN(_modDigits, _modDigitN);
Debug.Assert(modBitN == p1BitN + p2BitN && modBitN == _modBitN
, "internal error");
_primeBitN = new int[2] { p1BitN, p2BitN };
_modBytes = new byte[_modByteN];
_primeBytes
= new byte[2][] { new byte[p1ByteN], new byte[p2ByteN] };
_dBytes
= new byte[2][] { new byte[p1ByteN], new byte[p2ByteN] };
Digits.DigitsToBytes(_modDigits, _modBytes, 0, _modBitN);
for (int ip = 0; ip != 2; ip++) {
Digits.DigitsToBytes(
ip == 0 ? p1 : p2, _primeBytes[ip], 0, pBitN[ip]);
Digits.DigitsToBytes(
ip == 0 ? d1 : d2, _dBytes[ip], 0, pBitN[ip]);
}
int moduliCreated = 0;
Digits.BytesToDigits(_eBytes, 0, _eDigits, _eBitN);
_modulus = new Modulus(_modDigits, _modDigitN, true);
_privateModulus = new Modulus[2];
_dDigits = new Digits[2];
_chineseDigits = new Digits[2];
for (int ip = 0; ip != 2; ip++) {
Digits temp2 = new Digits(_modDigitN);
_dDigits[ip] = new Digits(p1DigitN);
_chineseDigits[ip] = new Digits(p1DigitN);
Digits.BytesToDigits(
_primeBytes[ip], 0, temp2, _primeBitN[ip]);
_privateModulus[ip]
= new
Modulus(temp2
, (_primeBitN[ip] + (Digit.BitN - 1)) / Digit.BitN
, true);
moduliCreated++;
Digits.BytesToDigits(
_dBytes[ip], 0, _dDigits[ip], _primeBitN[ip]);
}
int lgcd2 = 0;
Digits gcd2 = new Digits(_modDigitN);
Digits.ExtendedGcd(_privateModulus[0]._mod
, p1DigitN
, _privateModulus[1]._mod
, p2DigitN
, _chineseDigits[1]
, _chineseDigits[0]
, gcd2
, out lgcd2);
if (Digits.Compare(gcd2, 1, lgcd2) != 0) {
throw new ArgumentException();
}
}
class Block
{
readonly byte[] _bytes;
readonly int _i;
readonly int _n;
internal Block(byte[] bytes, int i, int n) {
_bytes = bytes;
_i = i;
_n = n;
}
internal byte[] _Bytes { get { return _bytes; } }
internal int _ByteI { get { return _i; } }
internal int _ByteN { get { return _n; } }
internal int _BitN { get { return 8 * _n; } }
internal int _DigitN {
get { return (_BitN + (Digit.BitN - 1)) / Digit.BitN; }
}
}
enum Mode { Encrypt, Decrypt };
public byte[] _Encrypt(byte[] inputBytes, int inputByteN) {
return _DoBlocks(inputBytes
, inputByteN
, _modByteN - 1
, _modByteN
, Mode.Encrypt);
}
public byte[] _Decrypt(byte[] inputBytes, int inputByteN) {
return _DoBlocks(inputBytes
, inputByteN
, _modByteN
, _modByteN - 1
, Mode.Decrypt);
}
public byte[] _Sign(byte[] inputBytes, int inputByteN) {
return _DoBlocks(inputBytes
, inputByteN
, _modByteN - 1
, _modByteN
, Mode.Decrypt);
}
public byte[] _Check(byte[] inputBytes, int inputByteN) {
return _DoBlocks(inputBytes
, inputByteN
, _modByteN
, _modByteN - 1
, Mode.Encrypt);
}
byte[] _DoBlocks(
byte[] inputBytes
, int inputByteN
, int inputBlockByteN
, int outputBlockByteN
, Mode mode
) {
int blockN = (inputByteN + inputBlockByteN - 1)
/ inputBlockByteN;
byte[] outputBytes = new byte[blockN * outputBlockByteN];
int i;
for (i = 0; i < blockN - 1; i++) {
Block inputBlock
= new Block(
inputBytes, i * inputBlockByteN, inputBlockByteN)
, outputBlock = new Block(outputBytes
, i * outputBlockByteN
, outputBlockByteN);
if (mode == Mode.Encrypt) {
_EncryptBlock(inputBlock, outputBlock);
}
else {
_DecryptBlock(inputBlock, outputBlock);
}
}
inputByteN -= i * inputBlockByteN;
if (inputByteN > 0) {
byte[] b = new byte[inputBlockByteN];
for (int j = 0; j < inputByteN; j++) {
b[j] = inputBytes[i * inputBlockByteN + j];
}
Block inputBlock = new Block(b, 0, inputBlockByteN)
, outputBlock = new Block(outputBytes
, i * outputBlockByteN
, outputBlockByteN);
if (mode == Mode.Encrypt) {
_EncryptBlock(inputBlock, outputBlock);
}
else {
_DecryptBlock(inputBlock, outputBlock);
}
}
return outputBytes;
}
void _EncryptBlock(Block input, Block output) {
if (input == null || output == null) {
throw new ArgumentNullException();
}
Digits digits = new Digits(_modDigitN);
Digits.BytesToDigits(
input._Bytes, input._ByteI, digits, input._BitN);
Modular.ValidateData(digits, _modulus._mod, _modulus._digitN);
_modulus._ToModular(digits, input._DigitN, digits);
_modulus._Exp(digits, _eDigits, _eDigitN, digits);
_modulus._FromModular(digits, digits);
Digits.DigitsToBytes(
digits, output._Bytes, output._ByteI, output._BitN);
}
void _DecryptBlock(Block input, Block output) {
if (input == null || output == null) {
throw new ArgumentNullException();
}
if (
_privateModulus[0]._digitN
+ _privateModulus[1]._digitN
- _modDigitN
> 1
) {
throw new ArgumentException();
}
Digits[] dmsg = new Digits[] {
new Digits(_modDigitN + 1)
, new Digits(_modDigitN + 1)
};
int longerDigitN
= _privateModulus[0]._digitN > _privateModulus[1]._digitN
? _privateModulus[0]._digitN
: _privateModulus[1]._digitN;
Digits res = new Digits(longerDigitN);
Digits.BytesToDigits(
input._Bytes, input._ByteI, dmsg[1], input._BitN);
Modular.ValidateData(dmsg[1], _modDigits, _modDigitN);
for (int ip = 0; ip != 2; ip++) {
_privateModulus[ip]._ToModular(dmsg[1], _modDigitN, res);
_privateModulus[ip]
._Exp(res, _dDigits[ip], _privateModulus[ip]._digitN, res);
_privateModulus[ip]._Mul(res, _chineseDigits[ip], res);
Digits.Mul(res
, _privateModulus[ip]._digitN
, _privateModulus[1 - ip]._mod
, _privateModulus[1 - ip]._digitN
, dmsg[ip]);
}
Modular.Add(dmsg[0], dmsg[1], dmsg[0], _modDigits, _modDigitN);
Digits.DigitsToBytes(
dmsg[0], output._Bytes, output._ByteI, output._BitN);
}
}
}
}