From 455e8b1777324bfa774babe78d30804e86ad1b7b Mon Sep 17 00:00:00 2001 From: filifa Date: Mon, 18 Aug 2025 21:43:45 -0400 Subject: [PATCH] check if moduli are pairwise coprime --- cmd/crt.go | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/cmd/crt.go b/cmd/crt.go index 707b60f..5e06c43 100644 --- a/cmd/crt.go +++ b/cmd/crt.go @@ -27,12 +27,14 @@ import ( var remainders []string var moduli []string -func solveTwo(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { +func crtSolve(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { + // use Bezout's identity to find m1, m2 such that m1*n1 + m2*n2 = 1 m1 := new(big.Int) m2 := new(big.Int) tmp := new(big.Int) tmp.GCD(m1, m2, n1, n2) + // x = a1*m2*n2 + a2*m1*n1 x := new(big.Int).Set(a1) x.Mul(x, m2) x.Mul(x, n2) @@ -51,28 +53,50 @@ func solveTwo(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { return x, N } +func arePairwiseCoprime(moduli []*big.Int) bool { + z := new(big.Int) + for i, a := range moduli { + for _, b := range moduli[i+1:] { + z.GCD(nil, nil, a, b) + if z.Cmp(big.NewInt(1)) != 0 { + return false + } + } + } + + return true +} + func crt(cmd *cobra.Command, args []string) { if len(remainders) != len(moduli) { log.Fatal("number of remainders and moduli do not match") } - // TODO: check for pairwise comprime - - n1 := new(big.Int) - a1 := new(big.Int) - for i := range remainders { - n2, ok := new(big.Int).SetString(moduli[i], 10) + ns := make([]*big.Int, len(moduli)) + for i := range moduli { + n, ok := new(big.Int).SetString(moduli[i], 10) if !ok { log.Fatal("invalid input " + moduli[i]) } + ns[i] = new(big.Int).Set(n) + } + + // TODO: support non-pairwise coprime moduli + if !arePairwiseCoprime(ns) { + log.Fatalf("moduli %v are not pairwise coprime", moduli) + } + + n1 := new(big.Int) + a1 := new(big.Int) + for i, n2 := range ns { a2, ok := new(big.Int).SetString(remainders[i], 10) if !ok { log.Fatal("invalid input " + remainders[i]) } if i != 0 { - x, N := solveTwo(a1, n1, a2, n2) + x, N := crtSolve(a1, n1, a2, n2) n1.Set(N) a1.Set(x) } else {