diff --git a/cmd/crt.go b/cmd/crt.go index 50d7438..d6c2e09 100644 --- a/cmd/crt.go +++ b/cmd/crt.go @@ -22,51 +22,12 @@ import ( "math/big" "github.com/spf13/cobra" + "scm.dairydemon.net/filifa/mathtools/internal/lib" ) var remainders []string var moduli []string -func crtSolve(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { - // use Bezout's identity to find m1, m2 such that m1*n1 + m2*n2 = 1 - m1 := new(big.Int) - m2 := new(big.Int) - tmp := new(big.Int) - tmp.GCD(m1, m2, n1, n2) - - // x = a1*m2*n2 + a2*m1*n1 - x := new(big.Int).Set(a1) - x.Mul(x, m2) - x.Mul(x, n2) - - tmp.Set(a2) - tmp.Mul(tmp, m1) - tmp.Mul(tmp, n1) - - x.Add(x, tmp) - - N := new(big.Int).Set(n1) - N.Mul(N, n2) - - x.Mod(x, N) - - return x, N -} - -func arePairwiseCoprime(moduli []*big.Int) bool { - z := new(big.Int) - for i, a := range moduli { - for _, b := range moduli[i+1:] { - z.GCD(nil, nil, a, b) - if z.Cmp(big.NewInt(1)) != 0 { - return false - } - } - } - - return true -} - func crt(cmd *cobra.Command, args []string) { if len(remainders) != len(moduli) { log.Fatal("number of remainders and moduli do not match") @@ -83,7 +44,7 @@ func crt(cmd *cobra.Command, args []string) { } // TODO: support non-pairwise coprime moduli - if !arePairwiseCoprime(ns) { + if !lib.ArePairwiseCoprime(ns) { log.Fatalf("moduli %v are not pairwise coprime", moduli) } @@ -101,7 +62,7 @@ func crt(cmd *cobra.Command, args []string) { continue } - x, N := crtSolve(a1, n1, a2, n2) + x, N := lib.CRTSolution(a1, n1, a2, n2) a1.Set(x) n1.Set(N) } diff --git a/internal/lib/lib.go b/internal/lib/lib.go index 66d2eb6..c7f5159 100644 --- a/internal/lib/lib.go +++ b/internal/lib/lib.go @@ -36,3 +36,43 @@ func SqrtRepetend(x *big.Int) ([]*big.Int, error) { return repetend, nil } + +func CRTSolution(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { + // use Bezout's identity to find m1, m2 such that m1*n1 + m2*n2 = 1 + m1 := new(big.Int) + m2 := new(big.Int) + tmp := new(big.Int) + tmp.GCD(m1, m2, n1, n2) + + // x = a1*m2*n2 + a2*m1*n1 + x := new(big.Int).Set(a1) + x.Mul(x, m2) + x.Mul(x, n2) + + tmp.Set(a2) + tmp.Mul(tmp, m1) + tmp.Mul(tmp, n1) + + x.Add(x, tmp) + + N := new(big.Int).Set(n1) + N.Mul(N, n2) + + x.Mod(x, N) + + return x, N +} + +func ArePairwiseCoprime(moduli []*big.Int) bool { + z := new(big.Int) + for i, a := range moduli { + for _, b := range moduli[i+1:] { + z.GCD(nil, nil, a, b) + if z.Cmp(big.NewInt(1)) != 0 { + return false + } + } + } + + return true +}