aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-05-14 15:53:46 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2017-05-14 15:56:56 -0400
commit0697253755e9816f5599fbf77c04b5f3db795e16 (patch)
treef4b1d7adc61f65722e305123d34ab219aabfc87b /src
parent7df5033c871aef6172d4e98d42ce00005e24f73e (diff)
make freeze use the correct versions of add_get_carry and zselect
Diffstat (limited to 'src')
-rw-r--r--src/Arithmetic/Core.v78
-rw-r--r--src/Arithmetic/Saturated.v149
-rw-r--r--src/Specific/ArithmeticSynthesisTest.v24
3 files changed, 110 insertions, 141 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v
index 06d0409e7..f294c06fc 100644
--- a/src/Arithmetic/Core.v
+++ b/src/Arithmetic/Core.v
@@ -248,6 +248,7 @@ Require Import Crypto.Algebra.Nsatz.
Require Import Crypto.Util.Decidable Crypto.Util.LetIn.
Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma.
Require Import Crypto.Util.CPSUtil Crypto.Util.Prod.
+Require Import Crypto.Util.ZUtil.Zselect.
Require Import Crypto.Arithmetic.PrimeFieldTheorems.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.UniquePose.
@@ -833,39 +834,43 @@ Module B.
End EvalHelpers.
Section Select.
- Context {weight : nat -> Z}
- {select_single : Z -> Z -> Z}
- {select_single_correct : forall cond x,
- select_single cond x = if dec (cond = 0) then 0 else x}
- .
+ Context {weight : nat -> Z}.
- Definition select_cps {n} cond (p : tuple Z n) {T} (f:_->T) :=
- Tuple.map_cps (select_single cond) p f.
+ Definition select_cps {n} (mask cond:Z) (p:tuple Z n)
+ {T} (f:tuple Z n->T) :=
+ dlet t := Z.zselect cond 0 mask in Tuple.map_cps (runtime_and t) p f.
- Definition select {n} cond p := @select_cps n cond p _ id.
- Lemma select_id {n} cond p T f :
- @select_cps n cond p T f = f (select cond p).
+ Definition select {n} mask cond p := @select_cps n mask cond p _ id.
+ Lemma select_id {n} mask cond p T f :
+ @select_cps n mask cond p T f = f (select mask cond p).
Proof.
- cbv [select_cps select]. autorewrite with uncps push_id.
- reflexivity.
+ cbv [select select_cps Let_In]; autorewrite with uncps push_id;
+ reflexivity.
Qed.
Hint Opaque select : uncps.
- Hint Rewrite @select_id : uncps.
- Lemma eval_select {n} cond p :
- eval weight (@select n cond p) = if dec (cond = 0) then 0 else eval weight p.
+ Lemma map_and_0 {n} (p:tuple Z n) : Tuple.map (Z.land 0) p = zeros n.
Proof.
- cbv [select select_cps]; autorewrite with uncps push_id.
- induction n; [destruct p|].
- { break_match; reflexivity. }
- { rewrite (Tuple.subst_left_append p).
- rewrite Tuple.map_left_append, !eval_left_append.
- rewrite select_single_correct, IHn.
- break_match; ring. }
- Qed. Hint Rewrite @eval_select : push_basesystem_eval.
+ induction n; [destruct p; reflexivity | ].
+ rewrite (Tuple.subst_append p), Tuple.map_append, Z.land_0_l, IHn.
+ reflexivity.
+ Qed.
+
+ Lemma eval_select {n} mask cond x (H:Tuple.map (Z.land mask) x = x) :
+ B.Positional.eval weight (@select n mask cond x) =
+ if dec (cond = 0) then 0 else B.Positional.eval weight x.
+ Proof.
+ cbv [select select_cps Let_In].
+ autorewrite with uncps push_id.
+ rewrite Z.zselect_correct; break_match.
+ { rewrite map_and_0. apply B.Positional.eval_zeros. }
+ { change runtime_and with Z.land. rewrite H; reflexivity. }
+ Qed.
+
End Select.
End Positional.
+
Hint Unfold
Positional.add_cps
Positional.mul_cps
@@ -942,33 +947,6 @@ Section DivMod.
Qed.
End DivMod.
-Section ZSelect.
-
- Definition mask width cond :=
- if dec (cond = 0) then 0 else Z.ones width.
-
- Definition zselect bitwidth (cond x : Z) : Z :=
- if (dec (x <= 0))
- then (if dec (cond = 0) then 0 else x)
- else (let width := Z.max (Z.log2 x + 1) bitwidth in
- dlet t := mask width cond in x &' t).
-
- Lemma zselect_correct bw cond x :
- zselect bw cond x = if dec (cond = 0) then 0 else x.
- Proof.
- cbv [zselect mask Let_In]; break_match;
- rewrite ?Z.land_0_r; try reflexivity; [ ].
- pose proof (Z.log2_nonneg x).
- pose proof (Z.log2_spec x) as Hlog2.
- rewrite <-Z.add_1_r in Hlog2.
- apply Z.max_case_strong; intros; rewrite Z.land_ones by omega;
- apply Z.mod_small; split; try omega.
- apply Z.lt_le_trans with (m:=2 ^ (Z.log2 x + 1));
- [|apply Z.pow_le_mono_r]; omega.
- Qed.
-
-End ZSelect.
-
Import B.
Ltac basesystem_partial_evaluation_RHS :=
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index 50dccd8c7..2c2dc60d0 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -9,6 +9,9 @@ Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
Require Import Crypto.Util.Tuple Crypto.Util.ListUtil.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Decidable Crypto.Util.ZUtil.
+Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.AddGetCarry.
+Require Import Crypto.Util.ZUtil.Zselect.
Local Notation "A ^ n" := (tuple A n) : type_scope.
(***
@@ -449,62 +452,80 @@ Section Freeze.
{weight_positive : forall i, weight i > 0}
{weight_multiples : forall i, weight (S i) mod weight i = 0}
{weight_divides : forall i : nat, weight (S i) / weight i > 0}
- (* add_get_carry takes in a number at which to split output *)
- {add_get_carry: Z ->Z -> Z -> (Z * Z)}
- {add_get_carry_mod : forall s x y,
- fst (add_get_carry s x y) = (x + y) mod s}
- {add_get_carry_div : forall s x y,
- snd (add_get_carry s x y) = (x + y) / s}
{div modulo : Z -> Z -> Z}
{div_correct : forall a b, div a b = a / b}
{modulo_correct : forall a b, modulo a b = a mod b}
- {select_cps : forall n, Z -> Z^n -> forall {T}, (Z^n->T) -> T}
- {select : forall n, Z -> Z^n -> Z^n}
- {select_id : forall n cond x T f, @select_cps n cond x T f = f (select n cond x)}
- {eval_select : forall n cond x,
- B.Positional.eval weight (select n cond x) = if dec (cond = 0) then 0 else B.Positional.eval weight x}
.
- Hint Rewrite select_id : uncps.
- Hint Rewrite eval_select : push_basesystem_eval.
+ Definition select_cps {n} (mask cond:Z) (p:Z^n) {T} (f:Z^n->T) :=
+ dlet t := Z.zselect cond 0 mask in Tuple.map_cps (runtime_and t) p f.
- (*
- (* adds p and q if cond is 0, else adds 0 to p*)
- Definition conditional_mask_cps {n} (mask:Z) (cond:Z) (p:Z^n)
- {T} (f:_->T) :=
- dlet and_term := if (dec (cond = 0)) then 0 else mask in
- f (Tuple.map (Z.land and_term) p).
-
- Definition conditional_mask {n} mask cond p :=
- @conditional_mask_cps n mask cond p _ id.
- Lemma conditional_mask_id {n} mask cond p T f:
- @conditional_mask_cps n mask cond p T f
- = f (conditional_mask mask cond p).
- Proof.
- cbv [conditional_mask_cps conditional_mask Let_In]; break_match;
- autounfold; autorewrite with uncps push_id; reflexivity.
- Qed.
- Hint Opaque conditional_mask : uncps.
- Hint Rewrite @conditional_mask_id : uncps.
+ Definition select {n} mask cond p := @select_cps n mask cond p _ id.
+ Lemma select_id {n} mask cond p T f :
+ @select_cps n mask cond p T f = f (select mask cond p).
+ Proof.
+ cbv [select select_cps Let_In]; autorewrite with uncps push_id;
+ reflexivity.
+ Qed.
+ Hint Opaque select : uncps.
+ Hint Rewrite @select_id : uncps.
- *)
+ Lemma map_and_0 {n} (p:Z^n) :
+ Tuple.map (Z.land 0) p = B.Positional.zeros n.
+ Proof.
+ induction n; [destruct p; reflexivity | ].
+ rewrite (subst_append p), map_append, Z.land_0_l, IHn.
+ reflexivity.
+ Qed.
- Definition conditional_add_cps {n} cond (p q : Z^n) {T} (f:_->T) :=
- select_cps n cond q _
- (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight p qq f).
- Definition conditional_add {n} cond p q :=
- @conditional_add_cps n cond p q _ id.
- Lemma conditional_add_id {n} cond p q T f:
- @conditional_add_cps n cond p q T f
- = f (conditional_add cond p q).
+ Lemma eval_select {n} mask cond x (H:Tuple.map (Z.land mask) x = x) :
+ B.Positional.eval weight (@select n mask cond x) =
+ if dec (cond = 0) then 0 else B.Positional.eval weight x.
+ Proof.
+ cbv [select select_cps Let_In].
+ autorewrite with uncps push_id.
+ rewrite Z.zselect_correct; break_match.
+ { rewrite map_and_0. apply B.Positional.eval_zeros. }
+ { change runtime_and with Z.land. rewrite H; reflexivity. }
+ Qed.
+ Hint Rewrite @eval_select using assumption : push_basesystem_eval.
+
+ Definition conditional_add_cps {n} mask cond (p q : Z^n)
+ {T} (f:_->T) :=
+ select_cps mask cond q
+ (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight p qq f).
+ Definition conditional_add {n} mask cond p q :=
+ @conditional_add_cps n mask cond p q _ id.
+ Lemma conditional_add_id {n} mask cond p q T f:
+ @conditional_add_cps n mask cond p q T f
+ = f (conditional_add mask cond p q).
+ Proof.
+ cbv [conditional_add_cps conditional_add]; autounfold.
+ autorewrite with uncps push_id; reflexivity.
+ Qed.
+ Hint Opaque conditional_add : uncps.
+ Hint Rewrite @conditional_add_id : uncps.
+
+ Lemma eval_conditional_add {n} mask cond p q (n_nonzero:n<>0%nat)
+ (H:Tuple.map (Z.land mask) q = q) :
+ B.Positional.eval weight (snd (@conditional_add n mask cond p q))
+ = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add mask cond p q)).
Proof.
- cbv [conditional_add_cps conditional_add]; autounfold.
- autorewrite with uncps push_id; reflexivity.
+ cbv [conditional_add_cps conditional_add];
+ repeat progress autounfold in *.
+ pose proof Z.add_get_carry_full_mod.
+ pose proof Z.add_get_carry_full_div.
+ autorewrite with uncps push_id push_basesystem_eval.
+ break_match;
+ match goal with
+ |- context [weight ?n * (?x / weight ?n)] =>
+ pose proof (Z.div_mod x (weight n) (weight_nonzero n))
+ end; omega.
Qed.
- Hint Opaque conditional_add : uncps.
- Hint Rewrite @conditional_add_id : uncps.
-
-
+ Hint Rewrite @eval_conditional_add using (omega || assumption)
+ : push_basesystem_eval.
+
+
(*
The input to [freeze] should be less than 2*m (this can probably
be accomplished by a single carry_reduce step, for most moduli).
@@ -520,17 +541,17 @@ Section Freeze.
(3) discard the carry after this last addition; it should be 1 if
the carry in step 3 was -1, so they cancel out.
*)
- Definition freeze_cps {n} (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
+ Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
Columns.sub_cps (div:=div) (modulo:=modulo)
- (add_get_carry:=add_get_carry) weight p m
- (fun carry_p => conditional_add_cps (fst carry_p) (snd carry_p) m
+ (add_get_carry:=Z.add_get_carry_full) weight p m
+ (fun carry_p => conditional_add_cps mask (fst carry_p) (snd carry_p) m
(fun carry_r => f (snd carry_r)))
.
- Definition freeze {n} m p :=
- @freeze_cps n m p _ id.
- Lemma freeze_id {n} m p T f:
- @freeze_cps n m p T f = f (freeze m p).
+ Definition freeze {n} mask m p :=
+ @freeze_cps n mask m p _ id.
+ Lemma freeze_id {n} mask m p T f:
+ @freeze_cps n mask m p T f = f (freeze mask m p).
Proof.
cbv [freeze_cps freeze]; repeat progress autounfold;
autorewrite with uncps push_id; reflexivity.
@@ -538,21 +559,6 @@ Section Freeze.
Hint Opaque freeze : uncps.
Hint Rewrite @freeze_id : uncps.
- Lemma eval_conditional_add {n} cond p q (n_nonzero:n<>0%nat):
- B.Positional.eval weight (snd (@conditional_add n cond p q))
- = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add cond p q)).
- Proof.
- cbv [conditional_add_cps conditional_add];
- repeat progress autounfold in *.
- autorewrite with uncps push_id push_basesystem_eval.
- match goal with
- |- context [weight ?n * (?x / weight ?n)] =>
- pose proof (Z.div_mod x (weight n) (weight_nonzero n))
- end; omega.
- Qed.
- Hint Rewrite @eval_conditional_add using (omega || assumption)
- : push_basesystem_eval.
-
Lemma freezeZ m s c y y0 z z0 c0 a :
m = s - c ->
0 < c < s ->
@@ -580,19 +586,22 @@ Section Freeze.
f_equal. ring. }
Qed.
- Lemma eval_freeze {n} c m p
+ Lemma eval_freeze {n} c mask m p
(n_nonzero:n<>0%nat)
(Hc : 0 < B.Associational.eval c < weight n)
+ (Hmask : Tuple.map (Z.land mask) m = m)
modulus (Hm : B.Positional.eval weight m = Z.pos modulus)
(Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus))
(Hsc : Z.pos modulus = weight n - B.Associational.eval c)
:
mod_eq modulus
- (B.Positional.eval weight (@freeze n m p))
+ (B.Positional.eval weight (@freeze n mask m p))
(B.Positional.eval weight p).
Proof.
cbv [freeze_cps freeze conditional_add_cps].
repeat progress autounfold.
+ pose proof Z.add_get_carry_full_mod.
+ pose proof Z.add_get_carry_full_div.
autorewrite with uncps push_id push_basesystem_eval.
pose proof (weight_nonzero n).
diff --git a/src/Specific/ArithmeticSynthesisTest.v b/src/Specific/ArithmeticSynthesisTest.v
index 8f1829347..94530a414 100644
--- a/src/Specific/ArithmeticSynthesisTest.v
+++ b/src/Specific/ArithmeticSynthesisTest.v
@@ -190,23 +190,6 @@ Section Ops51.
Require Import Crypto.Arithmetic.Saturated.
Section PreFreeze.
- Definition add_get_carry (s x y : Z) : Z * Z :=
- dlet z := (x + y)%RT in (Core.modulo z s, Core.div z s).
-
- Lemma add_get_carry_mod s x y :
- fst (add_get_carry s x y) = (x + y) mod s.
- Proof.
- cbv [add_get_carry]; autorewrite with cancel_pair.
- apply modulo_correct.
- Qed.
-
- Lemma add_get_carry_div s x y :
- snd (add_get_carry s x y) = (x + y) / s.
- Proof.
- cbv [add_get_carry]; autorewrite with cancel_pair.
- apply div_correct.
- Qed.
-
Lemma wt_pos i : wt i > 0.
Proof.
apply Z.lt_gt.
@@ -238,10 +221,9 @@ Section Ops51.
pose proof wt_nonzero. pose proof wt_pos.
pose proof div_mod. pose proof wt_divides_full_pos.
pose proof wt_multiples.
- pose proof add_get_carry_mod.
- pose proof add_get_carry_div.
pose proof div_correct. pose proof modulo_correct.
- let x := constr:(freeze (n:=5) (add_get_carry:=add_get_carry) (div:=div) (modulo:=modulo) (select_cps:=(@B.Positional.select_cps (zselect bitwidth))) wt m_enc a) in
+ About freeze.
+ let x := constr:(freeze (n:=sz) (div:=div) (modulo:=modulo) wt (Z.ones bitwidth) m_enc a) in
F_mod_eq;
transitivity (Positional.eval wt x); repeat autounfold;
[ | autorewrite with uncps push_id push_basesystem_eval;
@@ -252,7 +234,7 @@ Section Ops51.
vm_decide].
cbv[mod_eq]; apply f_equal2;
[ | reflexivity ]; apply f_equal.
- cbv - [runtime_opp runtime_add runtime_mul runtime_shr runtime_and Let_In add_get_carry zselect].
+ cbv - [runtime_opp runtime_add runtime_mul runtime_shr runtime_and Let_In Z.add_get_carry Z.zselect].
reflexivity.
Defined.