From 005eac45c099126b403eeacc67f6e029896a44ca Mon Sep 17 00:00:00 2001 From: filifa Date: Thu, 21 Aug 2025 20:17:11 -0400 Subject: [PATCH] refactor general crt algorithm into separate function --- cmd/crt.go | 29 +++++++++-------------------- internal/lib/lib.go | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/cmd/crt.go b/cmd/crt.go index 9ab9986..f3ce45f 100644 --- a/cmd/crt.go +++ b/cmd/crt.go @@ -34,12 +34,18 @@ func crt(cmd *cobra.Command, args []string) { } ns := make([]*big.Int, len(moduli)) + rs := make([]*big.Int, len(remainders)) for i := range moduli { var ok bool ns[i], ok = new(big.Int).SetString(moduli[i], 10) if !ok { log.Fatal("invalid input " + moduli[i]) } + + rs[i], ok = new(big.Int).SetString(remainders[i], 10) + if !ok { + log.Fatal("invalid input " + remainders[i]) + } } // TODO: support non-pairwise coprime moduli @@ -47,27 +53,10 @@ func crt(cmd *cobra.Command, args []string) { log.Fatalf("moduli %v are not pairwise coprime", moduli) } - n1 := new(big.Int) - a1 := new(big.Int) - for i, n2 := range ns { - a2, ok := new(big.Int).SetString(remainders[i], 10) - if !ok { - log.Fatal("invalid input " + remainders[i]) - } + x, N := lib.CRTSolutionGeneral(rs, ns) - if i == 0 { - a1.Set(a2) - n1.Set(n2) - continue - } - - x, N := lib.CRTSolution(a1, n1, a2, n2) - a1.Set(x) - n1.Set(N) - } - - fmt.Println(a1) - fmt.Println(n1) + fmt.Println(x) + fmt.Println(N) } // crtCmd represents the crt command diff --git a/internal/lib/lib.go b/internal/lib/lib.go index d1db9d8..1390734 100644 --- a/internal/lib/lib.go +++ b/internal/lib/lib.go @@ -63,6 +63,23 @@ func CRTSolution(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { return x, N } +func CRTSolutionGeneral(remainders, moduli []*big.Int) (*big.Int, *big.Int) { + n1 := new(big.Int) + a1 := new(big.Int) + for i, n2 := range moduli { + a2 := remainders[i] + if i == 0 { + a1.Set(a2) + n1.Set(n2) + continue + } + + a1, n1 = CRTSolution(a1, n1, a2, n2) + } + + return a1, n1 +} + func ArePairwiseCoprime(moduli []*big.Int) bool { z := new(big.Int) for i, a := range moduli {