refactor general crt algorithm into separate function

This commit is contained in:
filifa 2025-08-21 20:17:11 -04:00
parent e1dedd9c33
commit 005eac45c0
2 changed files with 26 additions and 20 deletions

View File

@ -34,12 +34,18 @@ func crt(cmd *cobra.Command, args []string) {
} }
ns := make([]*big.Int, len(moduli)) ns := make([]*big.Int, len(moduli))
rs := make([]*big.Int, len(remainders))
for i := range moduli { for i := range moduli {
var ok bool var ok bool
ns[i], ok = new(big.Int).SetString(moduli[i], 10) ns[i], ok = new(big.Int).SetString(moduli[i], 10)
if !ok { if !ok {
log.Fatal("invalid input " + moduli[i]) log.Fatal("invalid input " + moduli[i])
} }
rs[i], ok = new(big.Int).SetString(remainders[i], 10)
if !ok {
log.Fatal("invalid input " + remainders[i])
}
} }
// TODO: support non-pairwise coprime moduli // TODO: support non-pairwise coprime moduli
@ -47,27 +53,10 @@ func crt(cmd *cobra.Command, args []string) {
log.Fatalf("moduli %v are not pairwise coprime", moduli) log.Fatalf("moduli %v are not pairwise coprime", moduli)
} }
n1 := new(big.Int) x, N := lib.CRTSolutionGeneral(rs, ns)
a1 := new(big.Int)
for i, n2 := range ns {
a2, ok := new(big.Int).SetString(remainders[i], 10)
if !ok {
log.Fatal("invalid input " + remainders[i])
}
if i == 0 { fmt.Println(x)
a1.Set(a2) fmt.Println(N)
n1.Set(n2)
continue
}
x, N := lib.CRTSolution(a1, n1, a2, n2)
a1.Set(x)
n1.Set(N)
}
fmt.Println(a1)
fmt.Println(n1)
} }
// crtCmd represents the crt command // crtCmd represents the crt command

View File

@ -63,6 +63,23 @@ func CRTSolution(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) {
return x, N return x, N
} }
func CRTSolutionGeneral(remainders, moduli []*big.Int) (*big.Int, *big.Int) {
n1 := new(big.Int)
a1 := new(big.Int)
for i, n2 := range moduli {
a2 := remainders[i]
if i == 0 {
a1.Set(a2)
n1.Set(n2)
continue
}
a1, n1 = CRTSolution(a1, n1, a2, n2)
}
return a1, n1
}
func ArePairwiseCoprime(moduli []*big.Int) bool { func ArePairwiseCoprime(moduli []*big.Int) bool {
z := new(big.Int) z := new(big.Int)
for i, a := range moduli { for i, a := range moduli {