aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/Framework/make_curve.py
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-10-06 01:50:59 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2017-10-18 23:01:29 -0400
commit28359fcb5be530da65d5049846927a84a880b919 (patch)
tree8f0d8b6fc8ea4f109a9540c35869fd1d2adf759e /src/Specific/Framework/make_curve.py
parenta3a6eb12e7652e40b573372217f0771368ad50cb (diff)
Build curve-specific files from json
The X25519 curves are now generated from `.json` files. This code only works in >= 8.7, because it makes use of the recently-merged-from-fiat `transparent_abstract` tactic to allow defining things in tactics without massive slowdown. The structure is as follows: 0. The module types and tactic definitions that set up the infrastructure live in `src/Specific/Framework/` 1. There are `.json` files in `src/Specific/CurveParameters/` that specify curve characteristics. A simple example is `x2555_130.json`, which is: ```json { "modulus" : "2^255-5", "base" : "130", "a24" : "121665 (* XXX TODO(andreser) FIXME? Is this right for this curve? *)", "sz" : "3", "bitwidth" : "128", "carry_chain1" : "default", "carry_chain2" : ["0", "1"], "coef_div_modulus" : "2", "operations" : ["ladderstep"] } ``` A more complicated example is `x25519_c64.json`: ```json { "modulus" : "2^255-19", "base" : "51", "a24" : "121665", "sz" : "5", "bitwidth" : "64", "carry_chain1" : "default", "carry_chain2" : ["0", "1"], "coef_div_modulus" : "2", "operations" : ["femul", "fesquare", "freeze", "ladderstep"], "extra_files" : ["X25519_C64/scalarmult.c"], "compiler" : "gcc -march=native -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes", "mul_header" : "(* Micro-optimized form from curve25519-donna-c64 by Adam Langley (Google) and Daniel Bernstein. See <https://github.com/agl/curve25519-donna/blob/master/LICENSE.md>;. *)", "mul_code" : " uint128_t t[5]; limb r0,r1,r2,r3,r4,s0,s1,s2,s3,s4,c; r0 = in[0]; r1 = in[1]; r2 = in[2]; r3 = in[3]; r4 = in[4]; s0 = in2[0]; s1 = in2[1]; s2 = in2[2]; s3 = in2[3]; s4 = in2[4]; t[0] = ((uint128_t) r0) * s0; t[1] = ((uint128_t) r0) * s1 + ((uint128_t) r1) * s0; t[2] = ((uint128_t) r0) * s2 + ((uint128_t) r2) * s0 + ((uint128_t) r1) * s1; t[3] = ((uint128_t) r0) * s3 + ((uint128_t) r3) * s0 + ((uint128_t) r1) * s2 + ((uint128_t) r2) * s1; t[4] = ((uint128_t) r0) * s4 + ((uint128_t) r4) * s0 + ((uint128_t) r3) * s1 + ((uint128_t) r1) * s3 + ((uint128_t) r2) * s2; r4 *= 19; r1 *= 19; r2 *= 19; r3 *= 19; t[0] += ((uint128_t) r4) * s1 + ((uint128_t) r1) * s4 + ((uint128_t) r2) * s3 + ((uint128_t) r3) * s2; t[1] += ((uint128_t) r4) * s2 + ((uint128_t) r2) * s4 + ((uint128_t) r3) * s3; t[2] += ((uint128_t) r4) * s3 + ((uint128_t) r3) * s4; t[3] += ((uint128_t) r4) * s4; ", "square_header" : "(* Micro-optimized form from curve25519-donna-c64 by Adam Langley (Google) and Daniel Bernstein. See <https://github.com/agl/curve25519-donna/blob/master/LICENSE.md>;. *)", "square_code" : " uint128_t t[5]; limb r0,r1,r2,r3,r4,c; limb d0,d1,d2,d4,d419; r0 = in[0]; r1 = in[1]; r2 = in[2]; r3 = in[3]; r4 = in[4]; do { d0 = r0 * 2; d1 = r1 * 2; d2 = r2 * 2 * 19; d419 = r4 * 19; d4 = d419 * 2; t[0] = ((uint128_t) r0) * r0 + ((uint128_t) d4) * r1 + (((uint128_t) d2) * (r3 )); t[1] = ((uint128_t) d0) * r1 + ((uint128_t) d4) * r2 + (((uint128_t) r3) * (r3 * 19)); t[2] = ((uint128_t) d0) * r2 + ((uint128_t) r1) * r1 + (((uint128_t) d4) * (r3 )); t[3] = ((uint128_t) d0) * r3 + ((uint128_t) d1) * r2 + (((uint128_t) r4) * (d419 )); t[4] = ((uint128_t) d0) * r4 + ((uint128_t) d1) * r3 + (((uint128_t) r2) * (r2 )); " } ``` 3. The `src/Specific/CurveParameters/remake_curves.sh` script holds a list of curves to be made, what directories they should end up living in, and it invokes `src/Specific/Framework/make_curve.py` to transform these files into outputs. The Python script fills in a few defaults (such as computing `s` and `c` from the modulus, if you don't pass them explicitly), and does a lot of processing on the C code that is pasted verbatim from donna to get it to be in the right format for Coq. This Python script creates the files: - `CurveParameters.v` (the Coq-ified version of the json file, which instantiates an appropriate module type); - `Synthesis.v`, which instantiates a `MakeSynthesisTactics` with the curve parameter modules, invokes a tactic from the applied module functor to synthesize all of the relevant non-reflective bits (basically, what used to live in @jadephilipoom 's `ArithmeticSynthesisTest.v`), and then instantiates another module functor `PackageSynthesis` which defines notations via tactics in terms to access the names of the various fields defined by the synthesis tactic; - any other files you ask it for, such as `compiler.sh`, `femul.v`, `femulDisplay.v`. All of the `*Display.v` files are simple, and all the the operation synthesis files have a single `Definition` (with the appropriate type), and solve the definition by invoking a single tactic defined in `PackageSynthesis`, e.g., `synthesize_mul` or `synthesize_ladderstep`.
Diffstat (limited to 'src/Specific/Framework/make_curve.py')
-rwxr-xr-xsrc/Specific/Framework/make_curve.py52
1 files changed, 43 insertions, 9 deletions
diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py
index 58ffdc31a..b530a72ee 100755
--- a/src/Specific/Framework/make_curve.py
+++ b/src/Specific/Framework/make_curve.py
@@ -1,11 +1,11 @@
#!/usr/bin/env python
from __future__ import with_statement
-import json, sys, os, math, re
+import json, sys, os, math, re, shutil
def compute_bitwidth(base):
return 2**int(math.ceil(math.log(base, 2)))
-def compute_sz(modulus, bitwidth):
- return 1 + int(math.ceil(math.log(modulus, 2) / bitwidth))
+def compute_sz(modulus, base):
+ return 1 + int(math.ceil(math.log(modulus, 2) / base))
def default_carry_chain(cc):
assert(cc == 'carry_chain1')
return 'Some (seq 0 (pred sz))'
@@ -67,6 +67,9 @@ def format_c_code(header, code, numargs, sz, indent=' ', closing_indent='
lines = [repeat_until_unchanged((lambda line: re.sub(r'\(([A-Za-z0-9_]+)\)', r'\1', line)),
line)
for line in lines]
+ lines = [repeat_until_unchanged((lambda line: re.sub(r'\(([A-Za-z0-9_]+\[[0-9]+\])\)', r'\1', line)),
+ line)
+ for line in lines]
out_match = re.match(r'^\s*u?int[0-9]+_t ([A-Za-z_][A-Za-z_0-9]*)\[([0-9]+)\]$', lines[0])
if out_match is not None:
out_var, out_count = out_match.groups()
@@ -81,8 +84,8 @@ def format_c_code(header, code, numargs, sz, indent=' ', closing_indent='
limb_match = re.match(r'^\s*limb [a-zA-Z0-9, ]+$', line)
in_match = re.match(r'^\s*([A-Za-z_][A-Za-z0-9_]*)\s*=\s*in([0-9]*)\[([0-9]+)\]$', line)
fixed_line = do_fix(line)
- normal_match = re.match(r'^(\s*)([A-Za-z_][A-Za-z0-9_]*)(\s*)=(\s*)([A-Za-z_0-9\(\) *+-]+)$', fixed_line)
- upd_match = re.match(r'^(\s*)([A-Za-z_][A-Za-z0-9_]*)(\s*)([*+])=(\s*)([A-Za-z_0-9\(\) *+-]+)$', fixed_line)
+ normal_match = re.match(r'^(\s*)([A-Za-z_][A-Za-z0-9_]*)(\s*)=(\s*)([A-Za-z_0-9\(\)\s<>*+-]+)$', fixed_line)
+ upd_match = re.match(r'^(\s*)([A-Za-z_][A-Za-z0-9_]*)(\s*)([*+])=(\s*)([A-Za-z_0-9\(\)\s<>*+-]+)$', fixed_line)
if line == '':
ret_code.append(line)
elif out_match or limb_match: pass
@@ -104,13 +107,25 @@ def format_c_code(header, code, numargs, sz, indent=' ', closing_indent='
else:
print('Unhandled line:')
raw_input(line)
- main_code = '\n'.join((indent + i.strip(' \n')).rstrip() for i in ' '.join(ret_code).strip().split('\n'))
+ ret_code = ' '.join(ret_code).strip().split('\n')
+ ret_code = [((indent + i.strip(' \n')) if i.strip()[:len('dlet ')] == 'dlet '
+ else (indent + ' ' + i.rstrip(' \n'))).rstrip()
+ for i in ret_code]
+ main_code = '\n'.join(ret_code)
arg_code = []
for in_count in sorted(input_map.keys()):
arg_code.append("%slet '(%s) := %s in"
% (indent,
', '.join(v for k, v in sorted(input_map[in_count].items(), reverse=True)),
ARGS[in_count]))
+ if len(input_map.keys()) == 0:
+ for in_count in range(numargs):
+ in_str = str(in_count + 1)
+ if in_count == 0: in_str = ''
+ arg_code.append("%slet '(%s) := %s in"
+ % (indent,
+ ', '.join(do_fix('in%s[%d]' % (in_str, v)) for v in reversed(range(sz))),
+ ARGS[in_count]))
ret += '\n%s\n' % '\n'.join(arg_code)
ret += main_code
ret += '\n%s(%s)' % (indent, ', '.join(do_fix('%s[%d]' % (out_var, i)) for i in reversed(range(sz))))
@@ -121,10 +136,10 @@ def make_curve_parameters(parameters):
replacements = dict(parameters)
assert(all(ch in '0123456789^+- ' for ch in parameters['modulus']))
modulus = eval(parameters['modulus'].replace('^', '**'))
- base = int(parameters['base'])
+ base = float(parameters['base'])
replacements['bitwidth'] = parameters.get('bitwidth', str(compute_bitwidth(base)))
bitwidth = int(replacements['bitwidth'])
- replacements['sz'] = parameters.get('sz', str(compute_sz(modulus, bitwidth)))
+ replacements['sz'] = parameters.get('sz', str(compute_sz(modulus, base)))
sz = int(replacements['sz'])
for cc in ('carry_chain1', 'carry_chain2'):
if cc in replacements.keys() and isinstance(replacements[cc], list):
@@ -293,6 +308,13 @@ Require Import Crypto.Specific.Framework.IntegrationTestDisplayCommon.
Check display %(function_name)s.
""" % locals()
+def make_compiler(compiler):
+ return r"""#!/bin/sh
+set -eu
+
+%s "$@"
+""" % compiler
+
def main(*args):
if '--help' in args[1:] or '-h' in args[1:]: usage(0)
@@ -302,7 +324,7 @@ def main(*args):
with open(args[1], 'r') as f:
parameters = f.read()
output_folder = os.path.realpath(args[2])
- assert('|' not in parameters)
+ parameters_folder = os.path.dirname(os.path.realpath(args[1]))
parameters = json.loads(parameters, strict=False)
root = get_file_root(folder=output_folder)
output_prefix = 'Crypto.' + os.path.normpath(os.path.relpath(output_folder, os.path.join(root, 'src'))).replace(os.sep, '.')
@@ -312,6 +334,10 @@ def main(*args):
for arg in parameters['operations']:
outputs[arg + '.v'] = make_synthesized_arg(arg, output_prefix)
outputs[arg + 'Display.v'] = make_display_arg(arg, output_prefix)
+ for fname in parameters.get('extra_files', []):
+ outputs[os.path.basename(fname)] = open(os.path.join(parameters_folder, fname), 'r').read()
+ if 'compiler' in parameters.keys():
+ outputs['compiler.sh'] = make_compiler(parameters['compiler'])
file_list = tuple((k, os.path.join(output_folder, k)) for k in sorted(outputs.keys()))
if not force:
extant_files = [os.path.relpath(fname, os.getcwd())
@@ -322,12 +348,20 @@ def main(*args):
sys.exit(1)
if not os.path.isdir(output_folder):
os.makedirs(output_folder)
+ new_files = []
for k, fname in file_list:
if os.path.isfile(fname):
if open(fname, 'r').read() == outputs[k]:
continue
+ new_files.append(fname)
with open(fname, 'w') as f:
f.write(outputs[k])
+ if fname[-len('compiler.sh'):] == 'compiler.sh':
+ mode = os.fstat(f.fileno()).st_mode
+ mode |= 0o111
+ os.fchmod(f.fileno(), mode & 0o7777)
+ if len(new_files) > 0:
+ print('git add ' + ' '.join('"%s"' % i for i in new_files))
if __name__ == '__main__':
main(*sys.argv)