1
0
mirror of https://github.com/golang/go synced 2024-11-22 08:04:39 -07:00

big: add Div, Mod, Exp, GcdExt and several other fixes.

R=gri, rsc
CC=go-dev
http://go/go-review/1017036
This commit is contained in:
Adam Langley 2009-11-05 15:55:41 -08:00
parent 5e598c55dc
commit 65063bc61d
6 changed files with 896 additions and 29 deletions

View File

@ -50,10 +50,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) {
}
// TODO(gri) mulWW_g is not needed anymore. Keep around for
// now since mulAddWWW_g should use some of the
// optimizations from mulWW_g eventually.
// z1<<_W + z0 = x*y
func mulWW_g(x, y Word) (z1, z0 Word) {
// Split x and y into 2 halfWords each, multiply
@ -86,6 +82,9 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
// z = z[1]*_B + z[0] = x*y
z0 = t1<<_W2 + t0;
z1 = (t1 + t0>>_W2)>>_W2;
if z0 < t0 {
z1++;
}
return;
}
@ -98,13 +97,31 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
// x*y = t2*_B2*_B2 + t1*_B2 + t0
t0 := x0*y0;
t1 := x1*y0 + x0*y1;
t2 := x1*y1;
// t1 := x1*y0 + x0*y1;
var c Word;
t1 := x1*y0;
t1a := t1;
t1 += x0*y1;
if t1 < t1a {
c++;
}
t2 := x1*y1 + c*_B2;
// compute result digits but avoid overflow
// z = z[1]*_B + z[0] = x*y
// This may overflow, but that's ok because we also sum t1 and t0 above
// and we take care of the overflow there.
z0 = t1<<_W2 + t0;
z1 = t2 + (t1 + t0>>_W2)>>_W2;
// z1 = t2 + (t1 + t0>>_W2)>>_W2;
var c3 Word;
z1 = t1 + t0>>_W2;
if z1 < t1 {
c3++;
}
z1 >>= _W2;
z1 += c3*_B2;
z1 += t2;
return;
}
@ -126,14 +143,35 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
c1, c0 := c>>_W2, c&_M2;
// x*y + c = t2*_B2*_B2 + t1*_B2 + t0
// (1<<32-1)^2 == 1<<64 - 1<<33 + 1, so there's space to add c0 in here.
t0 := x0*y0 + c0;
t1 := x1*y0 + x0*y1 + c1;
t2 := x1*y1;
// t1 := x1*y0 + x0*y1 + c1;
var c2 Word; // extra carry
t1 := x1*y0 + c1;
t1a := t1;
t1 += x0*y1;
if t1 < t1a { // If the number got smaller then we overflowed.
c2++;
}
t2 := x1*y1 + c2*_B2;
// compute result digits but avoid overflow
// z = z[1]*_B + z[0] = x*y
// z0 = t1<<_W2 + t0;
// This may overflow, but that's ok because we also sum t1 and t0 below
// and we take care of the overflow there.
z0 = t1<<_W2 + t0;
z1 = t2 + (t1 + t0>>_W2)>>_W2;
var c3 Word;
z1 = t1 + t0>>_W2;
if z1 < t1 {
c3++;
}
z1 >>= _W2;
z1 += t2 + c3*_B2;
return;
}

View File

@ -268,3 +268,49 @@ func TestFunVWW(t *testing.T) {
}
}
}
type mulWWTest struct {
x, y Word;
q, r Word;
}
var mulWWTests = []mulWWTest{
mulWWTest{_M, _M, _M - 1, 1},
}
func TestMulWWW(t *testing.T) {
for i, test := range mulWWTests {
q, r := mulWW_g(test.x, test.y);
if q != test.q || r != test.r {
t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r);
}
}
}
type mulAddWWWTest struct {
x, y, c Word;
q, r Word;
}
var mulAddWWWTests = []mulAddWWWTest{
// TODO(agl): These will only work on 64-bit platforms.
// mulAddWWWTest{15064310297182388543, 0xe7df04d2d35d5d80, 13537600649892366549, 13644450054494335067, 10832252001440893781},
// mulAddWWWTest{15064310297182388543, 0xdab2f18048baa68d, 13644450054494335067, 12869334219691522700, 14233854684711418382},
mulAddWWWTest{_M, _M, 0, _M - 1, 1},
mulAddWWWTest{_M, _M, _M, _M, 0},
}
func TestMulAddWWW(t *testing.T) {
for i, test := range mulAddWWWTests {
q, r := mulAddWWW_g(test.x, test.y, test.c);
if q != test.q || r != test.r {
t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r);
}
}
}

View File

@ -26,6 +26,12 @@ func (z *Int) New(x int64) *Int {
}
// NewInt allocates and returns a new Int set to x.
func NewInt(x int64) *Int {
return new(Int).New(x);
}
// Set sets z to x.
func (z *Int) Set(x *Int) *Int {
z.neg = x.neg;
@ -96,6 +102,66 @@ func (z *Int) Mul(x, y *Int) *Int {
}
// Div calculates q = (x-r)/y where 0 <= r < y. The receiver is set to q.
func (z *Int) Div(x, y *Int) (q, r *Int) {
q = z;
r = new(Int);
div(q, r, x, y);
return;
}
// Mod calculates q = (x-r)/y and returns r.
func (z *Int) Mod(x, y *Int) (r *Int) {
q := new(Int);
r = z;
div(q, r, x, y);
return;
}
func div(q, r, x, y *Int) {
if len(y.abs) == 0 {
panic("Divide by zero undefined");
}
if cmpNN(x.abs, y.abs) < 0 {
q.neg = false;
q.abs = nil;
r.neg = y.neg;
src := x.abs;
dst := x.abs;
if r == x {
dst = nil;
}
r.abs = makeN(dst, len(src), false);
for i, v := range src {
r.abs[i] = v;
}
return;
}
if len(y.abs) == 1 {
var rprime Word;
q.abs, rprime = divNW(q.abs, x.abs, y.abs[0]);
if rprime > 0 {
r.abs = makeN(r.abs, 1, false);
r.abs[0] = rprime;
r.neg = x.neg;
}
q.neg = len(q.abs) > 0 && x.neg != y.neg;
return;
}
q.neg = x.neg != y.neg;
r.neg = x.neg;
q.abs, r.abs = divNN(q.abs, r.abs, x.abs, y.abs);
return;
}
// Neg computes z = -x.
func (z *Int) Neg(x *Int) *Int {
z.abs = setN(z.abs, x.abs);
@ -132,10 +198,234 @@ func CmpInt(x, y *Int) (r int) {
}
func (x *Int) String() string {
func (z *Int) String() string {
s := "";
if x.neg {
if z.neg {
s = "-";
}
return s + stringN(x.abs, 10);
return s + stringN(z.abs, 10);
}
// SetString sets z to the value of s, interpreted in the given base.
// If base is 0 then SetString attempts to detect the base by at the prefix of
// s. '0x' implies base 16, '0' implies base 8. Otherwise base 10 is assumed.
func (z *Int) SetString(s string, base int) (*Int, bool) {
var scanned int;
if base == 1 || base > 16 {
goto Error;
}
if len(s) == 0 {
goto Error;
}
if s[0] == '-' {
z.neg = true;
s = s[1:len(s)];
} else {
z.neg = false;
}
z.abs, _, scanned = scanN(z.abs, s, base);
if scanned != len(s) {
goto Error;
}
return z, true;
Error:
z.neg = false;
z.abs = nil;
return nil, false;
}
// SetBytes interprets b as the bytes of a big-endian, unsigned integer and
// sets x to that value.
func (z *Int) SetBytes(b []byte) *Int {
s := int(_S);
z.abs = makeN(z.abs, (len(b)+s-1)/s, false);
z.neg = false;
j := 0;
for len(b) >= s {
var w Word;
for i := s; i > 0; i-- {
w <<= 8;
w |= Word(b[len(b)-i]);
}
z.abs[j] = w;
j++;
b = b[0 : len(b)-s];
}
if len(b) > 0 {
var w Word;
for i := len(b); i > 0; i-- {
w <<= 8;
w |= Word(b[len(b)-i]);
}
z.abs[j] = w;
}
z.abs = normN(z.abs);
return z;
}
// Bytes returns the absolute value of x as a big-endian byte array.
func (z *Int) Bytes() []byte {
s := int(_S);
b := make([]byte, len(z.abs)*s);
for i, w := range z.abs {
wordBytes := b[(len(z.abs)-i-1)*s : (len(z.abs)-i)*s];
for j := s-1; j >= 0; j-- {
wordBytes[j] = byte(w);
w >>= 8;
}
}
i := 0;
for i < len(b) && b[i] == 0 {
i++;
}
return b[i:len(b)];
}
// Len returns the length of the absolute value of x in bits. Zero is
// considered to have a length of one.
func (z *Int) Len() int {
if len(z.abs) == 0 {
return 0;
}
return len(z.abs)*int(_W) - int(leadingZeros(z.abs[len(z.abs)-1]));
}
// Exp sets z = x**y mod m. If m is nil, z = x**y.
// See Knuth, volume 2, section 4.6.3.
func (z *Int) Exp(x, y, m *Int) *Int {
if y.neg || len(y.abs) == 0 {
z.New(1);
z.neg = x.neg;
return z;
}
z.Set(x);
v := y.abs[len(y.abs)-1];
// It's invalid for the most significant word to be zero, therefore we
// will find a one bit.
shift := leadingZeros(v) + 1;
v <<= shift;
const mask = 1<<(_W-1);
// We walk through the bits of the exponent one by one. Each time we see
// a bit, we square, thus doubling the power. If the bit is a one, we
// also multiply by x, thus adding one to the power.
w := int(_W)-int(shift);
for j := 0; j < w; j++ {
z.Mul(z, z);
if v&mask != 0 {
z.Mul(z, x);
}
if m != nil {
z.Mod(z, m);
}
v <<= 1;
}
for i := len(y.abs)-2; i >= 0; i-- {
v = y.abs[i];
for j := 0; j < int(_W); j++ {
z.Mul(z, z);
if v&mask != 0 {
z.Mul(z, x);
}
if m != nil {
z.Mod(z, m);
}
v <<= 1;
}
}
z.neg = x.neg && y.abs[0] & 1 == 1;
return z;
}
// GcdInt sets d to the greatest common divisor of a and b, which must be
// positive numbers.
// If x and y are not nil, GcdInt sets x and y such that d = a*x + b*y.
// If either a or b is not positive, GcdInt sets d = x = y = 0.
func GcdInt(d, x, y, a, b *Int) {
if a.neg || b.neg {
d.New(0);
if x != nil {
x.New(0);
}
if y != nil {
y.New(0);
}
return;
}
A := new(Int).Set(a);
B := new(Int).Set(b);
X := new(Int);
Y := new(Int).New(1);
lastX := new(Int).New(1);
lastY := new(Int);
q := new(Int);
temp := new(Int);
for len(B.abs) > 0 {
q, r := q.Div(A, B);
A, B = B, r;
temp.Set(X);
X.Mul(X, q);
X.neg = !X.neg;
X.Add(X, lastX);
lastX.Set(temp);
temp.Set(Y);
Y.Mul(Y, q);
Y.neg = !Y.neg;
Y.Add(Y, lastY);
lastY.Set(temp);
}
if x != nil {
*x = *lastX;
}
if y != nil {
*y = *lastY;
}
*d = *A;
}

View File

@ -4,8 +4,12 @@
package big
import "testing"
import (
"bytes";
"encoding/hex";
"testing";
"testing/quick";
)
func newZ(x int64) *Int {
var z Int;
@ -18,6 +22,7 @@ type argZZ struct {
z, x, y *Int;
}
var sumZZ = []argZZ{
argZZ{newZ(0), newZ(0), newZ(0)},
argZZ{newZ(1), newZ(1), newZ(0)},
@ -27,12 +32,13 @@ var sumZZ = []argZZ{
argZZ{newZ(-1111111110), newZ(-123456789), newZ(-987654321)},
}
var prodZZ = []argZZ{
argZZ{newZ(0), newZ(0), newZ(0)},
argZZ{newZ(0), newZ(1), newZ(0)},
argZZ{newZ(1), newZ(1), newZ(1)},
argZZ{newZ(-991 * 991), newZ(991), newZ(-991)},
// TODO(gri) add larger products
// TODO(gri) add larger products
}
@ -57,12 +63,8 @@ func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
func TestSumZZ(t *testing.T) {
AddZZ := func(z, x, y *Int) *Int {
return z.Add(x, y);
};
SubZZ := func(z, x, y *Int) *Int {
return z.Sub(x, y);
};
AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) };
SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) };
for _, a := range sumZZ {
arg := a;
testFunZZ(t, "AddZZ", AddZZ, arg);
@ -80,9 +82,7 @@ func TestSumZZ(t *testing.T) {
func TestProdZZ(t *testing.T) {
MulZZ := func(z, x, y *Int) *Int {
return z.Mul(x, y);
};
MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) };
for _, a := range prodZZ {
arg := a;
testFunZZ(t, "MulZZ", MulZZ, arg);
@ -125,3 +125,306 @@ func TestFact(t *testing.T) {
}
}
}
type fromStringTest struct {
in string;
base int;
out int64;
ok bool;
}
var fromStringTests = []fromStringTest{
fromStringTest{in: "", ok: false},
fromStringTest{in: "a", ok: false},
fromStringTest{in: "z", ok: false},
fromStringTest{in: "+", ok: false},
fromStringTest{"0", 0, 0, true},
fromStringTest{"0", 10, 0, true},
fromStringTest{"0", 16, 0, true},
fromStringTest{"10", 0, 10, true},
fromStringTest{"10", 10, 10, true},
fromStringTest{"10", 16, 16, true},
fromStringTest{"-10", 16, -16, true},
fromStringTest{in: "0x", ok: false},
fromStringTest{"0x10", 0, 16, true},
fromStringTest{in: "0x10", base: 16, ok: false},
fromStringTest{"-0x10", 0, -16, true},
}
func TestSetString(t *testing.T) {
for i, test := range fromStringTests {
n, ok := new(Int).SetString(test.in, test.base);
if ok != test.ok {
t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok);
continue;
}
if !ok {
continue;
}
if CmpInt(n, new(Int).New(test.out)) != 0 {
t.Errorf("#%d (input '%s') got: %s want: %d\n", i, test.in, n, test.out);
}
}
}
type divSignsTest struct {
x, y int64;
q, r int64;
}
// These examples taken from the Go Language Spec, section "Arithmetic operators"
var divSignsTests = []divSignsTest{
divSignsTest{5, 3, 1, 2},
divSignsTest{-5, 3, -1, -2},
divSignsTest{5, -3, -1, 2},
divSignsTest{-5, -3, 1, -2},
divSignsTest{1, 2, 0, 1},
}
func TestDivSigns(t *testing.T) {
for i, test := range divSignsTests {
x := new(Int).New(test.x);
y := new(Int).New(test.y);
q, r := new(Int).Div(x, y);
expectedQ := new(Int).New(test.q);
expectedR := new(Int).New(test.r);
if CmpInt(q, expectedQ) != 0 || CmpInt(r, expectedR) != 0 {
t.Errorf("#%d: got (%s, %s) want (%s, %s)", i, q, r, expectedQ, expectedR);
}
}
}
func checkSetBytes(b []byte) bool {
hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes());
hex2 := hex.EncodeToString(b);
for len(hex1) < len(hex2) {
hex1 = "0"+hex1;
}
for len(hex1) > len(hex2) {
hex2 = "0"+hex2;
}
return hex1 == hex2;
}
func TestSetBytes(t *testing.T) {
err := quick.Check(checkSetBytes, nil);
if err != nil {
t.Error(err);
}
}
func checkBytes(b []byte) bool {
b2 := new(Int).SetBytes(b).Bytes();
return bytes.Compare(b, b2) == 0;
}
func TestBytes(t *testing.T) {
err := quick.Check(checkSetBytes, nil);
if err != nil {
t.Error(err);
}
}
func checkDiv(x, y []byte) bool {
u := new(Int).SetBytes(x);
v := new(Int).SetBytes(y);
if len(v.abs) == 0 {
return true;
}
q, r := new(Int).Div(u, v);
if CmpInt(r, v) >= 0 {
return false;
}
uprime := new(Int).Set(q);
uprime.Mul(uprime, v);
uprime.Add(uprime, r);
return CmpInt(uprime, u) == 0;
}
func TestDiv(t *testing.T) {
err := quick.Check(checkDiv, nil);
if err != nil {
t.Error(err);
}
}
func TestDivStepD6(t *testing.T) {
// See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises
// a code path which only triggers 1 in 10^{-19} cases.
u := &Int{false, []Word{0, 0, 0x8000000000000001, 0x7fffffffffffffff}};
v := &Int{false, []Word{5, 0x8000000000000002, 0x8000000000000000}};
q, r := new(Int).Div(u, v);
const expectedQ = "18446744073709551613";
const expectedR = "3138550867693340382088035895064302439801311770021610913807";
if q.String() != expectedQ || r.String() != expectedR {
t.Errorf("got (%s, %s) want (%s, %s)", q, r, expectedQ, expectedR);
}
}
type lenTest struct {
in string;
out int;
}
var lenTests = []lenTest{
lenTest{"0", 0},
lenTest{"1", 1},
lenTest{"2", 2},
lenTest{"4", 3},
lenTest{"0x8000", 16},
lenTest{"0x80000000", 32},
lenTest{"0x800000000000", 48},
lenTest{"0x8000000000000000", 64},
lenTest{"0x80000000000000000000", 80},
}
func TestLen(t *testing.T) {
for i, test := range lenTests {
n, ok := new(Int).SetString(test.in, 0);
if !ok {
t.Errorf("#%d test input invalid: %s", i, test.in);
continue;
}
if n.Len() != test.out {
t.Errorf("#%d got %d want %d\n", i, n.Len(), test.out);
}
}
}
type expTest struct {
x, y, m string;
out string;
}
var expTests = []expTest{
/*expTest{"5", "0", "", "1"},
expTest{"-5", "0", "", "-1"},
expTest{"5", "1", "", "5"},
expTest{"-5", "1", "", "-5"},
expTest{"5", "2", "", "25"},*/
expTest{"1", "65537", "2", "1"},
/*expTest{"0x8000000000000000", "2", "", "0x40000000000000000000000000000000"},
expTest{"0x8000000000000000", "2", "6719", "4944"},
expTest{"0x8000000000000000", "3", "6719", "5447"},
expTest{"0x8000000000000000", "1000", "6719", "1603"},
expTest{"0x8000000000000000", "1000000", "6719", "3199"},
expTest{
"2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347",
"298472983472983471903246121093472394872319615612417471234712061",
"29834729834729834729347290846729561262544958723956495615629569234729836259263598127342374289365912465901365498236492183464",
"23537740700184054162508175125554701713153216681790245129157191391322321508055833908509185839069455749219131480588829346291",
},*/
}
func TestExp(t *testing.T) {
for i, test := range expTests {
x, ok1 := new(Int).SetString(test.x, 0);
y, ok2 := new(Int).SetString(test.y, 0);
out, ok3 := new(Int).SetString(test.out, 0);
var ok4 bool;
var m *Int;
if len(test.m) == 0 {
m, ok4 = nil, true;
} else {
m, ok4 = new(Int).SetString(test.m, 0);
}
if !ok1 || !ok2 || !ok3 || !ok4 {
t.Errorf("#%d error in input", i);
continue;
}
z := new(Int).Exp(x, y, m);
if CmpInt(z, out) != 0 {
t.Errorf("#%d got %s want %s", i, z, out);
}
}
}
func checkGcd(aBytes, bBytes []byte) bool {
a := new(Int).SetBytes(aBytes);
b := new(Int).SetBytes(bBytes);
x := new(Int);
y := new(Int);
d := new(Int);
GcdInt(d, x, y, a, b);
x.Mul(x, a);
y.Mul(y, b);
x.Add(x, y);
return CmpInt(x, d) == 0;
}
type gcdTest struct {
a, b int64;
d, x, y int64;
}
var gcdTests = []gcdTest{
gcdTest{120, 23, 1, -9, 47},
}
func TestGcd(t *testing.T) {
for i, test := range gcdTests {
a := new(Int).New(test.a);
b := new(Int).New(test.b);
x := new(Int);
y := new(Int);
d := new(Int);
expectedX := new(Int).New(test.x);
expectedY := new(Int).New(test.y);
expectedD := new(Int).New(test.d);
GcdInt(d, x, y, a, b);
if CmpInt(expectedX, x) != 0 ||
CmpInt(expectedY, y) != 0 ||
CmpInt(expectedD, d) != 0 {
t.Errorf("#%d got (%s %s %s) want (%s %s %s)", i, x, y, d, expectedX, expectedY, expectedD);
}
}
quick.Check(checkGcd, nil);
}

View File

@ -257,6 +257,66 @@ func divNW(z, x []Word, y Word) (q []Word, r Word) {
}
// q = (uIn-r)/v, with 0 <= r < y
// See Knuth, Volume 2, section 4.3.1, Algorithm D.
// Preconditions:
// len(v) >= 2
// len(uIn) >= 1 + len(vIn)
func divNN(z, z2, uIn, v []Word) (q, r []Word) {
n := len(v);
m := len(uIn)-len(v);
u := makeN(z2, len(uIn)+1, false);
qhatv := make([]Word, len(v)+1);
q = makeN(z, m+1, false);
// D1.
shift := leadingZeroBits(v[n-1]);
shiftLeft(v, v, shift);
shiftLeft(u, uIn, shift);
u[len(uIn)] = uIn[len(uIn)-1]>>(uint(_W)-uint(shift));
// D2.
for j := m; j >= 0; j-- {
// D3.
qhat, rhat := divWW_g(u[j+n], u[j+n-1], v[n-1]);
// x1 | x2 = q̂v_{n-2}
x1, x2 := mulWW_g(qhat, v[n-2]);
// test if q̂v_{n-2} > br̂ + u_{j+n-2}
for greaterThan(x1, x2, rhat, u[j+n-2]) {
qhat--;
prevRhat := rhat;
rhat += v[n-1];
// v[n-1] >= 0, so this tests for overflow.
if rhat < prevRhat {
break;
}
x1, x2 = mulWW_g(qhat, v[n-2]);
}
// D4.
qhatv[len(v)] = mulAddVWW(&qhatv[0], &v[0], qhat, 0, len(v));
c := subVV(&u[j], &u[j], &qhatv[0], len(qhatv));
if c != 0 {
c := addVV(&u[j], &u[j], &v[0], len(v));
u[j+len(v)] += c;
qhat--;
}
q[j] = qhat;
}
q = normN(q);
shiftRight(u, u, shift);
shiftRight(v, v, shift);
r = normN(u);
return q, r;
}
// log2 computes the integer binary logarithm of x.
// The result is the integer n for which 2^n <= x < 2^(n+1).
// If x == 0, the result is -1.
@ -314,6 +374,10 @@ func scanN(z []Word, s string, base int) ([]Word, int, int) {
base = 10;
if n > 0 && s[0] == '0' {
if n > 1 && (s[1] == 'x' || s[1] == 'X') {
if n == 2 {
// Reject a string which is just '0x' as nonsense.
return nil, 0, 0;
}
base, i = 16, 2;
} else {
base, i = 8, 1;
@ -362,9 +426,62 @@ func stringN(x []Word, base int) string {
for len(q) > 0 {
i--;
var r Word;
q, r = divNW(q, q, 10);
q, r = divNW(q, q, Word(base));
s[i] = "0123456789abcdef"[r];
}
return string(s[i:len(s)]);
}
// leadingZeroBits returns the number of leading zero bits in x.
func leadingZeroBits(x Word) int {
c := 0;
if x < 1 << (_W/2) {
x <<= _W/2;
c = int(_W/2);
}
for i := 0; x != 0; i++ {
if x&(1<<(_W-1)) != 0 {
return i + c;
}
x <<= 1;
}
return int(_W);
}
func shiftLeft(dst, src []Word, n int) {
if len(src) == 0 {
return;
}
ñ := uint(_W) - uint(n);
for i := len(src)-1; i >= 1; i-- {
dst[i] = src[i]<<uint(n);
dst[i] |= src[i-1]>>ñ;
}
dst[0] = src[0]<<uint(n);
}
func shiftRight(dst, src []Word, n int) {
if len(src) == 0 {
return;
}
ñ := uint(_W) - uint(n);
for i := 0; i < len(src)-1; i++ {
dst[i] = src[i]>>uint(n);
dst[i] |= src[i+1]<<ñ;
}
dst[len(src)-1] = src[len(src)-1]>>uint(n);
}
// greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2)
func greaterThan(x1, x2, y1, y2 Word) bool {
return x1 > y1 || x1 == y1 && x2 > y2;
}

View File

@ -7,7 +7,7 @@ package big
import "testing"
func TestCmpNN(t *testing.T) {
// TODO(gri) write this test - all other tests depends on it
// TODO(gri) write this test - all other tests depends on it
}
@ -16,6 +16,7 @@ type argNN struct {
z, x, y []Word;
}
var sumNN = []argNN{
argNN{},
argNN{[]Word{1}, nil, []Word{1}},
@ -25,6 +26,7 @@ var sumNN = []argNN{
argNN{[]Word{0, 0, 0, 1}, []Word{0, 0, _M}, []Word{0, 0, 1}},
}
var prodNN = []argNN{
argNN{},
argNN{nil, nil, nil},
@ -86,6 +88,7 @@ type strN struct {
s string;
}
var tabN = []strN{
strN{nil, 10, "0"},
strN{[]Word{1}, 10, "1"},
@ -93,6 +96,7 @@ var tabN = []strN{
strN{[]Word{1234567890}, 10, "1234567890"},
}
func TestStringN(t *testing.T) {
for _, a := range tabN {
s := stringN(a.x, a.b);
@ -112,3 +116,72 @@ func TestStringN(t *testing.T) {
}
}
}
func TestLeadingZeroBits(t *testing.T) {
var x Word = 1<<(_W-1);
for i := 0; i <= int(_W); i++ {
if leadingZeroBits(x) != i {
t.Errorf("failed at %x: got %d want %d", x, leadingZeroBits(x), i);
}
x >>= 1;
}
}
type shiftTest struct {
in []Word;
shift int;
out []Word;
}
var leftShiftTests = []shiftTest{
shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil},
shiftTest{[]Word{0}, 0, []Word{0}},
shiftTest{[]Word{1}, 0, []Word{1}},
shiftTest{[]Word{1}, 1, []Word{2}},
shiftTest{[]Word{1<<(_W-1)}, 1, []Word{0}},
shiftTest{[]Word{1<<(_W-1), 0}, 1, []Word{0, 1}},
}
func TestShiftLeft(t *testing.T) {
for i, test := range leftShiftTests {
dst := make([]Word, len(test.out));
shiftLeft(dst, test.in, test.shift);
for j, v := range dst {
if test.out[j] != v {
t.Errorf("#%d: got: %v want: %v", i, dst, test.out);
break;
}
}
}
}
var rightShiftTests = []shiftTest{
shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil},
shiftTest{[]Word{0}, 0, []Word{0}},
shiftTest{[]Word{1}, 0, []Word{1}},
shiftTest{[]Word{1}, 1, []Word{0}},
shiftTest{[]Word{2}, 1, []Word{1}},
shiftTest{[]Word{0, 1}, 1, []Word{1<<(_W-1), 0}},
shiftTest{[]Word{2, 1, 1}, 1, []Word{1<<(_W-1) + 1, 1<<(_W-1), 0}},
}
func TestShiftRight(t *testing.T) {
for i, test := range rightShiftTests {
dst := make([]Word, len(test.out));
shiftRight(dst, test.in, test.shift);
for j, v := range dst {
if test.out[j] != v {
t.Errorf("#%d: got: %v want: %v", i, dst, test.out);
break;
}
}
}
}