diff --git a/modules/compute.js b/modules/compute.js index db3b88c..7e953ee 100644 --- a/modules/compute.js +++ b/modules/compute.js @@ -1,4 +1,4 @@ -import { tonelliShanks, modinv, modpow, isprime } from "./math.js"; +import { modinv, modpow, modsqrt } from "./math.js"; function binaryOpPop(stack) { const b = stack.pop(); @@ -53,12 +53,8 @@ function compute(queue, modulus) { const c = modpow(a, b, modulus); stack.push(c); } else if (token === "sqrt") { - if (!isprime(modulus)) { - throw new Error("modulus must be prime to compute square root"); - } - const a = stack.pop(); - const s = tonelliShanks(a, modulus); + const s = modsqrt(a, modulus); stack.push(s); } } diff --git a/modules/math.js b/modules/math.js index 30edfee..96e5a0f 100644 --- a/modules/math.js +++ b/modules/math.js @@ -96,8 +96,76 @@ function isprime(n) { return true; } -function tonelliShanks(n, p) { - throw new Error("not implemented"); +function quadraticNonResidue(p) { + // TODO: consider randomizing this + for (let a = 2n; a < p; a++) { + if (modpow(a, (p-1n)/2n, p) === p - 1n) { + return a; + } + } } -export { tonelliShanks, modinv, modpow, isprime }; +function tonelliShanks(n, p) { + let q = p - 1n; + let s = 0; + while (q % 2n === 0n) { + q /= 2n; + s++; + } + + const z = quadraticNonResidue(p); + + let m = s; + let c = modpow(z, q, p); + let t = modpow(n, q, p); + let r = modpow(n, (q+1n)/2n, p); + + while (true) { + if (t === 0n) { + return 0n; + } else if (t === 1n) { + return r; + } + + let k = t; + let i = null; + for (i = 1; i < m; i++) { + k = modpow(k, 2n, p); + if (k === 1n) { + break; + } + } + + if (i === m) { + throw new Error("radicand is not a quadratic residue of the modulus"); + } + + const e = BigInt(Math.pow(2, m - i - 1)); + const b = modpow(c, e, p); + m = i; + c = modpow(b, 2n, p); + t = (t * c) % p; + r = (r * b) % p; + } +} + +function modsqrt(n, modulus) { + // TODO: add support for prime power modulus (Hensel's lemma) + if (!isprime(modulus)) { + throw new Error("modulus must be prime to compute square root"); + } + + // TODO: add special case for modulus = 3 (mod 4) + + if (n % modulus === 0n) { + return 0n; + } else if (modpow(n, (modulus-1n)/2n, modulus) !== 1n) { + throw new Error("radicand is not a quadratic residue of the modulus"); + } else if (modulus === 2n) { + return n % 2n; + } + + return tonelliShanks(n, modulus); +} + +export { modsqrt, modinv, modpow };