aboutsummaryrefslogtreecommitdiff
path: root/generate_parameters.py
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-11-10 11:58:03 -0500
committerGravatar jadep <jade.philipoom@gmail.com>2017-11-12 14:46:30 -0500
commit5dd6d684b83d4f01fee033bc89a1edc5ec74e3fb (patch)
tree679ef30b3507b4518fd2c97d5232042710635a8b /generate_parameters.py
parentb104d505f330e4330a23f86890612d9039500462 (diff)
changes to parameter-generation script
Diffstat (limited to 'generate_parameters.py')
-rw-r--r--generate_parameters.py155
1 files changed, 103 insertions, 52 deletions
diff --git a/generate_parameters.py b/generate_parameters.py
index 1cd031424..770d91b66 100644
--- a/generate_parameters.py
+++ b/generate_parameters.py
@@ -172,7 +172,7 @@ def get_params_montgomery(prime, bitwidth):
p = parse_prime(prime)
sanity_check(p)
sz = int(math.ceil(num_bits(p) / float(bitwidth)))
- return {
+ return [{
"modulus" : prime,
"base" : str(bitwidth),
"sz" : str(sz),
@@ -181,28 +181,70 @@ def get_params_montgomery(prime, bitwidth):
"extra_files" : ["montgomery%s/fesquare.c" % str(bitwidth)],
"compiler" : COMPILER_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz),
"compilerxx" : COMPILERXX_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz)
- }
-
-# given a parsed prime, pick a number of (unsaturated) limbs
-def get_num_limbs(p, bitwidth):
+ }]
+
+def place(weight, nlimbs, wt):
+ for i in range(nlimbs):
+ if weight(i) <= wt and weight(i+1) > wt:
+ return i
+ return None
+
+def solinas_reduce(p, pprods):
+ out = []
+ for wt, x in pprods:
+ if wt >= num_bits(p):
+ for coef, exp in p[1:]:
+ out.append((wt - num_bits(p) + exp, -coef * x))
+ else:
+ out.append((wt, x))
+ return out
+
+# check if the suggested number of limbs will overflow when adding partial
+# products after a multiplication and then doing solinas reduction
+def overflow_free(p, bitwidth, nlimbs):
+ # weight (exponent only)
+ weight = lambda n : math.ceil(n * (num_bits(p) / nlimbs))
+ # bit widths in canonical form
+ width = lambda i : weight(i + 1) - weight(i)
+
+ # num of bits in each term after 1 addition of things with bounds at 1.125 * width
+ start = [(2**width(i))*1.125*2-1 for i in range(nlimbs)]
+
+ # get partial products in (weight, # bits) pairs
+ pp = [(weight(i) + weight(j), start[i] * start[j]) for i in range(nlimbs) for j in range(nlimbs)]
+
+ # reduction step
+ ppr = pp
+ while max(ppr, key=lambda t:t[0])[0] >= num_bits(p):
+ ppr = solinas_reduce(p, ppr)
+
+ # accumulate partial products
+ cols = [[] for _ in range(nlimbs)]
+ for wt, x in ppr:
+ i = place(weight, nlimbs, wt)
+ if i == None:
+ raise LimbPickingException("Could not place weight %s (%s limbs, p=%s)" %(wt, nlimbs, p))
+ cols[i].append(x * (2**(wt - weight(i))))
+
+ # add partial products together at each position
+ final = [math.log2(sum(ls)) if sum(ls) > 0 else 0 for ls in cols]
+ #print(nlimbs, list(map(lambda x: round(x,1), final)))
+
+ result = all(map(lambda x:x < 2*bitwidth, final))
+ return result
+
+# given a parsed prime, pick out all plausible numbers of (unsaturated) limbs
+def get_possible_limbs(p, bitwidth):
# we want to leave enough bits unused to do a full solinas reduction
# without carrying; the number of bits necessary is the sum of the bits in
# the negative coefficients of p (other than the most significant digit)
unused_bits = sum(map(lambda t: math.ceil(math.log(-t[0], 2)) if t[0] < 0 else 0, p[1:]))
- # print(p,unused_bits)
min_limbs = int(math.ceil(num_bits(p) / (bitwidth - unused_bits)))
- choices = []
- for n in range(min_limbs, 2 * min_limbs): # don't search past 2x as many limbs as saturated representation; that's just wasteful
- # check that the number of 'extra' bits needed fits in this number of limbs
- min_bits = int(num_bits(p) / n)
- extra = num_bits(p) % n
- if (extra == 0 or n % extra == 0) and min_bits + 1 < bitwidth:
- choices.append((n, num_bits(p) / n))
- break
- if len(choices) == 0:
- raise LimbPickingException("Unable to pick a number of limbs for prime %s and bitwidth %s in range %s-%s limbs" %(p,bitwidth,min_limbs,5*min_limbs))
- # print (p,choices,min_limbs)
- return choices[0][0]
+
+ # don't search past 2x as many limbs as saturated representation; that's just wasteful
+ result = list(filter(lambda n : overflow_free(p, bitwidth, n), range(min_limbs, 2*min_limbs)))
+ # print("for prime %s, %s / %s limb choices were successful" %(p, len(result), min_limbs))
+ return result
def is_goldilocks(p):
return p[0][1] == 2 * p[1][1]
@@ -230,38 +272,44 @@ def remove_duplicates(l):
def get_params_solinas(prime, bitwidth):
p = parse_prime(prime)
sanity_check(p)
- sz = get_num_limbs(p, bitwidth)
- base = format_base(num_bits(p), sz)
-
- # Uncomment to pretty-print primes/bases
- # print(" ".join(map(str, [prime, " "*(35-len(prime)), bitwidth, round(base,1), sz])))
-
- if len(p) > 2:
- # do interleaved carry chains, starting at where the taps are
- starts = [(int(t[1] / (num_bits(p) / sz)) - 1) % sz for t in p[1:]]
- chain2 = []
- for n in range(1,sz):
- for j in starts:
- chain2.append((j + n) % sz)
- chain2 = remove_duplicates(chain2)
- chain3 = list(map(lambda x:(x+1)%sz,starts))
- carry_chains = [starts,chain2,chain3]
- else:
- carry_chains = "default"
- output = {
- "modulus": prime,
- "base" : str(base),
- "sz" : str(sz),
- "bitwidth" : bitwidth,
- "carry_chains" : carry_chains,
- "coef_div_modulus" : str(2),
- "operations" : ["femul", "feadd", "fesub", "fesquare", "freeze"],
- "compiler" : COMPILER_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz),
- "compilerxx" : COMPILERXX_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz)
- }
- if is_goldilocks(p):
- output["goldilocks"] = True
- return output
+ out = []
+ l = get_possible_limbs(p, bitwidth)
+ if len(l) == 0:
+ raise LimbPickingException("Could not find a good number of limbs for prime %s and bitwidth %s" %(prime, bitwidth))
+ # only use the top 2 choices
+ for sz in l[:2]:
+ base = format_base(num_bits(p), sz)
+
+ # Uncomment to pretty-print primes/bases
+ # print(" ".join(map(str, [prime, " "*(35-len(prime)), bitwidth, base, sz])))
+
+ if len(p) > 2:
+ # do interleaved carry chains, starting at where the taps are
+ starts = [(int(t[1] / (num_bits(p) / sz)) - 1) % sz for t in p[1:]]
+ chain2 = []
+ for n in range(1,sz):
+ for j in starts:
+ chain2.append((j + n) % sz)
+ chain2 = remove_duplicates(chain2)
+ chain3 = list(map(lambda x:(x+1)%sz,starts))
+ carry_chains = [starts,chain2,chain3]
+ else:
+ carry_chains = "default"
+ params = {
+ "modulus": prime,
+ "base" : str(base),
+ "sz" : str(sz),
+ "bitwidth" : bitwidth,
+ "carry_chains" : carry_chains,
+ "coef_div_modulus" : str(2),
+ "operations" : ["femul", "feadd", "fesub", "fesquare", "freeze"],
+ "compiler" : COMPILER_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz),
+ "compilerxx" : COMPILERXX_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz)
+ }
+ if is_goldilocks(p):
+ params["goldilocks"] = True
+ out.append(params)
+ return out
def write_if_changed(filename, contents):
if os.path.isfile(filename):
@@ -289,7 +337,8 @@ def format_json(params):
def write_output(name, params):
prime = params["modulus"]
- filename = (name + "_" + prime + ".json").replace("^","e").replace(" ","").replace("-","m").replace("+","p").replace("*","x")
+ nlimbs = params["sz"]
+ filename = (name + "_" + prime + "_" + nlimbs + "limbs" + ".json").replace("^","e").replace(" ","").replace("-","m").replace("+","p").replace("*","x")
write_if_changed(os.path.join(JSON_DIRECTORY, filename),
format_json(params))
@@ -297,7 +346,9 @@ def write_output(name, params):
def try_write_output(name, get_params, prime, bitwidth):
try:
- write_output(name, get_params(prime, bitwidth))
+ all_params = get_params(prime, bitwidth)
+ for params in all_params:
+ write_output(name, params)
except (LimbPickingException, NonBase2Exception, UnexpectedPrimeException) as e:
print(e)
except Exception as e: