diff options
author | jadep <jade.philipoom@gmail.com> | 2017-05-14 15:53:46 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2017-05-14 15:56:56 -0400 |
commit | 0697253755e9816f5599fbf77c04b5f3db795e16 (patch) | |
tree | f4b1d7adc61f65722e305123d34ab219aabfc87b | |
parent | 7df5033c871aef6172d4e98d42ce00005e24f73e (diff) |
make freeze use the correct versions of add_get_carry and zselect
-rw-r--r-- | src/Arithmetic/Core.v | 78 | ||||
-rw-r--r-- | src/Arithmetic/Saturated.v | 149 | ||||
-rw-r--r-- | src/Specific/ArithmeticSynthesisTest.v | 24 |
3 files changed, 110 insertions, 141 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v index 06d0409e7..f294c06fc 100644 --- a/src/Arithmetic/Core.v +++ b/src/Arithmetic/Core.v @@ -248,6 +248,7 @@ Require Import Crypto.Algebra.Nsatz. Require Import Crypto.Util.Decidable Crypto.Util.LetIn. Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma. Require Import Crypto.Util.CPSUtil Crypto.Util.Prod. +Require Import Crypto.Util.ZUtil.Zselect. Require Import Crypto.Arithmetic.PrimeFieldTheorems. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Tactics.UniquePose. @@ -833,39 +834,43 @@ Module B. End EvalHelpers. Section Select. - Context {weight : nat -> Z} - {select_single : Z -> Z -> Z} - {select_single_correct : forall cond x, - select_single cond x = if dec (cond = 0) then 0 else x} - . + Context {weight : nat -> Z}. - Definition select_cps {n} cond (p : tuple Z n) {T} (f:_->T) := - Tuple.map_cps (select_single cond) p f. + Definition select_cps {n} (mask cond:Z) (p:tuple Z n) + {T} (f:tuple Z n->T) := + dlet t := Z.zselect cond 0 mask in Tuple.map_cps (runtime_and t) p f. - Definition select {n} cond p := @select_cps n cond p _ id. - Lemma select_id {n} cond p T f : - @select_cps n cond p T f = f (select cond p). + Definition select {n} mask cond p := @select_cps n mask cond p _ id. + Lemma select_id {n} mask cond p T f : + @select_cps n mask cond p T f = f (select mask cond p). Proof. - cbv [select_cps select]. autorewrite with uncps push_id. - reflexivity. + cbv [select select_cps Let_In]; autorewrite with uncps push_id; + reflexivity. Qed. Hint Opaque select : uncps. - Hint Rewrite @select_id : uncps. - Lemma eval_select {n} cond p : - eval weight (@select n cond p) = if dec (cond = 0) then 0 else eval weight p. + Lemma map_and_0 {n} (p:tuple Z n) : Tuple.map (Z.land 0) p = zeros n. Proof. - cbv [select select_cps]; autorewrite with uncps push_id. - induction n; [destruct p|]. - { break_match; reflexivity. } - { rewrite (Tuple.subst_left_append p). - rewrite Tuple.map_left_append, !eval_left_append. - rewrite select_single_correct, IHn. - break_match; ring. } - Qed. Hint Rewrite @eval_select : push_basesystem_eval. + induction n; [destruct p; reflexivity | ]. + rewrite (Tuple.subst_append p), Tuple.map_append, Z.land_0_l, IHn. + reflexivity. + Qed. + + Lemma eval_select {n} mask cond x (H:Tuple.map (Z.land mask) x = x) : + B.Positional.eval weight (@select n mask cond x) = + if dec (cond = 0) then 0 else B.Positional.eval weight x. + Proof. + cbv [select select_cps Let_In]. + autorewrite with uncps push_id. + rewrite Z.zselect_correct; break_match. + { rewrite map_and_0. apply B.Positional.eval_zeros. } + { change runtime_and with Z.land. rewrite H; reflexivity. } + Qed. + End Select. End Positional. + Hint Unfold Positional.add_cps Positional.mul_cps @@ -942,33 +947,6 @@ Section DivMod. Qed. End DivMod. -Section ZSelect. - - Definition mask width cond := - if dec (cond = 0) then 0 else Z.ones width. - - Definition zselect bitwidth (cond x : Z) : Z := - if (dec (x <= 0)) - then (if dec (cond = 0) then 0 else x) - else (let width := Z.max (Z.log2 x + 1) bitwidth in - dlet t := mask width cond in x &' t). - - Lemma zselect_correct bw cond x : - zselect bw cond x = if dec (cond = 0) then 0 else x. - Proof. - cbv [zselect mask Let_In]; break_match; - rewrite ?Z.land_0_r; try reflexivity; [ ]. - pose proof (Z.log2_nonneg x). - pose proof (Z.log2_spec x) as Hlog2. - rewrite <-Z.add_1_r in Hlog2. - apply Z.max_case_strong; intros; rewrite Z.land_ones by omega; - apply Z.mod_small; split; try omega. - apply Z.lt_le_trans with (m:=2 ^ (Z.log2 x + 1)); - [|apply Z.pow_le_mono_r]; omega. - Qed. - -End ZSelect. - Import B. Ltac basesystem_partial_evaluation_RHS := diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index 50dccd8c7..2c2dc60d0 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -9,6 +9,9 @@ Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. Require Import Crypto.Util.Tuple Crypto.Util.ListUtil. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Decidable Crypto.Util.ZUtil. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.AddGetCarry. +Require Import Crypto.Util.ZUtil.Zselect. Local Notation "A ^ n" := (tuple A n) : type_scope. (*** @@ -449,62 +452,80 @@ Section Freeze. {weight_positive : forall i, weight i > 0} {weight_multiples : forall i, weight (S i) mod weight i = 0} {weight_divides : forall i : nat, weight (S i) / weight i > 0} - (* add_get_carry takes in a number at which to split output *) - {add_get_carry: Z ->Z -> Z -> (Z * Z)} - {add_get_carry_mod : forall s x y, - fst (add_get_carry s x y) = (x + y) mod s} - {add_get_carry_div : forall s x y, - snd (add_get_carry s x y) = (x + y) / s} {div modulo : Z -> Z -> Z} {div_correct : forall a b, div a b = a / b} {modulo_correct : forall a b, modulo a b = a mod b} - {select_cps : forall n, Z -> Z^n -> forall {T}, (Z^n->T) -> T} - {select : forall n, Z -> Z^n -> Z^n} - {select_id : forall n cond x T f, @select_cps n cond x T f = f (select n cond x)} - {eval_select : forall n cond x, - B.Positional.eval weight (select n cond x) = if dec (cond = 0) then 0 else B.Positional.eval weight x} . - Hint Rewrite select_id : uncps. - Hint Rewrite eval_select : push_basesystem_eval. + Definition select_cps {n} (mask cond:Z) (p:Z^n) {T} (f:Z^n->T) := + dlet t := Z.zselect cond 0 mask in Tuple.map_cps (runtime_and t) p f. - (* - (* adds p and q if cond is 0, else adds 0 to p*) - Definition conditional_mask_cps {n} (mask:Z) (cond:Z) (p:Z^n) - {T} (f:_->T) := - dlet and_term := if (dec (cond = 0)) then 0 else mask in - f (Tuple.map (Z.land and_term) p). - - Definition conditional_mask {n} mask cond p := - @conditional_mask_cps n mask cond p _ id. - Lemma conditional_mask_id {n} mask cond p T f: - @conditional_mask_cps n mask cond p T f - = f (conditional_mask mask cond p). - Proof. - cbv [conditional_mask_cps conditional_mask Let_In]; break_match; - autounfold; autorewrite with uncps push_id; reflexivity. - Qed. - Hint Opaque conditional_mask : uncps. - Hint Rewrite @conditional_mask_id : uncps. + Definition select {n} mask cond p := @select_cps n mask cond p _ id. + Lemma select_id {n} mask cond p T f : + @select_cps n mask cond p T f = f (select mask cond p). + Proof. + cbv [select select_cps Let_In]; autorewrite with uncps push_id; + reflexivity. + Qed. + Hint Opaque select : uncps. + Hint Rewrite @select_id : uncps. - *) + Lemma map_and_0 {n} (p:Z^n) : + Tuple.map (Z.land 0) p = B.Positional.zeros n. + Proof. + induction n; [destruct p; reflexivity | ]. + rewrite (subst_append p), map_append, Z.land_0_l, IHn. + reflexivity. + Qed. - Definition conditional_add_cps {n} cond (p q : Z^n) {T} (f:_->T) := - select_cps n cond q _ - (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight p qq f). - Definition conditional_add {n} cond p q := - @conditional_add_cps n cond p q _ id. - Lemma conditional_add_id {n} cond p q T f: - @conditional_add_cps n cond p q T f - = f (conditional_add cond p q). + Lemma eval_select {n} mask cond x (H:Tuple.map (Z.land mask) x = x) : + B.Positional.eval weight (@select n mask cond x) = + if dec (cond = 0) then 0 else B.Positional.eval weight x. + Proof. + cbv [select select_cps Let_In]. + autorewrite with uncps push_id. + rewrite Z.zselect_correct; break_match. + { rewrite map_and_0. apply B.Positional.eval_zeros. } + { change runtime_and with Z.land. rewrite H; reflexivity. } + Qed. + Hint Rewrite @eval_select using assumption : push_basesystem_eval. + + Definition conditional_add_cps {n} mask cond (p q : Z^n) + {T} (f:_->T) := + select_cps mask cond q + (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight p qq f). + Definition conditional_add {n} mask cond p q := + @conditional_add_cps n mask cond p q _ id. + Lemma conditional_add_id {n} mask cond p q T f: + @conditional_add_cps n mask cond p q T f + = f (conditional_add mask cond p q). + Proof. + cbv [conditional_add_cps conditional_add]; autounfold. + autorewrite with uncps push_id; reflexivity. + Qed. + Hint Opaque conditional_add : uncps. + Hint Rewrite @conditional_add_id : uncps. + + Lemma eval_conditional_add {n} mask cond p q (n_nonzero:n<>0%nat) + (H:Tuple.map (Z.land mask) q = q) : + B.Positional.eval weight (snd (@conditional_add n mask cond p q)) + = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add mask cond p q)). Proof. - cbv [conditional_add_cps conditional_add]; autounfold. - autorewrite with uncps push_id; reflexivity. + cbv [conditional_add_cps conditional_add]; + repeat progress autounfold in *. + pose proof Z.add_get_carry_full_mod. + pose proof Z.add_get_carry_full_div. + autorewrite with uncps push_id push_basesystem_eval. + break_match; + match goal with + |- context [weight ?n * (?x / weight ?n)] => + pose proof (Z.div_mod x (weight n) (weight_nonzero n)) + end; omega. Qed. - Hint Opaque conditional_add : uncps. - Hint Rewrite @conditional_add_id : uncps. - - + Hint Rewrite @eval_conditional_add using (omega || assumption) + : push_basesystem_eval. + + (* The input to [freeze] should be less than 2*m (this can probably be accomplished by a single carry_reduce step, for most moduli). @@ -520,17 +541,17 @@ Section Freeze. (3) discard the carry after this last addition; it should be 1 if the carry in step 3 was -1, so they cancel out. *) - Definition freeze_cps {n} (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := + Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := Columns.sub_cps (div:=div) (modulo:=modulo) - (add_get_carry:=add_get_carry) weight p m - (fun carry_p => conditional_add_cps (fst carry_p) (snd carry_p) m + (add_get_carry:=Z.add_get_carry_full) weight p m + (fun carry_p => conditional_add_cps mask (fst carry_p) (snd carry_p) m (fun carry_r => f (snd carry_r))) . - Definition freeze {n} m p := - @freeze_cps n m p _ id. - Lemma freeze_id {n} m p T f: - @freeze_cps n m p T f = f (freeze m p). + Definition freeze {n} mask m p := + @freeze_cps n mask m p _ id. + Lemma freeze_id {n} mask m p T f: + @freeze_cps n mask m p T f = f (freeze mask m p). Proof. cbv [freeze_cps freeze]; repeat progress autounfold; autorewrite with uncps push_id; reflexivity. @@ -538,21 +559,6 @@ Section Freeze. Hint Opaque freeze : uncps. Hint Rewrite @freeze_id : uncps. - Lemma eval_conditional_add {n} cond p q (n_nonzero:n<>0%nat): - B.Positional.eval weight (snd (@conditional_add n cond p q)) - = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add cond p q)). - Proof. - cbv [conditional_add_cps conditional_add]; - repeat progress autounfold in *. - autorewrite with uncps push_id push_basesystem_eval. - match goal with - |- context [weight ?n * (?x / weight ?n)] => - pose proof (Z.div_mod x (weight n) (weight_nonzero n)) - end; omega. - Qed. - Hint Rewrite @eval_conditional_add using (omega || assumption) - : push_basesystem_eval. - Lemma freezeZ m s c y y0 z z0 c0 a : m = s - c -> 0 < c < s -> @@ -580,19 +586,22 @@ Section Freeze. f_equal. ring. } Qed. - Lemma eval_freeze {n} c m p + Lemma eval_freeze {n} c mask m p (n_nonzero:n<>0%nat) (Hc : 0 < B.Associational.eval c < weight n) + (Hmask : Tuple.map (Z.land mask) m = m) modulus (Hm : B.Positional.eval weight m = Z.pos modulus) (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus)) (Hsc : Z.pos modulus = weight n - B.Associational.eval c) : mod_eq modulus - (B.Positional.eval weight (@freeze n m p)) + (B.Positional.eval weight (@freeze n mask m p)) (B.Positional.eval weight p). Proof. cbv [freeze_cps freeze conditional_add_cps]. repeat progress autounfold. + pose proof Z.add_get_carry_full_mod. + pose proof Z.add_get_carry_full_div. autorewrite with uncps push_id push_basesystem_eval. pose proof (weight_nonzero n). diff --git a/src/Specific/ArithmeticSynthesisTest.v b/src/Specific/ArithmeticSynthesisTest.v index 8f1829347..94530a414 100644 --- a/src/Specific/ArithmeticSynthesisTest.v +++ b/src/Specific/ArithmeticSynthesisTest.v @@ -190,23 +190,6 @@ Section Ops51. Require Import Crypto.Arithmetic.Saturated. Section PreFreeze. - Definition add_get_carry (s x y : Z) : Z * Z := - dlet z := (x + y)%RT in (Core.modulo z s, Core.div z s). - - Lemma add_get_carry_mod s x y : - fst (add_get_carry s x y) = (x + y) mod s. - Proof. - cbv [add_get_carry]; autorewrite with cancel_pair. - apply modulo_correct. - Qed. - - Lemma add_get_carry_div s x y : - snd (add_get_carry s x y) = (x + y) / s. - Proof. - cbv [add_get_carry]; autorewrite with cancel_pair. - apply div_correct. - Qed. - Lemma wt_pos i : wt i > 0. Proof. apply Z.lt_gt. @@ -238,10 +221,9 @@ Section Ops51. pose proof wt_nonzero. pose proof wt_pos. pose proof div_mod. pose proof wt_divides_full_pos. pose proof wt_multiples. - pose proof add_get_carry_mod. - pose proof add_get_carry_div. pose proof div_correct. pose proof modulo_correct. - let x := constr:(freeze (n:=5) (add_get_carry:=add_get_carry) (div:=div) (modulo:=modulo) (select_cps:=(@B.Positional.select_cps (zselect bitwidth))) wt m_enc a) in + About freeze. + let x := constr:(freeze (n:=sz) (div:=div) (modulo:=modulo) wt (Z.ones bitwidth) m_enc a) in F_mod_eq; transitivity (Positional.eval wt x); repeat autounfold; [ | autorewrite with uncps push_id push_basesystem_eval; @@ -252,7 +234,7 @@ Section Ops51. vm_decide]. cbv[mod_eq]; apply f_equal2; [ | reflexivity ]; apply f_equal. - cbv - [runtime_opp runtime_add runtime_mul runtime_shr runtime_and Let_In add_get_carry zselect]. + cbv - [runtime_opp runtime_add runtime_mul runtime_shr runtime_and Let_In Z.add_get_carry Z.zselect]. reflexivity. Defined. |