aboutsummaryrefslogtreecommitdiff
path: root/generate_parameters.py
diff options
context:
space:
mode:
Diffstat (limited to 'generate_parameters.py')
-rw-r--r--generate_parameters.py372
1 files changed, 0 insertions, 372 deletions
diff --git a/generate_parameters.py b/generate_parameters.py
deleted file mode 100644
index 60ee5bce0..000000000
--- a/generate_parameters.py
+++ /dev/null
@@ -1,372 +0,0 @@
-
-'''
-EXAMPLES (handwritten):
-
-
-# p256 - amd128
-{
- "modulus" : "2^256-2^224+2^192+2^96-1",
- "base" : "128",
- "sz" : "2",
- "bitwidth" : "128",
- "montgomery" : "true",
- "operations" : ["fenz", "feadd", "femul", "feopp", "fesub"],
- "compiler" : "gcc -fno-peephole2 `#GCC BUG 81300` -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes -Wno-incompatible-pointer-types -fno-strict-aliasing"
-}
-
-# p256 - amd64
-{
- "modulus" : "2^256-2^224+2^192+2^96-1",
- "base" : "64",
- "sz" : "4",
- "bitwidth" : "64",
- "montgomery" : "true",
- "operations" : ["fenz", "feadd", "femul", "feopp", "fesub"],
- "compiler" : "gcc -fno-peephole2 `#GCC BUG 81300` -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes -Wno-incompatible-pointer-types -fno-strict-aliasing"
-}
-
-
-# p448 - c64
-{
- "modulus" : "2^448-2^224-1",
- "base" : "56",
- "goldilocks" : "true",
- "sz" : "8",
- "bitwidth" : "64",
- "carry_chains" : [[3, 7],
- [0, 4, 1, 5, 2, 6, 3, 7],
- [4, 0]],
- "coef_div_modulus" : "2",
- "operations" : ["femul"]
-}
-
-# curve25519 - c64
-{
- "modulus" : "2^255-19",
- "base" : "51",
- "sz" : "5",
- "bitwidth" : "64",
- "carry_chains" : "default",
- "coef_div_modulus" : "2",
- "operations" : ["femul", "fesquare", "freeze"],
- "compiler" : "gcc -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes",
-}
-
-# curve25519 - c32
-{
- "modulus" : "2^255-19",
- "base" : "25.5",
- "sz" : "10",
- "bitwidth" : "32",
- "carry_chains" : "default",
- "coef_div_modulus" : "2",
- "operations" : ["femul", "fesquare", "freeze"],
- "compiler" : "gcc -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes",
-}
-
-'''
-
-import math,json,sys,os,traceback,re,textwrap
-from fractions import Fraction
-
-CC = "clang -fbracket-depth=999999 -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fuse-ld=lld -fomit-frame-pointer -fwrapv -Wno-attributes -fno-strict-aliasing"
-CCX = "clang++ -fbracket-depth=999999 -march=native -mbmi2 -mtune=native -std=gnu++11 -O3 -flto -fuse-ld=lld -fomit-frame-pointer -fwrapv -Wno-attributes -fno-strict-aliasing"
-
-# for montgomery
-COMPILER_MONT = CC
-COMPILERXX_MONT = CCX
-# for solinas
-COMPILER_SOLI = CC
-COMPILERXX_SOLI = CCX
-CUR_PATH = os.path.dirname(os.path.realpath(__file__))
-JSON_DIRECTORY = os.path.join(CUR_PATH, "src/Specific/CurveParameters")
-REMAKE_CURVES = os.path.join(JSON_DIRECTORY, 'remake_curves.sh')
-
-class LimbPickingException(Exception): pass
-class NonBase2Exception(Exception): pass
-class UnexpectedPrimeException(Exception): pass
-
-# given a string representing one term or "tap" in a prime, returns a pair of
-# integers representing the weight and coefficient of that tap
-# "2 ^ y" -> [1, y]
-# "x * 2 ^ y" -> [x, y]
-# "x * y" -> [x*y,0]
-# "x" -> [x,0]
-def parse_term(t) :
- if "*" not in t and "^" not in t:
- return [int(t),0]
-
- if "*" in t:
- if len(t.split("*")) > 2: # this occurs when e.g. [w - x * y] has been turned into [w + -1 * x * y]
- a1,a2,b = t.split("*")
- a = int(a1) * int(a2)
- else:
- a,b = t.split("*")
- if "^" not in b:
- return [int(a) * int(b),0]
- else:
- a,b = (1,t)
-
- b,e = b.split("^")
- if int(b) != 2:
- raise NonBase2Exception("Could not parse term, power with base other than 2: %s" %t)
- return [int(a),int(e)]
-
-
-# expects prime to be a string and expressed as sum/difference of products of
-# two with small coefficients (e.g. '2^448 - 2^224 - 1', '2^255 - 19')
-def parse_prime(prime):
- prime = prime.replace("-", "+ -").replace(' ', '').replace('+-2^', '+-1*2^')
- terms = prime.split("+")
- return list(map(parse_term, terms))
-
-# check that the parsed prime makes sense
-def sanity_check(p):
- if not all([
- # are there at least 2 terms?
- len(p) > 1,
- # do all terms have 2 elements?
- all(map(lambda t:len(t) == 2, p)),
- # are terms are in order (most to least significant)?
- p == list(sorted(p,reverse=True,key=lambda t:t[1])),
- # does the least significant term have weight 2^0=1?
- p[-1][1] == 0,
- # are all the exponents positive and the coefficients nonzero?
- all(map(lambda t:t[0] != 0 and t[1] >= 0, p)),
- # is second-most-significant term negative?
- p[1][0] < 0,
- # are any exponents repeated?
- len(set(map(lambda t:t[1], p))) == len(p)]) :
- raise UnexpectedPrimeException("Parsed prime %s has unexpected format" %p)
-
-
-def eval_numexpr(numexpr):
- # copying from https://stackoverflow.com/a/25437733/377022
- numexpr = re.sub(r"\.(?![0-9])", "", numexpr) # purge any instance of '.' not followed by a number
- return eval(numexpr, {'__builtins__':None})
-
-def get_extra_compiler_params(q, base, bitwidth, sz):
- def log_wt(i):
- return int(math.ceil(sum(map(Fraction, map(str.strip, str(base).split('+')))) * i))
- q_int = eval_numexpr(q.replace('^', '**'))
- a24 = 12345 # TODO
- modulus_bytes = (q_int.bit_length()+7)//8
- limb_widths = repr('{%s}' % ','.join(str(int(log_wt(i + 1) - log_wt(i))) for i in range(sz)))
- defs = {
- 'q_mpz' : repr(re.sub(r'2(\s*)\^(\s*)([0-9]+)', r'(1_mpz\1<<\2\3)', str(q))),
- 'modulus_bytes_val' : repr(str(modulus_bytes)),
- 'modulus_array' : repr('{%s}' % ','.join(reversed(list('0x%02x' % ((q_int >> 8*i)&0xff) for i in range(modulus_bytes))))),
- 'a_minus_two_over_four_array' : repr('{%s}' % ','.join(reversed(list('0x%02x' % ((a24 >> 8*i)&0xff) for i in range(modulus_bytes))))),
- 'a24_val' : repr(str(a24)),
- 'a24_hex' : repr(hex(a24)),
- 'bitwidth' : repr(str(bitwidth)),
- 'modulus_limbs' : repr(str(sz)),
- 'limb_weight_gaps_array' : limb_widths
- }
- return ' ' + ' '.join('-D%s=%s' % (k, v) for k, v in sorted(defs.items()))
-
-def num_bits(p):
- return p[0][1]
-
-def get_params_montgomery(prime, bitwidth):
- p = parse_prime(prime)
- sanity_check(p)
- sz = int(math.ceil(num_bits(p) / float(bitwidth)))
- return [{
- "modulus" : prime,
- "base" : str(bitwidth),
- "sz" : str(sz),
- "montgomery" : True,
- "operations" : ["fenz", "feadd", "femul", "feopp", "fesub"],
- "extra_files" : ["montgomery%s/fesquare.c" % str(bitwidth)],
- "compiler" : COMPILER_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz),
- "compilerxx" : COMPILERXX_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz)
- }]
-
-def place(weight, nlimbs, wt):
- for i in range(nlimbs):
- if weight(i) <= wt and weight(i+1) > wt:
- return i
- return None
-
-def solinas_reduce(p, pprods):
- out = []
- for wt, x in pprods:
- if wt >= num_bits(p):
- for coef, exp in p[1:]:
- out.append((wt - num_bits(p) + exp, -coef * x))
- else:
- out.append((wt, x))
- return out
-
-# check if the suggested number of limbs will overflow when adding partial
-# products after a multiplication and then doing solinas reduction
-def overflow_free(p, bitwidth, nlimbs):
- # weight (exponent only)
- weight = lambda n : math.ceil(n * (num_bits(p) / nlimbs))
- # bit widths in canonical form
- width = lambda i : weight(i + 1) - weight(i)
-
- # num of bits in each term after 1 addition of things with bounds at 1.125 * width
- start = [(2**width(i))*1.125*2-1 for i in range(nlimbs)]
-
- # get partial products in (weight, # bits) pairs
- pp = [(weight(i) + weight(j), start[i] * start[j]) for i in range(nlimbs) for j in range(nlimbs)]
-
- # reduction step
- ppr = pp
- while max(ppr, key=lambda t:t[0])[0] >= num_bits(p):
- ppr = solinas_reduce(p, ppr)
-
- # accumulate partial products
- cols = [[] for _ in range(nlimbs)]
- for wt, x in ppr:
- i = place(weight, nlimbs, wt)
- if i == None:
- raise LimbPickingException("Could not place weight %s (%s limbs, p=%s)" %(wt, nlimbs, p))
- cols[i].append(x * (2**(wt - weight(i))))
-
- # add partial products together at each position
- final = [math.log2(sum(ls)) if sum(ls) > 0 else 0 for ls in cols]
- #print(nlimbs, list(map(lambda x: round(x,1), final)))
-
- result = all(map(lambda x:x < 2*bitwidth, final))
- return result
-
-# given a parsed prime, pick out all plausible numbers of (unsaturated) limbs
-def get_possible_limbs(p, bitwidth):
- # we want to leave enough bits unused to do a full solinas reduction
- # without carrying; the number of bits necessary is the sum of the bits in
- # the negative coefficients of p (other than the most significant digit)
- unused_bits = sum(map(lambda t: math.ceil(math.log(-t[0], 2)) if t[0] < 0 else 0, p[1:]))
- min_limbs = int(math.ceil(num_bits(p) / (bitwidth - unused_bits)))
-
- # don't search past 2x as many limbs as saturated representation; that's just wasteful
- result = list(filter(lambda n : overflow_free(p, bitwidth, n), range(min_limbs, 2*min_limbs)))
- # print("for prime %s, %s / %s limb choices were successful" %(p, len(result), min_limbs))
- return result
-
-def is_goldilocks(p):
- return p[0][1] == 2 * p[1][1]
-
-def format_base(numerator, denominator):
- if numerator % denominator == 0:
- base = int(numerator / denominator)
- else:
- base = Fraction(numerator=numerator, denominator=denominator)
- if base.denominator in (1, 2, 4, 5, 8, 10):
- base = float(base)
- else:
- base_int, base_frac = int(base), base - int(base)
- base = '%d + %s' % (base_int, str(base_frac))
- return base
-
-# removes latest occurences, preserves order
-def remove_duplicates(l):
- seen = []
- for x in l:
- if x not in seen:
- seen.append(x)
- return seen
-
-def get_params_solinas(prime, bitwidth):
- p = parse_prime(prime)
- sanity_check(p)
- out = []
- l = get_possible_limbs(p, bitwidth)
- if len(l) == 0:
- raise LimbPickingException("Could not find a good number of limbs for prime %s and bitwidth %s" %(prime, bitwidth))
- # only use the top 2 choices
- for sz in l[:2]:
- base = format_base(num_bits(p), sz)
-
- # Uncomment to pretty-print primes/bases
- # print(" ".join(map(str, [prime, " "*(35-len(prime)), bitwidth, base, sz])))
-
- if len(p) > 2:
- # do interleaved carry chains, starting at where the taps are
- starts = [(int(t[1] / (num_bits(p) / sz)) - 1) % sz for t in p[1:]]
- chain2 = []
- for n in range(1,sz):
- for j in starts:
- chain2.append((j + n) % sz)
- chain2 = remove_duplicates(chain2)
- chain3 = list(map(lambda x:(x+1)%sz,starts))
- carry_chains = [starts,chain2,chain3]
- else:
- carry_chains = "default"
- params = {
- "modulus": prime,
- "base" : str(base),
- "sz" : str(sz),
- "bitwidth" : bitwidth,
- "carry_chains" : carry_chains,
- "coef_div_modulus" : str(2),
- "operations" : ["femul", "feadd", "fesub", "fesquare", "fecarry", "freeze"],
- "compiler" : COMPILER_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz),
- "compilerxx" : COMPILERXX_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz)
- }
- if is_goldilocks(p):
- params["goldilocks"] = True
- out.append(params)
- return out
-
-def write_if_changed(filename, contents):
- if os.path.isfile(filename):
- with open(filename, 'r') as f:
- old = f.read()
- if old == contents: return
- with open(filename, 'w') as f:
- f.write(contents)
-
-def update_remake_curves(filename):
- with open(REMAKE_CURVES, 'r') as f:
- lines = f.readlines()
- new_line = '${MAKE} "$@" %s ../%s/\n' % (filename, filename[:-len('.json')])
- if new_line in lines: return
- if any(filename in line for line in lines):
- lines = [(line if filename not in line else new_line)
- for line in lines]
- else:
- lines.append(new_line)
- write_if_changed(REMAKE_CURVES, ''.join(lines))
-
-def format_json(params):
- return json.dumps(params, indent=4, separators=(',', ': '), sort_keys=True) + '\n'
-
-
-def write_output(name, params):
- prime = params["modulus"]
- nlimbs = params["sz"]
- filename = (name + "_" + prime + "_" + nlimbs + "limbs" + ".json").replace("^","e").replace(" ","").replace("-","m").replace("+","p").replace("*","x")
-
- write_if_changed(os.path.join(JSON_DIRECTORY, filename),
- format_json(params))
- update_remake_curves(filename)
-
-def try_write_output(name, get_params, prime, bitwidth):
- try:
- all_params = get_params(prime, bitwidth)
- for params in all_params:
- write_output(name, params)
- except (LimbPickingException, NonBase2Exception, UnexpectedPrimeException) as e:
- print(e)
- except Exception as e:
- traceback.print_exc()
-
-USAGE = "python generate_parameters.py input_file"
-if __name__ == "__main__":
- if len(sys.argv) < 2:
- print(USAGE)
- sys.exit()
- f = open(sys.argv[1])
- for line in f:
- # skip comments and empty lines
- if line.strip().startswith("#") or len(line.strip()) == 0:
- continue
- prime = line.split("#")[0].strip() # remove trailing comments and trailing/leading whitespace
- try_write_output("montgomery32", get_params_montgomery, prime, 32)
- try_write_output("montgomery64", get_params_montgomery, prime, 64)
- try_write_output("solinas32", get_params_solinas, prime, 32)
- try_write_output("solinas64", get_params_solinas, prime, 64)
- f.close()