diff --git a/src/math/big/int.go b/src/math/big/int.go index 51dc6f78ff..a2c1b580f5 100644 --- a/src/math/big/int.go +++ b/src/math/big/int.go @@ -924,3 +924,14 @@ func (z *Int) Not(x *Int) *Int { z.neg = true // z cannot be zero if x is positive return z } + +// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z. +// It panics if x is negative. +func (z *Int) Sqrt(x *Int) *Int { + if x.neg { + panic("square root of negative number") + } + z.neg = false + z.abs = z.abs.sqrt(x.abs) + return z +} diff --git a/src/math/big/int_test.go b/src/math/big/int_test.go index 18f5be749d..b8e0778ca3 100644 --- a/src/math/big/int_test.go +++ b/src/math/big/int_test.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "fmt" "math/rand" + "strings" "testing" "testing/quick" ) @@ -1453,3 +1454,44 @@ func TestIssue2607(t *testing.T) { n := NewInt(10) n.Rand(rand.New(rand.NewSource(9)), n) } + +func TestSqrt(t *testing.T) { + root := 0 + r := new(Int) + for i := 0; i < 10000; i++ { + if (root+1)*(root+1) <= i { + root++ + } + n := NewInt(int64(i)) + r.SetInt64(-2) + r.Sqrt(n) + if r.Cmp(NewInt(int64(root))) != 0 { + t.Errorf("Sqrt(%v) = %v, want %v", n, r, root) + } + } + + for i := 0; i < 1000; i += 10 { + n, _ := new(Int).SetString("1"+strings.Repeat("0", i), 10) + r := new(Int).Sqrt(n) + root, _ := new(Int).SetString("1"+strings.Repeat("0", i/2), 10) + if r.Cmp(root) != 0 { + t.Errorf("Sqrt(1e%d) = %v, want 1e%d", i, r, i/2) + } + } + + // Test aliasing. + r.SetInt64(100) + r.Sqrt(r) + if r.Int64() != 10 { + t.Errorf("Sqrt(100) = %v, want 10 (aliased output)", r.Int64()) + } +} + +func BenchmarkSqrt(b *testing.B) { + n, _ := new(Int).SetString("1"+strings.Repeat("0", 1001), 10) + b.ResetTimer() + t := new(Int) + for i := 0; i < b.N; i++ { + t.Sqrt(n) + } +} diff --git a/src/math/big/nat.go b/src/math/big/nat.go index 4a3b7ae33f..9b1a626c4c 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -1223,3 +1223,37 @@ func (z nat) setBytes(buf []byte) nat { return z.norm() } + +// sqrt sets z = ⌊√x⌋ +func (z nat) sqrt(x nat) nat { + if x.cmp(natOne) <= 0 { + return z.set(x) + } + if alias(z, x) { + z = nil + } + + // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. + // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). + // https://members.loria.fr/PZimmermann/mca/pub226.html + // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; + // otherwise it converges to the correct z and stays there. + var z1, z2 nat + z1 = z + z1 = z1.setUint64(1) + z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x + for n := 0; ; n++ { + z2, _ = z2.div(nil, x, z1) + z2 = z2.add(z2, z1) + z2 = z2.shr(z2, 1) + if z2.cmp(z1) >= 0 { + // z1 is answer. + // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. + if n&1 == 0 { + return z1 + } + return z.set(z1) + } + z1, z2 = z2, z1 + } +}