diff --git a/cmd/crt.go b/cmd/crt.go index f3ce45f..b3fa80a 100644 --- a/cmd/crt.go +++ b/cmd/crt.go @@ -53,7 +53,7 @@ func crt(cmd *cobra.Command, args []string) { log.Fatalf("moduli %v are not pairwise coprime", moduli) } - x, N := lib.CRTSolutionGeneral(rs, ns) + x, N := lib.CRTSolution(rs, ns) fmt.Println(x) fmt.Println(N) diff --git a/internal/lib/continuedFrac.go b/internal/lib/continuedFrac.go new file mode 100644 index 0000000..66d2eb6 --- /dev/null +++ b/internal/lib/continuedFrac.go @@ -0,0 +1,38 @@ +package lib + +import ( + "errors" + "math/big" +) + +func SqrtRepetend(x *big.Int) ([]*big.Int, error) { + m := big.NewInt(0) + d := big.NewInt(1) + a0 := new(big.Int).Sqrt(x) + + s := new(big.Int).Exp(a0, big.NewInt(2), nil) + if x.Cmp(s) == 0 { + return nil, errors.New("input is a perfect square") + } + + repetend := make([]*big.Int, 0) + + a := new(big.Int).Set(a0) + twoa0 := new(big.Int).Mul(big.NewInt(2), a0) + for a.Cmp(twoa0) != 0 { + // m = d * a - m + tmp := new(big.Int) + m.Sub(tmp.Mul(d, a), m) + + // d = (x - m^2) // d + tmp.Exp(m, big.NewInt(2), nil) + d.Div(tmp.Sub(x, tmp), d) + + // a = (a0 + m) // d + a.Div(tmp.Add(a0, m), d) + + repetend = append(repetend, new(big.Int).Set(a)) + } + + return repetend, nil +} diff --git a/internal/lib/crt.go b/internal/lib/crt.go new file mode 100644 index 0000000..2728749 --- /dev/null +++ b/internal/lib/crt.go @@ -0,0 +1,62 @@ +package lib + +import ( + "math/big" +) + +func solveCRT(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { + // use Bezout's identity to find m1, m2 such that m1*n1 + m2*n2 = 1 + m1 := new(big.Int) + m2 := new(big.Int) + tmp := new(big.Int) + tmp.GCD(m1, m2, n1, n2) + + // x = a1*m2*n2 + a2*m1*n1 + x := new(big.Int).Set(a1) + x.Mul(x, m2) + x.Mul(x, n2) + + tmp.Set(a2) + tmp.Mul(tmp, m1) + tmp.Mul(tmp, n1) + + x.Add(x, tmp) + + N := new(big.Int).Set(n1) + N.Mul(N, n2) + + x.Mod(x, N) + + return x, N +} + +func CRTSolution(remainders, moduli []*big.Int) (*big.Int, *big.Int) { + n1 := new(big.Int) + a1 := new(big.Int) + for i, n2 := range moduli { + a2 := remainders[i] + if i == 0 { + a1.Set(a2) + n1.Set(n2) + continue + } + + a1, n1 = solveCRT(a1, n1, a2, n2) + } + + return a1, n1 +} + +func ArePairwiseCoprime(moduli []*big.Int) bool { + z := new(big.Int) + for i, a := range moduli { + for _, b := range moduli[i+1:] { + z.GCD(nil, nil, a, b) + if z.Cmp(big.NewInt(1)) != 0 { + return false + } + } + } + + return true +} diff --git a/internal/lib/lib.go b/internal/lib/primitiveRoot.go similarity index 57% rename from internal/lib/lib.go rename to internal/lib/primitiveRoot.go index 1390734..68669ec 100644 --- a/internal/lib/lib.go +++ b/internal/lib/primitiveRoot.go @@ -5,95 +5,6 @@ import ( "math/big" ) -func SqrtRepetend(x *big.Int) ([]*big.Int, error) { - m := big.NewInt(0) - d := big.NewInt(1) - a0 := new(big.Int).Sqrt(x) - - s := new(big.Int).Exp(a0, big.NewInt(2), nil) - if x.Cmp(s) == 0 { - return nil, errors.New("input is a perfect square") - } - - repetend := make([]*big.Int, 0) - - a := new(big.Int).Set(a0) - twoa0 := new(big.Int).Mul(big.NewInt(2), a0) - for a.Cmp(twoa0) != 0 { - // m = d * a - m - tmp := new(big.Int) - m.Sub(tmp.Mul(d, a), m) - - // d = (x - m^2) // d - tmp.Exp(m, big.NewInt(2), nil) - d.Div(tmp.Sub(x, tmp), d) - - // a = (a0 + m) // d - a.Div(tmp.Add(a0, m), d) - - repetend = append(repetend, new(big.Int).Set(a)) - } - - return repetend, nil -} - -func CRTSolution(a1, n1, a2, n2 *big.Int) (*big.Int, *big.Int) { - // use Bezout's identity to find m1, m2 such that m1*n1 + m2*n2 = 1 - m1 := new(big.Int) - m2 := new(big.Int) - tmp := new(big.Int) - tmp.GCD(m1, m2, n1, n2) - - // x = a1*m2*n2 + a2*m1*n1 - x := new(big.Int).Set(a1) - x.Mul(x, m2) - x.Mul(x, n2) - - tmp.Set(a2) - tmp.Mul(tmp, m1) - tmp.Mul(tmp, n1) - - x.Add(x, tmp) - - N := new(big.Int).Set(n1) - N.Mul(N, n2) - - x.Mod(x, N) - - return x, N -} - -func CRTSolutionGeneral(remainders, moduli []*big.Int) (*big.Int, *big.Int) { - n1 := new(big.Int) - a1 := new(big.Int) - for i, n2 := range moduli { - a2 := remainders[i] - if i == 0 { - a1.Set(a2) - n1.Set(n2) - continue - } - - a1, n1 = CRTSolution(a1, n1, a2, n2) - } - - return a1, n1 -} - -func ArePairwiseCoprime(moduli []*big.Int) bool { - z := new(big.Int) - for i, a := range moduli { - for _, b := range moduli[i+1:] { - z.GCD(nil, nil, a, b) - if z.Cmp(big.NewInt(1)) != 0 { - return false - } - } - } - - return true -} - func Totient(n *big.Int) *big.Int { N := new(big.Int).Set(n)