diff options
author | jadep <jade.philipoom@gmail.com> | 2017-04-23 16:58:58 -0400 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2017-05-01 14:34:48 -0400 |
commit | f5043a0a3210edf80b266f20998ad5b0f3153c0d (patch) | |
tree | c6f089ecec88a438d5539d32ca57889bc9a95359 | |
parent | 5475629802b34912b2d280b3a1ae54187a721124 (diff) |
proved freeze, removed initial carry step (the correctness proof of that step needs bounds-checker)
-rw-r--r-- | src/Arithmetic/Saturated.v | 122 |
1 files changed, 65 insertions, 57 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index addfe08e5..a2a751b7e 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -108,7 +108,7 @@ Module Columns. {weight_positive : forall i, weight i > 0} {weight_multiples : forall i, weight (S i) mod weight i = 0} {weight_divides : forall i : nat, weight (S i) / weight i > 0} - (* add_get_carry takes in a number at which to split output *) + (* add_get_carry takes in a number at which to split output *) {add_get_carry: Z ->Z -> Z -> (Z * Z)} {add_get_carry_mod : forall s x y, fst (add_get_carry s x y) = (x + y) mod s} @@ -418,27 +418,30 @@ Module Columns. rewrite eval_cons_to_nth by omega. nsatz. Qed. - Definition mul_cps {n m} (p q : Z^n) {T} (f : (list Z)^m->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.to_associational_cps weight q - (fun Q => B.Associational.mul_cps P Q - (fun PQ => from_associational_cps m PQ f))). + End Columns. - Definition add_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.to_associational_cps weight q - (fun Q => from_associational_cps n (P++Q) f)). + Section Wrappers. + Context (weight : nat->Z) + {add_get_carry: Z ->Z -> Z -> (Z * Z)} + {div modulo : Z -> Z -> Z}. - Definition sub_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) := + Definition add_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) := B.Positional.to_associational_cps weight p (fun P => B.Positional.to_associational_cps weight q - (fun Q => from_associational_cps n (P++Q) f)). + (fun Q => from_associational_cps weight n (P++Q) + (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f))). - End Columns. + Definition sub_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.negate_snd_cps weight q + (fun nq => B.Positional.to_associational_cps weight nq + (fun Q => from_associational_cps weight n (P++Q) + (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)))). + + End Wrappers. End Columns. Hint Unfold Columns.add_cps - Columns.mul_cps Columns.sub_cps. Hint Rewrite @Columns.compact_digit_id @@ -493,8 +496,7 @@ Section Freeze. Definition conditional_add_cps {n} mask cond (p q : Z^n) {T} (f:_->T) := conditional_mask_cps mask cond q - (fun qq => Columns.add_cps weight p qq - (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)). + (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight p qq f). Definition conditional_add {n} mask cond p q := @conditional_add_cps n mask cond p q _ id. Lemma conditional_add_id {n} mask cond p q T f: @@ -509,33 +511,31 @@ Section Freeze. (* + The input to [freeze] should be less than 2*m (this can probably + be accomplished by a single carry_reduce step, for most moduli). + [freeze] has the following steps: - (1) pseudomersenne reduction using [carry_reduce] - (2) subtract modulus in a carrying loop (in our framework, this + (1) subtract modulus in a carrying loop (in our framework, this consists of two steps; [Columns.sub_cps] combines the input p and the modulus m such that the ith limb in the output is the list [p[i];-m[i]]. We can then call [Columns.compact].) - (3) look at the final carry, which should be either 0 or -1. If + (2) look at the final carry, which should be either 0 or -1. If it's -1, then we add the modulus back in. Otherwise we add 0 for constant-timeness. - (4) discard the carry after this last addition; it should be 1 if + (3) discard the carry after this last addition; it should be 1 if the carry in step 3 was -1, so they cancel out. *) - Definition freeze_cps - {n} (mask:Z) (s:Z) (c:list B.limb) (m:Z^n) (p:Z^n) - {T} (f : Z^n->T) := - B.Positional.carry_reduce_cps - (div:=div) (modulo:=modulo) weight s c p - (fun P => Columns.sub_cps weight P m - (fun Q => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight Q - (fun carry_q => conditional_add_cps mask (fst carry_q) (snd carry_q) m - (fun carry_r => f (snd carry_r))))) + Definition freeze_cps {n} (mask:Z) (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := + Columns.sub_cps (div:=div) (modulo:=modulo) + (add_get_carry:=add_get_carry) weight p m + (fun carry_p => conditional_add_cps mask (fst carry_p) (snd carry_p) m + (fun carry_r => f (snd carry_r))) . - Definition freeze {n} mask s c m p := - @freeze_cps n mask s c m p _ id. - Lemma freeze_id {n} mask s c m p T f: - @freeze_cps n mask s c m p T f = f (freeze mask s c m p). + Definition freeze {n} mask m p := + @freeze_cps n mask m p _ id. + Lemma freeze_id {n} mask m p T f: + @freeze_cps n mask m p T f = f (freeze mask m p). Proof. cbv [freeze_cps freeze]; repeat progress autounfold; autorewrite with uncps push_id; reflexivity. @@ -590,13 +590,17 @@ Section Freeze. y0 = y - m -> z = y0 mod s -> c0 = y0 / s -> - c0 <> 0 -> z0 = z + (if (dec (c0 = 0)) then 0 else m) -> a = z0 mod s -> a mod m = y0 mod m. Proof. clear. intros. subst. break_match. - { rewrite Z.add_0_r, Z.mod_mod, !Z.mod_small by omega. + { rewrite Z.add_0_r, Z.mod_mod by omega. + assert (-(s-c) <= y - (s-c) < s-c) by omega. + match goal with H : s <> 0 |- _ => + rewrite (proj2 (Z.mod_small_iff _ s H)) + by (apply Z.div_small_iff; assumption) + end. reflexivity. } { rewrite <-Z.add_mod_l, Z.sub_mod_full. rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega. @@ -604,36 +608,40 @@ Section Freeze. by (pose proof (Z.div_small (y - (s-c)) s); omega). f_equal. ring. } Qed. - - Lemma eval_freeze {n} mask s c m p - (n_nonzero:n<>0%nat) (s_nonzero:s<>0) - (Hweight : weight (S (pred n)) / weight (pred n) <> 0) + + Lemma eval_freeze {n} mask c m p + (n_nonzero:n<>0%nat) (Hmask : Tuple.map (Z.land mask) m = m) + (Hc : 0 < B.Associational.eval c < weight n) modulus (Hm : B.Positional.eval weight m = Z.pos modulus) - (Hsc : Z.pos modulus = s - B.Associational.eval c) + (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus)) + (Hsc : Z.pos modulus = weight n - B.Associational.eval c) : mod_eq modulus - (B.Positional.eval weight (@freeze n mask s c m p)) + (B.Positional.eval weight (@freeze n mask m p)) (B.Positional.eval weight p). Proof. - cbv [freeze_cps freeze mod_eq conditional_add_cps]. + cbv [freeze_cps freeze conditional_add_cps]. repeat progress autounfold. - autorewrite with uncps push_id. - - match goal with |- context [B.Associational.reduce ?s ?c ?p] => - remember (B.Associational.reduce s c p) as y end. - match goal with |- context [fst ?x] => - remember x as carry_q end. - match goal with |- context [conditional_mask ?mask ?cond ?p] => - remember (conditional_mask mask cond p) as m0 end. - match goal with |- context [Columns.compact ?w ?p] => - remember (Columns.compact w p) as carry_r end. - destruct carry_q as [c0 z]. - destruct carry_r as [c1 a]. - - Admitted. + autorewrite with uncps push_id push_basesystem_eval. + + pose proof (weight_nonzero n). + + remember (B.Positional.eval weight p) as y. + remember (y + -B.Positional.eval weight m) as y0. + rewrite Hm in *. + + transitivity y0; cbv [mod_eq]. + { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0); + try assumption; reflexivity. } + { subst y0. + assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0. + rewrite Z.add_mod by assumption. + rewrite Z.mod_opp_l_z by auto using Z.mod_same. + rewrite Z.add_0_r, Z.mod_mod by assumption. + reflexivity. } + Qed. End Freeze. - (* |