aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-04-23 16:58:58 -0400
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2017-05-01 14:34:48 -0400
commitf5043a0a3210edf80b266f20998ad5b0f3153c0d (patch)
treec6f089ecec88a438d5539d32ca57889bc9a95359
parent5475629802b34912b2d280b3a1ae54187a721124 (diff)
proved freeze, removed initial carry step (the correctness proof of that step needs bounds-checker)
-rw-r--r--src/Arithmetic/Saturated.v122
1 files changed, 65 insertions, 57 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index addfe08e5..a2a751b7e 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -108,7 +108,7 @@ Module Columns.
{weight_positive : forall i, weight i > 0}
{weight_multiples : forall i, weight (S i) mod weight i = 0}
{weight_divides : forall i : nat, weight (S i) / weight i > 0}
- (* add_get_carry takes in a number at which to split output *)
+ (* add_get_carry takes in a number at which to split output *)
{add_get_carry: Z ->Z -> Z -> (Z * Z)}
{add_get_carry_mod : forall s x y,
fst (add_get_carry s x y) = (x + y) mod s}
@@ -418,27 +418,30 @@ Module Columns.
rewrite eval_cons_to_nth by omega. nsatz.
Qed.
- Definition mul_cps {n m} (p q : Z^n) {T} (f : (list Z)^m->T) :=
- B.Positional.to_associational_cps weight p
- (fun P => B.Positional.to_associational_cps weight q
- (fun Q => B.Associational.mul_cps P Q
- (fun PQ => from_associational_cps m PQ f))).
+ End Columns.
- Definition add_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) :=
- B.Positional.to_associational_cps weight p
- (fun P => B.Positional.to_associational_cps weight q
- (fun Q => from_associational_cps n (P++Q) f)).
+ Section Wrappers.
+ Context (weight : nat->Z)
+ {add_get_carry: Z ->Z -> Z -> (Z * Z)}
+ {div modulo : Z -> Z -> Z}.
- Definition sub_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) :=
+ Definition add_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) :=
B.Positional.to_associational_cps weight p
(fun P => B.Positional.to_associational_cps weight q
- (fun Q => from_associational_cps n (P++Q) f)).
+ (fun Q => from_associational_cps weight n (P++Q)
+ (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f))).
- End Columns.
+ Definition sub_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.negate_snd_cps weight q
+ (fun nq => B.Positional.to_associational_cps weight nq
+ (fun Q => from_associational_cps weight n (P++Q)
+ (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)))).
+
+ End Wrappers.
End Columns.
Hint Unfold
Columns.add_cps
- Columns.mul_cps
Columns.sub_cps.
Hint Rewrite
@Columns.compact_digit_id
@@ -493,8 +496,7 @@ Section Freeze.
Definition conditional_add_cps {n} mask cond (p q : Z^n) {T} (f:_->T) :=
conditional_mask_cps mask cond q
- (fun qq => Columns.add_cps weight p qq
- (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)).
+ (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight p qq f).
Definition conditional_add {n} mask cond p q :=
@conditional_add_cps n mask cond p q _ id.
Lemma conditional_add_id {n} mask cond p q T f:
@@ -509,33 +511,31 @@ Section Freeze.
(*
+ The input to [freeze] should be less than 2*m (this can probably
+ be accomplished by a single carry_reduce step, for most moduli).
+
[freeze] has the following steps:
- (1) pseudomersenne reduction using [carry_reduce]
- (2) subtract modulus in a carrying loop (in our framework, this
+ (1) subtract modulus in a carrying loop (in our framework, this
consists of two steps; [Columns.sub_cps] combines the input p and
the modulus m such that the ith limb in the output is the list
[p[i];-m[i]]. We can then call [Columns.compact].)
- (3) look at the final carry, which should be either 0 or -1. If
+ (2) look at the final carry, which should be either 0 or -1. If
it's -1, then we add the modulus back in. Otherwise we add 0 for
constant-timeness.
- (4) discard the carry after this last addition; it should be 1 if
+ (3) discard the carry after this last addition; it should be 1 if
the carry in step 3 was -1, so they cancel out.
*)
- Definition freeze_cps
- {n} (mask:Z) (s:Z) (c:list B.limb) (m:Z^n) (p:Z^n)
- {T} (f : Z^n->T) :=
- B.Positional.carry_reduce_cps
- (div:=div) (modulo:=modulo) weight s c p
- (fun P => Columns.sub_cps weight P m
- (fun Q => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight Q
- (fun carry_q => conditional_add_cps mask (fst carry_q) (snd carry_q) m
- (fun carry_r => f (snd carry_r)))))
+ Definition freeze_cps {n} (mask:Z) (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
+ Columns.sub_cps (div:=div) (modulo:=modulo)
+ (add_get_carry:=add_get_carry) weight p m
+ (fun carry_p => conditional_add_cps mask (fst carry_p) (snd carry_p) m
+ (fun carry_r => f (snd carry_r)))
.
- Definition freeze {n} mask s c m p :=
- @freeze_cps n mask s c m p _ id.
- Lemma freeze_id {n} mask s c m p T f:
- @freeze_cps n mask s c m p T f = f (freeze mask s c m p).
+ Definition freeze {n} mask m p :=
+ @freeze_cps n mask m p _ id.
+ Lemma freeze_id {n} mask m p T f:
+ @freeze_cps n mask m p T f = f (freeze mask m p).
Proof.
cbv [freeze_cps freeze]; repeat progress autounfold;
autorewrite with uncps push_id; reflexivity.
@@ -590,13 +590,17 @@ Section Freeze.
y0 = y - m ->
z = y0 mod s ->
c0 = y0 / s ->
- c0 <> 0 ->
z0 = z + (if (dec (c0 = 0)) then 0 else m) ->
a = z0 mod s ->
a mod m = y0 mod m.
Proof.
clear. intros. subst. break_match.
- { rewrite Z.add_0_r, Z.mod_mod, !Z.mod_small by omega.
+ { rewrite Z.add_0_r, Z.mod_mod by omega.
+ assert (-(s-c) <= y - (s-c) < s-c) by omega.
+ match goal with H : s <> 0 |- _ =>
+ rewrite (proj2 (Z.mod_small_iff _ s H))
+ by (apply Z.div_small_iff; assumption)
+ end.
reflexivity. }
{ rewrite <-Z.add_mod_l, Z.sub_mod_full.
rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega.
@@ -604,36 +608,40 @@ Section Freeze.
by (pose proof (Z.div_small (y - (s-c)) s); omega).
f_equal. ring. }
Qed.
-
- Lemma eval_freeze {n} mask s c m p
- (n_nonzero:n<>0%nat) (s_nonzero:s<>0)
- (Hweight : weight (S (pred n)) / weight (pred n) <> 0)
+
+ Lemma eval_freeze {n} mask c m p
+ (n_nonzero:n<>0%nat)
(Hmask : Tuple.map (Z.land mask) m = m)
+ (Hc : 0 < B.Associational.eval c < weight n)
modulus (Hm : B.Positional.eval weight m = Z.pos modulus)
- (Hsc : Z.pos modulus = s - B.Associational.eval c)
+ (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus))
+ (Hsc : Z.pos modulus = weight n - B.Associational.eval c)
:
mod_eq modulus
- (B.Positional.eval weight (@freeze n mask s c m p))
+ (B.Positional.eval weight (@freeze n mask m p))
(B.Positional.eval weight p).
Proof.
- cbv [freeze_cps freeze mod_eq conditional_add_cps].
+ cbv [freeze_cps freeze conditional_add_cps].
repeat progress autounfold.
- autorewrite with uncps push_id.
-
- match goal with |- context [B.Associational.reduce ?s ?c ?p] =>
- remember (B.Associational.reduce s c p) as y end.
- match goal with |- context [fst ?x] =>
- remember x as carry_q end.
- match goal with |- context [conditional_mask ?mask ?cond ?p] =>
- remember (conditional_mask mask cond p) as m0 end.
- match goal with |- context [Columns.compact ?w ?p] =>
- remember (Columns.compact w p) as carry_r end.
- destruct carry_q as [c0 z].
- destruct carry_r as [c1 a].
-
- Admitted.
+ autorewrite with uncps push_id push_basesystem_eval.
+
+ pose proof (weight_nonzero n).
+
+ remember (B.Positional.eval weight p) as y.
+ remember (y + -B.Positional.eval weight m) as y0.
+ rewrite Hm in *.
+
+ transitivity y0; cbv [mod_eq].
+ { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0);
+ try assumption; reflexivity. }
+ { subst y0.
+ assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0.
+ rewrite Z.add_mod by assumption.
+ rewrite Z.mod_opp_l_z by auto using Z.mod_same.
+ rewrite Z.add_0_r, Z.mod_mod by assumption.
+ reflexivity. }
+ Qed.
End Freeze.
-
(*