aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/ArithmeticSynthesisTest.v
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-04-27 14:33:36 -0400
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2017-05-01 14:34:48 -0400
commitbe54c2704f7bac666a3d4ca200f0622d6b38d7cb (patch)
tree9f03575d79c3430201630d331574b199706b6556 /src/Specific/ArithmeticSynthesisTest.v
parent3834198c658c888f33e572ee2eb6524073d9fbbe (diff)
first synthesis of freeze code
Diffstat (limited to 'src/Specific/ArithmeticSynthesisTest.v')
-rw-r--r--src/Specific/ArithmeticSynthesisTest.v87
1 files changed, 82 insertions, 5 deletions
diff --git a/src/Specific/ArithmeticSynthesisTest.v b/src/Specific/ArithmeticSynthesisTest.v
index 4750fd12c..5decf17da 100644
--- a/src/Specific/ArithmeticSynthesisTest.v
+++ b/src/Specific/ArithmeticSynthesisTest.v
@@ -20,6 +20,7 @@ Section Ops51.
(* These definitions will need to be passed as Ltac arguments (or
cleverly inferred) when things are eventually automated *)
Definition sz := 5%nat.
+ Definition bitwidth := 64.
Definition s : Z := 2^255.
Definition c : list B.limb := [(1, 19)].
Definition coef_div_modulus : nat := 2. (* add 2*modulus before subtracting *)
@@ -33,15 +34,14 @@ Section Ops51.
let si := Z.log2 s * i in
2 ^ ((si/sz) + (if dec ((si/sz)*sz=si) then 0 else 1)).
Definition sz2 := Eval vm_compute in ((sz * 2) - 1)%nat.
+ Definition m_enc :=
+ Eval vm_compute in (Positional.encode (modulo:=modulo) (div:=div) (n:=sz) wt (s-Associational.eval c)).
Definition coef := (* subtraction coefficient *)
Eval vm_compute in
- ( let p := Positional.encode
- (modulo:=modulo) (div:=div) (n:=sz)
- wt (s-Associational.eval c) in
- (fix addp (acc: Z^sz) (ctr : nat) : Z^sz :=
+ ((fix addm (acc: Z^sz) (ctr : nat) : Z^sz :=
match ctr with
| O => acc
- | S n => addp (Positional.add_cps wt acc p id) n
+ | S n => addm (Positional.add_cps wt acc m_enc id) n
end) (Positional.zeros sz) coef_div_modulus).
Definition coef_mod : mod_eq m (Positional.eval (n:=sz) wt coef) 0 := eq_refl.
@@ -187,6 +187,83 @@ Section Ops51.
solve_op_F wt x. reflexivity.
Defined.
+ Require Import Crypto.Arithmetic.Saturated.
+
+ Section PreFreeze.
+ Definition add_get_carry (s x y : Z) : Z * Z :=
+ dlet z := (x + y)%RT in (Core.modulo z s, Core.div z s).
+
+ Lemma add_get_carry_mod s x y :
+ fst (add_get_carry s x y) = (x + y) mod s.
+ Proof.
+ cbv [add_get_carry]; autorewrite with cancel_pair.
+ apply modulo_correct.
+ Qed.
+
+ Lemma add_get_carry_div s x y :
+ snd (add_get_carry s x y) = (x + y) / s.
+ Proof.
+ cbv [add_get_carry]; autorewrite with cancel_pair.
+ apply div_correct.
+ Qed.
+
+ Lemma wt_pos i : wt i > 0.
+ Proof.
+ apply Z.lt_gt.
+ apply Z.pow_pos_nonneg; zero_bounds; try break_match; vm_decide.
+ Qed.
+
+ Lemma wt_multiples i : wt (S i) mod (wt i) = 0.
+ Admitted.
+
+ Lemma wt_divides_full_pos i : wt (S i) / wt i > 0.
+ Proof.
+ pose proof (wt_divides_full i).
+ apply Z.div_positive_gt_0; auto using wt_pos.
+ apply wt_multiples.
+ Qed.
+
+ Context {select_cps : forall n, Z -> Z^n -> forall {T}, (Z^n->T) -> T}
+ {select : forall n, Z -> Z^n -> Z^n}
+ {select_id : forall n cond x T f, @select_cps n cond x T f = f (select n cond x)}
+ {eval_select : forall n cond x,
+ B.Positional.eval wt (select n cond x) = if dec (cond = 0) then 0 else B.Positional.eval wt x}
+ .
+
+ Hint Rewrite select_id : uncps.
+ Hint Rewrite eval_select : push_basesystem_eval.
+
+ Hint Opaque freeze : uncps.
+ Hint Rewrite freeze_id : uncps.
+ End PreFreeze.
+
+ Definition freeze_sig :
+ {freeze : (Z^sz -> Z^sz)%type |
+ forall a : Z^sz,
+ (0 <= Positional.eval wt a < 2 * Z.pos m)->
+ let eval := Positional.Fdecode (m := m) wt in
+ eval (freeze a) = eval a}.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ pose proof wt_nonzero. pose proof wt_pos.
+ pose proof div_mod. pose proof wt_divides_full_pos.
+ pose proof wt_multiples.
+ pose proof add_get_carry_mod.
+ pose proof add_get_carry_div.
+ pose proof div_correct. pose proof modulo_correct.
+ pose proof select_id. pose proof eval_select.
+ let x := constr:(freeze (n:=5) (add_get_carry:=add_get_carry) (div:=div) (modulo:=modulo) (select_cps:=select_cps) wt m_enc a) in
+ F_mod_eq;
+ transitivity (Positional.eval wt x); repeat autounfold;
+ [ | autorewrite with uncps push_id push_basesystem_eval;
+ rewrite eval_freeze with (c:=c);
+ try eassumption; try omega; try reflexivity; vm_decide].
+ cbv[mod_eq]; apply f_equal2;
+ [ | reflexivity ]; apply f_equal.
+ cbv - [runtime_opp runtime_add runtime_mul runtime_shr runtime_and Let_In add_get_carry].
+ reflexivity.
+ Defined.
+
Definition ring_51 :=
(Ring.ring_by_isomorphism
(F := F m)