diff options
author | jadep <jade.philipoom@gmail.com> | 2017-04-11 23:57:30 -0400 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2017-05-01 14:34:48 -0400 |
commit | 373bea4640df5c0d3858b4b628df171783a0812a (patch) | |
tree | e1e014a5cfaff0b42e4989cd56d67710c61a3139 /src/Arithmetic | |
parent | 08be7fa27881cf4bef5bede9d07feaaa9025b9a4 (diff) |
first attempts at freeze
Diffstat (limited to 'src/Arithmetic')
-rw-r--r-- | src/Arithmetic/Saturated.v | 223 |
1 files changed, 219 insertions, 4 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index ec1280213..87c0e5ec9 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -8,6 +8,7 @@ Require Import Crypto.Arithmetic.Core. Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. Require Import Crypto.Util.Tuple Crypto.Util.ListUtil. Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Decidable Crypto.Util.ZUtil. Local Notation "A ^ n" := (tuple A n) : type_scope. (*** @@ -101,7 +102,7 @@ check confirms our result. Module Columns. Section Columns. - Context {weight : nat->Z} + Context (weight : nat->Z) {weight_0 : weight 0%nat = 1} {weight_nonzero : forall i, weight i <> 0} {weight_multiples : forall i, weight (S i) mod weight i = 0} @@ -231,12 +232,12 @@ Module Columns. Qed. Lemma eval_compact {n} (xs : tuple (list Z) n) : - B.Positional.eval weight (snd (compact xs)) + (weight n * fst (compact xs)) = eval xs. + B.Positional.eval weight (snd (compact xs)) = eval xs - (weight n * fst (compact xs)). Proof using Type*. pose proof (compact_invariant_end 0 xs) as Hinv. cbv [compact_invariant] in Hinv. simpl in Hinv. autorewrite with zsimplify natsimplify in Hinv. - rewrite eval_from_0, B.Positional.eval_from_0 in Hinv; apply Hinv. + rewrite eval_from_0, B.Positional.eval_from_0 in Hinv. nsatz. Qed. Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n) @@ -339,9 +340,223 @@ Module Columns. (fun P => B.Positional.to_associational_cps weight q (fun Q => from_associational_cps n (P++Q) f)). + Definition sub_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.to_associational_cps weight q + (fun Q => from_associational_cps n (P++Q) f)). + End Columns. End Columns. - +Hint Unfold + Columns.add_cps + Columns.mul_cps + Columns.sub_cps. +Hint Rewrite + @Columns.compact_digit_id + @Columns.compact_step_id + @Columns.compact_id + @Columns.cons_to_nth_id + @Columns.from_associational_id + : uncps. +Hint Rewrite + @Columns.eval_compact + @Columns.eval_cons_to_nth + @Columns.eval_from_associational + @Columns.eval_nils + using (assumption || omega): push_basesystem_eval. + +Section Freeze. + Context (weight : nat->Z) + {weight_0 : weight 0%nat = 1} + {weight_nonzero : forall i, weight i <> 0} + {weight_multiples : forall i, weight (S i) mod weight i = 0} + (* add_get_carry takes in a number at which to split output *) + {add_get_carry: Z ->Z -> Z -> (Z * Z)} + {add_get_carry_correct : forall s x y, + fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)} + . + + (* adds p and q if cond is 0, else adds 0 to p*) + Definition conditional_mask_cps {n} (mask:Z) (cond:Z) (p:Z^n) + {T} (f:_->T) := + dlet and_term := if (dec (cond = 0)) then 0 else mask in + f (Tuple.map (Z.land and_term) p). + + Definition conditional_mask {n} mask cond p := + @conditional_mask_cps n mask cond p _ id. + Lemma conditional_mask_id {n} mask cond p T f: + @conditional_mask_cps n mask cond p T f + = f (conditional_mask mask cond p). + Proof. + cbv [conditional_mask_cps conditional_mask Let_In]; break_match; + autounfold; autorewrite with uncps push_id; reflexivity. + Qed. + Hint Opaque conditional_mask : uncps. + Hint Rewrite @conditional_mask_id : uncps. + + Definition conditional_add_cps {n} mask cond (p q : Z^n) {T} (f:_->T) := + conditional_mask_cps mask cond q + (fun qq => Columns.add_cps weight p qq + (fun R => Columns.compact_cps (add_get_carry:=add_get_carry) weight R f)). + Definition conditional_add {n} mask cond p q := + @conditional_add_cps n mask cond p q _ id. + Lemma conditional_add_id {n} mask cond p q T f: + @conditional_add_cps n mask cond p q T f + = f (conditional_add mask cond p q). + Proof. + cbv [conditional_add_cps conditional_add]; autounfold; + autorewrite with uncps push_id; reflexivity. + Qed. + Hint Opaque conditional_add : uncps. + Hint Rewrite @conditional_add_id : uncps. + + + (* + [freeze] has the following steps: + (1) pseudomersenne reduction using [carry_reduce] + (2) subtract modulus in a carrying loop (in our framework, this + consists of two steps; [Columns.sub_cps] combines the input p and + the modulus m such that the ith limb in the output is the list + [p[i];-m[i]]. We can then call [Columns.compact].) + (3) look at the final carry, which should be either 0 or -1. If + it's -1, then we add the modulus back in. Otherwise we add 0 for + constant-timeness. + (4) discard the carry after this last addition; it should be 1 if + the carry in step 3 was -1, so they cancel out. + *) + Definition freeze_cps + {n} (mask:Z) (s:Z) (c:list B.limb) (m:Z^n) (p:Z^n) + {T} (f : Z^n->T) := + B.Positional.carry_reduce_cps + (div:=div) (modulo:=modulo) weight s c p + (fun P => Columns.sub_cps weight P m + (fun Q => Columns.compact_cps (add_get_carry:=add_get_carry) weight Q + (fun carry_q => conditional_add_cps mask (fst carry_q) (snd carry_q) m + (fun carry_r => f (snd carry_r))))) + . + + SearchAbout (((_ mod _) mod _)). + + Definition freeze {n} mask s c m p := + @freeze_cps n mask s c m p _ id. + Lemma freeze_id {n} mask s c m p T f: + @freeze_cps n mask s c m p T f = f (freeze mask s c m p). + Proof. + cbv [freeze_cps freeze]; repeat progress autounfold; + autorewrite with uncps push_id; reflexivity. + Qed. + Hint Opaque freeze : uncps. + Hint Rewrite @freeze_id : uncps. + + Lemma map_land_zero {n} (p:Z^n): + Tuple.map (Z.land 0) p = B.Positional.zeros n. + Proof. + induction n; [ destruct p; reflexivity | ]. + replace p with (append (hd p) (tl p)) by + (simpl in p; destruct n; destruct p; reflexivity). + rewrite map_append, IHn, Z.land_0_l; reflexivity. + Qed. + + Lemma eval_conditional_mask {n} mask cond p (n_nonzero:n<>0%nat) + (Hmask : Tuple.map (Z.land mask) p = p): + B.Positional.eval weight (@conditional_mask n mask cond p) + = if (dec (cond = 0)) then 0 else B.Positional.eval weight p. + Proof. + cbv [conditional_mask_cps conditional_mask Let_In]; + repeat progress autounfold; break_match; + rewrite ?Hmask, ?map_land_zero; + autorewrite with uncps push_id push_basesystem_eval; ring. + Qed. + Hint Rewrite @eval_conditional_mask using (omega || assumption) + : push_basesystem_eval. + + Lemma eval_conditional_add {n} mask cond p q (n_nonzero:n<>0%nat) + (Hmask : Tuple.map (Z.land mask) q = q): + B.Positional.eval weight (snd (@conditional_add n mask cond p q)) + = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add mask cond p q)). + Proof. + cbv [conditional_add_cps conditional_add]; + repeat progress autounfold; rewrite ?Hmask, ?map_land_zero; + autorewrite with uncps push_id push_basesystem_eval; + break_match; ring. + Qed. + Hint Rewrite @eval_conditional_add using (omega || assumption) + : push_basesystem_eval. + + Lemma freezeZ m s c y0 z z0 c0 a : + m = s - c -> + 0 < c < s -> + s <> 0 -> + -m <= y0 < m -> + z = y0 mod s -> + c0 = y0 / s -> + c0 <> 0 -> + z0 = z + m -> + a = z0 mod s -> + a mod m = y0 mod m. + Proof. + clear. intros. subst. + rewrite Z.add_mod by assumption. + rewrite Z.mod_mod by assumption. + rewrite <-Z.add_mod by assumption. + assert (~ (0 <= y0 < s)) by (pose proof (Z.div_small y0 s); tauto). + assert (-(s-c) <= y0 < 0) by omega. + rewrite Z.mod_small with (b := s) by omega. + rewrite Z.add_mod, Z.mod_same, Z.add_0_r, Z.mod_mod by omega. + reflexivity. + Qed. + + Lemma eval_freeze {n} mask s c m p + (n_nonzero:n<>0%nat) (s_nonzero:s<>0) + (Hweight : weight (S (pred n)) / weight (pred n) <> 0) + (Hmask : Tuple.map (Z.land mask) m = m) + modulus (Hm : B.Positional.eval weight m = Z.pos modulus) + (Hsc : Z.pos modulus = s - B.Associational.eval c) + : + mod_eq modulus + (B.Positional.eval weight (@freeze n mask s c m p)) + (B.Positional.eval weight p). + Proof. + cbv [freeze_cps freeze mod_eq]; repeat progress autounfold; + autorewrite with uncps push_id. + + assert (Z.pos modulus <> 0) by (pose proof Pos2Z.is_pos modulus; omega). + pose proof div_mod. + break_match; subst. + + rewrite Z.mul_0_r, Z.add_0_r, Z.sub_0_r. + (* TODO : how to prove second carry is 0? *) + rewrite Z.add_mod, B.Associational.eval_reduce by assumption. + autorewrite with uncps push_id push_basesystem_eval. + rewrite Hm. autorewrite with zsimplify. + reflexivity. + + + ring_simplify. + let p := fresh "P" in + let carry := fresh "carry" in + let result := fresh "result" in + match goal with H:fst ?x <> 0 |- _ => + remember x as p; destruct p as [carry result]; + autorewrite with cancel_pair in * + end. + rewrite Columns.eval_from_associational by assumption. + autorewrite with uncps push_id push_basesystem_eval. + rewrite Hm. + Check Z.add_mod_full. + match goal with |- context [(?a + ?b - ?c + ?d) mod ?m] => + replace (a + b - c + d) with (b + (a-c) + d) by ring; + rewrite (Z.add_mod_full _ d), (Z.add_mod_full _ (a-c)) + end. + rewrite !Z.mod_same by assumption. + rewrite !Z.add_0_r, !Z.add_0_l. + rewrite !Z.mod_mod by assumption. + rewrite Z.sub_mod_full. + rewrite B.Associational.eval_reduce by assumption. + autorewrite with uncps push_id push_basesystem_eval. + + + (* (* Just some pretty-printing *) Local Notation "fst~ a" := (let (x,_) := a in x) (at level 40, only printing). |