diff options
author | jadep <jade.philipoom@gmail.com> | 2017-06-29 21:32:04 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2017-06-29 21:32:04 -0400 |
commit | 2876f7c688590a64189f47b439f7edf26c91c5de (patch) | |
tree | fffe55bc24e83105fca356a81a352e1fa4309999 /src/Arithmetic | |
parent | b291707642db5986240b3e9eb9a80839d81ffe42 (diff) |
Reorganization of saturated arithmetic
Diffstat (limited to 'src/Arithmetic')
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v | 6 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v | 11 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/AddSub.v | 109 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Core.v | 993 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Freeze.v | 122 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/MontgomeryAPI.v | 599 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/MulSplit.v | 73 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/UniformWeight.v | 71 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Wrappers.v | 53 |
9 files changed, 1036 insertions, 1001 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index f344cb7de..9affa82fa 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -5,7 +5,7 @@ of the algorithm; note that it may be that none of the algorithms there exactly match what we're doing here. *) Require Import Coq.ZArith.ZArith. -Require Import Crypto.Arithmetic.Saturated. +Require Import Crypto.Arithmetic.Saturated.MontgomeryAPI. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. Require Import Crypto.Util.Notations. Require Import Crypto.Util.LetIn. @@ -22,8 +22,8 @@ Section WordByWordMontgomery. (N : T R_numlimbs). Local Notation scmul := (@scmul (Z.pos r)). - Local Notation addT' := (@Saturated.add_S1 (Z.pos r)). - Local Notation addT := (@Saturated.add (Z.pos r)). + Local Notation addT' := (@MontgomeryAPI.add_S1 (Z.pos r)). + Local Notation addT := (@MontgomeryAPI.add (Z.pos r)). Local Notation conditional_sub_cps := (fun V => @conditional_sub_cps (Z.pos r) _ V N _). Local Notation conditional_sub := (fun V => @conditional_sub (Z.pos r) _ V N). Local Notation sub_then_maybe_add_cps := diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v index 747280fe6..83791ec5f 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v @@ -1,7 +1,8 @@ (*** Word-By-Word Montgomery Multiplication Proofs *) Require Import Coq.ZArith.BinInt. Require Import Coq.micromega.Lia. -Require Import Crypto.Arithmetic.Saturated. +Require Import Crypto.Arithmetic.Saturated.UniformWeight. +Require Import Crypto.Arithmetic.Saturated.MontgomeryAPI. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Proofs. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition. @@ -16,8 +17,8 @@ Section WordByWordMontgomery. (R_numlimbs : nat). Local Notation small := (@small (Z.pos r)). Local Notation eval := (@eval (Z.pos r)). - Local Notation addT' := (@Saturated.add_S1 (Z.pos r)). - Local Notation addT := (@Saturated.add (Z.pos r)). + Local Notation addT' := (@MontgomeryAPI.add_S1 (Z.pos r)). + Local Notation addT := (@MontgomeryAPI.add (Z.pos r)). Local Notation scmul := (@scmul (Z.pos r)). Local Notation eval_zero := (@eval_zero (Z.pos r)). Local Notation small_zero := (@small_zero r (Zorder.Zgt_pos_0 _)). @@ -61,11 +62,11 @@ Section WordByWordMontgomery. Qed. Local Lemma small_addT : forall n a b, small a -> small b -> small (@addT n a b). Proof. - intros; apply Saturated.small_add; auto; lia. + intros; apply MontgomeryAPI.small_add; auto; lia. Qed. Local Lemma small_addT' : forall n a b, small a -> small b -> small (@addT' n a b). Proof. - intros; apply Saturated.small_add_S1; auto; lia. + intros; apply MontgomeryAPI.small_add_S1; auto; lia. Qed. Local Notation conditional_sub_cps := (fun V : T (S R_numlimbs) => @conditional_sub_cps (Z.pos r) _ V N _). diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v new file mode 100644 index 000000000..c6758b865 --- /dev/null +++ b/src/Arithmetic/Saturated/AddSub.v @@ -0,0 +1,109 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Arithmetic.Saturated.Core. +Require Import Crypto.Arithmetic.Saturated.UniformWeight. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.Tuple Crypto.Util.LetIn. +Local Notation "A ^ n" := (tuple A n) : type_scope. + +Module B. + Module Positional. + Section Positional. + Context {s:Z}. (* s is bitwidth *) + Let small {n} := @small s n. + Section GenericOp. + Context {op : Z -> Z -> Z} + {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) + {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) + + Fixpoint chain_op'_cps {n}: + option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T := + match n with + | O => fun c p _ _ f => + let carry := match c with | None => 0 | Some x => x end in + f (carry,p) + | S n' => + fun c p q _ f => + (* for the first call, use op_get_carry, then op_with_carry *) + let op' := match c with + | None => op_get_carry + | Some x => op_with_carry x end in + dlet carry_result := op' (hd p) (hd q) in + chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) _ + (fun carry_pq => + f (fst carry_pq, + append (fst carry_result) (snd carry_pq))) + end. + Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id. + Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f. + Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id. + + Lemma chain_op'_id {n} : forall c p q T f, + @chain_op'_cps n c p q T f = f (chain_op' c p q). + Proof. + cbv [chain_op']; induction n; intros; destruct c; + simpl chain_op'_cps; cbv [Let_In]; try reflexivity. + { etransitivity; rewrite IHn; reflexivity. } + { etransitivity; rewrite IHn; reflexivity. } + Qed. + + Lemma chain_op_id {n} p q T f : + @chain_op_cps n p q T f = f (chain_op p q). + Proof. apply chain_op'_id. Qed. + End GenericOp. + + Section AddSub. + Let eval {n} := B.Positional.eval (n:=n) (uweight s). + + Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := + chain_op_cps (op_get_carry := Z.add_get_carry_full s) + (op_with_carry := Z.add_with_get_carry_full s) + p q f. + Definition sat_add {n} p q := @sat_add_cps n p q _ id. + + Lemma sat_add_id n p q T f : + @sat_add_cps n p q T f = f (sat_add p q). + Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed. + + Lemma sat_add_mod n p q : + eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n). + Admitted. + + Lemma sat_add_div n p q : + fst (@sat_add n p q) = (eval p + eval q) / (uweight s n). + Admitted. + + Lemma small_sat_add n p q : small (snd (@sat_add n p q)). + Admitted. + + Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := + chain_op_cps (op_get_carry := Z.sub_get_borrow_full s) + (op_with_carry := Z.sub_with_get_borrow_full s) + p q f. + Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. + + Lemma sat_sub_id n p q T f : + @sat_sub_cps n p q T f = f (sat_sub p q). + Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed. + + Lemma sat_sub_mod n p q : + eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n). + Admitted. + + Lemma sat_sub_div n p q : + fst (@sat_sub n p q) = - ((eval p - eval q) / uweight s n). + Admitted. + + Lemma small_sat_sub n p q : small (snd (@sat_sub n p q)). + Admitted. + + End AddSub. + End Positional. + End Positional. +End B. +Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps. +Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps. +Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval.
\ No newline at end of file diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v index 0c059b93d..27171c741 100644 --- a/src/Arithmetic/Saturated/Core.v +++ b/src/Arithmetic/Saturated/Core.v @@ -11,10 +11,6 @@ Require Import Crypto.Util.Tuple Crypto.Util.ListUtil. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Decidable Crypto.Util.ZUtil. Require Import Crypto.Util.NatUtil. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.AddGetCarry. -Require Import Crypto.Util.ZUtil.Zselect. -Require Import Crypto.Util.ZUtil.MulSplit. Require Import Crypto.Util.Tactics.SpecializeBy. Local Notation "A ^ n" := (tuple A n) : type_scope. @@ -107,71 +103,6 @@ check confirms our result. ***) -Module Associational. - Section Associational. - Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) - {mul_split_mod : forall s x y, - fst (mul_split s x y) = (x * y) mod s} - {mul_split_div : forall s x y, - snd (mul_split s x y) = (x * y) / s} - . - - Definition multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) := - dlet xy := mul_split s (snd t) (snd t') in - f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil). - - Definition multerm s t t' := multerm_cps s t t' id. - Lemma multerm_id s t t' T f : - @multerm_cps s t t' T f = f (multerm s t t'). - Proof. reflexivity. Qed. - Hint Opaque multerm : uncps. - Hint Rewrite multerm_id : uncps. - - Definition mul_cps s (p q : list B.limb) {T} (f : list B.limb -> T) := - flat_map_cps (fun t => @flat_map_cps _ _ (multerm_cps s t) q) p f. - - Definition mul s p q := mul_cps s p q id. - Lemma mul_id s p q T f : @mul_cps s p q T f = f (mul s p q). - Proof. cbv [mul mul_cps]. autorewrite with uncps. reflexivity. Qed. - Hint Opaque mul : uncps. - Hint Rewrite mul_id : uncps. - - Lemma eval_map_multerm s a q (s_nonzero:s<>0): - B.Associational.eval (flat_map (multerm s a) q) = fst a * snd a * B.Associational.eval q. - Proof. - cbv [multerm multerm_cps Let_In]; induction q; - repeat match goal with - | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * ) - | _ => progress simpl flat_map - | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div - | _ => rewrite Z.mod_eq by assumption - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_map_multerm using (omega || assumption) - : push_basesystem_eval. - - Lemma eval_mul s p q (s_nonzero:s<>0): - B.Associational.eval (mul s p q) = B.Associational.eval p * B.Associational.eval q. - Proof. - cbv [mul mul_cps]; induction p; [reflexivity|]. - repeat match goal with - | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * ) - | _ => progress simpl flat_map - | _ => rewrite IHp - | _ => progress change (fun x => multerm_cps s a x id) with (multerm s a) - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_mul : push_basesystem_eval. - - End Associational. -End Associational. -Hint Opaque Associational.mul Associational.multerm : uncps. -Hint Rewrite @Associational.mul_id @Associational.multerm_id : uncps. -Hint Rewrite @Associational.eval_mul @Associational.eval_map_multerm using (omega || assumption) : push_basesystem_eval. - - Module Columns. Section Columns. Context (weight : nat->Z) @@ -480,56 +411,7 @@ Module Columns. rewrite eval_cons_to_nth by omega. nsatz. Qed. End Columns. - Hint Rewrite - @Columns.compact_id - @Columns.from_associational_id - : uncps. - Hint Rewrite - @Columns.compact_mod - @Columns.compact_div - @Columns.eval_from_associational - using (assumption || omega): push_basesystem_eval. - - Section Wrappers. - Context (weight : nat->Z). - - Definition add_cps {n1 n2 n3} (p : Z^n1) (q : Z^n2) - {T} (f : (Z*Z^n3)->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.to_associational_cps weight q - (fun Q => from_associational_cps weight n3 (P++Q) - (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))). - - Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2) - {T} (f : (Z*Z^n3)->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.negate_snd_cps weight q - (fun nq => B.Positional.to_associational_cps weight nq - (fun Q => from_associational_cps weight n3 (P++Q) - (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). - - Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2) - {T} (f : (Z*Z^n3)->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.to_associational_cps weight q - (fun Q => Associational.mul_cps (mul_split := Z.mul_split) s P Q - (fun PQ => from_associational_cps weight n3 PQ - (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). - - Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2) - {T} (f:_->T) := - B.Positional.select_cps mask cond q - (fun qq => add_cps (n3:=n3) p qq f). - - End Wrappers. - Hint Unfold add_cps unbalanced_sub_cps mul_cps conditional_add_cps. - End Columns. -Hint Unfold - Columns.conditional_add_cps - Columns.add_cps - Columns.unbalanced_sub_cps - Columns.mul_cps. Hint Rewrite @Columns.compact_digit_id @Columns.compact_step_id @@ -544,878 +426,3 @@ Hint Rewrite @Columns.eval_from_associational @Columns.eval_nils using (assumption || omega): push_basesystem_eval. - -Section Freeze. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0} - . - - - (* - The input to [freeze] should be less than 2*m (this can probably - be accomplished by a single carry_reduce step, for most moduli). - - [freeze] has the following steps: - (1) subtract modulus in a carrying loop (in our framework, this - consists of two steps; [Columns.unbalanced_sub_cps] combines the - input p and the modulus m such that the ith limb in the output is - the list [p[i];-m[i]]. We can then call [Columns.compact].) - (2) look at the final carry, which should be either 0 or -1. If - it's -1, then we add the modulus back in. Otherwise we add 0 for - constant-timeness. - (3) discard the carry after this last addition; it should be 1 if - the carry in step 3 was -1, so they cancel out. - *) - Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := - Columns.unbalanced_sub_cps (n3:=n) weight p m - (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m - (fun carry_r => f (snd carry_r))) - . - - Definition freeze {n} mask m p := - @freeze_cps n mask m p _ id. - Lemma freeze_id {n} mask m p T f: - @freeze_cps n mask m p T f = f (freeze mask m p). - Proof. - cbv [freeze_cps freeze]; repeat progress autounfold; - autorewrite with uncps push_id; reflexivity. - Qed. - Hint Opaque freeze : uncps. - Hint Rewrite @freeze_id : uncps. - - Lemma freezeZ m s c y y0 z z0 c0 a : - m = s - c -> - 0 < c < s -> - s <> 0 -> - 0 <= y < 2*m -> - y0 = y - m -> - z = y0 mod s -> - c0 = y0 / s -> - z0 = z + (if (dec (c0 = 0)) then 0 else m) -> - a = z0 mod s -> - a mod m = y0 mod m. - Proof. - clear. intros. subst. break_match. - { rewrite Z.add_0_r, Z.mod_mod by omega. - assert (-(s-c) <= y - (s-c) < s-c) by omega. - match goal with H : s <> 0 |- _ => - rewrite (proj2 (Z.mod_small_iff _ s H)) - by (apply Z.div_small_iff; assumption) - end. - reflexivity. } - { rewrite <-Z.add_mod_l, Z.sub_mod_full. - rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega. - rewrite Z.mod_small with (b := s) - by (pose proof (Z.div_small (y - (s-c)) s); omega). - f_equal. ring. } - Qed. - - Lemma eval_freeze {n} c mask m p - (n_nonzero:n<>0%nat) - (Hc : 0 < B.Associational.eval c < weight n) - (Hmask : Tuple.map (Z.land mask) m = m) - modulus (Hm : B.Positional.eval weight m = Z.pos modulus) - (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus)) - (Hsc : Z.pos modulus = weight n - B.Associational.eval c) - : - mod_eq modulus - (B.Positional.eval weight (@freeze n mask m p)) - (B.Positional.eval weight p). - Proof. - cbv [freeze_cps freeze]. - repeat progress autounfold. - pose proof Z.add_get_carry_full_mod. - pose proof Z.add_get_carry_full_div. - pose proof div_correct. pose proof modulo_correct. - autorewrite with uncps push_id push_basesystem_eval. - - pose proof (weight_nonzero n). - - remember (B.Positional.eval weight p) as y. - remember (y + -B.Positional.eval weight m) as y0. - rewrite Hm in *. - - transitivity y0; cbv [mod_eq]. - { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0); - try assumption; reflexivity. } - { subst y0. - assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0. - rewrite Z.add_mod by assumption. - rewrite Z.mod_opp_l_z by auto using Z.mod_same. - rewrite Z.add_0_r, Z.mod_mod by assumption. - reflexivity. } - Qed. -End Freeze. - -Section UniformWeight. - Context (bound : Z) {bound_pos : bound > 0}. - - Definition uweight : nat -> Z := fun i => bound ^ Z.of_nat i. - Lemma uweight_0 : uweight 0%nat = 1. Proof. reflexivity. Qed. - Lemma uweight_positive i : uweight i > 0. - Proof. apply Z.lt_gt, Z.pow_pos_nonneg; omega. Qed. - Lemma uweight_nonzero i : uweight i <> 0. - Proof. auto using Z.positive_is_nonzero, uweight_positive. Qed. - Lemma uweight_multiples i : uweight (S i) mod uweight i = 0. - Proof. apply Z.mod_same_pow; rewrite Nat2Z.inj_succ; omega. Qed. - Lemma uweight_divides i : uweight (S i) / uweight i > 0. - Proof. - cbv [uweight]. rewrite <-Z.pow_sub_r by (rewrite ?Nat2Z.inj_succ; omega). - apply Z.lt_gt, Z.pow_pos_nonneg; rewrite ?Nat2Z.inj_succ; omega. - Qed. - - (* TODO : move to Positional *) - Lemma eval_from_eq {n} (p:Z^n) wt offset : - (forall i, wt i = uweight (i + offset)) -> - B.Positional.eval wt p = B.Positional.eval_from uweight offset p. - Proof. cbv [B.Positional.eval_from]. auto using B.Positional.eval_wt_equiv. Qed. - - Lemma uweight_eval_from {n} (p:Z^n): forall offset, - B.Positional.eval_from uweight offset p = uweight offset * B.Positional.eval uweight p. - Proof. - induction n; intros; cbv [B.Positional.eval_from]; - [|rewrite (subst_append p)]; - repeat match goal with - | _ => destruct p - | _ => rewrite B.Positional.eval_unit; [ ] - | _ => rewrite B.Positional.eval_step; [ ] - | _ => rewrite IHn; [ ] - | _ => rewrite eval_from_eq with (offset0:=S offset) - by (intros; f_equal; omega) - | _ => rewrite eval_from_eq with - (wt:=fun i => uweight (S i)) (offset0:=1%nat) - by (intros; f_equal; omega) - | _ => ring - end. - repeat match goal with - | _ => cbv [uweight]; progress autorewrite with natsimplify - | _ => progress (rewrite ?Nat2Z.inj_succ, ?Nat2Z.inj_0, ?Z.pow_0_r) - | _ => rewrite !Z.pow_succ_r by (try apply Nat2Z.is_nonneg; omega) - | _ => ring - end. - Qed. - - Lemma uweight_eval_step {n} (p:Z^S n): - B.Positional.eval uweight p = hd p + bound * B.Positional.eval uweight (tl p). - Proof. - rewrite (subst_append p) at 1; rewrite B.Positional.eval_step. - rewrite eval_from_eq with (offset := 1%nat) by (intros; f_equal; omega). - rewrite uweight_eval_from. cbv [uweight]; rewrite Z.pow_0_r, Z.pow_1_r. - ring. - Qed. - - Definition small {n} (p : Z^n) : Prop := - forall x, In x (to_list _ p) -> 0 <= x < bound. - -End UniformWeight. - -Module Positional. - Section Positional. - Context {s:Z}. (* s is bitwidth *) - Let small {n} := @small s n. - Section GenericOp. - Context {op : Z -> Z -> Z} - {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) - {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) - - Fixpoint chain_op'_cps {n}: - option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T := - match n with - | O => fun c p _ _ f => - let carry := match c with | None => 0 | Some x => x end in - f (carry,p) - | S n' => - fun c p q _ f => - (* for the first call, use op_get_carry, then op_with_carry *) - let op' := match c with - | None => op_get_carry - | Some x => op_with_carry x end in - dlet carry_result := op' (hd p) (hd q) in - chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) _ - (fun carry_pq => - f (fst carry_pq, - append (fst carry_result) (snd carry_pq))) - end. - Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id. - Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f. - Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id. - - Lemma chain_op'_id {n} : forall c p q T f, - @chain_op'_cps n c p q T f = f (chain_op' c p q). - Proof. - cbv [chain_op']; induction n; intros; destruct c; - simpl chain_op'_cps; cbv [Let_In]; try reflexivity. - { etransitivity; rewrite IHn; reflexivity. } - { etransitivity; rewrite IHn; reflexivity. } - Qed. - - Lemma chain_op_id {n} p q T f : - @chain_op_cps n p q T f = f (chain_op p q). - Proof. apply chain_op'_id. Qed. - End GenericOp. - - Section AddSub. - Let eval {n} := B.Positional.eval (n:=n) (uweight s). - - Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.add_get_carry_full s) - (op_with_carry := Z.add_with_get_carry_full s) - p q f. - Definition sat_add {n} p q := @sat_add_cps n p q _ id. - - Lemma sat_add_id n p q T f : - @sat_add_cps n p q T f = f (sat_add p q). - Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed. - - Lemma sat_add_mod n p q : - eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n). - Admitted. - - Lemma sat_add_div n p q : - fst (@sat_add n p q) = (eval p + eval q) / (uweight s n). - Admitted. - - Lemma small_sat_add n p q : small (snd (@sat_add n p q)). - Admitted. - - Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.sub_get_borrow_full s) - (op_with_carry := Z.sub_with_get_borrow_full s) - p q f. - Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. - - Lemma sat_sub_id n p q T f : - @sat_sub_cps n p q T f = f (sat_sub p q). - Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed. - - Lemma sat_sub_mod n p q : - eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n). - Admitted. - - Lemma sat_sub_div n p q : - fst (@sat_sub n p q) = - ((eval p - eval q) / uweight s n). - Admitted. - - Lemma small_sat_sub n p q : small (snd (@sat_sub n p q)). - Admitted. - - End AddSub. - End Positional. -End Positional. -Hint Opaque Positional.sat_sub Positional.sat_add Positional.chain_op Positional.chain_op' : uncps. -Hint Rewrite @Positional.sat_sub_id @Positional.sat_add_id @Positional.chain_op_id @Positional.chain_op' : uncps. -Hint Rewrite @Positional.sat_sub_mod @Positional.sat_sub_div @Positional.sat_add_mod @Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. - -Section API. - Context (bound : Z) {bound_pos : bound > 0}. - Definition T : nat -> Type := tuple Z. - - (* lowest limb is less than its bound; this is required for [divmod] - to simply separate the lowest limb from the rest and be equivalent - to normal div/mod with [bound]. *) - Local Notation small := (@small bound). - - Definition zero {n:nat} : T n := B.Positional.zeros n. - - (** Returns 0 iff all limbs are 0 *) - Definition nonzero_cps {n} (p : T n) {cpsT} (f : Z -> cpsT) : cpsT - := CPSUtil.to_list_cps _ p (fun p => CPSUtil.fold_right_cps runtime_lor 0%Z p f). - Definition nonzero {n} (p : T n) : Z - := nonzero_cps p id. - - Definition join0_cps {n:nat} (p : T n) {R} (f:T (S n) -> R) - := Tuple.left_append_cps 0 p f. - Definition join0 {n} p : T (S n) := @join0_cps n p _ id. - - Definition divmod_cps {n} (p : T (S n)) {R} (f:T n * Z->R) : R - := Tuple.tl_cps p (fun d => Tuple.hd_cps p (fun m => f (d, m))). - Definition divmod {n} p : T n * Z := @divmod_cps n p _ id. - - Definition drop_high_cps {n : nat} (p : T (S n)) {R} (f:T n->R) - := Tuple.left_tl_cps p f. - Definition drop_high {n} p : T n := @drop_high_cps n p _ id. - - Definition scmul_cps {n} (c : Z) (p : T n) {R} (f:T (S n)->R) := - Columns.mul_cps (n1:=1) (n3:=S n) (uweight bound) bound c p - (* The carry that comes out of Columns.mul_cps will be 0, since - (S n) limbs is enough to hold the result of the - multiplication, so we can safely discard it. *) - (fun carry_result =>f (snd carry_result)). - Definition scmul {n} c p : T (S n) := @scmul_cps n c p _ id. - - Definition add_cps {n} (p q: T n) {R} (f:T (S n)->R) := - Positional.sat_add_cps (s:=bound) p q _ - (* join the last carry *) - (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) f). - Definition add {n} p q : T (S n) := @add_cps n p q _ id. - - (* Wrappers for additions with slightly uneven limb counts *) - Definition add_S1_cps {n} (p: T (S n)) (q: T n) {R} (f:T (S (S n))->R) := - join0_cps q (fun Q => add_cps p Q f). - Definition add_S1 {n} p q := @add_S1_cps n p q _ id. - Definition add_S2_cps {n} (p: T n) (q: T (S n)) {R} (f:T (S (S n))->R) := - join0_cps p (fun P => add_cps P q f). - Definition add_S2 {n} p q := @add_S2_cps n p q _ id. ->>>>>>> addsubchains - - Definition sub_then_maybe_add_cps {n} mask (p q r : T n) - {R} (f:T n -> R) := - Positional.sat_sub_cps (s:=bound) p q _ - (* the carry will be 0 unless we underflow--we do the addition only - in the underflow case *) - (fun carry_result => - B.Positional.select_cps mask (fst carry_result) r - (fun selected => join0_cps selected - (fun selected' => - Positional.sat_sub_cps (s:=bound) (left_append (fst carry_result) (snd carry_result)) selected' _ - (* We can now safely discard the carry and the highest digit. - This relies on the precondition that p - q + r < bound^n. *) - (fun carry_result' => drop_high_cps (snd carry_result') f)))). - Definition sub_then_maybe_add {n} mask (p q r : T n) := - sub_then_maybe_add_cps mask p q r id. - - (* Subtract q if and only if p >= q. We rely on the preconditions - that 0 <= p < 2*q and q < bound^n (this ensures the output is less - than bound^n). *) - Definition conditional_sub_cps {n} (p:Z^S n) (q:Z^n) R (f:Z^n->R) := - join0_cps q - (fun qq => Positional.sat_sub_cps (s:=bound) p qq _ - (* if carry is zero, we select the result of the subtraction, - otherwise the first input *) - (fun carry_result => - Tuple.map2_cps (Z.zselect (fst carry_result)) (snd carry_result) p - (* in either case, since our result must be < q and therefore < - bound^n, we can drop the high digit *) - (fun r => drop_high_cps r f))). - Definition conditional_sub {n} p q := @conditional_sub_cps n p q _ id. - - Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps. - - Section CPSProofs. - - Local Ltac prove_id := - repeat autounfold; autorewrite with uncps; reflexivity. - - Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p). - Proof. cbv [nonzero nonzero_cps]. prove_id. Qed. - - Lemma join0_id n p R f : - @join0_cps n p R f = f (join0 p). - Proof. cbv [join0_cps join0]. prove_id. Qed. - - Lemma divmod_id n p R f : - @divmod_cps n p R f = f (divmod p). - Proof. cbv [divmod_cps divmod]; prove_id. Qed. - - Lemma drop_high_id n p R f : - @drop_high_cps n p R f = f (drop_high p). - Proof. cbv [drop_high_cps drop_high]; prove_id. Qed. - Hint Rewrite drop_high_id : uncps. - - Lemma scmul_id n c p R f : - @scmul_cps n c p R f = f (scmul c p). - Proof. cbv [scmul_cps scmul]. prove_id. Qed. - - Lemma add_id n p q R f : - @add_cps n p q R f = f (add p q). - Proof. cbv [add_cps add Let_In]. prove_id. Qed. - Hint Rewrite add_id : uncps. - - Lemma add_S1_id n p q R f : - @add_S1_cps n p q R f = f (add_S1 p q). - Proof. cbv [add_S1_cps add_S1 join0_cps]. prove_id. Qed. - - Lemma add_S2_id n p q R f : - @add_S2_cps n p q R f = f (add_S2 p q). - Proof. cbv [add_S2_cps add_S2 join0_cps]. prove_id. Qed. - - Lemma sub_then_maybe_add_id n mask p q r R f : - @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r). - Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add join0_cps Let_In]. prove_id. Qed. - - Lemma conditional_sub_id n p q R f : - @conditional_sub_cps n p q R f = f (conditional_sub p q). - Proof. cbv [conditional_sub_cps conditional_sub join0_cps Let_In]. prove_id. Qed. - - End CPSProofs. - Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps. - - Section Proofs. - - Definition eval {n} (p : T n) : Z := - B.Positional.eval (uweight bound) p. - - Lemma eval_small n (p : T n) (Hsmall : small p) : - 0 <= eval p < uweight bound n. - Proof. - cbv [small eval] in *; intros. - induction n; cbv [T uweight] in *; [destruct p|rewrite (subst_left_append p)]; - repeat match goal with - | _ => progress autorewrite with push_basesystem_eval - | _ => rewrite Z.pow_0_r - | _ => specialize (IHn (left_tl p)) - | _ => - let H := fresh "H" in - match type of IHn with - ?P -> _ => assert P as H by auto using Tuple.In_to_list_left_tl; - specialize (IHn H) - end - | |- context [?b ^ Z.of_nat (S ?n)] => - replace (b ^ Z.of_nat (S n)) with (b ^ Z.of_nat n * b) by - (rewrite Nat2Z.inj_succ, <-Z.add_1_r, Z.pow_add_r, - Z.pow_1_r by (omega || auto using Nat2Z.is_nonneg); - reflexivity) - | _ => omega - end. - - specialize (Hsmall _ (Tuple.In_left_hd _ p)). - split; [Z.zero_bounds; omega |]. - apply Z.lt_le_trans with (m:=bound^Z.of_nat n * (left_hd p+1)). - { rewrite Z.mul_add_distr_l. - apply Z.add_le_lt_mono; omega. } - { apply Z.mul_le_mono_nonneg; omega. } - Qed. - - Lemma eval_zero n : eval (@zero n) = 0. - Proof. - cbv [eval zero]. - autorewrite with push_basesystem_eval. - reflexivity. - Qed. - - Lemma small_zero n : small (@zero n). - Proof. - cbv [zero small B.Positional.zeros]. destruct n; [simpl;tauto|]. - rewrite to_list_repeat. - intros x H; apply repeat_spec in H; subst x; omega. - Qed. - - Lemma eval_pair n (p : T (S (S n))) : small p -> (snd p = 0 /\ eval (n:=S n) (fst p) = 0) <-> eval p = 0. - Admitted. - - Lemma eval_nonzero n p : small p -> @nonzero n p = 0 <-> eval p = 0. - Proof. - destruct n as [|n]. - { compute; split; trivial. } - induction n as [|n IHn]. - { simpl; rewrite Z.lor_0_r; unfold eval, id. - cbv -[Z.add iff]. - rewrite Z.add_0_r. - destruct p; omega. } - { destruct p as [ps p]; specialize (IHn ps). - unfold nonzero, nonzero_cps in *. - autorewrite with uncps in *. - unfold id in *. - setoid_rewrite to_list_S. - set (k := S n) in *; simpl in *. - intro Hsmall. - rewrite Z.lor_eq_0_iff, IHn - by (hnf in Hsmall |- *; simpl in *; eauto); - clear IHn. - exact (eval_pair n (ps, p) Hsmall). } - Qed. - - Lemma eval_join0 n p - : eval (@join0 n p) = eval p. - Proof. - Admitted. - - Local Ltac pose_uweight bound := - match goal with H : bound > 0 |- _ => - pose proof (uweight_0 bound); - pose proof (@uweight_positive bound H); - pose proof (@uweight_nonzero bound H); - pose proof (@uweight_multiples bound); - pose proof (@uweight_divides bound H) - end. - - Local Ltac pose_all := - pose_uweight bound; - pose proof Z.add_get_carry_full_div; - pose proof Z.add_get_carry_full_mod; - pose proof Z.mul_split_div; pose proof Z.mul_split_mod; - pose proof div_correct; pose proof modulo_correct. - - Lemma eval_add_nz n p q : - n <> 0%nat -> - eval (@add n p q) = eval p + eval q. - Proof. - intros. pose_all. - repeat match goal with - | _ => progress (cbv [add_cps add eval Let_In] in *; repeat autounfold) - | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval - | _ => rewrite B.Positional.eval_left_append - - | _ => progress - (rewrite <-!from_list_default_eq with (d:=0); - erewrite !length_to_list, !from_list_default_eq, - from_list_to_list) - | _ => apply Z.mod_small; omega - end. - Admitted. - - Lemma eval_add_z n p q : - n = 0%nat -> - eval (@add n p q) = eval p + eval q. - Proof. intros; subst; reflexivity. Qed. - - Lemma eval_add n p q - : eval (@add n p q) = eval p + eval q. - Proof. - destruct (Nat.eq_dec n 0%nat); intuition auto using eval_add_z, eval_add_nz. - Qed. - Lemma eval_add_same n p q - : eval (@add n p q) = eval p + eval q. - Proof. apply eval_add; omega. Qed. - Lemma eval_add_S1 n p q - : eval (@add_S1 n p q) = eval p + eval q. - Proof. - cbv [add_S1 add_S1_cps]. autorewrite with uncps push_id. - (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*) - Admitted. - Lemma eval_add_S2 n p q - : eval (@add_S2 n p q) = eval p + eval q. - Proof. - cbv [add_S2 add_S2_cps]. autorewrite with uncps push_id. - (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*) - Admitted. ->>>>>>> addsubchains - Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval. - - Lemma uweight_le_mono n m : (n <= m)%nat -> - uweight bound n <= uweight bound m. - Proof. - unfold uweight; intro; Z.peel_le; omega. - Qed. - - Lemma uweight_lt_mono (bound_gt_1 : bound > 1) n m : (n < m)%nat -> - uweight bound n < uweight bound m. - Proof. - clear bound_pos. - unfold uweight; intro; apply Z.pow_lt_mono_r; omega. - Qed. - - Lemma uweight_succ n : uweight bound (S n) = bound * uweight bound n. - Proof. - unfold uweight. - rewrite Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg; reflexivity. - Qed. - - Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). - Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). - Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)). - Proof. - pose_all. - match goal with - |- ?G => assert (G /\ fst (compact p) = fst (compact p)); [|tauto] - end. (* assert a dummy second statement so that fst (compact x) is in context *) - cbv [compact Columns.compact Columns.compact_cps small - Columns.compact_step Columns.compact_step_cps]; - autorewrite with uncps push_id. - change (fun i s a => Columns.compact_digit_cps (uweight bound) i (s :: a) id) - with (fun i s a => compact_digit i (s :: a)). - remember (fun i s a => compact_digit i (s :: a)) as f. - - apply @mapi_with'_linvariant with (n:=n) (f:=f) (inp:=p); - intros; [|simpl; tauto]. split; [|reflexivity]. - let P := fresh "H" in - match goal with H : _ /\ _ |- _ => destruct H end. - destruct n0; subst f. - { cbv [compact_digit uweight to_list to_list' In]. - rewrite Columns.compact_digit_mod by assumption. - rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?. - match goal with - H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end. - subst x. apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } - { rewrite Tuple.to_list_left_append. - let H := fresh "H" in - intros x H; apply in_app_or in H; destruct H; - [solve[auto]| cbv [In] in H; destruct H; - [|exfalso; assumption] ]. - subst x. cbv [compact_digit]. - rewrite Columns.compact_digit_mod by assumption. - rewrite !uweight_succ, Z.div_mul by - (apply Z.neq_mul_0; split; auto; omega). - apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } - Qed. - - Lemma small_add n a b : - (2 <= bound) -> - small a -> small b -> small (@add n a b). - Proof. - intros. pose_all. - cbv [add_cps add Let_In]. - autorewrite with uncps push_id. - apply Positional.small_sat_add. - (*apply Positional.small_sat_add.*) - Admitted. - - Lemma small_add_S1 n a b : - (2 <= bound) -> - small a -> small b -> small (@add_S1 n a b). - Proof. - intros. pose_all. - cbv [add_cps add add_S1 Let_In]. - autorewrite with uncps push_id. - (*apply Positional.small_sat_add.*) - Admitted. - - Lemma small_add_S2 n a b : - (2 <= bound) -> - small a -> small b -> small (@add_S2 n a b). - Proof. - intros. pose_all. - cbv [add_cps add add_S2 Let_In]. - autorewrite with uncps push_id. - (*apply Positional.small_sat_add.*) ->>>>>>> addsubchains - Admitted. - - Lemma small_left_tl n (v:T (S n)) : small v -> small (left_tl v). - Proof. cbv [small]. auto using Tuple.In_to_list_left_tl. Qed. - - Lemma small_divmod n (p: T (S n)) (Hsmall : small p) : - left_hd p = eval p / uweight bound n /\ eval (left_tl p) = eval p mod (uweight bound n). - Admitted. - - Lemma eval_drop_high n v : - small v -> eval (@drop_high n v) = eval v mod (uweight bound n). - Proof. - cbv [drop_high drop_high_cps eval]. - rewrite Tuple.left_tl_cps_correct, push_id. (* TODO : for some reason autorewrite with uncps doesn't work here *) - intro H. apply small_left_tl in H. - rewrite (subst_left_append v) at 2. - autorewrite with push_basesystem_eval. - apply eval_small in H. - rewrite Z.mod_add_l' by (pose_uweight bound; auto). - rewrite Z.mod_small; auto. - Qed. - - Lemma small_drop_high n v : small v -> small (@drop_high n v). - Proof. - cbv [drop_high drop_high_cps]. - rewrite Tuple.left_tl_cps_correct, push_id. - apply small_left_tl. - Qed. - - Lemma div_nonzero_neg_iff x y : x < y -> 0 < y -> x / y <> 0 <-> x < 0. - Proof. - repeat match goal with - | _ => progress intros - | _ => rewrite Z.div_small_iff by omega - | _ => split - | _ => omega - end. - Qed. - - Lemma eval_sub_then_maybe_add_nz n mask p q r: - small p -> small q -> small r -> (n<>0)%nat -> - (map (Z.land mask) r = r) -> - (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> - eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0). - Proof. - pose_all. - repeat match goal with - | _ => progress (cbv [sub_then_maybe_add sub_then_maybe_add_cps eval] in *; intros) - | _ => progress autounfold - | _ => progress autorewrite with uncps push_id push_basesystem_eval - | _ => rewrite eval_drop_high - | _ => rewrite eval_join0 - | H : small _ |- _ => apply eval_small in H - | _ => progress break_match - | _ => (rewrite Z.add_opp_r in * ) - | H : _ |- _ => rewrite Z.ltb_lt in H; - rewrite <-div_nonzero_neg_iff with - (y:=uweight bound n) in H by (auto; omega) - | H : _ |- _ => rewrite Z.ltb_ge in H - | _ => rewrite Z.mod_small by omega - | _ => omega - | _ => progress autorewrite with zsimplify; [ ] - end. - Admitted. - - Lemma eval_sub_then_maybe_add n mask p q r : - small p -> small q -> small r -> - (map (Z.land mask) r = r) -> - (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> - eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0). - Proof. - destruct n; [|solve[auto using eval_sub_then_maybe_add_nz]]. - destruct p, q, r; reflexivity. - Qed. - - Lemma small_sub_then_maybe_add n mask (p q r : T n) : - small (sub_then_maybe_add mask p q r). - Proof. - cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros. - repeat progress autounfold. autorewrite with uncps push_id. - apply small_drop_high, Positional.small_sat_sub. - Qed. - - (* TODO : remove if unneeded when all admits are proven - Lemma small_highest_zero_iff {n} (p: T (S n)) (Hsmall : small p) : - (left_hd p = 0 <-> eval p < uweight bound n). - Proof. - destruct (small_divmod _ p Hsmall) as [Hdiv Hmod]. - pose proof Hsmall as Hsmalltl. apply eval_small in Hsmall. - apply small_left_tl, eval_small in Hsmalltl. rewrite Hdiv. - rewrite (Z.div_small_iff (eval p) (uweight bound n)) - by auto using uweight_nonzero. - split; [|intros; left; omega]. - let H := fresh "H" in intro H; destruct H; [|omega]. - omega. - Qed. - *) - - Lemma map2_zselect n cond x y : - Tuple.map2 (n:=n) (Z.zselect cond) x y = if dec (cond = 0) then x else y. - Proof. - unfold Z.zselect. - break_innermost_match; Z.ltb_to_lt; subst; try omega; - [ rewrite Tuple.map2_fst, Tuple.map_id - | rewrite Tuple.map2_snd, Tuple.map_id ]; - reflexivity. - Qed. - - Lemma eval_conditional_sub_nz n (p:T (S n)) (q:T n) - (n_nonzero: (n <> 0)%nat) (psmall : small p) (qsmall : small q): - 0 <= eval p < eval q + uweight bound n -> - eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0). - Proof. - cbv [conditional_sub conditional_sub_cps]. intros. pose_all. - repeat autounfold. apply eval_small in qsmall. - pose proof psmall; apply eval_small in psmall. - cbv [eval] in *. autorewrite with uncps push_id push_basesystem_eval. - rewrite map2_zselect. - let H := fresh "H" in let X := fresh "P" in - match goal with |- context [?x / ?y] => - pose proof (div_nonzero_neg_iff x y) end; - repeat match type of H with ?P -> _ => - assert P as X by omega; specialize (H X); - clear X end. - - break_match; - repeat match goal with - | _ => progress cbv [eval] - | H : (_ <=? _) = true |- _ => apply Z.leb_le in H - | H : (_ <=? _) = false |- _ => apply Z.leb_gt in H - | _ => rewrite eval_drop_high by auto using Positional.small_sat_sub - | _ => (rewrite eval_join0 in * ) - | _ => progress autorewrite with uncps push_id push_basesystem_eval - | _ => repeat rewrite Z.mod_small; omega - | _ => omega - end. - Admitted. - - Lemma eval_conditional_sub n (p:T (S n)) (q:T n) - (psmall : small p) (qsmall : small q) : - 0 <= eval p < eval q + uweight bound n -> - eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0). - Proof. - destruct n; [|solve[auto using eval_conditional_sub_nz]]. - repeat match goal with - | _ => progress (intros; cbv [T tuple tuple'] in p, q) - | q : unit |- _ => destruct q - | _ => progress (cbv [conditional_sub conditional_sub_cps eval] in * ) - | _ => progress autounfold - | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * ) - | _ => (rewrite uweight_0 in * ) - | _ => assert (p = 0) by omega; subst p; break_match; ring - end. - Qed. - - Lemma small_conditional_sub n (p:T (S n)) (q:T n) - (psmall : small p) (qsmall : small q) : - 0 <= eval p < eval q + uweight bound n -> - small (conditional_sub p q). - Admitted. - - Lemma eval_scmul n a v : small v -> 0 <= a < bound -> - eval (@scmul n a v) = a * eval v. - Proof. - intro Hsmall. pose_all. apply eval_small in Hsmall. - intros. cbv [scmul scmul_cps eval] in *. repeat autounfold. - autorewrite with uncps push_id push_basesystem_eval. - rewrite uweight_0, Z.mul_1_l. apply Z.mod_small. - split; [solve[Z.zero_bounds]|]. cbv [uweight] in *. - rewrite !Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg. - apply Z.mul_lt_mono_nonneg; omega. - Qed. - - Lemma small_scmul n a v : small (@scmul n a v). - Proof. - cbv [scmul scmul_cps eval] in *. repeat autounfold. - autorewrite with uncps push_id push_basesystem_eval. - apply small_compact. - Qed. - - (* TODO : move to tuple *) - Lemma from_list_tl {A n} (ls : list A) H H': - from_list n (List.tl ls) H = tl (from_list (S n) ls H'). - Proof. - induction ls; distr_length. simpl List.tl. - rewrite from_list_cons, tl_append, <-!(from_list_default_eq a ls). - reflexivity. - Qed. - - Lemma small_hd n p : @small (S n) p -> 0 <= hd p < bound. - Proof. - cbv [small]. let H := fresh "H" in intro H; apply H. - rewrite (subst_append p). rewrite to_list_append, hd_append. - apply in_eq. - Qed. - - - Lemma eval_div n p : small p -> eval (fst (@divmod n p)) = eval p / bound. - Proof. - cbv [divmod divmod_cps eval]. intros. - autorewrite with uncps push_id cancel_pair. - rewrite (subst_append p) at 2. - rewrite uweight_eval_step. rewrite hd_append, tl_append. - rewrite Z.div_add' by omega. rewrite Z.div_small by auto using small_hd. - ring. - Qed. - - Lemma eval_mod n p : small p -> snd (@divmod n p) = eval p mod bound. - Proof. - cbv [divmod divmod_cps eval]. intros. - autorewrite with uncps push_id cancel_pair. - rewrite (subst_append p) at 2. - rewrite uweight_eval_step, Z.mod_add'_full, hd_append. - rewrite Z.mod_small by auto using small_hd. reflexivity. - Qed. - - Lemma small_div n v : small v -> small (fst (@divmod n v)). - Admitted. - - End Proofs. -End API. -Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id add_S1_id add_S2_id sub_then_maybe_add_id conditional_sub_id : uncps. - -(* -(* Just some pretty-printing *) -Local Notation "fst~ a" := (let (x,_) := a in x) (at level 40, only printing). -Local Notation "snd~ a" := (let (_,y) := a in y) (at level 40, only printing). - -(* Simple example : base 10, multiply two bignums and compact them *) -Definition base10 i := Eval compute in 10^(Z.of_nat i). -Eval cbv -[runtime_add runtime_mul Let_In] in - (fun adc a0 a1 a2 b0 b1 b2 => - Columns.mul_cps (weight := base10) (n:=3) (a2,a1,a0) (b2,b1,b0) (fun ab => Columns.compact (n:=5) (add_get_carry:=adc) (weight:=base10) ab)). - -(* More complex example : base 2^56, 8 limbs *) -Definition base2pow56 i := Eval compute in 2^(56*Z.of_nat i). -Time Eval cbv -[runtime_add runtime_mul Let_In] in - (fun adc a0 a1 a2 a3 a4 a5 a6 a7 b0 b1 b2 b3 b4 b5 b6 b7 => - Columns.mul_cps (weight := base2pow56) (n:=8) (a7,a6,a5,a4,a3,a2,a1,a0) (b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=15) (add_get_carry:=adc) (weight:=base2pow56) ab)). (* Finished transaction in 151.392 secs *) - -(* Mixed-radix example : base 2^25.5, 10 limbs *) -Definition base2pow25p5 i := Eval compute in 2^(25*Z.of_nat i + ((Z.of_nat i + 1) / 2)). -Time Eval cbv -[runtime_add runtime_mul Let_In] in - (fun adc a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 => - Columns.mul_cps (weight := base2pow25p5) (n:=10) (a9,a8,a7,a6,a5,a4,a3,a2,a1,a0) (b9,b8,b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=19) (add_get_carry:=adc) (weight:=base2pow25p5) ab)). (* Finished transaction in 97.341 secs *) -*)
\ No newline at end of file diff --git a/src/Arithmetic/Saturated/Freeze.v b/src/Arithmetic/Saturated/Freeze.v new file mode 100644 index 000000000..735663636 --- /dev/null +++ b/src/Arithmetic/Saturated/Freeze.v @@ -0,0 +1,122 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Arithmetic.Saturated.Core. +Require Import Crypto.Arithmetic.Saturated.Wrappers. +Require Import Crypto.Util.ZUtil.AddGetCarry. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Decidable Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple Crypto.Util.LetIn. +Local Notation "A ^ n" := (tuple A n) : type_scope. + +(* Canonicalize bignums by fully reducing them modulo p. + This works on unsaturated digits, but uses saturated add/subtract + loops.*) +Section Freeze. + Context (weight : nat->Z) + {weight_0 : weight 0%nat = 1} + {weight_nonzero : forall i, weight i <> 0} + {weight_positive : forall i, weight i > 0} + {weight_multiples : forall i, weight (S i) mod weight i = 0} + {weight_divides : forall i : nat, weight (S i) / weight i > 0} + . + + + (* + The input to [freeze] should be less than 2*m (this can probably + be accomplished by a single carry_reduce step, for most moduli). + + [freeze] has the following steps: + (1) subtract modulus in a carrying loop (in our framework, this + consists of two steps; [Columns.unbalanced_sub_cps] combines the + input p and the modulus m such that the ith limb in the output is + the list [p[i];-m[i]]. We can then call [Columns.compact].) + (2) look at the final carry, which should be either 0 or -1. If + it's -1, then we add the modulus back in. Otherwise we add 0 for + constant-timeness. + (3) discard the carry after this last addition; it should be 1 if + the carry in step 3 was -1, so they cancel out. + *) + Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := + Columns.unbalanced_sub_cps (n3:=n) weight p m + (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m + (fun carry_r => f (snd carry_r))) + . + + Definition freeze {n} mask m p := + @freeze_cps n mask m p _ id. + Lemma freeze_id {n} mask m p T f: + @freeze_cps n mask m p T f = f (freeze mask m p). + Proof. + cbv [freeze_cps freeze]; repeat progress autounfold; + autorewrite with uncps push_id; reflexivity. + Qed. + Hint Opaque freeze : uncps. + Hint Rewrite @freeze_id : uncps. + + Lemma freezeZ m s c y y0 z z0 c0 a : + m = s - c -> + 0 < c < s -> + s <> 0 -> + 0 <= y < 2*m -> + y0 = y - m -> + z = y0 mod s -> + c0 = y0 / s -> + z0 = z + (if (dec (c0 = 0)) then 0 else m) -> + a = z0 mod s -> + a mod m = y0 mod m. + Proof. + clear. intros. subst. break_match. + { rewrite Z.add_0_r, Z.mod_mod by omega. + assert (-(s-c) <= y - (s-c) < s-c) by omega. + match goal with H : s <> 0 |- _ => + rewrite (proj2 (Z.mod_small_iff _ s H)) + by (apply Z.div_small_iff; assumption) + end. + reflexivity. } + { rewrite <-Z.add_mod_l, Z.sub_mod_full. + rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega. + rewrite Z.mod_small with (b := s) + by (pose proof (Z.div_small (y - (s-c)) s); omega). + f_equal. ring. } + Qed. + + Lemma eval_freeze {n} c mask m p + (n_nonzero:n<>0%nat) + (Hc : 0 < B.Associational.eval c < weight n) + (Hmask : Tuple.map (Z.land mask) m = m) + modulus (Hm : B.Positional.eval weight m = Z.pos modulus) + (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus)) + (Hsc : Z.pos modulus = weight n - B.Associational.eval c) + : + mod_eq modulus + (B.Positional.eval weight (@freeze n mask m p)) + (B.Positional.eval weight p). + Proof. + cbv [freeze_cps freeze]. + repeat progress autounfold. + pose proof Z.add_get_carry_full_mod. + pose proof Z.add_get_carry_full_div. + pose proof div_correct. pose proof modulo_correct. + autorewrite with uncps push_id push_basesystem_eval. + + pose proof (weight_nonzero n). + + remember (B.Positional.eval weight p) as y. + remember (y + -B.Positional.eval weight m) as y0. + rewrite Hm in *. + + transitivity y0; cbv [mod_eq]. + { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0); + try assumption; reflexivity. } + { subst y0. + assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0. + rewrite Z.add_mod by assumption. + rewrite Z.mod_opp_l_z by auto using Z.mod_same. + rewrite Z.add_0_r, Z.mod_mod by assumption. + reflexivity. } + Qed. +End Freeze.
\ No newline at end of file diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v new file mode 100644 index 000000000..0ce1ac265 --- /dev/null +++ b/src/Arithmetic/Saturated/MontgomeryAPI.v @@ -0,0 +1,599 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Arithmetic.Saturated.Core. +Require Import Crypto.Arithmetic.Saturated.UniformWeight. +Require Import Crypto.Arithmetic.Saturated.Wrappers. +Require Import Crypto.Arithmetic.Saturated.AddSub. +Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. +Require Import Crypto.Util.Tuple Crypto.Util.LetIn. +Require Import Crypto.Util.Tactics Crypto.Util.Decidable. +Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddGetCarry. +Require Import Crypto.Util.ZUtil.MulSplit. +Local Notation "A ^ n" := (tuple A n) : type_scope. + +Section API. + Context (bound : Z) {bound_pos : bound > 0}. + Definition T : nat -> Type := tuple Z. + + (* lowest limb is less than its bound; this is required for [divmod] + to simply separate the lowest limb from the rest and be equivalent + to normal div/mod with [bound]. *) + Local Notation small := (@small bound). + + Definition zero {n:nat} : T n := B.Positional.zeros n. + + (** Returns 0 iff all limbs are 0 *) + Definition nonzero_cps {n} (p : T n) {cpsT} (f : Z -> cpsT) : cpsT + := CPSUtil.to_list_cps _ p (fun p => CPSUtil.fold_right_cps runtime_lor 0%Z p f). + Definition nonzero {n} (p : T n) : Z + := nonzero_cps p id. + + Definition join0_cps {n:nat} (p : T n) {R} (f:T (S n) -> R) + := Tuple.left_append_cps 0 p f. + Definition join0 {n} p : T (S n) := @join0_cps n p _ id. + + Definition divmod_cps {n} (p : T (S n)) {R} (f:T n * Z->R) : R + := Tuple.tl_cps p (fun d => Tuple.hd_cps p (fun m => f (d, m))). + Definition divmod {n} p : T n * Z := @divmod_cps n p _ id. + + Definition drop_high_cps {n : nat} (p : T (S n)) {R} (f:T n->R) + := Tuple.left_tl_cps p f. + Definition drop_high {n} p : T n := @drop_high_cps n p _ id. + + Definition scmul_cps {n} (c : Z) (p : T n) {R} (f:T (S n)->R) := + Columns.mul_cps (n1:=1) (n3:=S n) (uweight bound) bound c p + (* The carry that comes out of Columns.mul_cps will be 0, since + (S n) limbs is enough to hold the result of the + multiplication, so we can safely discard it. *) + (fun carry_result =>f (snd carry_result)). + Definition scmul {n} c p : T (S n) := @scmul_cps n c p _ id. + + Definition add_cps {n} (p q: T n) {R} (f:T (S n)->R) := + B.Positional.sat_add_cps (s:=bound) p q _ + (* join the last carry *) + (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) f). + Definition add {n} p q : T (S n) := @add_cps n p q _ id. + + (* Wrappers for additions with slightly uneven limb counts *) + Definition add_S1_cps {n} (p: T (S n)) (q: T n) {R} (f:T (S (S n))->R) := + join0_cps q (fun Q => add_cps p Q f). + Definition add_S1 {n} p q := @add_S1_cps n p q _ id. + Definition add_S2_cps {n} (p: T n) (q: T (S n)) {R} (f:T (S (S n))->R) := + join0_cps p (fun P => add_cps P q f). + Definition add_S2 {n} p q := @add_S2_cps n p q _ id. + + Definition sub_then_maybe_add_cps {n} mask (p q r : T n) + {R} (f:T n -> R) := + B.Positional.sat_sub_cps (s:=bound) p q _ + (* the carry will be 0 unless we underflow--we do the addition only + in the underflow case *) + (fun carry_result => + B.Positional.select_cps mask (fst carry_result) r + (fun selected => join0_cps selected + (fun selected' => + B.Positional.sat_sub_cps (s:=bound) (left_append (fst carry_result) (snd carry_result)) selected' _ + (* We can now safely discard the carry and the highest digit. + This relies on the precondition that p - q + r < bound^n. *) + (fun carry_result' => drop_high_cps (snd carry_result') f)))). + Definition sub_then_maybe_add {n} mask (p q r : T n) := + sub_then_maybe_add_cps mask p q r id. + + (* Subtract q if and only if p >= q. We rely on the preconditions + that 0 <= p < 2*q and q < bound^n (this ensures the output is less + than bound^n). *) + Definition conditional_sub_cps {n} (p:Z^S n) (q:Z^n) R (f:Z^n->R) := + join0_cps q + (fun qq => B.Positional.sat_sub_cps (s:=bound) p qq _ + (* if carry is zero, we select the result of the subtraction, + otherwise the first input *) + (fun carry_result => + Tuple.map2_cps (Z.zselect (fst carry_result)) (snd carry_result) p + (* in either case, since our result must be < q and therefore < + bound^n, we can drop the high digit *) + (fun r => drop_high_cps r f))). + Definition conditional_sub {n} p q := @conditional_sub_cps n p q _ id. + + Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps. + + Section CPSProofs. + + Local Ltac prove_id := + repeat autounfold; autorewrite with uncps; reflexivity. + + Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p). + Proof. cbv [nonzero nonzero_cps]. prove_id. Qed. + + Lemma join0_id n p R f : + @join0_cps n p R f = f (join0 p). + Proof. cbv [join0_cps join0]. prove_id. Qed. + + Lemma divmod_id n p R f : + @divmod_cps n p R f = f (divmod p). + Proof. cbv [divmod_cps divmod]; prove_id. Qed. + + Lemma drop_high_id n p R f : + @drop_high_cps n p R f = f (drop_high p). + Proof. cbv [drop_high_cps drop_high]; prove_id. Qed. + Hint Rewrite drop_high_id : uncps. + + Lemma scmul_id n c p R f : + @scmul_cps n c p R f = f (scmul c p). + Proof. cbv [scmul_cps scmul]. prove_id. Qed. + + Lemma add_id n p q R f : + @add_cps n p q R f = f (add p q). + Proof. cbv [add_cps add Let_In]. prove_id. Qed. + Hint Rewrite add_id : uncps. + + Lemma add_S1_id n p q R f : + @add_S1_cps n p q R f = f (add_S1 p q). + Proof. cbv [add_S1_cps add_S1 join0_cps]. prove_id. Qed. + + Lemma add_S2_id n p q R f : + @add_S2_cps n p q R f = f (add_S2 p q). + Proof. cbv [add_S2_cps add_S2 join0_cps]. prove_id. Qed. + + Lemma sub_then_maybe_add_id n mask p q r R f : + @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r). + Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add join0_cps Let_In]. prove_id. Qed. + + Lemma conditional_sub_id n p q R f : + @conditional_sub_cps n p q R f = f (conditional_sub p q). + Proof. cbv [conditional_sub_cps conditional_sub join0_cps Let_In]. prove_id. Qed. + + End CPSProofs. + Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps. + + Section Proofs. + + Definition eval {n} (p : T n) : Z := + B.Positional.eval (uweight bound) p. + + Lemma eval_small n (p : T n) (Hsmall : small p) : + 0 <= eval p < uweight bound n. + Proof. + cbv [small eval] in *; intros. + induction n; cbv [T uweight] in *; [destruct p|rewrite (subst_left_append p)]; + repeat match goal with + | _ => progress autorewrite with push_basesystem_eval + | _ => rewrite Z.pow_0_r + | _ => specialize (IHn (left_tl p)) + | _ => + let H := fresh "H" in + match type of IHn with + ?P -> _ => assert P as H by auto using Tuple.In_to_list_left_tl; + specialize (IHn H) + end + | |- context [?b ^ Z.of_nat (S ?n)] => + replace (b ^ Z.of_nat (S n)) with (b ^ Z.of_nat n * b) by + (rewrite Nat2Z.inj_succ, <-Z.add_1_r, Z.pow_add_r, + Z.pow_1_r by (omega || auto using Nat2Z.is_nonneg); + reflexivity) + | _ => omega + end. + + specialize (Hsmall _ (Tuple.In_left_hd _ p)). + split; [Z.zero_bounds; omega |]. + apply Z.lt_le_trans with (m:=bound^Z.of_nat n * (left_hd p+1)). + { rewrite Z.mul_add_distr_l. + apply Z.add_le_lt_mono; omega. } + { apply Z.mul_le_mono_nonneg; omega. } + Qed. + + Lemma eval_zero n : eval (@zero n) = 0. + Proof. + cbv [eval zero]. + autorewrite with push_basesystem_eval. + reflexivity. + Qed. + + Lemma small_zero n : small (@zero n). + Proof. + cbv [zero small B.Positional.zeros]. destruct n; [simpl;tauto|]. + rewrite to_list_repeat. + intros x H; apply repeat_spec in H; subst x; omega. + Qed. + + Lemma eval_pair n (p : T (S (S n))) : small p -> (snd p = 0 /\ eval (n:=S n) (fst p) = 0) <-> eval p = 0. + Admitted. + + Lemma eval_nonzero n p : small p -> @nonzero n p = 0 <-> eval p = 0. + Proof. + destruct n as [|n]. + { compute; split; trivial. } + induction n as [|n IHn]. + { simpl; rewrite Z.lor_0_r; unfold eval, id. + cbv -[Z.add iff]. + rewrite Z.add_0_r. + destruct p; omega. } + { destruct p as [ps p]; specialize (IHn ps). + unfold nonzero, nonzero_cps in *. + autorewrite with uncps in *. + unfold id in *. + setoid_rewrite to_list_S. + set (k := S n) in *; simpl in *. + intro Hsmall. + rewrite Z.lor_eq_0_iff, IHn + by (hnf in Hsmall |- *; simpl in *; eauto); + clear IHn. + exact (eval_pair n (ps, p) Hsmall). } + Qed. + + Lemma eval_join0 n p + : eval (@join0 n p) = eval p. + Proof. + Admitted. + + Local Ltac pose_uweight bound := + match goal with H : bound > 0 |- _ => + pose proof (uweight_0 bound); + pose proof (@uweight_positive bound H); + pose proof (@uweight_nonzero bound H); + pose proof (@uweight_multiples bound); + pose proof (@uweight_divides bound H) + end. + + Local Ltac pose_all := + pose_uweight bound; + pose proof Z.add_get_carry_full_div; + pose proof Z.add_get_carry_full_mod; + pose proof Z.mul_split_div; pose proof Z.mul_split_mod; + pose proof div_correct; pose proof modulo_correct. + + Lemma eval_add_nz n p q : + n <> 0%nat -> + eval (@add n p q) = eval p + eval q. + Proof. + intros. pose_all. + repeat match goal with + | _ => progress (cbv [add_cps add eval Let_In] in *; repeat autounfold) + | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval + | _ => rewrite B.Positional.eval_left_append + + | _ => progress + (rewrite <-!from_list_default_eq with (d:=0); + erewrite !length_to_list, !from_list_default_eq, + from_list_to_list) + | _ => apply Z.mod_small; omega + end. + Admitted. + + Lemma eval_add_z n p q : + n = 0%nat -> + eval (@add n p q) = eval p + eval q. + Proof. intros; subst; reflexivity. Qed. + + Lemma eval_add n p q + : eval (@add n p q) = eval p + eval q. + Proof. + destruct (Nat.eq_dec n 0%nat); intuition auto using eval_add_z, eval_add_nz. + Qed. + Lemma eval_add_same n p q + : eval (@add n p q) = eval p + eval q. + Proof. apply eval_add; omega. Qed. + Lemma eval_add_S1 n p q + : eval (@add_S1 n p q) = eval p + eval q. + Proof. + cbv [add_S1 add_S1_cps]. autorewrite with uncps push_id. + (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*) + Admitted. + Lemma eval_add_S2 n p q + : eval (@add_S2 n p q) = eval p + eval q. + Proof. + cbv [add_S2 add_S2_cps]. autorewrite with uncps push_id. + (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*) + Admitted. + Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval. + + Lemma uweight_le_mono n m : (n <= m)%nat -> + uweight bound n <= uweight bound m. + Proof. + unfold uweight; intro; Z.peel_le; omega. + Qed. + + Lemma uweight_lt_mono (bound_gt_1 : bound > 1) n m : (n < m)%nat -> + uweight bound n < uweight bound m. + Proof. + clear bound_pos. + unfold uweight; intro; apply Z.pow_lt_mono_r; omega. + Qed. + + Lemma uweight_succ n : uweight bound (S n) = bound * uweight bound n. + Proof. + unfold uweight. + rewrite Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg; reflexivity. + Qed. + + Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). + Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). + Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)). + Proof. + pose_all. + match goal with + |- ?G => assert (G /\ fst (compact p) = fst (compact p)); [|tauto] + end. (* assert a dummy second statement so that fst (compact x) is in context *) + cbv [compact Columns.compact Columns.compact_cps small + Columns.compact_step Columns.compact_step_cps]; + autorewrite with uncps push_id. + change (fun i s a => Columns.compact_digit_cps (uweight bound) i (s :: a) id) + with (fun i s a => compact_digit i (s :: a)). + remember (fun i s a => compact_digit i (s :: a)) as f. + + apply @mapi_with'_linvariant with (n:=n) (f:=f) (inp:=p); + intros; [|simpl; tauto]. split; [|reflexivity]. + let P := fresh "H" in + match goal with H : _ /\ _ |- _ => destruct H end. + destruct n0; subst f. + { cbv [compact_digit uweight to_list to_list' In]. + rewrite Columns.compact_digit_mod by assumption. + rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?. + match goal with + H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end. + subst x. apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } + { rewrite Tuple.to_list_left_append. + let H := fresh "H" in + intros x H; apply in_app_or in H; destruct H; + [solve[auto]| cbv [In] in H; destruct H; + [|exfalso; assumption] ]. + subst x. cbv [compact_digit]. + rewrite Columns.compact_digit_mod by assumption. + rewrite !uweight_succ, Z.div_mul by + (apply Z.neq_mul_0; split; auto; omega). + apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } + Qed. + + Lemma small_add n a b : + (2 <= bound) -> + small a -> small b -> small (@add n a b). + Proof. + intros. pose_all. + cbv [add_cps add Let_In]. + autorewrite with uncps push_id. + (*apply Positional.small_sat_add.*) + Admitted. + + Lemma small_add_S1 n a b : + (2 <= bound) -> + small a -> small b -> small (@add_S1 n a b). + Proof. + intros. pose_all. + cbv [add_cps add add_S1 Let_In]. + (*apply Positional.small_sat_add.*) + Admitted. + + Lemma small_add_S2 n a b : + (2 <= bound) -> + small a -> small b -> small (@add_S2 n a b). + Proof. + intros. pose_all. + cbv [add_cps add add_S2 Let_In]. + autorewrite with uncps push_id. + (*apply Positional.small_sat_add.*) + Admitted. + + Lemma small_left_tl n (v:T (S n)) : small v -> small (left_tl v). + Proof. cbv [small]. auto using Tuple.In_to_list_left_tl. Qed. + + Lemma small_divmod n (p: T (S n)) (Hsmall : small p) : + left_hd p = eval p / uweight bound n /\ eval (left_tl p) = eval p mod (uweight bound n). + Admitted. + + Lemma eval_drop_high n v : + small v -> eval (@drop_high n v) = eval v mod (uweight bound n). + Proof. + cbv [drop_high drop_high_cps eval]. + rewrite Tuple.left_tl_cps_correct, push_id. (* TODO : for some reason autorewrite with uncps doesn't work here *) + intro H. apply small_left_tl in H. + rewrite (subst_left_append v) at 2. + autorewrite with push_basesystem_eval. + apply eval_small in H. + rewrite Z.mod_add_l' by (pose_uweight bound; auto). + rewrite Z.mod_small; auto. + Qed. + + Lemma small_drop_high n v : small v -> small (@drop_high n v). + Proof. + cbv [drop_high drop_high_cps]. + rewrite Tuple.left_tl_cps_correct, push_id. + apply small_left_tl. + Qed. + + Lemma div_nonzero_neg_iff x y : x < y -> 0 < y -> x / y <> 0 <-> x < 0. + Proof. + repeat match goal with + | _ => progress intros + | _ => rewrite Z.div_small_iff by omega + | _ => split + | _ => omega + end. + Qed. + + Lemma eval_sub_then_maybe_add_nz n mask p q r: + small p -> small q -> small r -> (n<>0)%nat -> + (map (Z.land mask) r = r) -> + (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> + eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0). + Proof. + pose_all. + repeat match goal with + | _ => progress (cbv [sub_then_maybe_add sub_then_maybe_add_cps eval] in *; intros) + | _ => progress autounfold + | _ => progress autorewrite with uncps push_id push_basesystem_eval + | _ => rewrite eval_drop_high + | _ => rewrite eval_join0 + | H : small _ |- _ => apply eval_small in H + | _ => progress break_match + | _ => (rewrite Z.add_opp_r in * ) + | H : _ |- _ => rewrite Z.ltb_lt in H; + rewrite <-div_nonzero_neg_iff with + (y:=uweight bound n) in H by (auto; omega) + | H : _ |- _ => rewrite Z.ltb_ge in H + | _ => rewrite Z.mod_small by omega + | _ => omega + | _ => progress autorewrite with zsimplify; [ ] + end. + Admitted. + + Lemma eval_sub_then_maybe_add n mask p q r : + small p -> small q -> small r -> + (map (Z.land mask) r = r) -> + (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> + eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0). + Proof. + destruct n; [|solve[auto using eval_sub_then_maybe_add_nz]]. + destruct p, q, r; reflexivity. + Qed. + + Lemma small_sub_then_maybe_add n mask (p q r : T n) : + small (sub_then_maybe_add mask p q r). + Proof. + cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros. + repeat progress autounfold. autorewrite with uncps push_id. + apply small_drop_high, B.Positional.small_sat_sub. + Qed. + + (* TODO : remove if unneeded when all admits are proven + Lemma small_highest_zero_iff {n} (p: T (S n)) (Hsmall : small p) : + (left_hd p = 0 <-> eval p < uweight bound n). + Proof. + destruct (small_divmod _ p Hsmall) as [Hdiv Hmod]. + pose proof Hsmall as Hsmalltl. apply eval_small in Hsmall. + apply small_left_tl, eval_small in Hsmalltl. rewrite Hdiv. + rewrite (Z.div_small_iff (eval p) (uweight bound n)) + by auto using uweight_nonzero. + split; [|intros; left; omega]. + let H := fresh "H" in intro H; destruct H; [|omega]. + omega. + Qed. + *) + + Lemma map2_zselect n cond x y : + Tuple.map2 (n:=n) (Z.zselect cond) x y = if dec (cond = 0) then x else y. + Proof. + unfold Z.zselect. + break_innermost_match; Z.ltb_to_lt; subst; try omega; + [ rewrite Tuple.map2_fst, Tuple.map_id + | rewrite Tuple.map2_snd, Tuple.map_id ]; + reflexivity. + Qed. + + Lemma eval_conditional_sub_nz n (p:T (S n)) (q:T n) + (n_nonzero: (n <> 0)%nat) (psmall : small p) (qsmall : small q): + 0 <= eval p < eval q + uweight bound n -> + eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0). + Proof. + cbv [conditional_sub conditional_sub_cps]. intros. pose_all. + repeat autounfold. apply eval_small in qsmall. + pose proof psmall; apply eval_small in psmall. + cbv [eval] in *. autorewrite with uncps push_id push_basesystem_eval. + rewrite map2_zselect. + let H := fresh "H" in let X := fresh "P" in + match goal with |- context [?x / ?y] => + pose proof (div_nonzero_neg_iff x y) end; + repeat match type of H with ?P -> _ => + assert P as X by omega; specialize (H X); + clear X end. + + break_match; + repeat match goal with + | _ => progress cbv [eval] + | H : (_ <=? _) = true |- _ => apply Z.leb_le in H + | H : (_ <=? _) = false |- _ => apply Z.leb_gt in H + | _ => rewrite eval_drop_high by auto using B.Positional.small_sat_sub + | _ => (rewrite eval_join0 in * ) + | _ => progress autorewrite with uncps push_id push_basesystem_eval + | _ => repeat rewrite Z.mod_small; omega + | _ => omega + end. + Admitted. + + Lemma eval_conditional_sub n (p:T (S n)) (q:T n) + (psmall : small p) (qsmall : small q) : + 0 <= eval p < eval q + uweight bound n -> + eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0). + Proof. + destruct n; [|solve[auto using eval_conditional_sub_nz]]. + repeat match goal with + | _ => progress (intros; cbv [T tuple tuple'] in p, q) + | q : unit |- _ => destruct q + | _ => progress (cbv [conditional_sub conditional_sub_cps eval] in * ) + | _ => progress autounfold + | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * ) + | _ => (rewrite uweight_0 in * ) + | _ => assert (p = 0) by omega; subst p; break_match; ring + end. + Qed. + + Lemma small_conditional_sub n (p:T (S n)) (q:T n) + (psmall : small p) (qsmall : small q) : + 0 <= eval p < eval q + uweight bound n -> + small (conditional_sub p q). + Admitted. + + Lemma eval_scmul n a v : small v -> 0 <= a < bound -> + eval (@scmul n a v) = a * eval v. + Proof. + intro Hsmall. pose_all. apply eval_small in Hsmall. + intros. cbv [scmul scmul_cps eval] in *. repeat autounfold. + autorewrite with uncps push_id push_basesystem_eval. + rewrite uweight_0, Z.mul_1_l. apply Z.mod_small. + split; [solve[Z.zero_bounds]|]. cbv [uweight] in *. + rewrite !Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg. + apply Z.mul_lt_mono_nonneg; omega. + Qed. + + Lemma small_scmul n a v : small (@scmul n a v). + Proof. + cbv [scmul scmul_cps eval] in *. repeat autounfold. + autorewrite with uncps push_id push_basesystem_eval. + apply small_compact. + Qed. + + (* TODO : move to tuple *) + Lemma from_list_tl {A n} (ls : list A) H H': + from_list n (List.tl ls) H = tl (from_list (S n) ls H'). + Proof. + induction ls; distr_length. simpl List.tl. + rewrite from_list_cons, tl_append, <-!(from_list_default_eq a ls). + reflexivity. + Qed. + + Lemma small_hd n p : @small (S n) p -> 0 <= hd p < bound. + Proof. + cbv [small]. let H := fresh "H" in intro H; apply H. + rewrite (subst_append p). rewrite to_list_append, hd_append. + apply in_eq. + Qed. + + + Lemma eval_div n p : small p -> eval (fst (@divmod n p)) = eval p / bound. + Proof. + cbv [divmod divmod_cps eval]. intros. + autorewrite with uncps push_id cancel_pair. + rewrite (subst_append p) at 2. + rewrite uweight_eval_step. rewrite hd_append, tl_append. + rewrite Z.div_add' by omega. rewrite Z.div_small by auto using small_hd. + ring. + Qed. + + Lemma eval_mod n p : small p -> snd (@divmod n p) = eval p mod bound. + Proof. + cbv [divmod divmod_cps eval]. intros. + autorewrite with uncps push_id cancel_pair. + rewrite (subst_append p) at 2. + rewrite uweight_eval_step, Z.mod_add'_full, hd_append. + rewrite Z.mod_small by auto using small_hd. reflexivity. + Qed. + + Lemma small_div n v : small v -> small (fst (@divmod n v)). + Admitted. + + End Proofs. +End API. +Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id add_S1_id add_S2_id sub_then_maybe_add_id conditional_sub_id : uncps.
\ No newline at end of file diff --git a/src/Arithmetic/Saturated/MulSplit.v b/src/Arithmetic/Saturated/MulSplit.v new file mode 100644 index 000000000..45f37ef56 --- /dev/null +++ b/src/Arithmetic/Saturated/MulSplit.v @@ -0,0 +1,73 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. + +(* Defines bignum multiplication using a two-output multiply operation. *) +Module B. + Module Associational. + Section Associational. + Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) + {mul_split_mod : forall s x y, + fst (mul_split s x y) = (x * y) mod s} + {mul_split_div : forall s x y, + snd (mul_split s x y) = (x * y) / s} + . + + Definition sat_multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) := + dlet xy := mul_split s (snd t) (snd t') in + f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil). + + Definition sat_multerm s t t' := sat_multerm_cps s t t' id. + Lemma sat_multerm_id s t t' T f : + @sat_multerm_cps s t t' T f = f (sat_multerm s t t'). + Proof. reflexivity. Qed. + Hint Opaque sat_multerm : uncps. + Hint Rewrite sat_multerm_id : uncps. + + Definition sat_mul_cps s (p q : list B.limb) {T} (f : list B.limb -> T) := + flat_map_cps (fun t => @flat_map_cps _ _ (sat_multerm_cps s t) q) p f. + + Definition sat_mul s p q := sat_mul_cps s p q id. + Lemma sat_mul_id s p q T f : @sat_mul_cps s p q T f = f (sat_mul s p q). + Proof. cbv [sat_mul sat_mul_cps]. autorewrite with uncps. reflexivity. Qed. + Hint Opaque sat_mul : uncps. + Hint Rewrite sat_mul_id : uncps. + + Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0): + B.Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * B.Associational.eval q. + Proof. + cbv [sat_multerm sat_multerm_cps Let_In]; induction q; + repeat match goal with + | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * ) + | _ => progress simpl flat_map + | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div + | _ => rewrite Z.mod_eq by assumption + | _ => ring_simplify; omega + end. + Qed. + Hint Rewrite eval_map_sat_multerm using (omega || assumption) + : push_basesystem_eval. + + Lemma eval_sat_mul s p q (s_nonzero:s<>0): + B.Associational.eval (sat_mul s p q) = B.Associational.eval p * B.Associational.eval q. + Proof. + cbv [sat_mul sat_mul_cps]; induction p; [reflexivity|]. + repeat match goal with + | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * ) + | _ => progress simpl flat_map + | _ => rewrite IHp + | _ => progress change (fun x => sat_multerm_cps s a x id) with (sat_multerm s a) + | _ => ring_simplify; omega + end. + Qed. + Hint Rewrite eval_sat_mul : push_basesystem_eval. + End Associational. + End Associational. +End B. +Hint Opaque B.Associational.sat_mul B.Associational.sat_multerm : uncps. +Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id : uncps. +Hint Rewrite @B.Associational.eval_sat_mul @B.Associational.eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval. + diff --git a/src/Arithmetic/Saturated/UniformWeight.v b/src/Arithmetic/Saturated/UniformWeight.v new file mode 100644 index 000000000..51eb71b0b --- /dev/null +++ b/src/Arithmetic/Saturated/UniformWeight.v @@ -0,0 +1,71 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Arithmetic.Saturated.Core. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.LetIn Crypto.Util.Tuple. +Local Notation "A ^ n" := (tuple A n) : type_scope. + +Section UniformWeight. + Context (bound : Z) {bound_pos : bound > 0}. + + Definition uweight : nat -> Z := fun i => bound ^ Z.of_nat i. + Lemma uweight_0 : uweight 0%nat = 1. Proof. reflexivity. Qed. + Lemma uweight_positive i : uweight i > 0. + Proof. apply Z.lt_gt, Z.pow_pos_nonneg; omega. Qed. + Lemma uweight_nonzero i : uweight i <> 0. + Proof. auto using Z.positive_is_nonzero, uweight_positive. Qed. + Lemma uweight_multiples i : uweight (S i) mod uweight i = 0. + Proof. apply Z.mod_same_pow; rewrite Nat2Z.inj_succ; omega. Qed. + Lemma uweight_divides i : uweight (S i) / uweight i > 0. + Proof. + cbv [uweight]. rewrite <-Z.pow_sub_r by (rewrite ?Nat2Z.inj_succ; omega). + apply Z.lt_gt, Z.pow_pos_nonneg; rewrite ?Nat2Z.inj_succ; omega. + Qed. + + (* TODO : move to Positional *) + Lemma eval_from_eq {n} (p:Z^n) wt offset : + (forall i, wt i = uweight (i + offset)) -> + B.Positional.eval wt p = B.Positional.eval_from uweight offset p. + Proof. cbv [B.Positional.eval_from]. auto using B.Positional.eval_wt_equiv. Qed. + + Lemma uweight_eval_from {n} (p:Z^n): forall offset, + B.Positional.eval_from uweight offset p = uweight offset * B.Positional.eval uweight p. + Proof. + induction n; intros; cbv [B.Positional.eval_from]; + [|rewrite (subst_append p)]; + repeat match goal with + | _ => destruct p + | _ => rewrite B.Positional.eval_unit; [ ] + | _ => rewrite B.Positional.eval_step; [ ] + | _ => rewrite IHn; [ ] + | _ => rewrite eval_from_eq with (offset0:=S offset) + by (intros; f_equal; omega) + | _ => rewrite eval_from_eq with + (wt:=fun i => uweight (S i)) (offset0:=1%nat) + by (intros; f_equal; omega) + | _ => ring + end. + repeat match goal with + | _ => cbv [uweight]; progress autorewrite with natsimplify + | _ => progress (rewrite ?Nat2Z.inj_succ, ?Nat2Z.inj_0, ?Z.pow_0_r) + | _ => rewrite !Z.pow_succ_r by (try apply Nat2Z.is_nonneg; omega) + | _ => ring + end. + Qed. + + Lemma uweight_eval_step {n} (p:Z^S n): + B.Positional.eval uweight p = hd p + bound * B.Positional.eval uweight (tl p). + Proof. + rewrite (subst_append p) at 1; rewrite B.Positional.eval_step. + rewrite eval_from_eq with (offset := 1%nat) by (intros; f_equal; omega). + rewrite uweight_eval_from. cbv [uweight]; rewrite Z.pow_0_r, Z.pow_1_r. + ring. + Qed. + + Definition small {n} (p : Z^n) : Prop := + forall x, In x (to_list _ p) -> 0 <= x < bound. + +End UniformWeight.
\ No newline at end of file diff --git a/src/Arithmetic/Saturated/Wrappers.v b/src/Arithmetic/Saturated/Wrappers.v new file mode 100644 index 000000000..e1da74e60 --- /dev/null +++ b/src/Arithmetic/Saturated/Wrappers.v @@ -0,0 +1,53 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Arithmetic.Saturated.Core. +Require Import Crypto.Arithmetic.Saturated.MulSplit. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.Tuple. +Local Notation "A ^ n" := (tuple A n) : type_scope. + +(* Define wrapper definitions that use Columns representation +internally but with input and output in Positonal representation.*) +Module Columns. + Section Wrappers. + Context (weight : nat->Z). + + Definition add_cps {n1 n2 n3} (p : Z^n1) (q : Z^n2) + {T} (f : (Z*Z^n3)->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.to_associational_cps weight q + (fun Q => Columns.from_associational_cps weight n3 (P++Q) + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))). + + Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2) + {T} (f : (Z*Z^n3)->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.negate_snd_cps weight q + (fun nq => B.Positional.to_associational_cps weight nq + (fun Q => Columns.from_associational_cps weight n3 (P++Q) + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + + Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2) + {T} (f : (Z*Z^n3)->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.to_associational_cps weight q + (fun Q => B.Associational.sat_mul_cps (mul_split := Z.mul_split) s P Q + (fun PQ => Columns.from_associational_cps weight n3 PQ + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + + Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2) + {T} (f:_->T) := + B.Positional.select_cps mask cond q + (fun qq => add_cps (n3:=n3) p qq f). + + End Wrappers. +End Columns. +Hint Unfold + Columns.conditional_add_cps + Columns.add_cps + Columns.unbalanced_sub_cps + Columns.mul_cps.
\ No newline at end of file |