325 lines
15 KiB
C#
325 lines
15 KiB
C#
///////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// Microsoft Research Singularity
|
|
//
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
//
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
using System;
|
|
using System.Diagnostics;
|
|
|
|
namespace Microsoft.Singularity.Crypto.PublicKey {
|
|
class Modulus {
|
|
internal readonly int _digitN;
|
|
readonly int _scaleBitN;
|
|
readonly bool _fromRight;
|
|
readonly Reciprocal _leftRecip;
|
|
readonly Digit _rightRecip;
|
|
internal readonly Digits _mod;
|
|
readonly Digits _multiplier1;
|
|
readonly Digits _multiplier2;
|
|
internal readonly Digits _one;
|
|
MulAlgorithm _algorithm;
|
|
internal Modulus(Digits mod, int digitN, bool fromRight) {
|
|
if (digitN == 0 || mod[digitN - 1] == 0) {
|
|
throw new ArgumentException();
|
|
}
|
|
_mod = mod;
|
|
_digitN = digitN;
|
|
_fromRight = fromRight;
|
|
_one = new Digits(digitN);
|
|
_multiplier1 = new Digits(digitN);
|
|
_multiplier2 = new Digits(digitN);
|
|
// this.mod.Set(this.mod, _digitN);
|
|
_leftRecip = new Reciprocal();
|
|
Digits.DivPrecondition(mod, digitN, _leftRecip);
|
|
Digit mod0inv = 0;
|
|
if ((mod[0] & 1) != 0) { mod0inv = Digit.TwoAdicInverse(mod[0]); }
|
|
_rightRecip = mod0inv;
|
|
int digitN2 = (digitN + 1) / 2;
|
|
Digits temp = new Digits(digitN + digitN2);
|
|
if (!fromRight) {
|
|
_algorithm = new MulAlgorithm(_MulFromLeft);
|
|
int dividendN = digitN + digitN2;
|
|
_scaleBitN = 0;
|
|
for (int i = 0; i != dividendN; i++) {
|
|
temp[i] = Digit.MaxValue;
|
|
}
|
|
temp[dividendN - 1] = Digit.MaxValue >> _leftRecip._shiftBitN;
|
|
Digits q = new Digits(digitN2 + 1);
|
|
Digits r = new Digits(digitN);
|
|
Digits.Div(temp, dividendN, mod, digitN, _leftRecip, q, r);
|
|
Debug.Assert(q[digitN2] == 1, "internal error");
|
|
Digits.Add(r, 1U, r, digitN);
|
|
Digits.Sub(mod, r, r, digitN);
|
|
} else {
|
|
_algorithm = new MulAlgorithm(_MulFromRight);
|
|
_scaleBitN = Digit.BitN * digitN;
|
|
if (mod0inv == 0) { throw new ArgumentException(); }
|
|
_multiplier2[0] = mod0inv;
|
|
temp[digitN] = Digits.Mul(mod, mod0inv, temp, digitN);
|
|
Debug.Assert(temp[0] == 1, "internal error");
|
|
for (int i = 1; i != digitN2; i++) {
|
|
Digit mul = unchecked(0 - mod0inv * temp[i]);
|
|
_multiplier2[i] = mul;
|
|
temp[i + digitN]
|
|
= Digits.Accumulate(mod, mul, temp + i, digitN);
|
|
Debug.Assert(temp[i] == 0, "internal error");
|
|
}
|
|
_multiplier1._Set(temp + digitN2, digitN);
|
|
}
|
|
_ToModular(new Digits(new Digit[] { 1 }), 1, _one);
|
|
}
|
|
void _BasePowerSquaring(Digits basepower) {
|
|
_Mul(basepower, basepower, basepower);
|
|
}
|
|
class Temps2000 {
|
|
internal Modulus _modulus;
|
|
internal Digits[] _bucket = new Digits[1L << 5];
|
|
internal bool[] _bucketBusy = new bool[1L << 5];
|
|
}
|
|
void _BucketMul(UInt32 ibucket, Digits mult, Temps2000 temps) {
|
|
Digits bloc = temps._bucket[ibucket];
|
|
if (temps._bucketBusy[ibucket]) {
|
|
_Mul(bloc, mult, bloc);
|
|
} else {
|
|
temps._bucketBusy[ibucket] = true;
|
|
bloc._Set(mult, _digitN);
|
|
}
|
|
}
|
|
internal void
|
|
_Exp(Digits @base, Digits exp, int expDigitN, Digits answer) {
|
|
int expBitNUsed = Digits.SigBitN(exp, expDigitN);
|
|
ushort[] widthCutoffs = new ushort[] { 6, 24, 80, 240, 672 };
|
|
Digits basepower = answer;
|
|
int bucketWidth = 1;
|
|
while (
|
|
bucketWidth < 5 && widthCutoffs[bucketWidth - 1] < expBitNUsed
|
|
) {
|
|
bucketWidth++;
|
|
}
|
|
Modular.ValidateData(@base, _mod, _digitN);
|
|
UInt32 bucketMask = (1U << bucketWidth) - 1;
|
|
UInt32 maxBucket = bucketMask;
|
|
Digits bucketData = new Digits((int)(_digitN * maxBucket));
|
|
Temps2000 temps = new Temps2000();
|
|
temps._modulus = this;
|
|
temps._bucket[0] = null;
|
|
Modular.Add(_one, _one, bucketData, _mod, _digitN);
|
|
bool base2 = Digits.Compare(@base, bucketData, _digitN) == 0;
|
|
if (base2 && expBitNUsed != 0) {
|
|
int shiftMax
|
|
= Digit.BitN * _digitN > 1024 ? 1024 : Digit.BitN * _digitN;
|
|
int highExponBitN = 0;
|
|
bool bighBitNProcessed = false;
|
|
Digits temp = bucketData;
|
|
for (int i = expBitNUsed; i-- != 0; ) {
|
|
Digit expBit = Digits.GetBit(exp, i);
|
|
if (bighBitNProcessed) {
|
|
_Mul(temp, temp, temp);
|
|
if (expBit != 0) {
|
|
Modular.Add(temp, temp, temp, _mod, _digitN);
|
|
}
|
|
} else {
|
|
highExponBitN = (int)(2 * highExponBitN + expBit);
|
|
if (i == 0 || 2 * highExponBitN >= shiftMax) {
|
|
bighBitNProcessed = true;
|
|
_Shift(_one, highExponBitN, temp);
|
|
}
|
|
}
|
|
}
|
|
temps._bucket[1] = temp;
|
|
Debug.Assert(bighBitNProcessed, "internal error");
|
|
} else {
|
|
UInt32 ibucket;
|
|
for (ibucket = 1; ibucket <= maxBucket; ibucket++) {
|
|
Digits bloc = bucketData
|
|
+ (int)(_digitN
|
|
* (ibucket
|
|
- 1
|
|
+ ((ibucket & 1) == 0 ? maxBucket : 0))
|
|
/ 2);
|
|
temps._bucket[ibucket] = bloc;
|
|
temps._bucketBusy[ibucket] = false;
|
|
bloc._Set(_one, _digitN);
|
|
}
|
|
basepower._Set(@base, _digitN);
|
|
Digit carried = 0;
|
|
int ndoubling = 0;
|
|
for (int i = 0; i != expBitNUsed; i++) {
|
|
Digit bitNow = Digits.GetBit(exp, i);
|
|
Debug.Assert(carried >> bucketWidth + 2 == 0
|
|
, "internal error");
|
|
if (bitNow != 0) {
|
|
while (ndoubling >= bucketWidth + 1) {
|
|
if ((carried & 1) != 0) {
|
|
ibucket = carried & bucketMask;
|
|
carried -= ibucket;
|
|
temps._modulus
|
|
._BucketMul(ibucket, basepower, temps);
|
|
}
|
|
temps._modulus._BasePowerSquaring(basepower);
|
|
carried /= 2;
|
|
ndoubling--;
|
|
}
|
|
carried |= 1U << ndoubling;
|
|
}
|
|
ndoubling++;
|
|
}
|
|
while (carried != 0) {
|
|
bool squareNow = false;
|
|
if (carried <= maxBucket) {
|
|
ibucket = carried;
|
|
} else if ((carried & 1) == 0) {
|
|
squareNow = true;
|
|
} else if (carried <= 3 * maxBucket) {
|
|
ibucket = maxBucket;
|
|
} else {
|
|
Debug.Assert(false, "untested code");
|
|
ibucket = carried & bucketMask;
|
|
}
|
|
if (squareNow) {
|
|
carried /= 2;
|
|
temps._modulus._BasePowerSquaring(basepower);
|
|
} else {
|
|
carried -= ibucket;
|
|
temps._modulus._BucketMul(ibucket, basepower, temps);
|
|
}
|
|
}
|
|
for (ibucket = maxBucket; ibucket >= 2; ibucket--) {
|
|
if (temps._bucketBusy[ibucket]) {
|
|
bool found = false;
|
|
UInt32 jbucket, jbucketMax, kbucket;
|
|
Digits bloci;
|
|
if ((ibucket & 1) == 0) {
|
|
jbucketMax = ibucket / 2;
|
|
} else {
|
|
jbucketMax = 1;
|
|
}
|
|
for (
|
|
jbucket = ibucket >> 1;
|
|
jbucket != ibucket && !found;
|
|
jbucket++
|
|
) {
|
|
if (temps._bucketBusy[jbucket]) {
|
|
jbucketMax = jbucket;
|
|
found = temps._bucketBusy[ibucket - jbucket];
|
|
}
|
|
}
|
|
jbucket = jbucketMax;
|
|
kbucket = ibucket - jbucket;
|
|
bloci = temps._bucket[ibucket];
|
|
temps._modulus._BucketMul(jbucket, bloci, temps);
|
|
temps._modulus._BucketMul(kbucket, bloci, temps);
|
|
}
|
|
}
|
|
}
|
|
answer._Set(temps._bucket[1], _digitN);
|
|
}
|
|
void _MulFromLeft(Digits a, Digits b, Digits c) {
|
|
Digits product = new Digits(2 * _digitN);
|
|
Digits.Mul(a, _digitN, b, _digitN, product);
|
|
Digits
|
|
.Div(product, 2 * _digitN, _mod, _digitN, _leftRecip, null, c);
|
|
}
|
|
void _MulFromRight(Digits a, Digits b, Digits c) {
|
|
Digit minv = _rightRecip
|
|
, minva0 = unchecked(minv * a[0])
|
|
, mul1 = b[0]
|
|
, mul2 = unchecked(minva0 * mul1)
|
|
, carry1 = Digit2.Hi((UInt64)mul1 * a[0])
|
|
, carry2 = Digit2.Hi((UInt64)mul2 * _mod[0]);
|
|
UInt64 prod1, prod2;
|
|
Debug.Assert(unchecked(mul1 * a[0]) == unchecked(mul2 * _mod[0])
|
|
, "internal error");
|
|
Digits temp1 = new Digits(_digitN), temp2 = new Digits(_digitN);
|
|
for (int i = 1; i != _digitN; i++) {
|
|
prod1 = (UInt64)mul1 * a[i] + carry1;
|
|
prod2 = (UInt64)mul2 * _mod[i] + carry2;
|
|
temp1[i - 1] = Digit2.Lo(prod1);
|
|
temp2[i - 1] = Digit2.Lo(prod2);
|
|
carry1 = Digit2.Hi(prod1);
|
|
carry2 = Digit2.Hi(prod2);
|
|
}
|
|
temp1[_digitN - 1] = carry1;
|
|
temp2[_digitN - 1] = carry2;
|
|
for (int j = 1; j != _digitN; j++) {
|
|
mul1 = b[j];
|
|
mul2 = unchecked(minva0 * mul1 + minv * (temp1[0] - temp2[0]));
|
|
prod1 = (UInt64)mul1 * a[0] + temp1[0];
|
|
prod2 = (UInt64)mul2 * _mod[0] + temp2[0];
|
|
Debug.Assert(Digit2.Lo(prod1) == Digit2.Lo(prod2)
|
|
, "internal error");
|
|
carry1 = Digit2.Hi(prod1);
|
|
carry2 = Digit2.Hi(prod2);
|
|
for (int i = 1; i != _digitN; i++) {
|
|
prod1 = (UInt64)mul1 * a[i] + temp1[i] + carry1;
|
|
prod2 = (UInt64)mul2 * _mod[i] + temp2[i] + carry2;
|
|
temp1[i - 1] = Digit2.Lo(prod1);
|
|
temp2[i - 1] = Digit2.Lo(prod2);
|
|
carry1 = Digit2.Hi(prod1);
|
|
carry2 = Digit2.Hi(prod2);
|
|
}
|
|
temp1[_digitN - 1] = carry1;
|
|
temp2[_digitN - 1] = carry2;
|
|
}
|
|
Modular.Sub(temp1, temp2, c, _mod, _digitN);
|
|
}
|
|
internal void _FromModular(Digits a, Digits b) {
|
|
Modular.ValidateData(a, _mod, _digitN);
|
|
_Shift(a, -_scaleBitN, b);
|
|
}
|
|
internal void _Mul(Digits a, Digits b, Digits c) {
|
|
Modular.ValidateData(a, _mod, _digitN);
|
|
if (a != b) { Modular.ValidateData(b, _mod, _digitN); }
|
|
_algorithm(a, b, c);
|
|
}
|
|
void _Shift(Digits a, int n, Digits b) {
|
|
if (a != b) { b._Set(a, _digitN); }
|
|
Modular.ValidateData(a, _mod, _digitN);
|
|
if (n < 0 && (_mod[0] & 1) == 0) { throw new ArgumentException(); }
|
|
while (n > 0) {
|
|
int shiftNow = n > Digit.BitN ? Digit.BitN : n;
|
|
Digit carryOut = Digits.ShiftLost(b, shiftNow, b, _digitN)
|
|
, qest = _leftRecip
|
|
.EstQuotient(carryOut
|
|
, b[_digitN - 1]
|
|
, _digitN >= 2 ? b[_digitN - 2] : 0);
|
|
carryOut -= Digits.Decumulate(_mod, qest, b, _digitN);
|
|
if (carryOut != 0 || Digits.Compare(b, _mod, _digitN) >= 0) {
|
|
carryOut -= Digits.Sub(b, _mod, b, _digitN);
|
|
}
|
|
Debug.Assert(carryOut == 0, "internal error");
|
|
n -= shiftNow;
|
|
}
|
|
while (n < 0) {
|
|
int shiftNow = -n > Digit.BitN ? Digit.BitN : -n;
|
|
Digit mul = unchecked(0 - _rightRecip * b[0])
|
|
& Digit.MaxValue >> Digit.BitN - shiftNow
|
|
, carry = Digits.Accumulate(_mod, mul, b, _digitN)
|
|
, lowBitNLost = Digits.ShiftLost(b, -shiftNow, b, _digitN);
|
|
b[_digitN - 1] |= carry << Digit.BitN - shiftNow;
|
|
Debug.Assert(lowBitNLost == 0, "internal error");
|
|
n += shiftNow;
|
|
}
|
|
}
|
|
internal void _ToModular(Digits a, int aDigitN, Digits b) {
|
|
Digits aR;
|
|
int aRDigitN;
|
|
if (Digits.Compare(a, aDigitN, _mod, _digitN) >= 0) {
|
|
aR = new Digits(_digitN);
|
|
Digits.Div(a, aDigitN, _mod, _digitN, _leftRecip, null, aR);
|
|
aRDigitN = _digitN;
|
|
} else {
|
|
aR = a;
|
|
aRDigitN = aDigitN;
|
|
}
|
|
aRDigitN = Digits.SigDigitN(aR, aRDigitN);
|
|
Digits.Set(aR, aRDigitN, b, _digitN);
|
|
_Shift(b, _scaleBitN, b);
|
|
}
|
|
delegate void MulAlgorithm(Digits arg0, Digits arg1, Digits arg2);
|
|
}
|
|
}
|