aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-04-11 23:57:30 -0400
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2017-05-01 14:34:48 -0400
commit373bea4640df5c0d3858b4b628df171783a0812a (patch)
treee1e014a5cfaff0b42e4989cd56d67710c61a3139 /src/Arithmetic
parent08be7fa27881cf4bef5bede9d07feaaa9025b9a4 (diff)
first attempts at freeze
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/Saturated.v223
1 files changed, 219 insertions, 4 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index ec1280213..87c0e5ec9 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -8,6 +8,7 @@ Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
Require Import Crypto.Util.Tuple Crypto.Util.ListUtil.
Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Decidable Crypto.Util.ZUtil.
Local Notation "A ^ n" := (tuple A n) : type_scope.
(***
@@ -101,7 +102,7 @@ check confirms our result.
Module Columns.
Section Columns.
- Context {weight : nat->Z}
+ Context (weight : nat->Z)
{weight_0 : weight 0%nat = 1}
{weight_nonzero : forall i, weight i <> 0}
{weight_multiples : forall i, weight (S i) mod weight i = 0}
@@ -231,12 +232,12 @@ Module Columns.
Qed.
Lemma eval_compact {n} (xs : tuple (list Z) n) :
- B.Positional.eval weight (snd (compact xs)) + (weight n * fst (compact xs)) = eval xs.
+ B.Positional.eval weight (snd (compact xs)) = eval xs - (weight n * fst (compact xs)).
Proof using Type*.
pose proof (compact_invariant_end 0 xs) as Hinv.
cbv [compact_invariant] in Hinv.
simpl in Hinv. autorewrite with zsimplify natsimplify in Hinv.
- rewrite eval_from_0, B.Positional.eval_from_0 in Hinv; apply Hinv.
+ rewrite eval_from_0, B.Positional.eval_from_0 in Hinv. nsatz.
Qed.
Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n)
@@ -339,9 +340,223 @@ Module Columns.
(fun P => B.Positional.to_associational_cps weight q
(fun Q => from_associational_cps n (P++Q) f)).
+ Definition sub_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.to_associational_cps weight q
+ (fun Q => from_associational_cps n (P++Q) f)).
+
End Columns.
End Columns.
-
+Hint Unfold
+ Columns.add_cps
+ Columns.mul_cps
+ Columns.sub_cps.
+Hint Rewrite
+ @Columns.compact_digit_id
+ @Columns.compact_step_id
+ @Columns.compact_id
+ @Columns.cons_to_nth_id
+ @Columns.from_associational_id
+ : uncps.
+Hint Rewrite
+ @Columns.eval_compact
+ @Columns.eval_cons_to_nth
+ @Columns.eval_from_associational
+ @Columns.eval_nils
+ using (assumption || omega): push_basesystem_eval.
+
+Section Freeze.
+ Context (weight : nat->Z)
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ (* add_get_carry takes in a number at which to split output *)
+ {add_get_carry: Z ->Z -> Z -> (Z * Z)}
+ {add_get_carry_correct : forall s x y,
+ fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)}
+ .
+
+ (* adds p and q if cond is 0, else adds 0 to p*)
+ Definition conditional_mask_cps {n} (mask:Z) (cond:Z) (p:Z^n)
+ {T} (f:_->T) :=
+ dlet and_term := if (dec (cond = 0)) then 0 else mask in
+ f (Tuple.map (Z.land and_term) p).
+
+ Definition conditional_mask {n} mask cond p :=
+ @conditional_mask_cps n mask cond p _ id.
+ Lemma conditional_mask_id {n} mask cond p T f:
+ @conditional_mask_cps n mask cond p T f
+ = f (conditional_mask mask cond p).
+ Proof.
+ cbv [conditional_mask_cps conditional_mask Let_In]; break_match;
+ autounfold; autorewrite with uncps push_id; reflexivity.
+ Qed.
+ Hint Opaque conditional_mask : uncps.
+ Hint Rewrite @conditional_mask_id : uncps.
+
+ Definition conditional_add_cps {n} mask cond (p q : Z^n) {T} (f:_->T) :=
+ conditional_mask_cps mask cond q
+ (fun qq => Columns.add_cps weight p qq
+ (fun R => Columns.compact_cps (add_get_carry:=add_get_carry) weight R f)).
+ Definition conditional_add {n} mask cond p q :=
+ @conditional_add_cps n mask cond p q _ id.
+ Lemma conditional_add_id {n} mask cond p q T f:
+ @conditional_add_cps n mask cond p q T f
+ = f (conditional_add mask cond p q).
+ Proof.
+ cbv [conditional_add_cps conditional_add]; autounfold;
+ autorewrite with uncps push_id; reflexivity.
+ Qed.
+ Hint Opaque conditional_add : uncps.
+ Hint Rewrite @conditional_add_id : uncps.
+
+
+ (*
+ [freeze] has the following steps:
+ (1) pseudomersenne reduction using [carry_reduce]
+ (2) subtract modulus in a carrying loop (in our framework, this
+ consists of two steps; [Columns.sub_cps] combines the input p and
+ the modulus m such that the ith limb in the output is the list
+ [p[i];-m[i]]. We can then call [Columns.compact].)
+ (3) look at the final carry, which should be either 0 or -1. If
+ it's -1, then we add the modulus back in. Otherwise we add 0 for
+ constant-timeness.
+ (4) discard the carry after this last addition; it should be 1 if
+ the carry in step 3 was -1, so they cancel out.
+ *)
+ Definition freeze_cps
+ {n} (mask:Z) (s:Z) (c:list B.limb) (m:Z^n) (p:Z^n)
+ {T} (f : Z^n->T) :=
+ B.Positional.carry_reduce_cps
+ (div:=div) (modulo:=modulo) weight s c p
+ (fun P => Columns.sub_cps weight P m
+ (fun Q => Columns.compact_cps (add_get_carry:=add_get_carry) weight Q
+ (fun carry_q => conditional_add_cps mask (fst carry_q) (snd carry_q) m
+ (fun carry_r => f (snd carry_r)))))
+ .
+
+ SearchAbout (((_ mod _) mod _)).
+
+ Definition freeze {n} mask s c m p :=
+ @freeze_cps n mask s c m p _ id.
+ Lemma freeze_id {n} mask s c m p T f:
+ @freeze_cps n mask s c m p T f = f (freeze mask s c m p).
+ Proof.
+ cbv [freeze_cps freeze]; repeat progress autounfold;
+ autorewrite with uncps push_id; reflexivity.
+ Qed.
+ Hint Opaque freeze : uncps.
+ Hint Rewrite @freeze_id : uncps.
+
+ Lemma map_land_zero {n} (p:Z^n):
+ Tuple.map (Z.land 0) p = B.Positional.zeros n.
+ Proof.
+ induction n; [ destruct p; reflexivity | ].
+ replace p with (append (hd p) (tl p)) by
+ (simpl in p; destruct n; destruct p; reflexivity).
+ rewrite map_append, IHn, Z.land_0_l; reflexivity.
+ Qed.
+
+ Lemma eval_conditional_mask {n} mask cond p (n_nonzero:n<>0%nat)
+ (Hmask : Tuple.map (Z.land mask) p = p):
+ B.Positional.eval weight (@conditional_mask n mask cond p)
+ = if (dec (cond = 0)) then 0 else B.Positional.eval weight p.
+ Proof.
+ cbv [conditional_mask_cps conditional_mask Let_In];
+ repeat progress autounfold; break_match;
+ rewrite ?Hmask, ?map_land_zero;
+ autorewrite with uncps push_id push_basesystem_eval; ring.
+ Qed.
+ Hint Rewrite @eval_conditional_mask using (omega || assumption)
+ : push_basesystem_eval.
+
+ Lemma eval_conditional_add {n} mask cond p q (n_nonzero:n<>0%nat)
+ (Hmask : Tuple.map (Z.land mask) q = q):
+ B.Positional.eval weight (snd (@conditional_add n mask cond p q))
+ = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add mask cond p q)).
+ Proof.
+ cbv [conditional_add_cps conditional_add];
+ repeat progress autounfold; rewrite ?Hmask, ?map_land_zero;
+ autorewrite with uncps push_id push_basesystem_eval;
+ break_match; ring.
+ Qed.
+ Hint Rewrite @eval_conditional_add using (omega || assumption)
+ : push_basesystem_eval.
+
+ Lemma freezeZ m s c y0 z z0 c0 a :
+ m = s - c ->
+ 0 < c < s ->
+ s <> 0 ->
+ -m <= y0 < m ->
+ z = y0 mod s ->
+ c0 = y0 / s ->
+ c0 <> 0 ->
+ z0 = z + m ->
+ a = z0 mod s ->
+ a mod m = y0 mod m.
+ Proof.
+ clear. intros. subst.
+ rewrite Z.add_mod by assumption.
+ rewrite Z.mod_mod by assumption.
+ rewrite <-Z.add_mod by assumption.
+ assert (~ (0 <= y0 < s)) by (pose proof (Z.div_small y0 s); tauto).
+ assert (-(s-c) <= y0 < 0) by omega.
+ rewrite Z.mod_small with (b := s) by omega.
+ rewrite Z.add_mod, Z.mod_same, Z.add_0_r, Z.mod_mod by omega.
+ reflexivity.
+ Qed.
+
+ Lemma eval_freeze {n} mask s c m p
+ (n_nonzero:n<>0%nat) (s_nonzero:s<>0)
+ (Hweight : weight (S (pred n)) / weight (pred n) <> 0)
+ (Hmask : Tuple.map (Z.land mask) m = m)
+ modulus (Hm : B.Positional.eval weight m = Z.pos modulus)
+ (Hsc : Z.pos modulus = s - B.Associational.eval c)
+ :
+ mod_eq modulus
+ (B.Positional.eval weight (@freeze n mask s c m p))
+ (B.Positional.eval weight p).
+ Proof.
+ cbv [freeze_cps freeze mod_eq]; repeat progress autounfold;
+ autorewrite with uncps push_id.
+
+ assert (Z.pos modulus <> 0) by (pose proof Pos2Z.is_pos modulus; omega).
+ pose proof div_mod.
+ break_match; subst.
+
+ rewrite Z.mul_0_r, Z.add_0_r, Z.sub_0_r.
+ (* TODO : how to prove second carry is 0? *)
+ rewrite Z.add_mod, B.Associational.eval_reduce by assumption.
+ autorewrite with uncps push_id push_basesystem_eval.
+ rewrite Hm. autorewrite with zsimplify.
+ reflexivity.
+
+
+ ring_simplify.
+ let p := fresh "P" in
+ let carry := fresh "carry" in
+ let result := fresh "result" in
+ match goal with H:fst ?x <> 0 |- _ =>
+ remember x as p; destruct p as [carry result];
+ autorewrite with cancel_pair in *
+ end.
+ rewrite Columns.eval_from_associational by assumption.
+ autorewrite with uncps push_id push_basesystem_eval.
+ rewrite Hm.
+ Check Z.add_mod_full.
+ match goal with |- context [(?a + ?b - ?c + ?d) mod ?m] =>
+ replace (a + b - c + d) with (b + (a-c) + d) by ring;
+ rewrite (Z.add_mod_full _ d), (Z.add_mod_full _ (a-c))
+ end.
+ rewrite !Z.mod_same by assumption.
+ rewrite !Z.add_0_r, !Z.add_0_l.
+ rewrite !Z.mod_mod by assumption.
+ rewrite Z.sub_mod_full.
+ rewrite B.Associational.eval_reduce by assumption.
+ autorewrite with uncps push_id push_basesystem_eval.
+
+
+
(*
(* Just some pretty-printing *)
Local Notation "fst~ a" := (let (x,_) := a in x) (at level 40, only printing).