From 5dd6d684b83d4f01fee033bc89a1edc5ec74e3fb Mon Sep 17 00:00:00 2001 From: jadep Date: Fri, 10 Nov 2017 11:58:03 -0500 Subject: changes to parameter-generation script --- generate_parameters.py | 155 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 103 insertions(+), 52 deletions(-) (limited to 'generate_parameters.py') diff --git a/generate_parameters.py b/generate_parameters.py index 1cd031424..770d91b66 100644 --- a/generate_parameters.py +++ b/generate_parameters.py @@ -172,7 +172,7 @@ def get_params_montgomery(prime, bitwidth): p = parse_prime(prime) sanity_check(p) sz = int(math.ceil(num_bits(p) / float(bitwidth))) - return { + return [{ "modulus" : prime, "base" : str(bitwidth), "sz" : str(sz), @@ -181,28 +181,70 @@ def get_params_montgomery(prime, bitwidth): "extra_files" : ["montgomery%s/fesquare.c" % str(bitwidth)], "compiler" : COMPILER_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz), "compilerxx" : COMPILERXX_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz) - } - -# given a parsed prime, pick a number of (unsaturated) limbs -def get_num_limbs(p, bitwidth): + }] + +def place(weight, nlimbs, wt): + for i in range(nlimbs): + if weight(i) <= wt and weight(i+1) > wt: + return i + return None + +def solinas_reduce(p, pprods): + out = [] + for wt, x in pprods: + if wt >= num_bits(p): + for coef, exp in p[1:]: + out.append((wt - num_bits(p) + exp, -coef * x)) + else: + out.append((wt, x)) + return out + +# check if the suggested number of limbs will overflow when adding partial +# products after a multiplication and then doing solinas reduction +def overflow_free(p, bitwidth, nlimbs): + # weight (exponent only) + weight = lambda n : math.ceil(n * (num_bits(p) / nlimbs)) + # bit widths in canonical form + width = lambda i : weight(i + 1) - weight(i) + + # num of bits in each term after 1 addition of things with bounds at 1.125 * width + start = [(2**width(i))*1.125*2-1 for i in range(nlimbs)] + + # get partial products in (weight, # bits) pairs + pp = [(weight(i) + weight(j), start[i] * start[j]) for i in range(nlimbs) for j in range(nlimbs)] + + # reduction step + ppr = pp + while max(ppr, key=lambda t:t[0])[0] >= num_bits(p): + ppr = solinas_reduce(p, ppr) + + # accumulate partial products + cols = [[] for _ in range(nlimbs)] + for wt, x in ppr: + i = place(weight, nlimbs, wt) + if i == None: + raise LimbPickingException("Could not place weight %s (%s limbs, p=%s)" %(wt, nlimbs, p)) + cols[i].append(x * (2**(wt - weight(i)))) + + # add partial products together at each position + final = [math.log2(sum(ls)) if sum(ls) > 0 else 0 for ls in cols] + #print(nlimbs, list(map(lambda x: round(x,1), final))) + + result = all(map(lambda x:x < 2*bitwidth, final)) + return result + +# given a parsed prime, pick out all plausible numbers of (unsaturated) limbs +def get_possible_limbs(p, bitwidth): # we want to leave enough bits unused to do a full solinas reduction # without carrying; the number of bits necessary is the sum of the bits in # the negative coefficients of p (other than the most significant digit) unused_bits = sum(map(lambda t: math.ceil(math.log(-t[0], 2)) if t[0] < 0 else 0, p[1:])) - # print(p,unused_bits) min_limbs = int(math.ceil(num_bits(p) / (bitwidth - unused_bits))) - choices = [] - for n in range(min_limbs, 2 * min_limbs): # don't search past 2x as many limbs as saturated representation; that's just wasteful - # check that the number of 'extra' bits needed fits in this number of limbs - min_bits = int(num_bits(p) / n) - extra = num_bits(p) % n - if (extra == 0 or n % extra == 0) and min_bits + 1 < bitwidth: - choices.append((n, num_bits(p) / n)) - break - if len(choices) == 0: - raise LimbPickingException("Unable to pick a number of limbs for prime %s and bitwidth %s in range %s-%s limbs" %(p,bitwidth,min_limbs,5*min_limbs)) - # print (p,choices,min_limbs) - return choices[0][0] + + # don't search past 2x as many limbs as saturated representation; that's just wasteful + result = list(filter(lambda n : overflow_free(p, bitwidth, n), range(min_limbs, 2*min_limbs))) + # print("for prime %s, %s / %s limb choices were successful" %(p, len(result), min_limbs)) + return result def is_goldilocks(p): return p[0][1] == 2 * p[1][1] @@ -230,38 +272,44 @@ def remove_duplicates(l): def get_params_solinas(prime, bitwidth): p = parse_prime(prime) sanity_check(p) - sz = get_num_limbs(p, bitwidth) - base = format_base(num_bits(p), sz) - - # Uncomment to pretty-print primes/bases - # print(" ".join(map(str, [prime, " "*(35-len(prime)), bitwidth, round(base,1), sz]))) - - if len(p) > 2: - # do interleaved carry chains, starting at where the taps are - starts = [(int(t[1] / (num_bits(p) / sz)) - 1) % sz for t in p[1:]] - chain2 = [] - for n in range(1,sz): - for j in starts: - chain2.append((j + n) % sz) - chain2 = remove_duplicates(chain2) - chain3 = list(map(lambda x:(x+1)%sz,starts)) - carry_chains = [starts,chain2,chain3] - else: - carry_chains = "default" - output = { - "modulus": prime, - "base" : str(base), - "sz" : str(sz), - "bitwidth" : bitwidth, - "carry_chains" : carry_chains, - "coef_div_modulus" : str(2), - "operations" : ["femul", "feadd", "fesub", "fesquare", "freeze"], - "compiler" : COMPILER_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz), - "compilerxx" : COMPILERXX_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz) - } - if is_goldilocks(p): - output["goldilocks"] = True - return output + out = [] + l = get_possible_limbs(p, bitwidth) + if len(l) == 0: + raise LimbPickingException("Could not find a good number of limbs for prime %s and bitwidth %s" %(prime, bitwidth)) + # only use the top 2 choices + for sz in l[:2]: + base = format_base(num_bits(p), sz) + + # Uncomment to pretty-print primes/bases + # print(" ".join(map(str, [prime, " "*(35-len(prime)), bitwidth, base, sz]))) + + if len(p) > 2: + # do interleaved carry chains, starting at where the taps are + starts = [(int(t[1] / (num_bits(p) / sz)) - 1) % sz for t in p[1:]] + chain2 = [] + for n in range(1,sz): + for j in starts: + chain2.append((j + n) % sz) + chain2 = remove_duplicates(chain2) + chain3 = list(map(lambda x:(x+1)%sz,starts)) + carry_chains = [starts,chain2,chain3] + else: + carry_chains = "default" + params = { + "modulus": prime, + "base" : str(base), + "sz" : str(sz), + "bitwidth" : bitwidth, + "carry_chains" : carry_chains, + "coef_div_modulus" : str(2), + "operations" : ["femul", "feadd", "fesub", "fesquare", "freeze"], + "compiler" : COMPILER_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz), + "compilerxx" : COMPILERXX_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz) + } + if is_goldilocks(p): + params["goldilocks"] = True + out.append(params) + return out def write_if_changed(filename, contents): if os.path.isfile(filename): @@ -289,7 +337,8 @@ def format_json(params): def write_output(name, params): prime = params["modulus"] - filename = (name + "_" + prime + ".json").replace("^","e").replace(" ","").replace("-","m").replace("+","p").replace("*","x") + nlimbs = params["sz"] + filename = (name + "_" + prime + "_" + nlimbs + "limbs" + ".json").replace("^","e").replace(" ","").replace("-","m").replace("+","p").replace("*","x") write_if_changed(os.path.join(JSON_DIRECTORY, filename), format_json(params)) @@ -297,7 +346,9 @@ def write_output(name, params): def try_write_output(name, get_params, prime, bitwidth): try: - write_output(name, get_params(prime, bitwidth)) + all_params = get_params(prime, bitwidth) + for params in all_params: + write_output(name, params) except (LimbPickingException, NonBase2Exception, UnexpectedPrimeException) as e: print(e) except Exception as e: -- cgit v1.2.3