gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

fermat.go (4081B)


      1 package bigfft
      2 
      3 import (
      4 	"math/big"
      5 )
      6 
      7 // Arithmetic modulo 2^n+1.
      8 
      9 // A fermat of length w+1 represents a number modulo 2^(w*_W) + 1. The last
     10 // word is zero or one. A number has at most two representatives satisfying the
     11 // 0-1 last word constraint.
     12 type fermat nat
     13 
     14 func (n fermat) String() string { return nat(n).String() }
     15 
     16 func (z fermat) norm() {
     17 	n := len(z) - 1
     18 	c := z[n]
     19 	if c == 0 {
     20 		return
     21 	}
     22 	if z[0] >= c {
     23 		z[n] = 0
     24 		z[0] -= c
     25 		return
     26 	}
     27 	// z[0] < z[n].
     28 	subVW(z, z, c) // Substract c
     29 	if c > 1 {
     30 		z[n] -= c - 1
     31 		c = 1
     32 	}
     33 	// Add back c.
     34 	if z[n] == 1 {
     35 		z[n] = 0
     36 		return
     37 	} else {
     38 		addVW(z, z, 1)
     39 	}
     40 }
     41 
     42 // Shift computes (x << k) mod (2^n+1).
     43 func (z fermat) Shift(x fermat, k int) {
     44 	if len(z) != len(x) {
     45 		panic("len(z) != len(x) in Shift")
     46 	}
     47 	n := len(x) - 1
     48 	// Shift by n*_W is taking the opposite.
     49 	k %= 2 * n * _W
     50 	if k < 0 {
     51 		k += 2 * n * _W
     52 	}
     53 	neg := false
     54 	if k >= n*_W {
     55 		k -= n * _W
     56 		neg = true
     57 	}
     58 
     59 	kw, kb := k/_W, k%_W
     60 
     61 	z[n] = 1 // Add (-1)
     62 	if !neg {
     63 		for i := 0; i < kw; i++ {
     64 			z[i] = 0
     65 		}
     66 		// Shift left by kw words.
     67 		// x = a·2^(n-k) + b
     68 		// x<<k = (b<<k) - a
     69 		copy(z[kw:], x[:n-kw])
     70 		b := subVV(z[:kw+1], z[:kw+1], x[n-kw:])
     71 		if z[kw+1] > 0 {
     72 			z[kw+1] -= b
     73 		} else {
     74 			subVW(z[kw+1:], z[kw+1:], b)
     75 		}
     76 	} else {
     77 		for i := kw + 1; i < n; i++ {
     78 			z[i] = 0
     79 		}
     80 		// Shift left and negate, by kw words.
     81 		copy(z[:kw+1], x[n-kw:n+1])            // z_low = x_high
     82 		b := subVV(z[kw:n], z[kw:n], x[:n-kw]) // z_high -= x_low
     83 		z[n] -= b
     84 	}
     85 	// Add back 1.
     86 	if z[n] > 0 {
     87 		z[n]--
     88 	} else if z[0] < ^big.Word(0) {
     89 		z[0]++
     90 	} else {
     91 		addVW(z, z, 1)
     92 	}
     93 	// Shift left by kb bits
     94 	shlVU(z, z, uint(kb))
     95 	z.norm()
     96 }
     97 
     98 // ShiftHalf shifts x by k/2 bits the left. Shifting by 1/2 bit
     99 // is multiplication by sqrt(2) mod 2^n+1 which is 2^(3n/4) - 2^(n/4).
    100 // A temporary buffer must be provided in tmp.
    101 func (z fermat) ShiftHalf(x fermat, k int, tmp fermat) {
    102 	n := len(z) - 1
    103 	if k%2 == 0 {
    104 		z.Shift(x, k/2)
    105 		return
    106 	}
    107 	u := (k - 1) / 2
    108 	a := u + (3*_W/4)*n
    109 	b := u + (_W/4)*n
    110 	z.Shift(x, a)
    111 	tmp.Shift(x, b)
    112 	z.Sub(z, tmp)
    113 }
    114 
    115 // Add computes addition mod 2^n+1.
    116 func (z fermat) Add(x, y fermat) fermat {
    117 	if len(z) != len(x) {
    118 		panic("Add: len(z) != len(x)")
    119 	}
    120 	addVV(z, x, y) // there cannot be a carry here.
    121 	z.norm()
    122 	return z
    123 }
    124 
    125 // Sub computes substraction mod 2^n+1.
    126 func (z fermat) Sub(x, y fermat) fermat {
    127 	if len(z) != len(x) {
    128 		panic("Add: len(z) != len(x)")
    129 	}
    130 	n := len(y) - 1
    131 	b := subVV(z[:n], x[:n], y[:n])
    132 	b += y[n]
    133 	// If b > 0, we need to subtract b<<n, which is the same as adding b.
    134 	z[n] = x[n]
    135 	if z[0] <= ^big.Word(0)-b {
    136 		z[0] += b
    137 	} else {
    138 		addVW(z, z, b)
    139 	}
    140 	z.norm()
    141 	return z
    142 }
    143 
    144 func (z fermat) Mul(x, y fermat) fermat {
    145 	if len(x) != len(y) {
    146 		panic("Mul: len(x) != len(y)")
    147 	}
    148 	n := len(x) - 1
    149 	if n < 30 {
    150 		z = z[:2*n+2]
    151 		basicMul(z, x, y)
    152 		z = z[:2*n+1]
    153 	} else {
    154 		var xi, yi, zi big.Int
    155 		xi.SetBits(x)
    156 		yi.SetBits(y)
    157 		zi.SetBits(z)
    158 		zb := zi.Mul(&xi, &yi).Bits()
    159 		if len(zb) <= n {
    160 			// Short product.
    161 			copy(z, zb)
    162 			for i := len(zb); i < len(z); i++ {
    163 				z[i] = 0
    164 			}
    165 			return z
    166 		}
    167 		z = zb
    168 	}
    169 	// len(z) is at most 2n+1.
    170 	if len(z) > 2*n+1 {
    171 		panic("len(z) > 2n+1")
    172 	}
    173 	// We now have
    174 	// z = z[:n] + 1<<(n*W) * z[n:2n+1]
    175 	// which normalizes to:
    176 	// z = z[:n] - z[n:2n] + z[2n]
    177 	c1 := big.Word(0)
    178 	if len(z) > 2*n {
    179 		c1 = addVW(z[:n], z[:n], z[2*n])
    180 	}
    181 	c2 := big.Word(0)
    182 	if len(z) >= 2*n {
    183 		c2 = subVV(z[:n], z[:n], z[n:2*n])
    184 	} else {
    185 		m := len(z) - n
    186 		c2 = subVV(z[:m], z[:m], z[n:])
    187 		c2 = subVW(z[m:n], z[m:n], c2)
    188 	}
    189 	// Restore carries.
    190 	// Substracting z[n] -= c2 is the same
    191 	// as z[0] += c2
    192 	z = z[:n+1]
    193 	z[n] = c1
    194 	c := addVW(z, z, c2)
    195 	if c != 0 {
    196 		panic("impossible")
    197 	}
    198 	z.norm()
    199 	return z
    200 }
    201 
    202 // copied from math/big
    203 //
    204 // basicMul multiplies x and y and leaves the result in z.
    205 // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
    206 func basicMul(z, x, y fermat) {
    207 	// initialize z
    208 	for i := 0; i < len(z); i++ {
    209 		z[i] = 0
    210 	}
    211 	for i, d := range y {
    212 		if d != 0 {
    213 			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
    214 		}
    215 	}
    216 }