From 3ca227f1137e6a3b65bc33f5689e1c230d591595 Mon Sep 17 00:00:00 2001 From: Andres Erbsen Date: Tue, 8 Jan 2019 04:21:38 -0500 Subject: remove old pipeline --- src/Arithmetic/Core.v | 1431 -------------------- src/Arithmetic/CoreUnfolder.v | 400 ------ src/Arithmetic/Karatsuba.v | 228 ---- .../WordByWord/Abstract/Definition.v | 61 - .../WordByWord/Abstract/Dependent/Definition.v | 81 -- .../WordByWord/Abstract/Dependent/Proofs.v | 582 -------- .../WordByWord/Abstract/Proofs.v | 497 ------- .../MontgomeryReduction/WordByWord/Definition.v | 108 -- .../MontgomeryReduction/WordByWord/Proofs.v | 329 ----- src/Arithmetic/Saturated/AddSub.v | 285 ---- src/Arithmetic/Saturated/Core.v | 485 ------- src/Arithmetic/Saturated/CoreUnfolder.v | 97 -- src/Arithmetic/Saturated/Freeze.v | 145 -- src/Arithmetic/Saturated/FreezeUnfolder.v | 27 - src/Arithmetic/Saturated/MontgomeryAPI.v | 691 ---------- src/Arithmetic/Saturated/MulSplit.v | 100 -- src/Arithmetic/Saturated/MulSplitUnfolder.v | 45 - src/Arithmetic/Saturated/UniformWeight.v | 93 -- src/Arithmetic/Saturated/UniformWeightInstances.v | 34 - src/Arithmetic/Saturated/Wrappers.v | 68 - src/Arithmetic/Saturated/WrappersUnfolder.v | 45 - 21 files changed, 5832 deletions(-) delete mode 100644 src/Arithmetic/Core.v delete mode 100644 src/Arithmetic/CoreUnfolder.v delete mode 100644 src/Arithmetic/Karatsuba.v delete mode 100644 src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Definition.v delete mode 100644 src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v delete mode 100644 src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v delete mode 100644 src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Proofs.v delete mode 100644 src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v delete mode 100644 src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v delete mode 100644 src/Arithmetic/Saturated/AddSub.v delete mode 100644 src/Arithmetic/Saturated/Core.v delete mode 100644 src/Arithmetic/Saturated/CoreUnfolder.v delete mode 100644 src/Arithmetic/Saturated/Freeze.v delete mode 100644 src/Arithmetic/Saturated/FreezeUnfolder.v delete mode 100644 src/Arithmetic/Saturated/MontgomeryAPI.v delete mode 100644 src/Arithmetic/Saturated/MulSplit.v delete mode 100644 src/Arithmetic/Saturated/MulSplitUnfolder.v delete mode 100644 src/Arithmetic/Saturated/UniformWeight.v delete mode 100644 src/Arithmetic/Saturated/UniformWeightInstances.v delete mode 100644 src/Arithmetic/Saturated/Wrappers.v delete mode 100644 src/Arithmetic/Saturated/WrappersUnfolder.v (limited to 'src/Arithmetic') diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v deleted file mode 100644 index 48046d7e3..000000000 --- a/src/Arithmetic/Core.v +++ /dev/null @@ -1,1431 +0,0 @@ -(***** - -This file provides a generalized version of arithmetic with "mixed -radix" numerical systems. Later, parameters are entered into the -general functions, and they are partially evaluated until only runtime -basic arithmetic operations remain. - -CPS ---- - -Fuctions are written in continuation passing style (CPS). This means -that each operation is passed a "continuation" function, which it is -expected to call on its own output (like a callback). See the end of -this comment for a motivating example explaining why we do CPS, -despite a fair amount of resulting boilerplate code for each -operation. The code block for an operation called A would look like -this: - -``` -Definition A_cps x y {T} f : T := ... - -Definition A x y := A_cps x y id. -Lemma A_cps_id x y : forall {T} f, @A_cps x y T f = f (A x y). -Hint Opaque A : uncps. -Hint Rewrite A_cps_id : uncps. - -Lemma eval_A x y : eval (A x y) = ... -Hint Rewrite eval_A : push_basesystem_eval. -``` - -`A_cps` is the main, CPS-style definition of the operation (`f` is the -continuation function). `A` is the non-CPS version of `A_cps`, simply -defined by passing an identity function to `A_cps`. `A_cps_id` states -that we can replace the CPS version with the non-cps version. `eval_A` -is the actual correctness lemma for the operation, stating that it has -the correct arithmetic properties. In general, the middle block -containing `A` and `A_cps_id` is boring boilerplate and can be safely -ignored. - -HintDbs -------- - -+ `uncps` : Converts CPS operations to their non-CPS versions. -+ `push_basesystem_eval` : Contains all the correctness lemmas for - operations in this file, which are in terms of the `eval` function. - -Positional/Associational ------------------------- - -We represent mixed-radix numbers in a few different ways: - -+ "Positional" : a tuple of numbers and a weight function (nat->Z), -which is evaluated by multiplying the `i`th element of the tuple by -`weight i`, and then summing the products. -+ "Associational" : a list of pairs of numbers--the first is the -weight, the second is the runtime value. Evaluated by multiplying each -pair and summing the products. - -The associational representation is good for basic operations like -addition and multiplication; for addition, one can simply just append -two associational lists. But the end-result code should use the -positional representation (with each digit representing a machine -word). Since converting to and fro can be easily compiled away once -the weight function is known, we use associational to write most of -the operations and liberally convert back and forth to ensure correct -output. In particular, it is important to convert before carrying. - -Runtime Operations ------------------- - -Since some instances of e.g. Z.add or Z.mul operate on (compile-time) -weights, and some operate on runtime values, we need a way to -differentiate these cases before partial evaluation. We define a -runtime_scope to mark certain additions/multiplications as runtime -values, so they will not be unfolded during partial evaluation. For -instance, if we have: - -``` -Definition f (x y : Z * Z) := (fst x + fst y, (snd x + snd y)%RT). -``` - -then when we are partially evaluating `f`, we can easily exclude the -runtime operations (`cbv - [runtime_add]`) and prevent Coq from trying -to simplify the second addition. - - -Why CPS? --------- - -Let's suppose we want to add corresponding elements of two `list Z`s -(so on inputs `[1,2,3]` and `[2,3,1]`, we get `[3,5,4]`). We might -write our function like this : - -``` -Fixpoint add_lists (p q : list Z) := - match p, q with - | p0 :: p', q0 :: q' => - dlet sum := p0 + q0 in - sum :: add_lists p' q' - | _, _ => nil - end. -``` - -(Note : `dlet` is a notation for `Let_In`, which is just a dumb -wrapper for `let`. This allows us to `cbv - [Let_In]` if we want to -not simplify certain `let`s.) - -A CPS equivalent of `add_lists` would look like this: - -``` -Fixpoint add_lists_cps (p q : list Z) {T} (f:list Z->T) := - match p, q with - | p0 :: p', q0 :: q' => - dlet sum := p0 + q0 in - add_lists_cps p' q' (fun r => f (sum :: r)) - | _, _ => f nil - end. -``` - -Now let's try some partial evaluation. The expression we'll evaluate is: - -``` -Definition x := - (fun a0 a1 a2 b0 b1 b2 => - let r := add_lists [a0;a1;a2] [b0;b1;b2] in - let rr := add_lists r r in - add_lists rr rr). -``` - -Or, using `add_lists_cps`: - -``` -Definition y := - (fun a0 a1 a2 b0 b1 b2 => - add_lists_cps [a0;a1;a2] [b0;b1;b2] - (fun r => add_lists_cps r r - (fun rr => add_lists_cps rr rr id))). -``` - -If we run `Eval cbv -[Z.add] in x` and `Eval cbv -[Z.add] in y`, we get -identical output: - -``` -fun a0 a1 a2 b0 b1 b2 : Z => - [a0 + b0 + (a0 + b0) + (a0 + b0 + (a0 + b0)); - a1 + b1 + (a1 + b1) + (a1 + b1 + (a1 + b1)); - a2 + b2 + (a2 + b2) + (a2 + b2 + (a2 + b2))] -``` - -However, there are a lot of common subexpressions here--this is what -the `dlet` we put into the functions should help us avoid. Let's try -`Eval cbv -[Let_In Z.add] in x`: - -``` -fun a0 a1 a2 b0 b1 b2 : Z => - (fix add_lists (p q : list Z) {struct p} : - list Z := - match p with - | [] => [] - | p0 :: p' => - match q with - | [] => [] - | q0 :: q' => - dlet sum := p0 + q0 in - sum :: add_lists p' q' - end - end) - ((fix add_lists (p q : list Z) {struct p} : - list Z := - match p with - | [] => [] - | p0 :: p' => - match q with - | [] => [] - | q0 :: q' => - dlet sum := p0 + q0 in - sum :: add_lists p' q' - end - end) - (dlet sum := a0 + b0 in - sum - :: (dlet sum0 := a1 + b1 in - sum0 :: (dlet sum1 := a2 + b2 in - [sum1]))) - (dlet sum := a0 + b0 in - sum - :: (dlet sum0 := a1 + b1 in - sum0 :: (dlet sum1 := a2 + b2 in - [sum1])))) - ((fix add_lists (p q : list Z) {struct p} : - list Z := - match p with - | [] => [] - | p0 :: p' => - match q with - | [] => [] - | q0 :: q' => - dlet sum := p0 + q0 in - sum :: add_lists p' q' - end - end) - (dlet sum := a0 + b0 in - sum - :: (dlet sum0 := a1 + b1 in - sum0 :: (dlet sum1 := a2 + b2 in - [sum1]))) - (dlet sum := a0 + b0 in - sum - :: (dlet sum0 := a1 + b1 in - sum0 :: (dlet sum1 := a2 + b2 in - [sum1])))) -``` - -Not so great. Because the `dlet`s are stuck in the inner terms, we -can't simplify the expression very nicely. Let's try that on the CPS -version (`Eval cbv -[Let_In Z.add] in y`): - -``` -fun a0 a1 a2 b0 b1 b2 : Z => - dlet sum := a0 + b0 in - dlet sum0 := a1 + b1 in - dlet sum1 := a2 + b2 in - dlet sum2 := sum + sum in - dlet sum3 := sum0 + sum0 in - dlet sum4 := sum1 + sum1 in - dlet sum5 := sum2 + sum2 in - dlet sum6 := sum3 + sum3 in - dlet sum7 := sum4 + sum4 in - [sum5; sum6; sum7] -``` - -Isn't that lovely? Since we can push continuation functions "under" -the `dlet`s, we can end up with a nice, concise, simplified -expression. - -One might suggest that we could just inline the `dlet`s and do common -subexpression elimination. But some of our terms have so many `dlet`s -that inlining them all would make a term too huge to process in -reasonable time, so this is not really an option. - -*****) - -Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega. -Require Import Coq.ZArith.BinIntDef. -Local Open Scope Z_scope. - -Require Import Crypto.Algebra.Nsatz. -Require Import Crypto.Util.Decidable Crypto.Util.LetIn. -Require Import Crypto.Util.ListUtil Crypto.Util.Sigma. -Require Import Crypto.Util.CPSUtil Crypto.Util.Prod. -Require Import Crypto.Util.ZUtil.Modulo.PullPush. -Require Import Crypto.Util.ZUtil.Zselect. -Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.CPS. -Require Import Crypto.Arithmetic.PrimeFieldTheorems. -Require Import Crypto.Util.Tactics.BreakMatch. -Require Import Crypto.Util.Tactics.UniquePose. -Require Import Crypto.Util.Tactics.VM. -Require Import Crypto.Util.IdfunWithAlt. -Require Import Crypto.Util.Notations. - -Require Import Coq.Lists.List. Import ListNotations. -Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple. - -Local Ltac prove_id := - repeat match goal with - | _ => progress intros - | _ => progress simpl - | _ => progress cbv [Let_In] - | _ => progress (autorewrite with uncps push_id in * ) - | _ => break_innermost_match_step - | _ => contradiction - | _ => reflexivity - | _ => nsatz - | _ => solve [auto] - end. - -Create HintDb push_basesystem_eval discriminated. -Local Ltac prove_eval := - repeat match goal with - | _ => progress intros - | _ => progress simpl - | _ => progress cbv [Let_In] - | _ => progress Z.ltb_to_lt - | _ => progress (autorewrite with push_basesystem_eval uncps push_id cancel_pair in * ) - | _ => break_innermost_match_step - | _ => split - | H : _ /\ _ |- _ => destruct H - | H : Some _ = Some _ |- _ => progress (inversion H; subst) - | _ => discriminate - | _ => reflexivity - | _ => nsatz - end. - -Definition mod_eq (m:positive) a b := a mod m = b mod m. -Global Instance mod_eq_equiv m : RelationClasses.Equivalence (mod_eq m). -Proof. constructor; congruence. Qed. -Definition mod_eq_dec m a b : {mod_eq m a b} + {~ mod_eq m a b} - := Z.eq_dec _ _. -Lemma mod_eq_Z2F_iff m a b : - mod_eq m a b <-> Logic.eq (F.of_Z m a) (F.of_Z m b). -Proof. rewrite <-F.eq_of_Z_iff; reflexivity. Qed. - -Delimit Scope runtime_scope with RT. - -Definition runtime_mul := Z.mul. -Global Notation "a * b" := (runtime_mul a%RT b%RT) : runtime_scope. -Definition runtime_add := Z.add. -Global Notation "a + b" := (runtime_add a%RT b%RT) : runtime_scope. -Definition runtime_opp := Z.opp. -Global Notation "- a" := (runtime_opp a%RT) : runtime_scope. -Definition runtime_and := Z.land. -Global Notation "a &' b" := (runtime_and a%RT b%RT) : runtime_scope. -Definition runtime_shr := Z.shiftr. -Global Notation "a >> b" := (runtime_shr a%RT b%RT) : runtime_scope. -Definition runtime_lor := Z.lor. -Global Arguments runtime_lor (_ _)%RT. - -Ltac cbv_runtime := cbv beta delta [runtime_add runtime_and runtime_lor runtime_mul runtime_opp runtime_shr]. - -Module B. - Definition limb := (Z*Z)%type. (* position coefficient and run-time value *) - Module Associational. - Definition eval (p:list limb) : Z := - List.fold_right Z.add 0%Z (List.map (fun t => fst t * snd t) p). - - Lemma eval_nil : eval nil = 0. Proof. reflexivity. Qed. - Lemma eval_cons p q : eval (p::q) = (fst p) * (snd p) + eval q. Proof. reflexivity. Qed. - Lemma eval_app p q: eval (p++q) = eval p + eval q. - Proof. induction p; simpl eval; rewrite ?eval_nil, ?eval_cons; nsatz. Qed. - Hint Rewrite eval_nil eval_cons eval_app : push_basesystem_eval. - - Definition multerm (t t' : limb) : limb := - (fst t * fst t', (snd t * snd t')%RT). - Lemma eval_map_multerm (a:limb) (q:list limb) - : eval (List.map (multerm a) q) = fst a * snd a * eval q. - Proof. - induction q; cbv [multerm]; simpl List.map; - autorewrite with push_basesystem_eval cancel_pair; nsatz. - Qed. Hint Rewrite eval_map_multerm : push_basesystem_eval. - - Definition mul_cps (p q:list limb) {T} (f : list limb->T) := - flat_map_cps (fun t => @map_cps _ _ (multerm t) q) p f. - - Definition mul (p q:list limb) := mul_cps p q id. - Lemma mul_cps_id p q: forall {T} f, @mul_cps p q T f = f (mul p q). - Proof. cbv [mul_cps mul]; prove_id. Qed. - Hint Opaque mul : uncps. - Hint Rewrite mul_cps_id : uncps. - - Lemma eval_mul p q: eval (mul p q) = eval p * eval q. - Proof. cbv [mul mul_cps]; induction p; prove_eval. Qed. - Hint Rewrite eval_mul : push_basesystem_eval. - - Section split_cps. - Context (s:Z) {T : Type}. - - Fixpoint split_cps (xs:list limb) - (f :list limb*list limb->T) := - match xs with - | nil => f (nil, nil) - | cons x xs' => - split_cps xs' - (fun sxs' => - Z.eqb_cps (fst x mod s) 0 - (fun b => - if b - then f (fst sxs', cons (fst x / s, snd x) (snd sxs')) - else f (cons x (fst sxs'), snd sxs'))) - end. - End split_cps. - - Definition split s xs := split_cps s xs id. - Lemma split_cps_id s p: forall {T} f, - @split_cps s T p f = f (split s p). - Proof. - induction p as [|?? IHp]; - repeat match goal with - | _ => rewrite IHp - | _ => progress (cbv [split]; prove_id) - end. - Qed. - Hint Opaque split : uncps. - Hint Rewrite split_cps_id : uncps. - - Lemma eval_split s p (s_nonzero:s<>0): - eval (fst (split s p)) + s*eval (snd (split s p)) = eval p. - Proof. - cbv [split]; induction p; prove_eval. - match goal with - H:_ |- _ => - unique pose proof (Z_div_exact_full_2 _ _ s_nonzero H) - end; nsatz. - Qed. Hint Rewrite @eval_split using auto : push_basesystem_eval. - - Definition reduce_cps (s:Z) (c:list limb) (p:list limb) - {T} (f : list limb->T) := - split_cps s p - (fun ab => mul_cps c (snd ab) - (fun rr =>f (fst ab ++ rr))). - - Definition reduce s c p := reduce_cps s c p id. - Lemma reduce_cps_id s c p {T} f: - @reduce_cps s c p T f = f (reduce s c p). - Proof. cbv [reduce_cps reduce]; prove_id. Qed. - Hint Opaque reduce : uncps. - Hint Rewrite reduce_cps_id : uncps. - - Lemma reduction_rule a b s c m (m_eq:Z.pos m = s - c): - (a + s * b) mod m = (a + c * b) mod m. - Proof. - rewrite m_eq. pose proof (Pos2Z.is_pos m). - replace (a + s * b) with ((a + c*b) + b*(s-c)) by ring. - rewrite Z.add_mod, Z_mod_mult, Z.add_0_r, Z.mod_mod by omega. - trivial. - Qed. - Lemma eval_reduce s c p (s_nonzero:s<>0) m (m_eq : Z.pos m = s - eval c) : - mod_eq m (eval (reduce s c p)) (eval p). - Proof. - cbv [reduce reduce_cps mod_eq]; prove_eval. - erewrite <-reduction_rule by eauto; prove_eval. - Qed. - Hint Rewrite eval_reduce using (omega || assumption) : push_basesystem_eval. - (* Why TF does this hint get picked up outside the section (while other eval_ hints do not?) *) - - - Definition negate_snd_cps (p:list limb) {T} (f:list limb ->T) := - map_cps (fun cx => (fst cx, (-snd cx)%RT)) p f. - - Definition negate_snd p := negate_snd_cps p id. - Lemma negate_snd_id p {T} f : @negate_snd_cps p T f = f (negate_snd p). - Proof. cbv [negate_snd_cps negate_snd]; prove_id. Qed. - Hint Opaque negate_snd : uncps. - Hint Rewrite negate_snd_id : uncps. - - Lemma eval_negate_snd p : eval (negate_snd p) = - eval p. - Proof. - cbv [negate_snd_cps negate_snd]; induction p; prove_eval. - Qed. Hint Rewrite eval_negate_snd : push_basesystem_eval. - - Section Carries. - Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}. - Let modulo x y := modulo_cps _ x y id. - Let div x y := div_cps _ x y id. - Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)} - {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}. - Context {div_mod : forall a b:Z, b <> 0 -> - a = b * (div a b) + modulo a b}. - Hint Rewrite modulo_cps_id div_cps_id : uncps. - - Definition carryterm_cps (w fw:Z) (t:limb) {T} (f:list limb->T) := - Z.eqb_cps (fst t) w (fun eqb => - if eqb - then dlet t2 := snd t in - div_cps _ t2 fw (fun d2 => - modulo_cps _ t2 fw (fun m2 => - dlet d2 := d2 in - dlet m2 := m2 in - f ((w*fw, d2) :: (w, m2) :: @nil limb))) - else f [t]). - - Definition carryterm w fw t := carryterm_cps w fw t id. - Lemma carryterm_cps_id w fw t {T} f : - @carryterm_cps w fw t T f - = f (@carryterm w fw t). - Proof using div_cps_id modulo_cps_id. - cbv [carryterm_cps carryterm Let_In]; prove_id. - Qed. - Hint Opaque carryterm : uncps. - Hint Rewrite carryterm_cps_id : uncps. - - - Lemma eval_carryterm w fw (t:limb) (fw_nonzero:fw<>0): - eval (carryterm w fw t) = eval [t]. - Proof using Type*. - cbv [carryterm_cps carryterm Let_In]; prove_eval. - specialize (div_mod (snd t) fw fw_nonzero). - nsatz. - Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval. - - Definition carry_cps (w fw:Z) (p:list limb) {T} (f:list limb->T) := - flat_map_cps (carryterm_cps w fw) p f. - - Definition carry w fw p := carry_cps w fw p id. - Lemma carry_cps_id w fw p {T} f: - @carry_cps w fw p T f = f (carry w fw p). - Proof using div_cps_id modulo_cps_id. - cbv [carry_cps carry]; prove_id. - Qed. - Hint Opaque carry : uncps. - Hint Rewrite carry_cps_id : uncps. - - Lemma eval_carry w fw p (fw_nonzero:fw<>0): - eval (carry w fw p) = eval p. - Proof using Type*. cbv [carry_cps carry]; induction p; prove_eval. Qed. - Hint Rewrite eval_carry using auto : push_basesystem_eval. - End Carries. - - End Associational. - - Ltac div_mod_cps_t := - intros; autorewrite with uncps push_id; try reflexivity. - - Hint Rewrite - @Associational.reduce_cps_id - @Associational.split_cps_id - @Associational.mul_cps_id : uncps. - Hint Rewrite - @Associational.carry_cps_id - @Associational.carryterm_cps_id - using div_mod_cps_t : uncps. - - - Module Positional. - Section Positional. - Import Associational. - Context (weight : nat -> Z) (* [weight i] is the weight of position [i] *) - (weight_0 : weight 0%nat = 1%Z) - (weight_nonzero : forall i, weight i <> 0). - - (** Converting from positional to associational *) - Definition to_associational_cps {n:nat} (xs:tuple Z n) - {T} (f:list limb->T) := - map_cps weight (seq 0 n) - (fun r => - to_list_cps n xs (fun rr => combine_cps r rr f)). - - Definition to_associational {n} xs := - @to_associational_cps n xs _ id. - Lemma to_associational_cps_id {n} x {T} f: - @to_associational_cps n x T f = f (to_associational x). - Proof using Type. cbv [to_associational_cps to_associational]; prove_id. Qed. - Hint Opaque to_associational : uncps. - Hint Rewrite @to_associational_cps_id : uncps. - - Definition eval {n} x := - @to_associational_cps n x _ Associational.eval. - - Lemma eval_single (x:Z) : eval (n:=1) x = weight 0%nat * x. - Proof. cbv - [Z.mul Z.add]. ring. Qed. - - Lemma eval_unit : eval (n:=0) tt = 0. - Proof. reflexivity. Qed. - Hint Rewrite eval_unit eval_single : push_basesystem_eval. - - Lemma eval_to_associational {n} x : - Associational.eval (@to_associational n x) = eval x. - Proof using Type. - cbv [to_associational_cps eval to_associational]; prove_eval. - Qed. Hint Rewrite @eval_to_associational : push_basesystem_eval. - - (** (modular) equality that tolerates redundancy **) - Definition eq {sz} m (a b : tuple Z sz) : Prop := - mod_eq m (eval a) (eval b). - - (** Converting from associational to positional *) - - Definition zeros n : tuple Z n := Tuple.repeat 0 n. - Lemma eval_zeros n : eval (zeros n) = 0. - Proof using Type. - cbv [eval Associational.eval to_associational_cps zeros]. - pose proof (seq_length n 0). generalize dependent (seq 0 n). - intro xs; revert n; induction xs as [|?? IHxs]; intros n H; - [autorewrite with uncps; reflexivity|]. - destruct n as [|n]; [distr_length|]. - specialize (IHxs n). autorewrite with uncps in *. - rewrite !@Tuple.to_list_repeat in *. - simpl List.repeat. rewrite map_cons, combine_cons, map_cons. - simpl fold_right. rewrite IHxs by distr_length. ring. - Qed. Hint Rewrite eval_zeros : push_basesystem_eval. - - Definition add_to_nth_cps {n} i x t {T} (f:tuple Z n->T) := - @on_tuple_cps _ _ 0 (update_nth_cps i (runtime_add x)) n n t _ f. - - Definition add_to_nth {n} i x t := @add_to_nth_cps n i x t _ id. - Lemma add_to_nth_cps_id {n} i x xs {T} f: - @add_to_nth_cps n i x xs T f = f (add_to_nth i x xs). - Proof using weight. - cbv [add_to_nth_cps add_to_nth]; erewrite !on_tuple_cps_correct - by (intros; autorewrite with uncps; reflexivity); prove_id. - Unshelve. - intros; subst. autorewrite with uncps push_id. distr_length. - Qed. - Hint Opaque add_to_nth : uncps. - Hint Rewrite @add_to_nth_cps_id : uncps. - - Lemma eval_add_to_nth {n} (i:nat) (x:Z) (H:(i progress (apply Zminus_eq; ring_simplify) - | _ => progress autorewrite with push_basesystem_eval cancel_pair distr_length - | _ => progress rewrite <-?ListUtil.map_nth_default_always, ?map_fst_combine, ?List.firstn_all2, ?ListUtil.map_nth_default_always, ?nth_default_seq_inbouns, ?plus_O_n - end; trivial; lia. - Unshelve. - intros; subst. autorewrite with uncps push_id. distr_length. - Qed. Hint Rewrite @eval_add_to_nth using omega : push_basesystem_eval. - - Section place_cps. - Context {T : Type}. - - Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) := - Z.eqb_cps (fst t mod weight i) 0 (fun eqb => - if eqb - then f (i, let c := fst t / weight i in (c * snd t)%RT) - else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end). - End place_cps. - - Definition place t i := place_cps t i id. - Lemma place_cps_id t i {T} f : - @place_cps T t i f = f (place t i). - Proof using Type. cbv [place]; induction i; prove_id. Qed. - Hint Opaque place : uncps. - Hint Rewrite place_cps_id : uncps. - - Lemma place_cps_in_range (t:limb) (n:nat) - : (fst (place_cps t n id) < S n)%nat. - Proof using Type. induction n; simpl; cbv [Z.eqb_cps]; break_match; simpl; omega. Qed. - Lemma weight_place_cps t i - : weight (fst (place_cps t i id)) * snd (place_cps t i id) - = fst t * snd t. - Proof using Type*. - induction i; cbv [id]; simpl place_cps; cbv [Z.eqb_cps]; break_match; - Z.ltb_to_lt; - autorewrite with cancel_pair; - try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end; - nsatz || auto. - Qed. - - Definition from_associational_cps n (p:list limb) - {T} (f:tuple Z n->T):= - fold_right_cps2 - (fun t st T' f' => - place_cps t (pred n) - (fun p=> add_to_nth_cps (fst p) (snd p) st f')) - (zeros n) p f. - - Definition from_associational n p := from_associational_cps n p id. - Lemma from_associational_cps_id {n} p {T} f: - @from_associational_cps n p T f = f (from_associational n p). - Proof using Type. - cbv [from_associational_cps from_associational]; prove_id. - Qed. - Hint Opaque from_associational : uncps. - Hint Rewrite @from_associational_cps_id : uncps. - - Lemma eval_from_associational {n} p (n_nonzero:n<>O): - eval (from_associational n p) = Associational.eval p. - Proof using Type*. - cbv [from_associational_cps from_associational]; induction p; - [|pose proof (place_cps_in_range a (pred n))]; prove_eval. - cbv [place]; rewrite weight_place_cps. nsatz. - Qed. - Hint Rewrite @eval_from_associational using omega - : push_basesystem_eval. - - - Section Wrappers. - (* Simple wrappers for Associational definitions; convert to - associational, do the operation, convert back. *) - - Definition add_cps {n} (p q : tuple Z n) {T} (f:tuple Z n->T) := - to_associational_cps p - (fun P => to_associational_cps q - (fun Q => from_associational_cps n (P++Q) f)). - - Definition mul_cps {n m} (p q : tuple Z n) {T} (f:tuple Z m->T) := - to_associational_cps p - (fun P => to_associational_cps q - (fun Q => Associational.mul_cps P Q - (fun PQ => from_associational_cps m PQ f))). - - Definition reduce_cps {m n} (s:Z) (c:list B.limb) (p : tuple Z m) - {T} (f:tuple Z n->T) := - to_associational_cps p - (fun P => Associational.reduce_cps s c P - (fun R => from_associational_cps n R f)). - - Definition negate_snd_cps {n} (p : tuple Z n) - {T} (f:tuple Z n->T) := - to_associational_cps p - (fun P => Associational.negate_snd_cps P - (fun R => from_associational_cps n R f)). - - Definition split_cps {n m1 m2} (s:Z) (p : tuple Z n) - {T} (f:(tuple Z m1 * tuple Z m2) -> T) := - to_associational_cps p - (fun P => Associational.split_cps s P - (fun split_P => - from_associational_cps m1 (fst split_P) - (fun m1_P => - from_associational_cps m2 (snd split_P) - (fun m2_P => - f (m1_P, m2_P))))). - - Definition scmul_cps {n} (x : Z) (p: tuple Z n) - {T} (f:tuple Z n->T) := - to_associational_cps p - (fun P => Associational.mul_cps P [(1, x)] - (fun R => from_associational_cps n R f)). - - (* This version of sub does not add balance; bounds must be - carefully handled. *) - Definition unbalanced_sub_cps {n} (p q: tuple Z n) - {T} (f:tuple Z n->T) := - to_associational_cps p - (fun P => to_associational_cps q - (fun Q => Associational.negate_snd_cps Q - (fun negQ => from_associational_cps n (P ++ negQ) f))). - - End Wrappers. - Hint Unfold - Positional.add_cps - Positional.mul_cps - Positional.reduce_cps - Positional.negate_snd_cps - Positional.split_cps - Positional.scmul_cps - Positional.unbalanced_sub_cps - . - - Section Carries. - Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}. - Let modulo x y := modulo_cps _ x y id. - Let div x y := div_cps _ x y id. - Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)} - {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}. - Context {div_mod : forall a b:Z, b <> 0 -> - a = b * (div a b) + modulo a b}. - Hint Rewrite modulo_cps_id div_cps_id : uncps. - - Definition carry_cps {n m} (index:nat) (p:tuple Z n) - {T} (f:tuple Z m->T) := - to_associational_cps p - (fun P => @Associational.carry_cps - modulo_cps div_cps - (weight index) - (weight (S index) / weight index) - P T - (fun R => from_associational_cps m R f)). - - Definition carry {n m} i p := @carry_cps n m i p _ id. - Lemma carry_cps_id {n m} i p {T} f: - @carry_cps n m i p T f = f (carry i p). - Proof. - cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity. - Qed. - Hint Opaque carry : uncps. Hint Rewrite @carry_cps_id : uncps. - - Lemma eval_carry {n m} i p: (n <> 0%nat) -> (m <> 0%nat) -> - weight (S i) / weight i <> 0 -> - eval (carry (n:=n) (m:=m) i p) = eval p. - Proof. - cbv [carry_cps carry]; intros. prove_eval. - rewrite @eval_carry by eauto. - apply eval_to_associational. - Qed. - Hint Rewrite @eval_carry : push_basesystem_eval. - - Definition carry_reduce_cps {n} - (s:Z) (c:list limb) (p : tuple Z n) - {T} (f: tuple Z n ->T) := - carry_cps (n:=n) (m:=S n) (pred n) p - (fun r => reduce_cps (m:=S n) (n:=n) s c r f). - Hint Unfold carry_reduce_cps. - - (* N.B. It is important to reverse [idxs] here. Like - [fold_right], [fold_right_cps2] is written such that the first - terms in the list are actually used last in the computation. For - example, running: - - `Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).` - - will produce [fun a b c d => (a + (b + (c + d)))].*) - Definition chained_carries_cps {n} (p:tuple Z n) (idxs : list nat) - {T} (f:tuple Z n->T) := - fold_right_cps2 carry_cps p (rev idxs) f. - - Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id. - Lemma chained_carries_id {n} p idxs : forall {T} f, - @chained_carries_cps n p idxs T f = f (chained_carries p idxs). - Proof using modulo_cps_id div_cps_id. - cbv [chained_carries_cps chained_carries]; prove_id. - Qed. - Hint Opaque chained_carries : uncps. - Hint Rewrite @chained_carries_id : uncps. - - Lemma eval_chained_carries {n} (p:tuple Z n) idxs : - (forall i, In i idxs -> weight (S i) / weight i <> 0) -> - eval (chained_carries p idxs) = eval p. - Proof using Type*. - cbv [chained_carries chained_carries_cps]; intros; - autorewrite with uncps push_id. - apply fold_right_invariant; [|intro; rewrite <-in_rev]; - destruct n; prove_eval; auto. - Qed. Hint Rewrite @eval_chained_carries : push_basesystem_eval. - - Definition chained_carries_reduce_cps_step {n} (s:Z) (c:list limb) {T} - (chained_carries_reduce_cps : forall (p:tuple Z n) (carry_chains : list (list nat)) (f : tuple Z n -> T), T) - (p : tuple Z n) (carry_chains : list (list nat)) - (f : tuple Z n -> T) - : T - := match carry_chains with - | nil => f p - | carry_chain :: nil - => chained_carries_cps - (n:=n) p carry_chain f - | carry_chain :: carry_chains - => chained_carries_cps - (n:=n) p carry_chain - (fun r => carry_reduce_cps (n:=n) s c r - (fun r' => chained_carries_reduce_cps r' carry_chains f)) - end. - Section chained_carries_reduce_cps. - Context {n:nat} (s:Z) (c:list limb) {T:Type}. - - Fixpoint chained_carries_reduce_cps - (p : tuple Z n) (carry_chains : list (list nat)) - (f : tuple Z n -> T) - : T - := @chained_carries_reduce_cps_step - n s c T - chained_carries_reduce_cps p carry_chains f. - End chained_carries_reduce_cps. - - Lemma step_chained_carries_reduce_cps {n} (s:Z) (c:list limb) {T} p carry_chain carry_chains (f : tuple Z n -> T) - : chained_carries_reduce_cps s c p (carry_chain :: carry_chains) f - = match length carry_chains with - | O => chained_carries_cps - (n:=n) p carry_chain f - | S _ - => chained_carries_cps - (n:=n) p carry_chain - (fun r => carry_reduce_cps (n:=n) s c r - (fun r' => chained_carries_reduce_cps s c r' carry_chains f)) - end. - Proof. - destruct carry_chains; reflexivity. - Qed. - - Definition chained_carries_reduce {n} (s:Z) (c:list limb) (p:tuple Z n) (carry_chains : list (list nat)) - : tuple Z n - := chained_carries_reduce_cps s c p carry_chains id. - - Lemma chained_carries_reduce_id {n} s c {T} p carry_chains f - : @chained_carries_reduce_cps n s c T p carry_chains f - = f (@chained_carries_reduce n s c p carry_chains). - Proof. - destruct carry_chains as [|carry_chain carry_chains]; [ reflexivity | ]. - cbv [chained_carries_reduce]. - revert p carry_chain; induction carry_chains as [|? carry_chains IHcarry_chains]; intros. - { simpl; repeat autounfold; autorewrite with uncps. reflexivity. } - { rewrite !step_chained_carries_reduce_cps. - simpl @length; cbv iota beta. - repeat autounfold; autorewrite with uncps. - rewrite !IHcarry_chains. - reflexivity. } - Qed. - Hint Opaque chained_carries_reduce : uncps. - Hint Rewrite @chained_carries_reduce_id : uncps. - - Lemma eval_chained_carries_reduce {n} (s:Z) (c:list limb) (p:tuple Z n) carry_chains - (Hn : n <> 0%nat) - (s_nonzero:s<>0) m (m_eq : Z.pos m = s - Associational.eval c) - (Hwt : weight (S (Init.Nat.pred n)) / weight (Init.Nat.pred n) <> 0) - : (List.fold_right - and - True - (List.map - (fun idxs - => forall i, In i idxs -> weight (S i) / weight i <> 0) - carry_chains)) -> - mod_eq m (eval (chained_carries_reduce s c p carry_chains)) (eval p). - Proof using Type*. - destruct carry_chains as [|carry_chain carry_chains]; [ reflexivity | ]. - cbv [chained_carries_reduce]. - revert p carry_chain; induction carry_chains as [|? carry_chains IHcarry_chains]; intros. - { cbn in *; prove_eval; auto. } - { rewrite !step_chained_carries_reduce_cps. - simpl @length; cbv iota beta. - repeat autounfold; autorewrite with uncps push_id push_basesystem_eval. - cbv [chained_carries_reduce]. - rewrite !IHcarry_chains by (cbn in *; tauto); clear IHcarry_chains. - cbn in * |- . - prove_eval; auto. } - Qed. - Hint Rewrite @eval_chained_carries_reduce using (omega || assumption) : push_basesystem_eval. - - (* Reverse of [eval]; translate from Z to basesystem by putting - everything in first digit and then carrying. This function, like - [eval], is not defined using CPS. *) - Definition encode {n} (x : Z) : tuple Z n := - chained_carries (from_associational n [(1,x)]) (seq 0 n). - Lemma eval_encode {n} x : (n <> 0%nat) -> - (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> - eval (@encode n x) = x. - Proof using Type*. cbv [encode]; intros; prove_eval; auto. Qed. - Hint Rewrite @eval_encode : push_basesystem_eval. - - End Carries. - Hint Unfold carry_reduce_cps. - - Section Subtraction. - Context {m n} {coef : tuple Z n} - {coef_mod : mod_eq m (eval coef) 0}. - - Definition sub_cps (p q : tuple Z n) {T} (f:tuple Z n->T):= - add_cps coef p - (fun cp => negate_snd_cps q - (fun _q => add_cps cp _q f)). - - Definition sub p q := sub_cps p q id. - Lemma sub_id p q {T} f : @sub_cps p q T f = f (sub p q). - Proof using Type. cbv [sub_cps sub]; autounfold; prove_id. Qed. - Hint Opaque sub : uncps. - Hint Rewrite sub_id : uncps. - - Lemma eval_sub p q : mod_eq m (eval (sub p q)) (eval p - eval q). - Proof using Type*. - cbv [sub sub_cps]; autounfold; destruct n; prove_eval. - transitivity (eval coef + (eval p - eval q)). - { apply f_equal2; ring. } - { cbv [mod_eq] in *; rewrite Z.add_mod_full, coef_mod, Z.add_0_l, Zmod_mod. reflexivity. } - Qed. - - Definition opp_cps (p : tuple Z n) {T} (f:tuple Z n->T):= - sub_cps (zeros n) p f. - End Subtraction. - - (* Lemmas about converting to/from F. Will be useful in proving - that basesystem is isomorphic to F.commutative_ring_modulo.*) - Section F. - Context {sz:nat} {sz_nonzero : sz<>0%nat} {m :positive}. - Context (weight_divides : forall i : nat, weight (S i) / weight i <> 0). - Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}. - Let modulo x y := modulo_cps _ x y id. - Let div x y := div_cps _ x y id. - Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)} - {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}. - Context {div_mod : forall a b:Z, b <> 0 -> - a = b * (div a b) + modulo a b}. - Hint Rewrite modulo_cps_id div_cps_id : uncps. - - Definition Fencode (x : F m) : tuple Z sz := - encode (div_cps:=div_cps) (modulo_cps:=modulo_cps) (F.to_Z x). - - Definition Fdecode (x : tuple Z sz) : F m := F.of_Z m (eval x). - - Lemma Fdecode_Fencode_id x : Fdecode (Fencode x) = x. - Proof using div_mod sz_nonzero weight_0 weight_divides weight_nonzero div_cps_id modulo_cps_id. - cbv [Fdecode Fencode]; rewrite @eval_encode by eauto. - apply F.of_Z_to_Z. - Qed. - - Lemma eq_Feq_iff a b : - Logic.eq (Fdecode a) (Fdecode b) <-> eq m a b. - Proof using Type. cbv [Fdecode]; rewrite <-F.eq_of_Z_iff; reflexivity. Qed. - End F. - - - End Positional. - Hint Rewrite eval_unit eval_single : push_basesystem_eval. - - (* Helper lemmas and definitions for [eval] that to be in a - separate section so the weight function can change. *) - Section EvalHelpers. - Lemma eval_step {n} (x:tuple Z n) : forall wt z, - eval wt (Tuple.append z x) = wt 0%nat * z + eval (fun i => wt (S i)) x. - Proof. - destruct n; [reflexivity|]. - intros; cbv [eval to_associational_cps]. - autorewrite with uncps. rewrite map_S_seq. reflexivity. - Qed. - - Lemma eval_left_append {n} : forall wt x xs, - eval wt (Tuple.left_append (n:=n) x xs) - = wt n * x + eval wt xs. - Proof. - induction n as [|n IHn]; intros wt x xs; try destruct xs; - unfold Tuple.left_append; fold @Tuple.left_append; - autorewrite with push_basesystem_eval; [ring|]. - rewrite (Tuple.subst_append xs), Tuple.hd_append, Tuple.tl_append. - rewrite !eval_step, IHn. ring. - Qed. - Hint Rewrite @eval_left_append : push_basesystem_eval. - - Lemma eval_wt_equiv {n} :forall wta wtb (x:tuple Z n), - (forall i, wta i = wtb i) -> eval wta x = eval wtb x. - Proof. - destruct n as [|n]; [reflexivity|]. - induction n as [|n IHn]; intros wta wtb x H; [rewrite !eval_single, H; reflexivity|]. - simpl tuple in *; destruct x. - change (t, z) with (Tuple.append (n:=S n) z t). - rewrite !eval_step. rewrite (H 0%nat). apply Group.cancel_left. - apply IHn; auto. - Qed. - - Definition eval_from {n} weight (offset:nat) (x : tuple Z n) : Z := - eval (fun i => weight (i+offset)%nat) x. - - Lemma eval_from_0 {n} wt x : @eval_from n wt 0 x = eval wt x. - Proof. cbv [eval_from]. auto using eval_wt_equiv. Qed. - End EvalHelpers. - - Section Select. - Context {weight : nat -> Z}. - - Definition select_cps {n} (mask cond:Z) (p:tuple Z n) - {T} (f:tuple Z n->T) := - dlet t := Z.zselect cond 0 mask in Tuple.map_cps (runtime_and t) p f. - - Definition select {n} mask cond p := @select_cps n mask cond p _ id. - Lemma select_id {n} mask cond p T f : - @select_cps n mask cond p T f = f (select mask cond p). - Proof. - cbv [select select_cps Let_In]; autorewrite with uncps push_id; - reflexivity. - Qed. - Hint Opaque select : uncps. - - Lemma map_and_0 {n} (p:tuple Z n) : Tuple.map (Z.land 0) p = zeros n. - Proof. - induction n as [|n IHn]; [destruct p; reflexivity | ]. - rewrite (Tuple.subst_append p), Tuple.map_append, Z.land_0_l, IHn. - reflexivity. - Qed. - - Lemma eval_select {n} mask cond x (H:Tuple.map (Z.land mask) x = x) : - B.Positional.eval weight (@select n mask cond x) = - if dec (cond = 0) then 0 else B.Positional.eval weight x. - Proof. - cbv [select select_cps Let_In]. - autorewrite with uncps push_id. - rewrite Z.zselect_correct; break_match. - { rewrite map_and_0. apply B.Positional.eval_zeros. } - { change runtime_and with Z.land. rewrite H; reflexivity. } - Qed. - - End Select. - - End Positional. - - Hint Unfold - Positional.add_cps - Positional.mul_cps - Positional.reduce_cps - Positional.carry_reduce_cps - Positional.negate_snd_cps - Positional.split_cps - Positional.scmul_cps - Positional.unbalanced_sub_cps - Positional.opp_cps - . - Hint Rewrite - @Associational.reduce_cps_id - @Associational.split_cps_id - @Associational.mul_cps_id - @Positional.from_associational_cps_id - @Positional.place_cps_id - @Positional.add_to_nth_cps_id - @Positional.to_associational_cps_id - @Positional.sub_id - @Positional.select_id - : uncps. - Hint Rewrite - @Associational.carry_cps_id - @Associational.carryterm_cps_id - @Positional.carry_cps_id - @Positional.chained_carries_id - @Positional.chained_carries_reduce_id - using div_mod_cps_t : uncps. - Hint Rewrite - @Associational.eval_mul - @Positional.eval_single - @Positional.eval_unit - @Positional.eval_to_associational - @Positional.eval_left_append - @Associational.eval_carry - @Associational.eval_carryterm - @Associational.eval_reduce - @Associational.eval_split - @Positional.eval_zeros - @Positional.eval_carry - @Positional.eval_from_associational - @Positional.eval_add_to_nth - @Positional.eval_chained_carries - @Positional.eval_chained_carries_reduce - @Positional.eval_sub - @Positional.eval_select - using (assumption || (div_mod_cps_t; auto) || vm_decide) : push_basesystem_eval. -End B. - -(* Modulo and div that do shifts if possible, otherwise normal mod/div *) -Section DivMod. - Definition modulo_cps {T} (a b : Z) (f : Z -> T) : T := - Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => - if eqb - then let x := (Z.ones (Z.log2 b)) in f (a &' x)%RT - else f (Z.modulo a b)). - - Definition div_cps {T} (a b : Z) (f : Z -> T) : T := - Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => - if eqb - then let x := Z.log2 b in f ((a >> x)%RT) - else f (Z.div a b)). - - Definition modulo (a b : Z) : Z := modulo_cps a b id. - Definition div (a b : Z) : Z := div_cps a b id. - - Lemma modulo_id {T} a b f - : @modulo_cps T a b f = f (modulo a b). - Proof. cbv [modulo_cps modulo]; autorewrite with uncps; break_match; reflexivity. Qed. - Hint Opaque modulo : uncps. - Hint Rewrite @modulo_id : uncps. - - Lemma div_id {T} a b f - : @div_cps T a b f = f (div a b). - Proof. cbv [div_cps div]; autorewrite with uncps; break_match; reflexivity. Qed. - Hint Opaque div : uncps. - Hint Rewrite @div_id : uncps. - - Lemma div_cps_correct {T} a b f : @div_cps T a b f = f (Z.div a b). - Proof. - cbv [div_cps Z.eqb_cps]; intros. break_match; try reflexivity. - rewrite Z.shiftr_div_pow2 by apply Z.log2_nonneg. - Z.ltb_to_lt; congruence. - Qed. - - Lemma modulo_cps_correct {T} a b f : @modulo_cps T a b f = f (Z.modulo a b). - Proof. - cbv [modulo_cps Z.eqb_cps]; intros. break_match; try reflexivity. - rewrite Z.land_ones by apply Z.log2_nonneg. - Z.ltb_to_lt; congruence. - Qed. - - Definition div_correct a b : div a b = Z.div a b := div_cps_correct a b id. - Definition modulo_correct a b : modulo a b = Z.modulo a b := modulo_cps_correct a b id. - - Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b. - Proof. - rewrite div_correct, modulo_correct; auto using Z.div_mod. - Qed. -End DivMod. - -Hint Opaque div modulo : uncps. -Hint Rewrite @div_id @modulo_id : uncps. - -Import B. - -Create HintDb basesystem_partial_evaluation_unfolder. - -Hint Unfold - id - Associational.eval - Associational.multerm - Associational.mul_cps - Associational.mul - Associational.split_cps - Associational.split - Associational.reduce_cps - Associational.reduce - Associational.negate_snd_cps - Associational.negate_snd - Associational.carryterm_cps - Associational.carryterm - Associational.carry_cps - Associational.carry - Positional.to_associational_cps - Positional.to_associational - Positional.eval - Positional.zeros - Positional.add_to_nth_cps - Positional.add_to_nth - Positional.place_cps - Positional.place - Positional.from_associational_cps - Positional.from_associational - Positional.carry_cps - Positional.carry - Positional.chained_carries_cps - Positional.chained_carries - Positional.chained_carries_reduce_cps_step - Positional.chained_carries_reduce_cps - Positional.chained_carries_reduce - Positional.encode - Positional.add_cps - Positional.mul_cps - Positional.reduce_cps - Positional.carry_reduce_cps - Positional.negate_snd_cps - Positional.split_cps - Positional.scmul_cps - Positional.unbalanced_sub_cps - Positional.sub_cps - Positional.sub - Positional.opp_cps - Positional.Fencode - Positional.Fdecode - Positional.eval_from - Positional.select_cps - Positional.select - modulo div modulo_cps div_cps - id_tuple_with_alt id_tuple'_with_alt id_tuple_with_alt_cps' - Z.add_get_carry_full Z.add_get_carry_full_cps - : basesystem_partial_evaluation_unfolder. - -Hint Unfold - B.limb ListUtil.sum ListUtil.sum_firstn - CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2 - Decidable.dec Decidable.dec_eq_Z - id_tuple_with_alt id_tuple'_with_alt id_tuple_with_alt_cps' - Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps Z.mul_split_cps' - : basesystem_partial_evaluation_unfolder. - - -Ltac basesystem_partial_evaluation_unfolder t := - eval - cbv - delta [ - (* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], [runtime_opp], [runtime_mul], [runtime_shr], or [runtime_and] *) - id - Positional.to_associational_cps Positional.to_associational - Positional.eval Positional.zeros Positional.add_to_nth_cps - Positional.add_to_nth Positional.place_cps Positional.place - Positional.from_associational_cps Positional.from_associational - Positional.carry_cps Positional.carry - Positional.chained_carries_cps Positional.chained_carries - Positional.chained_carries_reduce_cps - Positional.chained_carries_reduce - Positional.chained_carries_reduce_cps_step - Positional.sub_cps Positional.sub Positional.split_cps - Positional.scmul_cps Positional.unbalanced_sub_cps - Positional.negate_snd_cps Positional.add_cps Positional.opp_cps - Associational.eval Associational.multerm Associational.mul_cps - Associational.mul Associational.split_cps Associational.split - Associational.reduce_cps Associational.reduce - Associational.carryterm_cps Associational.carryterm - Associational.carry_cps Associational.carry - Associational.negate_snd_cps Associational.negate_snd div modulo - id_tuple_with_alt id_tuple'_with_alt id_tuple_with_alt_cps' - Z.add_get_carry_full Z.add_get_carry_full_cps - ] in t. - -Ltac pattern_strip t := - let t := (eval pattern @Let_In, - @runtime_mul, @runtime_add, @runtime_opp, @runtime_shr, @runtime_and, @runtime_lor, - @id_with_alt, - @Z.add_get_carry, @Z.zselect - in t) in - let t := match t with ?t _ _ _ _ _ _ _ _ _ _ => t end in - t. - -Ltac apply_patterned t1 := - constr:(t1 - (@Let_In) - (@runtime_mul) - (@runtime_add) - (@runtime_opp) - (@runtime_shr) - (@runtime_and) - (@runtime_lor) - (@id_with_alt) - (@Z.add_get_carry) - (@Z.zselect)). - -Ltac pattern_strip_full t := - let t := (eval pattern - (@Let_In Z (fun _ => Z)), - @Z.add_get_carry_cps, @Z.mul_split_at_bitwidth_cps, - (@Z.eq_dec_cps), (@Z.eqb_cps), - @runtime_mul, @runtime_add, @runtime_opp, @runtime_shr, @runtime_and, @runtime_lor, - (@id_with_alt Z), - @Z.add_get_carry, @Z.zselect, @Z.mul_split_at_bitwidth, - Z.mul, Z.add, Z.opp, Z.shiftr, Z.shiftl, Z.land, Z.lor, - Z.modulo, Z.div, Z.log2, Z.pow, Z.ones, - Z.eq_dec, Z.eqb, - (@ModularArithmetic.F.to_Z), (@ModularArithmetic.F.of_Z), - 2%Z, 1%Z, 0%Z - in t) in - let t := match t with ?t - _ - _ _ - _ _ - _ _ _ _ _ _ - _ - _ _ _ - _ _ _ _ _ _ _ - _ _ _ _ _ - _ _ - _ _ - _ _ _ => t end in - let t := (eval pattern Z, (@Let_In), (@id_with_alt) in t) in - let t := match t with ?t _ _ _ => t end in - t. - -Ltac apply_patterned_full t1 := - constr:(t1 - Z - (@Let_In) (@id_with_alt) - (@Let_In Z (fun _ => Z)) - (@Z.add_get_carry_cps) (@Z.mul_split_at_bitwidth_cps) - (@Z.eq_dec_cps) (@Z.eqb_cps) - (@runtime_mul) (@runtime_add) (@runtime_opp) (@runtime_shr) (@runtime_and) (@runtime_lor) - (@id_with_alt Z) - (@Z.add_get_carry) (@Z.zselect) (@Z.mul_split_at_bitwidth) - Z.mul Z.add Z.opp Z.shiftr Z.shiftl Z.land Z.lor - Z.modulo Z.div Z.log2 Z.pow Z.ones - Z.eq_dec Z.eqb - (@ModularArithmetic.F.to_Z) (@ModularArithmetic.F.of_Z) - 2%Z 1%Z 0%Z). - -Ltac basesystem_partial_evaluation_gen unfold_tac t t1 := - let t := unfold_tac t in - let t := pattern_strip t in - let dummy := match goal with _ => pose t as t1 end in - let t1' := apply_patterned t1 in - t1'. - -Ltac basesystem_partial_evaluation_RHS_gen unfold_tac := - let t := match goal with |- _ _ ?t => t end in - let t1 := fresh "t1" in - let t1' := basesystem_partial_evaluation_gen unfold_tac t t1 in - transitivity t1'; - [replace_with_vm_compute t1; clear t1|reflexivity]. - -Ltac basesystem_partial_evaluation_default_unfolder t := - basesystem_partial_evaluation_unfolder t. - -Ltac basesystem_partial_evaluation_RHS := - basesystem_partial_evaluation_RHS_gen basesystem_partial_evaluation_default_unfolder. -Ltac basesystem_partial_evaluation := - basesystem_partial_evaluation_gen basesystem_partial_evaluation_default_unfolder. - - -(** This block of tactic code works around bug #5434 - (https://coq.inria.fr/bugs/show_bug.cgi?id=5434), that - [vm_compute] breaks an invariant in pretyping/constr_matching.ml. - So we refresh all of the names in match statements in the goal by - crawling it. - - In particular, [replace_with_vm_compute] creates a [vm_compute]d - term which has anonymous binders where pretyping expects there to - be named binders. This shows up when you try to match on the - function (the branch statement of the match) with an Ltac pattern - like [(fun x : ?T => ?C)] rather than [(fun x : ?T => @?C x)]; we - use the former in reification to save the cost of many extra - invocations of [cbv beta]. Luckily, patterns like [(fun x : ?T => - @?C x)] don't trigger this anomaly, so we can walk the term, - fixing all match statements whose branches are functions whose - binder names were eaten by [vm_compute] (note that in a match, - every branch where the corresponding constructor takes arguments - is represented internally as a function (lambda term)). We fix - the match statements by pulling out the branch with the [@?] - pattern that doesn't trigger the anomaly, and then recreating the - match with a destructuring [let] that hasn't been through - [vm_compute], and therefore has name information that - constr_matching is happy with. *) -Ltac replace_match_with_destructuring_match T := - match T with - | ?F ?X - => let F' := replace_match_with_destructuring_match F in - let X' := replace_match_with_destructuring_match X in - constr:(F' X') - (* we must use [@?f a b] here and not [?f], or else we get an anomaly *) - | match ?d with pair a b => @?f a b end - => let d' := replace_match_with_destructuring_match d in - let T' := fresh in - constr:(let '(a, b) := d' in - match f a b with - | T' => ltac:(let v := (eval cbv beta delta [T'] in T') in - let v := replace_match_with_destructuring_match v in - exact v) - end) - | (fun a : ?A => @?f a) - => let T' := fresh in - let T' := fresh T' in - let T' := fresh T' in - constr:(fun a : A - => match f a with - | T' => ltac:(let v := (eval cbv beta delta [T'] in T') in - let v := replace_match_with_destructuring_match v in - exact v) - end) - | ?x => x - end. -Ltac do_replace_match_with_destructuring_match_in_goal := - let G := get_goal in - let G' := replace_match_with_destructuring_match G in - change G'. - -(* TODO : move *) -Lemma F_of_Z_opp {m} x : F.of_Z m (- x) = F.opp (F.of_Z m x). -Proof. - cbv [F.opp]; intros. rewrite F.to_Z_of_Z, <-Z.sub_0_l. - etransitivity; rewrite F.of_Z_mod; - [rewrite Z.opp_mod_mod|]; reflexivity. -Qed. - -Hint Rewrite <-@F.of_Z_add : pull_FofZ. -Hint Rewrite <-@F.of_Z_mul : pull_FofZ. -Hint Rewrite <-@F.of_Z_sub : pull_FofZ. -Hint Rewrite <-@F_of_Z_opp : pull_FofZ. - -Ltac F_mod_eq := - cbv [Positional.Fdecode]; autorewrite with pull_FofZ; - apply mod_eq_Z2F_iff. - -Ltac presolve_op_mod_eq wt x := - transitivity (Positional.eval wt x); repeat autounfold; - [ cbv [mod_eq]; apply f_equal2; [|reflexivity]; - apply f_equal - | autorewrite with uncps push_id push_basesystem_eval ]. - -Ltac solve_op_mod_eq wt x := - presolve_op_mod_eq wt x; - [ basesystem_partial_evaluation_RHS; - do_replace_match_with_destructuring_match_in_goal - | reflexivity ]. - -Ltac solve_op_F wt x := F_mod_eq; solve_op_mod_eq wt x. -Ltac presolve_op_F wt x := F_mod_eq; presolve_op_mod_eq wt x. diff --git a/src/Arithmetic/CoreUnfolder.v b/src/Arithmetic/CoreUnfolder.v deleted file mode 100644 index b1c79f16d..000000000 --- a/src/Arithmetic/CoreUnfolder.v +++ /dev/null @@ -1,400 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.CPS. -Require Import Crypto.Util.IdfunWithAlt. -Require Import Crypto.Util.CPSUtil. -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Util.Tactics.VM. - -Create HintDb arithmetic_cps_unfolder. - -Hint Unfold Core.div Core.modulo : arithmetic_cps_unfolder. - -Ltac make_parameterized_sig t := - refine (_ : { v : _ | v = t }); - eexists; cbv delta [t - Core.B.Positional.chained_carries_reduce_cps_step - B.limb ListUtil.sum ListUtil.sum_firstn - CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2 - Decidable.dec Decidable.dec_eq_Z - id_tuple_with_alt id_tuple'_with_alt id_tuple_with_alt_cps' - Z.add_get_carry_full Z.mul_split - Z.add_get_carry_full_cps Z.mul_split_cps Z.mul_split_cps' - Z.add_get_carry_cps]; - repeat autorewrite with pattern_runtime; - reflexivity. - -Notation parameterize_sig t := ltac:(let v := constr:(t) in make_parameterized_sig v) (only parsing). - -Ltac make_parameterized_from_sig t_sig := - let t := (eval cbv [proj1_sig t_sig] in (proj1_sig t_sig)) in - let t := pattern_strip t in - exact t. - -Notation parameterize_from_sig t := ltac:(let v := constr:(t) in make_parameterized_from_sig v) (only parsing). - -Ltac make_parameterized_eq t t_sig := - let t := apply_patterned t in - exact (proj2_sig t_sig : t = _). - -Notation parameterize_eq t t_sig := ltac:(let v := constr:(t) in let v_sig := t_sig in make_parameterized_eq v v_sig) (only parsing). - -Ltac basesystem_partial_evaluation_RHS_fast := - repeat autorewrite with pattern_runtime; - let t := match goal with |- _ _ ?t => t end in - let t := pattern_strip t in - let t1 := fresh "t1" in - pose t as t1; - let t1' := apply_patterned t1 in - transitivity t1'; - [replace_with_vm_compute t1; clear t1|reflexivity]. - -Module B. - Module Associational. - (** -<< -#!/bin/bash -for i in eval multerm mul_cps mul split_cps split reduce_cps reduce negate_snd_cps negate_snd carryterm_cps carryterm carry_cps carry; do - echo " Definition ${i}_sig := parameterize_sig (@Core.B.Associational.${i})."; - echo " Definition ${i} := parameterize_from_sig ${i}_sig."; - echo " Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; - echo " Hint Unfold ${i} : basesystem_partial_evaluation_unfolder."; - echo " Hint Rewrite <- ${i}_eq : pattern_runtime."; echo ""; -done -echo " End Associational." -echo " Module Positional." -for i in to_associational_cps to_associational eval zeros add_to_nth_cps add_to_nth place_cps place from_associational_cps from_associational carry_cps carry chained_carries_cps chained_carries encode add_cps mul_cps reduce_cps carry_reduce_cps chained_carries_reduce_cps_step chained_carries_reduce_cps chained_carries_reduce negate_snd_cps split_cps scmul_cps unbalanced_sub_cps sub_cps sub opp_cps Fencode Fdecode eval_from select_cps select; do - echo " Definition ${i}_sig := parameterize_sig (@Core.B.Positional.${i})."; - echo " Definition ${i} := parameterize_from_sig ${i}_sig."; - echo " Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; - echo " Hint Unfold ${i} : basesystem_partial_evaluation_unfolder."; - echo " Hint Rewrite <- ${i}_eq : pattern_runtime."; echo ""; -done -echo " End Positional." -echo "End B." -echo "" -for i in modulo_cps div_cps modulo div; do - echo "Definition ${i}_sig := parameterize_sig (@Core.${i})."; - echo "Definition ${i} := parameterize_from_sig ${i}_sig."; - echo "Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; - echo "Hint Unfold ${i} : basesystem_partial_evaluation_unfolder."; - echo "Hint Rewrite <- ${i}_eq : pattern_runtime."; echo ""; -done ->> *) - Definition eval_sig := parameterize_sig (@Core.B.Associational.eval). - Definition eval := parameterize_from_sig eval_sig. - Definition eval_eq := parameterize_eq eval eval_sig. - Hint Unfold eval : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- eval_eq : pattern_runtime. - - Definition multerm_sig := parameterize_sig (@Core.B.Associational.multerm). - Definition multerm := parameterize_from_sig multerm_sig. - Definition multerm_eq := parameterize_eq multerm multerm_sig. - Hint Unfold multerm : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- multerm_eq : pattern_runtime. - - Definition mul_cps_sig := parameterize_sig (@Core.B.Associational.mul_cps). - Definition mul_cps := parameterize_from_sig mul_cps_sig. - Definition mul_cps_eq := parameterize_eq mul_cps mul_cps_sig. - Hint Unfold mul_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- mul_cps_eq : pattern_runtime. - - Definition mul_sig := parameterize_sig (@Core.B.Associational.mul). - Definition mul := parameterize_from_sig mul_sig. - Definition mul_eq := parameterize_eq mul mul_sig. - Hint Unfold mul : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- mul_eq : pattern_runtime. - - Definition split_cps_sig := parameterize_sig (@Core.B.Associational.split_cps). - Definition split_cps := parameterize_from_sig split_cps_sig. - Definition split_cps_eq := parameterize_eq split_cps split_cps_sig. - Hint Unfold split_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- split_cps_eq : pattern_runtime. - - Definition split_sig := parameterize_sig (@Core.B.Associational.split). - Definition split := parameterize_from_sig split_sig. - Definition split_eq := parameterize_eq split split_sig. - Hint Unfold split : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- split_eq : pattern_runtime. - - Definition reduce_cps_sig := parameterize_sig (@Core.B.Associational.reduce_cps). - Definition reduce_cps := parameterize_from_sig reduce_cps_sig. - Definition reduce_cps_eq := parameterize_eq reduce_cps reduce_cps_sig. - Hint Unfold reduce_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- reduce_cps_eq : pattern_runtime. - - Definition reduce_sig := parameterize_sig (@Core.B.Associational.reduce). - Definition reduce := parameterize_from_sig reduce_sig. - Definition reduce_eq := parameterize_eq reduce reduce_sig. - Hint Unfold reduce : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- reduce_eq : pattern_runtime. - - Definition negate_snd_cps_sig := parameterize_sig (@Core.B.Associational.negate_snd_cps). - Definition negate_snd_cps := parameterize_from_sig negate_snd_cps_sig. - Definition negate_snd_cps_eq := parameterize_eq negate_snd_cps negate_snd_cps_sig. - Hint Unfold negate_snd_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- negate_snd_cps_eq : pattern_runtime. - - Definition negate_snd_sig := parameterize_sig (@Core.B.Associational.negate_snd). - Definition negate_snd := parameterize_from_sig negate_snd_sig. - Definition negate_snd_eq := parameterize_eq negate_snd negate_snd_sig. - Hint Unfold negate_snd : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- negate_snd_eq : pattern_runtime. - - Definition carryterm_cps_sig := parameterize_sig (@Core.B.Associational.carryterm_cps). - Definition carryterm_cps := parameterize_from_sig carryterm_cps_sig. - Definition carryterm_cps_eq := parameterize_eq carryterm_cps carryterm_cps_sig. - Hint Unfold carryterm_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- carryterm_cps_eq : pattern_runtime. - - Definition carryterm_sig := parameterize_sig (@Core.B.Associational.carryterm). - Definition carryterm := parameterize_from_sig carryterm_sig. - Definition carryterm_eq := parameterize_eq carryterm carryterm_sig. - Hint Unfold carryterm : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- carryterm_eq : pattern_runtime. - - Definition carry_cps_sig := parameterize_sig (@Core.B.Associational.carry_cps). - Definition carry_cps := parameterize_from_sig carry_cps_sig. - Definition carry_cps_eq := parameterize_eq carry_cps carry_cps_sig. - Hint Unfold carry_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- carry_cps_eq : pattern_runtime. - - Definition carry_sig := parameterize_sig (@Core.B.Associational.carry). - Definition carry := parameterize_from_sig carry_sig. - Definition carry_eq := parameterize_eq carry carry_sig. - Hint Unfold carry : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- carry_eq : pattern_runtime. - - End Associational. - Module Positional. - Definition to_associational_cps_sig := parameterize_sig (@Core.B.Positional.to_associational_cps). - Definition to_associational_cps := parameterize_from_sig to_associational_cps_sig. - Definition to_associational_cps_eq := parameterize_eq to_associational_cps to_associational_cps_sig. - Hint Unfold to_associational_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- to_associational_cps_eq : pattern_runtime. - - Definition to_associational_sig := parameterize_sig (@Core.B.Positional.to_associational). - Definition to_associational := parameterize_from_sig to_associational_sig. - Definition to_associational_eq := parameterize_eq to_associational to_associational_sig. - Hint Unfold to_associational : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- to_associational_eq : pattern_runtime. - - Definition eval_sig := parameterize_sig (@Core.B.Positional.eval). - Definition eval := parameterize_from_sig eval_sig. - Definition eval_eq := parameterize_eq eval eval_sig. - Hint Unfold eval : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- eval_eq : pattern_runtime. - - Definition zeros_sig := parameterize_sig (@Core.B.Positional.zeros). - Definition zeros := parameterize_from_sig zeros_sig. - Definition zeros_eq := parameterize_eq zeros zeros_sig. - Hint Unfold zeros : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- zeros_eq : pattern_runtime. - - Definition add_to_nth_cps_sig := parameterize_sig (@Core.B.Positional.add_to_nth_cps). - Definition add_to_nth_cps := parameterize_from_sig add_to_nth_cps_sig. - Definition add_to_nth_cps_eq := parameterize_eq add_to_nth_cps add_to_nth_cps_sig. - Hint Unfold add_to_nth_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- add_to_nth_cps_eq : pattern_runtime. - - Definition add_to_nth_sig := parameterize_sig (@Core.B.Positional.add_to_nth). - Definition add_to_nth := parameterize_from_sig add_to_nth_sig. - Definition add_to_nth_eq := parameterize_eq add_to_nth add_to_nth_sig. - Hint Unfold add_to_nth : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- add_to_nth_eq : pattern_runtime. - - Definition place_cps_sig := parameterize_sig (@Core.B.Positional.place_cps). - Definition place_cps := parameterize_from_sig place_cps_sig. - Definition place_cps_eq := parameterize_eq place_cps place_cps_sig. - Hint Unfold place_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- place_cps_eq : pattern_runtime. - - Definition place_sig := parameterize_sig (@Core.B.Positional.place). - Definition place := parameterize_from_sig place_sig. - Definition place_eq := parameterize_eq place place_sig. - Hint Unfold place : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- place_eq : pattern_runtime. - - Definition from_associational_cps_sig := parameterize_sig (@Core.B.Positional.from_associational_cps). - Definition from_associational_cps := parameterize_from_sig from_associational_cps_sig. - Definition from_associational_cps_eq := parameterize_eq from_associational_cps from_associational_cps_sig. - Hint Unfold from_associational_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- from_associational_cps_eq : pattern_runtime. - - Definition from_associational_sig := parameterize_sig (@Core.B.Positional.from_associational). - Definition from_associational := parameterize_from_sig from_associational_sig. - Definition from_associational_eq := parameterize_eq from_associational from_associational_sig. - Hint Unfold from_associational : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- from_associational_eq : pattern_runtime. - - Definition carry_cps_sig := parameterize_sig (@Core.B.Positional.carry_cps). - Definition carry_cps := parameterize_from_sig carry_cps_sig. - Definition carry_cps_eq := parameterize_eq carry_cps carry_cps_sig. - Hint Unfold carry_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- carry_cps_eq : pattern_runtime. - - Definition carry_sig := parameterize_sig (@Core.B.Positional.carry). - Definition carry := parameterize_from_sig carry_sig. - Definition carry_eq := parameterize_eq carry carry_sig. - Hint Unfold carry : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- carry_eq : pattern_runtime. - - Definition chained_carries_cps_sig := parameterize_sig (@Core.B.Positional.chained_carries_cps). - Definition chained_carries_cps := parameterize_from_sig chained_carries_cps_sig. - Definition chained_carries_cps_eq := parameterize_eq chained_carries_cps chained_carries_cps_sig. - Hint Unfold chained_carries_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- chained_carries_cps_eq : pattern_runtime. - - Definition chained_carries_sig := parameterize_sig (@Core.B.Positional.chained_carries). - Definition chained_carries := parameterize_from_sig chained_carries_sig. - Definition chained_carries_eq := parameterize_eq chained_carries chained_carries_sig. - Hint Unfold chained_carries : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- chained_carries_eq : pattern_runtime. - - Definition encode_sig := parameterize_sig (@Core.B.Positional.encode). - Definition encode := parameterize_from_sig encode_sig. - Definition encode_eq := parameterize_eq encode encode_sig. - Hint Unfold encode : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- encode_eq : pattern_runtime. - - Definition add_cps_sig := parameterize_sig (@Core.B.Positional.add_cps). - Definition add_cps := parameterize_from_sig add_cps_sig. - Definition add_cps_eq := parameterize_eq add_cps add_cps_sig. - Hint Unfold add_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- add_cps_eq : pattern_runtime. - - Definition mul_cps_sig := parameterize_sig (@Core.B.Positional.mul_cps). - Definition mul_cps := parameterize_from_sig mul_cps_sig. - Definition mul_cps_eq := parameterize_eq mul_cps mul_cps_sig. - Hint Unfold mul_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- mul_cps_eq : pattern_runtime. - - Definition reduce_cps_sig := parameterize_sig (@Core.B.Positional.reduce_cps). - Definition reduce_cps := parameterize_from_sig reduce_cps_sig. - Definition reduce_cps_eq := parameterize_eq reduce_cps reduce_cps_sig. - Hint Unfold reduce_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- reduce_cps_eq : pattern_runtime. - - Definition carry_reduce_cps_sig := parameterize_sig (@Core.B.Positional.carry_reduce_cps). - Definition carry_reduce_cps := parameterize_from_sig carry_reduce_cps_sig. - Definition carry_reduce_cps_eq := parameterize_eq carry_reduce_cps carry_reduce_cps_sig. - Hint Unfold carry_reduce_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- carry_reduce_cps_eq : pattern_runtime. - - Definition chained_carries_reduce_cps_step_sig := parameterize_sig (@Core.B.Positional.chained_carries_reduce_cps_step). - Definition chained_carries_reduce_cps_step := parameterize_from_sig chained_carries_reduce_cps_step_sig. - Definition chained_carries_reduce_cps_step_eq := parameterize_eq chained_carries_reduce_cps_step chained_carries_reduce_cps_step_sig. - Hint Unfold chained_carries_reduce_cps_step : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- chained_carries_reduce_cps_step_eq : pattern_runtime. - - Definition chained_carries_reduce_cps_sig := parameterize_sig (@Core.B.Positional.chained_carries_reduce_cps). - Definition chained_carries_reduce_cps := parameterize_from_sig chained_carries_reduce_cps_sig. - Definition chained_carries_reduce_cps_eq := parameterize_eq chained_carries_reduce_cps chained_carries_reduce_cps_sig. - Hint Unfold chained_carries_reduce_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- chained_carries_reduce_cps_eq : pattern_runtime. - - Definition chained_carries_reduce_sig := parameterize_sig (@Core.B.Positional.chained_carries_reduce). - Definition chained_carries_reduce := parameterize_from_sig chained_carries_reduce_sig. - Definition chained_carries_reduce_eq := parameterize_eq chained_carries_reduce chained_carries_reduce_sig. - Hint Unfold chained_carries_reduce : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- chained_carries_reduce_eq : pattern_runtime. - - Definition negate_snd_cps_sig := parameterize_sig (@Core.B.Positional.negate_snd_cps). - Definition negate_snd_cps := parameterize_from_sig negate_snd_cps_sig. - Definition negate_snd_cps_eq := parameterize_eq negate_snd_cps negate_snd_cps_sig. - Hint Unfold negate_snd_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- negate_snd_cps_eq : pattern_runtime. - - Definition split_cps_sig := parameterize_sig (@Core.B.Positional.split_cps). - Definition split_cps := parameterize_from_sig split_cps_sig. - Definition split_cps_eq := parameterize_eq split_cps split_cps_sig. - Hint Unfold split_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- split_cps_eq : pattern_runtime. - - Definition scmul_cps_sig := parameterize_sig (@Core.B.Positional.scmul_cps). - Definition scmul_cps := parameterize_from_sig scmul_cps_sig. - Definition scmul_cps_eq := parameterize_eq scmul_cps scmul_cps_sig. - Hint Unfold scmul_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- scmul_cps_eq : pattern_runtime. - - Definition unbalanced_sub_cps_sig := parameterize_sig (@Core.B.Positional.unbalanced_sub_cps). - Definition unbalanced_sub_cps := parameterize_from_sig unbalanced_sub_cps_sig. - Definition unbalanced_sub_cps_eq := parameterize_eq unbalanced_sub_cps unbalanced_sub_cps_sig. - Hint Unfold unbalanced_sub_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- unbalanced_sub_cps_eq : pattern_runtime. - - Definition sub_cps_sig := parameterize_sig (@Core.B.Positional.sub_cps). - Definition sub_cps := parameterize_from_sig sub_cps_sig. - Definition sub_cps_eq := parameterize_eq sub_cps sub_cps_sig. - Hint Unfold sub_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- sub_cps_eq : pattern_runtime. - - Definition sub_sig := parameterize_sig (@Core.B.Positional.sub). - Definition sub := parameterize_from_sig sub_sig. - Definition sub_eq := parameterize_eq sub sub_sig. - Hint Unfold sub : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- sub_eq : pattern_runtime. - - Definition opp_cps_sig := parameterize_sig (@Core.B.Positional.opp_cps). - Definition opp_cps := parameterize_from_sig opp_cps_sig. - Definition opp_cps_eq := parameterize_eq opp_cps opp_cps_sig. - Hint Unfold opp_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- opp_cps_eq : pattern_runtime. - - Definition Fencode_sig := parameterize_sig (@Core.B.Positional.Fencode). - Definition Fencode := parameterize_from_sig Fencode_sig. - Definition Fencode_eq := parameterize_eq Fencode Fencode_sig. - Hint Unfold Fencode : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- Fencode_eq : pattern_runtime. - - Definition Fdecode_sig := parameterize_sig (@Core.B.Positional.Fdecode). - Definition Fdecode := parameterize_from_sig Fdecode_sig. - Definition Fdecode_eq := parameterize_eq Fdecode Fdecode_sig. - Hint Unfold Fdecode : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- Fdecode_eq : pattern_runtime. - - Definition eval_from_sig := parameterize_sig (@Core.B.Positional.eval_from). - Definition eval_from := parameterize_from_sig eval_from_sig. - Definition eval_from_eq := parameterize_eq eval_from eval_from_sig. - Hint Unfold eval_from : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- eval_from_eq : pattern_runtime. - - Definition select_cps_sig := parameterize_sig (@Core.B.Positional.select_cps). - Definition select_cps := parameterize_from_sig select_cps_sig. - Definition select_cps_eq := parameterize_eq select_cps select_cps_sig. - Hint Unfold select_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- select_cps_eq : pattern_runtime. - - Definition select_sig := parameterize_sig (@Core.B.Positional.select). - Definition select := parameterize_from_sig select_sig. - Definition select_eq := parameterize_eq select select_sig. - Hint Unfold select : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- select_eq : pattern_runtime. - - End Positional. -End B. - -Definition modulo_cps_sig := parameterize_sig (@Core.modulo_cps). -Definition modulo_cps := parameterize_from_sig modulo_cps_sig. -Definition modulo_cps_eq := parameterize_eq modulo_cps modulo_cps_sig. -Hint Unfold modulo_cps : basesystem_partial_evaluation_unfolder. -Hint Rewrite <- modulo_cps_eq : pattern_runtime. - -Definition div_cps_sig := parameterize_sig (@Core.div_cps). -Definition div_cps := parameterize_from_sig div_cps_sig. -Definition div_cps_eq := parameterize_eq div_cps div_cps_sig. -Hint Unfold div_cps : basesystem_partial_evaluation_unfolder. -Hint Rewrite <- div_cps_eq : pattern_runtime. - -Definition modulo_sig := parameterize_sig (@Core.modulo). -Definition modulo := parameterize_from_sig modulo_sig. -Definition modulo_eq := parameterize_eq modulo modulo_sig. -Hint Unfold modulo : basesystem_partial_evaluation_unfolder. -Hint Rewrite <- modulo_eq : pattern_runtime. - -Definition div_sig := parameterize_sig (@Core.div). -Definition div := parameterize_from_sig div_sig. -Definition div_eq := parameterize_eq div div_sig. -Hint Unfold div : basesystem_partial_evaluation_unfolder. -Hint Rewrite <- div_eq : pattern_runtime. diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v deleted file mode 100644 index 1873e5ef1..000000000 --- a/src/Arithmetic/Karatsuba.v +++ /dev/null @@ -1,228 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.micromega.Lia. -Require Import Crypto.Algebra.Nsatz. -Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. -Require Import Crypto.Arithmetic.Core. Import B. Import Positional. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.IdfunWithAlt. -Require Import Crypto.Util.ZUtil.EquivModulo. -Local Open Scope Z_scope. - -Section Karatsuba. -Context (weight : nat -> Z) - (weight_0 : weight 0%nat = 1%Z) - (weight_nonzero : forall i, weight i <> 0). - (* [tuple Z n] is the "half-length" type, - [tuple Z n2] is the "full-length" type *) - Context {n n2 : nat} (n_nonzero : n <> 0%nat) (n2_nonzero : n2 <> 0%nat). - Let T := tuple Z n. - Let T2 := tuple Z n2. - - (* - If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0, - with: - - z2 = x1y1 - z0 = x0y0 - z1 = (x1+x0)(y1+y0) - (z2 + z0) - - Computing z1 one operation at a time: - sum_z = z0 + z2 - sum_x = x1 + x0 - sum_y = y1 + y0 - mul_sumxy = sum_x * sum_y - z1 = mul_sumxy - sum_z - *) - Definition karatsuba_mul_cps s (x y : T2) {R} (f:T2->R) := - split_cps (n:=n2) (m1:=n) (m2:=n) weight s x - (fun x0_x1 => split_cps weight s y - (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1) - (fun z0 => mul_cps weight(snd x0_x1) (snd y0_y1) - (fun z2 => add_cps weight z0 z2 - (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1) - (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1) - (fun sum_y => mul_cps weight sum_x sum_y - (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy sum_z - (fun z1 => scmul_cps weight s z1 - (fun sz1 => scmul_cps weight (s^2) z2 - (fun s2z2 => add_cps weight s2z2 sz1 - (fun add_s2z2_sz1 => add_cps weight add_s2z2_sz1 z0 f)))))))))))). - - Definition karatsuba_mul s x y := @karatsuba_mul_cps s x y _ id. - Lemma karatsuba_mul_id s x y R f : - @karatsuba_mul_cps s x y R f = f (karatsuba_mul s x y). - Proof. - cbv [karatsuba_mul karatsuba_mul_cps]. - repeat autounfold. - autorewrite with cancel_pair push_id uncps. - reflexivity. - Qed. - Hint Opaque karatsuba_mul : uncps. - Hint Rewrite karatsuba_mul_id : uncps. - - Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) : - eval weight (karatsuba_mul s x y) = eval weight x * eval weight y. - Proof. - cbv [karatsuba_mul karatsuba_mul_cps]; repeat autounfold. - autorewrite with cancel_pair push_id uncps push_basesystem_eval. - repeat match goal with - | _ => rewrite <-eval_to_associational - | |- context [(to_associational ?w ?x)] => - rewrite <-(Associational.eval_split - s (to_associational w x)) by assumption - | _ => rewrite <-Associational.eval_split by assumption - | _ => setoid_rewrite Associational.eval_nil - end. - ring_simplify. - nsatz. - Qed. - - (* These definitions are intended to make bounds analysis go through - for karatsuba. Essentially, we provide a version of the code to - actually run and a version to bounds-check, along with a proof - that they are exactly equal. This works around cases where the - bounds proof requires high-level reasoning. *) - Local Notation id_with_alt_bounds_cps := id_tuple_with_alt_cps'. - - (* - If: - s^2 mod p = (s + 1) mod p - x = x0 + sx1 - y = y0 + sy1 - Then, with z0 and z2 as before (x0y0 and x1y1 respectively), let z1 = ((x0 + x1) * (y0 + y1)) - z0. - - Computing xy one operation at a time: - sum_z = z0 + z2 - sum_x = x0 + x1 - sum_y = y0 + y1 - mul_sumxy = sum_x * sum_y - z1 = mul_sumxy - z0 - sz1 = s * z1 - xy = sum_z - sz1 - - The subtraction in the computation of z1 presents issues for - bounds analysis. In particular, just analyzing the upper and lower - bounds of the values would indicate that it could underflow--we - know it won't because - - mul_sumxy -z0 = ((x0+x1) * (y0+y1)) - x0y0 - = (x0y0 + x1y0 + x0y1 + x1y1) - x0y0 - = x1y0 + x0y1 + x1y1 - - Therefore, we use id_with_alt_bounds to indicate that the - bounds-checker should check the non-subtracting form. - - *) - - (* - Definition goldilocks_mul_cps_for_bounds_checker - s (xs ys : T2) {R} (f:T2->R) := - split_cps (m1:=n) (m2:=n) weight s xs - (fun x0_x1 => split_cps weight s ys - - (fun z1 => Positional.to_associational_cps weight z1 - (fun z1 => Associational.mul_cps (pair s 1::nil) z1 - (fun sz1 => Positional.from_associational_cps weight n2 sz1 - (fun sz1 => add_cps weight sum_z sz1 f)))))))))))). - *) - - Let T3 := tuple Z (n2+n). - Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T3->R) := - split_cps (m1:=n) (m2:=n) weight s xs - (fun x0_x1 => split_cps weight s ys - (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1) - (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1) - (fun z2 => add_cps weight z0 z2 - (fun sum_z : tuple _ n2 => add_cps weight (fst x0_x1) (snd x0_x1) - (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1) - (fun sum_y => mul_cps weight sum_x sum_y - (fun mul_sumxy => - - id_with_alt_bounds_cps (fun f => - (unbalanced_sub_cps weight mul_sumxy z0 f)) (fun f => - - (mul_cps weight (fst x0_x1) (snd y0_y1) - (fun x0_y1 => mul_cps weight (snd x0_x1) (fst y0_y1) - (fun x1_y0 => mul_cps weight (fst x0_x1) (fst y0_y1) - (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1) - (fun z2 => add_cps weight z0 z2 - (fun sum_z => add_cps weight x0_y1 x1_y0 - (fun z1' => add_cps weight z1' z2 f)))))))) (fun z1 => - - Positional.to_associational_cps weight z1 - (fun z1 => Associational.mul_cps (pair s 1::nil) z1 - (fun sz1 => Positional.to_associational_cps weight sum_z - (fun sum_z => Positional.from_associational_cps weight _ (sum_z++sz1) f - )))))))))))). - - Definition goldilocks_mul s xs ys := goldilocks_mul_cps s xs ys id. - Lemma goldilocks_mul_id s xs ys R f : - @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys). - Proof. - cbv [goldilocks_mul goldilocks_mul_cps Let_In]. - repeat autounfold. autorewrite with uncps push_id. - reflexivity. - Qed. - Hint Opaque goldilocks_mul : uncps. - Hint Rewrite goldilocks_mul_id : uncps. - - Local Existing Instances Z.equiv_modulo_Reflexive - RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric - Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper - Z.modulo_equiv_modulo_Proper. - - Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys : - (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p. - Proof. - cbv [goldilocks_mul_cps goldilocks_mul Let_In]. - Zmod_to_equiv_modulo. - progress autounfold. - progress autorewrite with push_id cancel_pair uncps push_basesystem_eval. - rewrite !unfold_id_tuple_with_alt. - repeat match goal with - | _ => rewrite <-eval_to_associational - | |- context [(to_associational ?w ?x)] => - rewrite <-(Associational.eval_split - s (to_associational w x)) by assumption - | _ => rewrite <-Associational.eval_split by assumption - | _ => setoid_rewrite Associational.eval_nil - end. - progress autorewrite with push_id cancel_pair uncps push_basesystem_eval. - repeat (rewrite ?eval_from_associational, ?eval_to_associational). - progress autorewrite with push_id cancel_pair uncps push_basesystem_eval. - repeat match goal with - | _ => rewrite <-eval_to_associational - | |- context [(to_associational ?w ?x)] => - rewrite <-(Associational.eval_split - s (to_associational w x)) by assumption - | _ => rewrite <-Associational.eval_split by assumption - | _ => setoid_rewrite Associational.eval_nil - end. - ring_simplify. - setoid_rewrite s2_modp. - apply f_equal2; nsatz. - assumption. assumption. omega. - Qed. - - Lemma eval_goldilocks_mul (p : positive) s (s_nonzero : s <> 0) (s2_modp : mod_eq p (s^2) (s+1)) xs ys : - mod_eq p (eval weight (goldilocks_mul s xs ys)) (eval weight xs * eval weight ys). - Proof. - apply goldilocks_mul_correct; auto; lia. - Qed. -End Karatsuba. -Hint Opaque karatsuba_mul goldilocks_mul : uncps. -Hint Rewrite karatsuba_mul_id goldilocks_mul_id : uncps. - -Hint Rewrite - @eval_karatsuba_mul - @eval_goldilocks_mul - @goldilocks_mul_correct - using (assumption || (div_mod_cps_t; auto)) : push_basesystem_eval. - -Ltac basesystem_partial_evaluation_unfolder t := - let t := (eval cbv delta [goldilocks_mul karatsuba_mul goldilocks_mul_cps karatsuba_mul_cps] in t) in - let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in - t. - -Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::= - basesystem_partial_evaluation_unfolder t. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Definition.v deleted file mode 100644 index 2ea623b0b..000000000 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Definition.v +++ /dev/null @@ -1,61 +0,0 @@ -(*** Word-By-Word Montgomery Multiplication *) -(** This file implements Montgomery Form, Montgomery Reduction, and - Montgomery Multiplication on an abstract [T]. See - https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion - of the algorithm; note that it may be that none of the algorithms - there exactly match what we're doing here. *) -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.ZUtil.Definitions. - -Local Open Scope Z_scope. - -Section WordByWordMontgomery. - Local Coercion Z.pos : positive >-> Z. - Context - {T : Type} - {eval : T -> Z} - {numlimbs : T -> nat} - {zero : nat -> T} - {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) - {r : positive} - {scmul : Z -> T -> T} (* uses double-output multiply *) - {R : positive} - {add : T -> T -> T} (* joins carry *) - {drop_high : T -> T} (* drops the highest limb *) - (N : T). - - (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) - Section Iteration. - Context (B : T) (k : Z). - Context (A S : T). - (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) - Local Definition A_a := dlet p := divmod A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a. - Local Definition S1 := add S (scmul a B). - Local Definition s := snd (divmod S1). - Local Definition q := fst (Z.mul_split r s k). - Local Definition S2 := add S1 (scmul q N). - Local Definition S3 := fst (divmod S2). - Local Definition S4 := drop_high S3. - End Iteration. - - Section loop. - Context (A B : T) (k : Z) (S' : T). - - Definition redc_body : T * T -> T * T - := fun '(A, S') => (A' A, S4 B k A S'). - - Fixpoint redc_loop (count : nat) : T * T -> T * T - := match count with - | O => fun A_S => A_S - | S count' => fun A_S => redc_loop count' (redc_body A_S) - end. - - Definition redc : T - := snd (redc_loop (numlimbs A) (A, zero (1 + numlimbs B))). - End loop. -End WordByWordMontgomery. - -Create HintDb word_by_word_montgomery. -Hint Unfold S4 S3 S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v deleted file mode 100644 index cff906465..000000000 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v +++ /dev/null @@ -1,81 +0,0 @@ -(*** Word-By-Word Montgomery Multiplication *) -(** This file implements Montgomery Form, Montgomery Reduction, and - Montgomery Multiplication on an abstract [T : ℕ → Type]. See - https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion - of the algorithm; note that it may be that none of the algorithms - there exactly match what we're doing here. *) -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.ZUtil.Definitions. - -Local Open Scope Z_scope. - -Section WordByWordMontgomery. - Local Coercion Z.pos : positive >-> Z. - Context - {T : nat -> Type} - {eval : forall {n}, T n -> Z} - {zero : forall {n}, T n} - {divmod : forall {n}, T (S n) -> T n * Z} (* returns lowest limb and all-but-lowest-limb *) - {r : positive} - {R : positive} - {R_numlimbs : nat} - {scmul : forall {n}, Z -> T n -> T (S n)} (* uses double-output multiply *) - {addT : forall {n}, T n -> T n -> T (S n)} (* joins carry *) - {addT' : forall {n}, T (S n) -> T n -> T (S (S n))} (* joins carry *) - {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *) - {conditional_sub : T (S R_numlimbs) -> T R_numlimbs} (* computes [arg - N] if [N <= arg], and drops high bit *) - {sub_then_maybe_add : T R_numlimbs -> T R_numlimbs -> T R_numlimbs} (* computes [a - b + if (a - b) T pred_A_numlimbs * T (S R_numlimbs) - := fun '(A, S') => (A' _ A, S4 _ B k A S'). - - Fixpoint redc_loop (count : nat) : T count * T (S R_numlimbs) -> T O * T (S R_numlimbs) - := match count return T count * _ -> _ with - | O => fun A_S => A_S - | S count' => fun A_S => redc_loop count' (redc_body A_S) - end. - - Definition pre_redc : T (S R_numlimbs) - := snd (redc_loop A_numlimbs (A, zero (1 + R_numlimbs))). - - Definition redc : T R_numlimbs - := conditional_sub pre_redc. - End loop. - - Definition add (A B : T R_numlimbs) : T R_numlimbs - := conditional_sub (addT _ A B). - Definition sub (A B : T R_numlimbs) : T R_numlimbs - := sub_then_maybe_add A B. - Definition opp (A : T R_numlimbs) : T R_numlimbs - := sub (zero _) A. -End WordByWordMontgomery. - -Create HintDb word_by_word_montgomery. -Hint Unfold S4 S3 S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v deleted file mode 100644 index 3dd7fc0b3..000000000 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v +++ /dev/null @@ -1,582 +0,0 @@ -(*** Word-By-Word Montgomery Multiplication Proofs *) -Require Import Coq.Arith.Arith. -Require Import Coq.ZArith.BinInt Coq.ZArith.ZArith Coq.ZArith.Zdiv Coq.micromega.Lia. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Prod. -Require Import Crypto.Util.NatUtil. -Require Import Crypto.Arithmetic.ModularArithmeticTheorems Crypto.Spec.ModularArithmetic. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. -Require Import Crypto.Algebra.Ring. -Require Import Crypto.Util.ZUtil.MulSplit. -Require Import Crypto.Util.ZUtil.Div. -Require Import Crypto.Util.ZUtil.EquivModulo. -Require Import Crypto.Util.ZUtil.Modulo. -Require Import Crypto.Util.ZUtil.Modulo.PullPush. -Require Import Crypto.Util.ZUtil.Tactics.PeelLe. -Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. -Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. -Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. -Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. -Require Import Crypto.Util.Sigma. -Require Import Crypto.Util.Tactics.SetEvars. -Require Import Crypto.Util.Tactics.SubstEvars. -Require Import Crypto.Util.Tactics.DestructHead. -Require Import Crypto.Util.Tactics.BreakMatch. -Local Open Scope Z_scope. - -Section WordByWordMontgomery. - Context - {T : nat -> Type} - {eval : forall {n}, T n -> Z} - {zero : forall {n}, T n} - {divmod : forall {n}, T (S n) -> T n * Z} (* returns lowest limb and all-but-lowest-limb *) - {r : positive} - {r_big : r > 1} - {R : positive} - {R_numlimbs : nat} - {R_correct : R = r^Z.of_nat R_numlimbs :> Z} - {small : forall {n}, T n -> Prop} - {eval_zero : forall n, eval (@zero n) = 0} - {small_zero : forall n, small (@zero n)} - {eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r} - {eval_mod : forall n v, small v -> snd (@divmod n v) = eval v mod r} - {small_div : forall n v, small v -> small (fst (@divmod n v))} - {scmul : forall {n}, Z -> T n -> T (S n)} (* uses double-output multiply *) - {eval_scmul: forall n a v, small v -> 0 <= a < r -> 0 <= eval v < R -> eval (@scmul n a v) = a * eval v} - {small_scmul : forall n a v, small v -> 0 <= a < r -> 0 <= eval v < R -> small (@scmul n a v)} - {addT : forall {n}, T n -> T n -> T (S n)} (* joins carry *) - {eval_addT : forall n a b, eval (@addT n a b) = eval a + eval b} - {small_addT : forall n a b, small a -> small b -> small (@addT n a b)} - {addT' : forall {n}, T (S n) -> T n -> T (S (S n))} (* joins carry *) - {eval_addT' : forall n a b, eval (@addT' n a b) = eval a + eval b} - {small_addT' : forall n a b, small a -> small b -> small (@addT' n a b)} - {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *) - {eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs)} - {small_drop_high : forall v, small v -> small (drop_high v)} - (N : T R_numlimbs) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) - (small_N : small N) - (N_lt_R : eval N < R) - {conditional_sub : T (S R_numlimbs) -> T R_numlimbs} (* computes [arg - N] if [N <= arg], and drops high bit *) - {eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v) = eval v + if eval N <=? eval v then -eval N else 0} - {small_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> small (conditional_sub v)} - {sub_then_maybe_add : T R_numlimbs -> T R_numlimbs -> T R_numlimbs} (* computes [a - b + if (a - b) small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> eval (sub_then_maybe_add a b) = eval a - eval b + if eval a - eval b destruct H - end ]. - Hint Rewrite - eval_zero - eval_div - eval_mod - eval_addT - eval_addT' - eval_scmul - eval_drop_high - eval_conditional_sub - eval_sub_then_maybe_add - using (repeat autounfold with word_by_word_montgomery; t_small) - : push_eval. - - Local Arguments eval {_} _. - Local Arguments small {_} _. - Local Arguments divmod {_} _. - - (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) - Section Iteration. - Context (pred_A_numlimbs : nat) - (A : T (S pred_A_numlimbs)) - (S : T (S R_numlimbs)) - (small_A : small A) - (small_S : small S) - (S_nonneg : 0 <= eval S). - (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) - - Local Coercion eval : T >-> Z. - - Local Notation a := (@WordByWord.Abstract.Dependent.Definition.a T (@divmod) pred_A_numlimbs A). - Local Notation A' := (@WordByWord.Abstract.Dependent.Definition.A' T (@divmod) pred_A_numlimbs A). - Local Notation S1 := (@WordByWord.Abstract.Dependent.Definition.S1 T (@divmod) R_numlimbs scmul addT pred_A_numlimbs B A S). - Local Notation s := (@WordByWord.Abstract.Dependent.Definition.s T (@divmod) R_numlimbs scmul addT pred_A_numlimbs B A S). - Local Notation q := (@WordByWord.Abstract.Dependent.Definition.q T (@divmod) r R_numlimbs scmul addT pred_A_numlimbs B k A S). - Local Notation S2 := (@WordByWord.Abstract.Dependent.Definition.S2 T (@divmod) r R_numlimbs scmul addT addT' N pred_A_numlimbs B k A S). - Local Notation S3 := (@WordByWord.Abstract.Dependent.Definition.S3 T (@divmod) r R_numlimbs scmul addT addT' N pred_A_numlimbs B k A S). - Local Notation S4 := (@WordByWord.Abstract.Dependent.Definition.S4 T (@divmod) r R_numlimbs scmul addT addT' drop_high N pred_A_numlimbs B k A S). - - Lemma S3_bound - : eval S < eval N + eval B - -> eval S3 < eval N + eval B. - Proof. - assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) - by (intros x y; pose proof (Z_mod_lt x y); omega). - intro HS. - unfold S3, S2, S1. - autorewrite with push_eval; []. - eapply Z.le_lt_trans. - { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); - [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. - Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; - repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; - autorewrite with push_eval; - try Z.zero_bounds; - auto with lia. } - rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. - autorewrite with zsimplify. - simpl; omega. - Qed. - - Lemma small_A' - : small A'. - Proof. - repeat autounfold with word_by_word_montgomery; auto. - Qed. - - Lemma small_S3 - : small S3. - Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. - - Lemma S3_nonneg : 0 <= eval S3. - Proof. - repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; - autorewrite with push_eval; []. - rewrite ?Npos_correct; Z.zero_bounds; lia. - Qed. - - Lemma S4_nonneg : 0 <= eval S4. - Proof. unfold S4; rewrite eval_drop_high by apply small_S3; Z.zero_bounds. Qed. - - Lemma S4_bound - : eval S < eval N + eval B - -> eval S4 < eval N + eval B. - Proof. - intro H; pose proof (S3_bound H); pose proof S3_nonneg. - unfold S4. - rewrite eval_drop_high by apply small_S3. - rewrite Z.mod_small by nia. - assumption. - Qed. - - Lemma small_S4 - : small S4. - Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. - - Lemma S1_eq : eval S1 = S + a*B. - Proof. - cbv [S1 a A']. - repeat autorewrite with push_eval. - reflexivity. - Qed. - - Lemma S2_mod_N : (eval S2) mod N = (S + a*B) mod N. - Proof. - cbv [S2]; autorewrite with push_eval zsimplify. rewrite S1_eq. reflexivity. - Qed. - - Lemma S2_mod_r : S2 mod r = 0. - Proof. - cbv [S2 q s]; autorewrite with push_eval. - assert (r > 0) by lia. - assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). - { destruct (Z.eq_dec r 1) as [H'|H']. - { rewrite H'; split; reflexivity. } - { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } - autorewrite with pull_Zmod. - replace 0 with (0 mod r) by apply Zmod_0_l. - eapply F.eq_of_Z_iff. - rewrite Z.mul_split_mod. - repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. - rewrite <-Algebra.Hierarchy.associative. - replace ((F.of_Z r k * F.of_Z r (eval N))%F) with (F.opp (m:=r) F.one). - { cbv [F.of_Z F.add]; simpl. - apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. - simpl. - rewrite (proj1 Hr), Z.mul_sub_distr_l. - push_Zmod; pull_Zmod. - autorewrite with zsimplify; reflexivity. } - { rewrite <- F.of_Z_mul. - rewrite F.of_Z_mod. - rewrite k_correct. - cbv [F.of_Z F.add F.opp F.one]; simpl. - change (-(1)) with (-1) in *. - apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. - rewrite (proj1 Hr), (proj2 Hr); Z.rewrite_mod_small; reflexivity. } - Qed. - - Lemma S3_mod_N - : S3 mod N = (S + a*B)*ri mod N. - Proof. - cbv [S3]; autorewrite with push_eval cancel_pair. - pose proof fun a => Z.div_to_inv_modulo N a r ri eq_refl ri_correct as HH; - cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. - etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| - rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. - rewrite <-S2_mod_N; repeat (f_equal; []); autorewrite with push_eval. - autorewrite with push_Zmod; - rewrite S2_mod_r; - autorewrite with zsimplify. - reflexivity. - Qed. - - Lemma S4_mod_N - (Hbound : eval S < eval N + eval B) - : S4 mod N = (S + a*B)*ri mod N. - Proof. - pose proof (S3_bound Hbound); pose proof S3_nonneg. - unfold S4; autorewrite with push_eval. - rewrite (Z.mod_small _ (r * _)) by nia. - apply S3_mod_N. - Qed. - End Iteration. - - Local Notation redc_body := (@redc_body T (@divmod) r R_numlimbs scmul addT addT' drop_high N B k). - Local Notation redc_loop := (@redc_loop T (@divmod) r R_numlimbs scmul addT addT' drop_high N B k). - Local Notation pre_redc A := (@pre_redc T zero (@divmod) r R_numlimbs scmul addT addT' drop_high N _ A B k). - Local Notation redc A := (@redc T zero (@divmod) r R_numlimbs scmul addT addT' drop_high conditional_sub N _ A B k). - - Section body. - Context (pred_A_numlimbs : nat) - (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)). - Let A:=fst A_S. - Let S:=snd A_S. - Let A_a:=divmod A. - Let a:=snd A_a. - Context (small_A : small A) - (small_S : small S) - (S_bound : 0 <= eval S < eval N + eval B). - - Lemma small_fst_redc_body : small (fst (redc_body A_S)). - Proof. destruct A_S; apply small_A'; assumption. Qed. - Lemma small_snd_redc_body : small (snd (redc_body A_S)). - Proof. destruct A_S; unfold redc_body; apply small_S4; assumption. Qed. - Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). - Proof. destruct A_S; apply S4_nonneg; assumption. Qed. - - Lemma snd_redc_body_mod_N - : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). - Proof. destruct A_S; apply S4_mod_N; auto; omega. Qed. - - Lemma fst_redc_body - : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. - Proof. - destruct A_S; simpl; repeat autounfold with word_by_word_montgomery; simpl. - autorewrite with push_eval. - reflexivity. - Qed. - - Lemma fst_redc_body_mod_N - : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). - Proof. - rewrite fst_redc_body. - etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. - unfold a, A_a, A. - autorewrite with push_eval. - reflexivity. - Qed. - - Lemma redc_body_bound - : eval S < eval N + eval B - -> eval (snd (redc_body A_S)) < eval N + eval B. - Proof. - destruct A_S; apply S4_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. - Qed. - End body. - - Local Arguments Z.pow !_ !_. - Local Arguments Z.of_nat !_. - Local Ltac induction_loop count IHcount - := induction count as [|count IHcount]; intros; cbn [redc_loop] in *; [ | (*rewrite redc_loop_comm_body in * *) ]. - Lemma redc_loop_good count A_S - (Hsmall : small (fst A_S) /\ small (snd A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : (small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S))) - /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. - Proof. - induction_loop count IHcount; auto; []. - change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). - destruct_head'_and. - repeat first [ apply conj - | apply small_fst_redc_body - | apply small_snd_redc_body - | apply redc_body_bound - | apply snd_redc_body_nonneg - | apply IHcount - | solve [ auto ] ]. - Qed. - - Lemma small_redc_loop count A_S - (Hsmall : small (fst A_S) /\ small (snd A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S)). - Proof. apply redc_loop_good; assumption. Qed. - - Lemma redc_loop_bound count A_S - (Hsmall : small (fst A_S) /\ small (snd A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. - Proof. apply redc_loop_good; assumption. Qed. - - Local Ltac handle_IH_small := - repeat first [ apply redc_loop_good - | apply small_fst_redc_body - | apply small_snd_redc_body - | apply redc_body_bound - | apply snd_redc_body_nonneg - | apply conj - | progress cbn [fst snd] - | progress destruct_head' and - | solve [ auto ] ]. - - Lemma fst_redc_loop count A_S - (Hsmall : small (fst A_S) /\ small (snd A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). - Proof. - induction_loop count IHcount. - { simpl; autorewrite with zsimplify; reflexivity. } - { rewrite IHcount, fst_redc_body by handle_IH_small. - change (1 + R_numlimbs)%nat with (S R_numlimbs) in *. - rewrite Zdiv_Zdiv by Z.zero_bounds. - rewrite <- (Z.pow_1_r r) at 1. - rewrite <- Z.pow_add_r by lia. - replace (1 + Z.of_nat count) with (Z.of_nat (S count)) by lia. - reflexivity. } - Qed. - - Lemma fst_redc_loop_mod_N count A_S - (Hsmall : small (fst A_S) /\ small (snd A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : eval (fst (redc_loop count A_S)) mod (eval N) - = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) - * ri^(Z.of_nat count) mod (eval N). - Proof. - rewrite fst_redc_loop by assumption. - destruct count. - { simpl; autorewrite with zsimplify; reflexivity. } - { etransitivity; - [ eapply Z.div_to_inv_modulo; - try solve [ eassumption - | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] - | ]. - { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. - { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } - { lia. } } - reflexivity. } - Qed. - - Local Arguments Z.pow : simpl never. - Lemma snd_redc_loop_mod_N count A_S - (Hsmall : small (fst A_S) /\ small (snd A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : (eval (snd (redc_loop count A_S))) mod (eval N) - = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). - Proof. - induction_loop count IHcount. - { simpl; autorewrite with zsimplify; reflexivity. } - { rewrite IHcount by handle_IH_small. - push_Zmod; rewrite snd_redc_body_mod_N, fst_redc_body by handle_IH_small; pull_Zmod. - autorewrite with push_eval; []. - match goal with - | [ |- ?x mod ?N = ?y mod ?N ] - => change (Z.equiv_modulo N x y) - end. - destruct A_S as [A S]. - cbn [fst snd]. - change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). - rewrite !Z.mul_add_distr_r. - rewrite <- !Z.mul_assoc. - replace (ri * ri^(Z.of_nat count)) with (ri^(Z.of_nat (Datatypes.S count))) - by (change (Datatypes.S count) with (1 + count)%nat; - autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). - rewrite <- !Z.add_assoc. - apply Z.add_mod_Proper; [ reflexivity | ]. - unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). - rewrite Z.mod_pull_div by auto with zarith lia. - push_Zmod. - erewrite Z.div_to_inv_modulo; - [ - | apply Z.lt_gt; lia - | eassumption ]. - pull_Zmod. - match goal with - | [ |- ?x mod ?N = ?y mod ?N ] - => change (Z.equiv_modulo N x y) - end. - repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia - | rewrite (Z.mul_comm _ ri) - | rewrite (Z.mul_assoc _ ri _) - | rewrite (Z.mul_comm _ (ri^_)) - | rewrite (Z.mul_assoc _ (ri^_) _) ]. - repeat first [ rewrite <- Z.mul_assoc - | rewrite <- Z.mul_add_distr_l - | rewrite (Z.mul_comm _ (eval B)) - | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; - rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) - | rewrite Zplus_minus - | rewrite (Z.mul_comm r (r^_)) - | reflexivity ]. } - Qed. - - Lemma pre_redc_bound A_numlimbs (A : T A_numlimbs) - (small_A : small A) - : 0 <= eval (pre_redc A) < eval N + eval B. - Proof. - unfold pre_redc. - apply redc_loop_good; simpl; autorewrite with push_eval; - rewrite ?Npos_correct; auto; lia. - Qed. - - Lemma small_pre_redc A_numlimbs (A : T A_numlimbs) - (small_A : small A) - : small (pre_redc A). - Proof. - unfold pre_redc. - apply redc_loop_good; simpl; autorewrite with push_eval; - rewrite ?Npos_correct; auto; lia. - Qed. - - Lemma pre_redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) - : (eval (pre_redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). - Proof. - unfold pre_redc. - rewrite snd_redc_loop_mod_N; cbn [fst snd]; - autorewrite with push_eval zsimplify; - [ | rewrite ?Npos_correct; auto; lia.. ]. - Z.rewrite_mod_small. - reflexivity. - Qed. - - Lemma redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) - : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). - Proof. - pose proof (@small_pre_redc _ A small_A). - pose proof (@pre_redc_bound _ A small_A). - unfold redc. - autorewrite with push_eval; []. - break_innermost_match; - try rewrite Z.add_opp_r, Zminus_mod, Z_mod_same_full; - autorewrite with zsimplify_fast; - apply pre_redc_mod_N; auto. - Qed. - - Lemma redc_bound_tight A_numlimbs (A : T A_numlimbs) - (small_A : small A) - : 0 <= eval (redc A) < eval N + eval B + if eval N <=? eval (pre_redc A) then -eval N else 0. - Proof. - pose proof (@small_pre_redc _ A small_A). - pose proof (@pre_redc_bound _ A small_A). - unfold redc. - rewrite eval_conditional_sub by t_small. - break_innermost_match; Z.ltb_to_lt; omega. - Qed. - - Lemma redc_bound_N A_numlimbs (A : T A_numlimbs) - (small_A : small A) - : eval B < eval N -> 0 <= eval (redc A) < eval N. - Proof. - pose proof (@small_pre_redc _ A small_A). - pose proof (@pre_redc_bound _ A small_A). - unfold redc. - rewrite eval_conditional_sub by t_small. - break_innermost_match; Z.ltb_to_lt; omega. - Qed. - - Lemma redc_bound A_numlimbs (A : T A_numlimbs) - (small_A : small A) - (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) - : 0 <= eval (redc A) < R. - Proof. - pose proof (@small_pre_redc _ A small_A). - pose proof (@pre_redc_bound _ A small_A). - unfold redc. - rewrite eval_conditional_sub by t_small. - break_innermost_match; Z.ltb_to_lt; try omega. - Qed. - - Lemma small_redc A_numlimbs (A : T A_numlimbs) - (small_A : small A) - (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) - : small (redc A). - Proof. - pose proof (@small_pre_redc _ A small_A). - pose proof (@pre_redc_bound _ A small_A). - unfold redc. - apply small_conditional_sub; [ apply small_pre_redc | .. ]; auto; omega. - Qed. - - Local Notation add := (@add T R_numlimbs addT conditional_sub). - Local Notation sub := (@sub T R_numlimbs sub_then_maybe_add). - Local Notation opp := (@opp T (@zero) R_numlimbs sub_then_maybe_add). - - Section add_sub. - Context (Av Bv : T R_numlimbs) - (small_Av : small Av) - (small_Bv : small Bv) - (Av_bound : 0 <= eval Av < eval N) - (Bv_bound : 0 <= eval Bv < eval N). - - Local Ltac do_clear := - clear dependent B; clear dependent k; clear dependent ri; clear dependent Npos. - - Lemma small_add : small (add Av Bv). - Proof. do_clear; unfold add; t_small. Qed. - Lemma small_sub : small (sub Av Bv). - Proof. do_clear; unfold sub; t_small. Qed. - Lemma small_opp : small (opp Av). - Proof. clear dependent Bv; do_clear; unfold opp, sub; t_small. Qed. - - Lemma eval_add : eval (add Av Bv) = eval Av + eval Bv + if (eval N <=? eval Av + eval Bv) then -eval N else 0. - Proof. do_clear; unfold add; autorewrite with push_eval; reflexivity. Qed. - Lemma eval_sub : eval (sub Av Bv) = eval Av - eval Bv + if (eval Av - eval Bv Z} - {numlimbs : T -> nat} - {zero : nat -> T} - {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) - {r : positive} - {r_big : r > 1} - {R : positive} - {R_numlimbs : nat} - {R_correct : R = r^Z.of_nat R_numlimbs :> Z} - {small : T -> Prop} - {eval_zero : forall n, eval (zero n) = 0} - {numlimbs_zero : forall n, numlimbs (zero n) = n} - {eval_div : forall v, small v -> eval (fst (divmod v)) = eval v / r} - {eval_mod : forall v, small v -> snd (divmod v) = eval v mod r} - {small_div : forall v, small v -> small (fst (divmod v))} - {numlimbs_div : forall v, numlimbs (fst (divmod v)) = pred (numlimbs v)} - {scmul : Z -> T -> T} (* uses double-output multiply *) - {eval_scmul: forall a v, 0 <= a < r -> 0 <= eval v < R -> eval (scmul a v) = a * eval v} - {numlimbs_scmul : forall a v, 0 <= a < r -> numlimbs (scmul a v) = S (numlimbs v)} - {add : T -> T -> T} (* joins carry *) - {eval_add : forall a b, eval (add a b) = eval a + eval b} - {small_add : forall a b, small (add a b)} - {numlimbs_add : forall a b, numlimbs (add a b) = Datatypes.S (max (numlimbs a) (numlimbs b))} - {drop_high : T -> T} (* drops things after [S R_numlimbs] *) - {eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs)} - {numlimbs_drop_high : forall v, numlimbs (drop_high v) = min (numlimbs v) (S R_numlimbs)} - (N : T) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) - (N_lt_R : eval N < R) - (B : T) - (B_bounds : 0 <= eval B < R) - ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)). - Context (k : Z) (k_correct : k * eval N mod r = (-1) mod r). - - Create HintDb push_numlimbs discriminated. - Create HintDb push_eval discriminated. - Local Ltac t_small := - repeat first [ assumption - | apply small_add - | apply small_div - | apply Z_mod_lt - | rewrite Z.mul_split_mod - | solve [ auto with zarith ] - | lia - | progress autorewrite with push_eval - | progress autorewrite with push_numlimbs ]. - Hint Rewrite - eval_zero - eval_div - eval_mod - eval_add - eval_scmul - eval_drop_high - using (repeat autounfold with word_by_word_montgomery; t_small) - : push_eval. - Hint Rewrite - numlimbs_zero - numlimbs_div - numlimbs_add - numlimbs_scmul - numlimbs_drop_high - using (repeat autounfold with word_by_word_montgomery; t_small) - : push_numlimbs. - Hint Rewrite <- Max.succ_max_distr pred_Sn Min.succ_min_distr : push_numlimbs. - - - (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) - Section Iteration. - Context (A S : T) - (small_A : small A) - (S_nonneg : 0 <= eval S). - (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) - - Local Coercion eval : T >-> Z. - - Local Notation a := (@WordByWord.Abstract.Definition.a T divmod A). - Local Notation A' := (@WordByWord.Abstract.Definition.A' T divmod A). - Local Notation S1 := (@WordByWord.Abstract.Definition.S1 T divmod scmul add B A S). - Local Notation S2 := (@WordByWord.Abstract.Definition.S2 T divmod r scmul add N B k A S). - Local Notation S3 := (@WordByWord.Abstract.Definition.S3 T divmod r scmul add N B k A S). - Local Notation S4 := (@WordByWord.Abstract.Definition.S4 T divmod r scmul add drop_high N B k A S). - - Lemma S3_bound - : eval S < eval N + eval B - -> eval S3 < eval N + eval B. - Proof. - assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) - by (intros x y; pose proof (Z_mod_lt x y); omega). - intro HS. - unfold S3, WordByWord.Abstract.Definition.S2, WordByWord.Abstract.Definition.S1. - autorewrite with push_eval; []. - eapply Z.le_lt_trans. - { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); - [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. - Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; - repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; - autorewrite with push_eval; - try Z.zero_bounds; - auto with lia. } - rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. - autorewrite with zsimplify. - omega. - Qed. - - Lemma small_A' - : small A'. - Proof. - repeat autounfold with word_by_word_montgomery; auto. - Qed. - - Lemma small_S3 - : small S3. - Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. - - Lemma S3_nonneg : 0 <= eval S3. - Proof. - repeat autounfold with word_by_word_montgomery; rewrite Z.mul_split_mod; - autorewrite with push_eval; []. - rewrite ?Npos_correct; Z.zero_bounds; lia. - Qed. - - Lemma S4_nonneg : 0 <= eval S4. - Proof. unfold S4; rewrite eval_drop_high by apply small_S3; Z.zero_bounds. Qed. - - Lemma S4_bound - : eval S < eval N + eval B - -> eval S4 < eval N + eval B. - Proof. - intro H; pose proof (S3_bound H); pose proof S3_nonneg. - unfold S4. - rewrite eval_drop_high by apply small_S3. - rewrite Z.mod_small by nia. - assumption. - Qed. - - Lemma numlimbs_S4 : numlimbs S4 = min (max (1 + numlimbs S) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs). - Proof. - cbn [plus]. - repeat autounfold with word_by_word_montgomery; rewrite Z.mul_split_mod. - repeat autorewrite with push_numlimbs. - change Init.Nat.max with Nat.max. - rewrite <- ?(Max.max_assoc (numlimbs S)). - reflexivity. - Qed. - - Lemma S1_eq : eval S1 = S + a*B. - Proof. - cbv [S1 a WordByWord.Abstract.Definition.A']. - repeat autorewrite with push_eval. - reflexivity. - Qed. - - Lemma S2_mod_N : (eval S2) mod N = (S + a*B) mod N. - Proof. - cbv [S2 WordByWord.Abstract.Definition.q WordByWord.Abstract.Definition.s]; autorewrite with push_eval zsimplify. rewrite S1_eq. reflexivity. - Qed. - - Lemma S2_mod_r : S2 mod r = 0. - cbv [S2 WordByWord.Abstract.Definition.q WordByWord.Abstract.Definition.s]; autorewrite with push_eval. - assert (r > 0) by lia. - assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). - { destruct (Z.eq_dec r 1) as [H'|H']. - { rewrite H'; split; reflexivity. } - { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } - autorewrite with pull_Zmod. - replace 0 with (0 mod r) by apply Zmod_0_l. - eapply F.eq_of_Z_iff. - rewrite Z.mul_split_mod. - repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. - rewrite <-Algebra.Hierarchy.associative. - replace ((F.of_Z r k * F.of_Z r (eval N))%F) with (F.opp (m:=r) F.one). - { cbv [F.of_Z F.add]; simpl. - apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. - simpl. - rewrite (proj1 Hr), Z.mul_sub_distr_l. - push_Zmod; pull_Zmod. - autorewrite with zsimplify; reflexivity. } - { rewrite <- F.of_Z_mul. - rewrite F.of_Z_mod. - rewrite k_correct. - cbv [F.of_Z F.add F.opp F.one]; simpl. - change (-(1)) with (-1) in *. - apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. - rewrite (proj1 Hr), (proj2 Hr); Z.rewrite_mod_small; reflexivity. } - Qed. - - Lemma S3_mod_N - : S3 mod N = (S + a*B)*ri mod N. - Proof. - cbv [S3]; autorewrite with push_eval cancel_pair. - pose proof fun a => Z.div_to_inv_modulo N a r ri eq_refl ri_correct as HH; - cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. - etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| - rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. - rewrite <-S2_mod_N; repeat (f_equal; []); autorewrite with push_eval. - autorewrite with push_Zmod; - rewrite S2_mod_r; - autorewrite with zsimplify. - reflexivity. - Qed. - - Lemma S4_mod_N - (Hbound : eval S < eval N + eval B) - : S4 mod N = (S + a*B)*ri mod N. - Proof. - pose proof (S3_bound Hbound); pose proof S3_nonneg. - unfold S4; autorewrite with push_eval. - rewrite (Z.mod_small _ (r * _)) by nia. - apply S3_mod_N. - Qed. - End Iteration. - - Local Notation redc_body := (@redc_body T divmod r scmul add drop_high N B k). - Local Notation redc_loop := (@redc_loop T divmod r scmul add drop_high N B k). - Local Notation redc A := (@redc T numlimbs zero divmod r scmul add drop_high N A B k). - - Lemma redc_loop_comm_body count - : forall A_S, redc_loop count (redc_body A_S) = redc_body (redc_loop count A_S). - Proof. - induction count as [|count IHcount]; try reflexivity. - simpl; intro; rewrite IHcount; reflexivity. - Qed. - - Section body. - Context (A_S : T * T). - Let A:=fst A_S. - Let S:=snd A_S. - Let A_a:=divmod A. - Let a:=snd A_a. - Context (small_A : small A) - (S_bound : 0 <= eval S < eval N + eval B). - - Lemma small_fst_redc_body : small (fst (redc_body A_S)). - Proof. destruct A_S; apply small_A'; assumption. Qed. - Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). - Proof. destruct A_S; apply S4_nonneg; assumption. Qed. - - Lemma snd_redc_body_mod_N - : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). - Proof. destruct A_S; apply S4_mod_N; auto; omega. Qed. - - Lemma fst_redc_body - : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. - Proof. - destruct A_S; simpl; unfold WordByWord.Abstract.Definition.A', WordByWord.Abstract.Definition.A_a, Let_In, a, A_a, A; simpl. - autorewrite with push_eval. - reflexivity. - Qed. - - Lemma fst_redc_body_mod_N - : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). - Proof. - rewrite fst_redc_body. - etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. - unfold a, A_a, A. - autorewrite with push_eval. - reflexivity. - Qed. - - Lemma redc_body_bound - : eval S < eval N + eval B - -> eval (snd (redc_body A_S)) < eval N + eval B. - Proof. - destruct A_S; apply S4_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. - Qed. - - Lemma numlimbs_redc_body : numlimbs (snd (redc_body A_S)) - = min (max (1 + numlimbs (snd A_S)) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs). - Proof. destruct A_S; apply numlimbs_S4; assumption. Qed. - End body. - - Local Arguments Z.pow !_ !_. - Local Arguments Z.of_nat !_. - Local Ltac induction_loop count IHcount - := induction count as [|count IHcount]; intros; cbn [redc_loop] in *; [ | rewrite redc_loop_comm_body in * ]. - Lemma redc_loop_good A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : small (fst (redc_loop count A_S)) - /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. - Proof. - induction_loop count IHcount; auto; []. - change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). - destruct_head'_and. - repeat first [ apply conj - | apply small_fst_redc_body - | apply redc_body_bound - | apply snd_redc_body_nonneg - | solve [ auto ] ]. - Qed. - - Lemma redc_loop_bound A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. - Proof. apply redc_loop_good; assumption. Qed. - - Local Ltac t_min_max_step _ := - match goal with - | [ |- context[Init.Nat.max ?x ?y] ] - => first [ rewrite (Max.max_l x y) by omega - | rewrite (Max.max_r x y) by omega ] - | [ |- context[Init.Nat.min ?x ?y] ] - => first [ rewrite (Min.min_l x y) by omega - | rewrite (Min.min_r x y) by omega ] - | _ => progress change Init.Nat.max with Nat.max - | _ => progress change Init.Nat.min with Nat.min - end. - - Lemma numlimbs_redc_loop A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - (Hnumlimbs : (R_numlimbs <= numlimbs (snd A_S))%nat) - : numlimbs (snd (redc_loop count A_S)) - = match count with - | O => numlimbs (snd A_S) - | S _ => 1 + R_numlimbs - end%nat. - Proof. - assert (Hgen - : numlimbs (snd (redc_loop count A_S)) - = match count with - | O => numlimbs (snd A_S) - | S _ => min (max (count + numlimbs (snd A_S)) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs) - end). - { induction_loop count IHcount; [ reflexivity | ]. - rewrite numlimbs_redc_body by (try apply redc_loop_good; auto). - rewrite IHcount; clear IHcount. - destruct count; [ reflexivity | ]. - destruct (Compare_dec.le_lt_dec (1 + max (1 + numlimbs B) (numlimbs N)) (S count + numlimbs (snd A_S))), - (Compare_dec.le_lt_dec (1 + R_numlimbs) (S count + numlimbs (snd A_S))), - (Compare_dec.le_lt_dec (1 + R_numlimbs) (1 + max (1 + numlimbs B) (numlimbs N))); - repeat first [ reflexivity - | t_min_max_step () - | progress autorewrite with push_numlimbs - | rewrite Nat.min_comm, Nat.min_max_distr ]. } - rewrite Hgen; clear Hgen. - destruct count; [ reflexivity | ]. - repeat apply Max.max_case_strong; apply Min.min_case_strong; omega. - Qed. - - - Lemma fst_redc_loop A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). - Proof. - induction_loop count IHcount. - { simpl; autorewrite with zsimplify; reflexivity. } - { rewrite fst_redc_body, IHcount - by (apply redc_loop_good; auto). - rewrite Zdiv_Zdiv by Z.zero_bounds. - rewrite <- (Z.pow_1_r r) at 2. - rewrite <- Z.pow_add_r by lia. - replace (Z.of_nat count + 1) with (Z.of_nat (S count)) by (simpl; lia). - reflexivity. } - Qed. - - Lemma fst_redc_loop_mod_N A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : eval (fst (redc_loop count A_S)) mod (eval N) - = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) - * ri^(Z.of_nat count) mod (eval N). - Proof. - rewrite fst_redc_loop by assumption. - destruct count. - { simpl; autorewrite with zsimplify; reflexivity. } - { etransitivity; - [ eapply Z.div_to_inv_modulo; - try solve [ eassumption - | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] - | ]. - { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. - { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } - { lia. } } - reflexivity. } - Qed. - - Local Arguments Z.pow : simpl never. - Lemma snd_redc_loop_mod_N A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : (eval (snd (redc_loop count A_S))) mod (eval N) - = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). - Proof. - induction_loop count IHcount. - { simpl; autorewrite with zsimplify; reflexivity. } - { simpl; rewrite snd_redc_body_mod_N - by (apply redc_loop_good; auto). - push_Zmod; rewrite IHcount; pull_Zmod. - autorewrite with push_eval; [ | apply redc_loop_good; auto.. ]; []. - match goal with - | [ |- ?x mod ?N = ?y mod ?N ] - => change (Z.equiv_modulo N x y) - end. - destruct A_S as [A S]. - cbn [fst snd]. - change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). - rewrite !Z.mul_add_distr_r. - rewrite <- !Z.mul_assoc. - replace (ri^(Z.of_nat count) * ri) with (ri^(Z.of_nat (Datatypes.S count))) - by (change (Datatypes.S count) with (1 + count)%nat; - autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). - rewrite <- !Z.add_assoc. - apply Z.add_mod_Proper; [ reflexivity | ]. - unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). - rewrite fst_redc_loop by (try apply redc_loop_good; auto; omega). - cbn [fst]. - rewrite Z.mod_pull_div by lia. - erewrite Z.div_to_inv_modulo; - [ - | solve [ eassumption | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] - | erewrite <- Z.pow_mul_l, <- Z.pow_1_l; - [ apply Z.pow_mod_Proper; [ eassumption | reflexivity ] - | lia ] ]. - pull_Zmod. - match goal with - | [ |- ?x mod ?N = ?y mod ?N ] - => change (Z.equiv_modulo N x y) - end. - repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia - | rewrite (Z.mul_comm _ ri) - | rewrite (Z.mul_assoc _ ri _) - | rewrite (Z.mul_comm _ (ri^_)) - | rewrite (Z.mul_assoc _ (ri^_) _) ]. - repeat first [ rewrite <- Z.mul_assoc - | rewrite <- Z.mul_add_distr_l - | rewrite (Z.mul_comm _ (eval B)) - | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; - rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) - | rewrite Zplus_minus - | reflexivity ]. } - Qed. - - Lemma redc_bound A - (small_A : small A) - : 0 <= eval (redc A) < eval N + eval B. - Proof. - unfold redc. - apply redc_loop_good; simpl; autorewrite with push_eval; - rewrite ?Npos_correct; auto; lia. - Qed. - - Lemma numlimbs_redc_gen A (small_A : small A) (Hnumlimbs : (R_numlimbs <= numlimbs B)%nat) - : numlimbs (redc A) - = match numlimbs A with - | O => S (numlimbs B) - | _ => S R_numlimbs - end. - Proof. - unfold redc; rewrite numlimbs_redc_loop by (cbn [fst snd]; t_small); - cbn [snd]; rewrite ?numlimbs_zero. - reflexivity. - Qed. - Lemma numlimbs_redc A (small_A : small A) (Hnumlimbs : R_numlimbs = numlimbs B) - : numlimbs (redc A) = S (numlimbs B). - Proof. rewrite numlimbs_redc_gen; subst; auto; destruct (numlimbs A); reflexivity. Qed. - - Lemma redc_mod_N A (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat (numlimbs A)) - : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N). - Proof. - unfold redc. - rewrite snd_redc_loop_mod_N; cbn [fst snd]; - autorewrite with push_eval zsimplify; - [ | rewrite ?Npos_correct; auto; lia.. ]. - Z.rewrite_mod_small. - reflexivity. - Qed. -End WordByWordMontgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v deleted file mode 100644 index fd4869f23..000000000 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ /dev/null @@ -1,108 +0,0 @@ -(*** Word-By-Word Montgomery Multiplication *) -(** This file implements Montgomery Form, Montgomery Reduction, and - Montgomery Multiplication on an abstract [ℤⁿ]. See - https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion - of the algorithm; note that it may be that none of the algorithms - there exactly match what we're doing here. *) -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Arithmetic.Saturated.MontgomeryAPI. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.CPS. - -Local Open Scope Z_scope. - -Section WordByWordMontgomery. - Local Coercion Z.pos : positive >-> Z. - (** TODO: pick better names for the arguments to this definition. *) - Context - {r : positive} - {R_numlimbs : nat} - (N : T R_numlimbs). - - Local Notation scmul := (@scmul (Z.pos r)). - Local Notation addT' := (@MontgomeryAPI.add_S1 (Z.pos r)). - Local Notation addT := (@MontgomeryAPI.add (Z.pos r)). - Local Notation conditional_sub_cps := (fun V => @conditional_sub_cps (Z.pos r) _ V N _). - Local Notation conditional_sub := (fun V => @conditional_sub (Z.pos r) _ V N). - Local Notation sub_then_maybe_add_cps := - (fun V1 V2 => @sub_then_maybe_add_cps (Z.pos r) R_numlimbs (Z.pos r - 1) V1 V2 N). - Local Notation sub_then_maybe_add := (fun V1 V2 => @sub_then_maybe_add (Z.pos r) R_numlimbs (Z.pos r - 1) V1 V2 N). - - Definition redc_body_no_cps (B : T R_numlimbs) (k : Z) {pred_A_numlimbs} (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)) - : T pred_A_numlimbs * T (S R_numlimbs) - := @redc_body T (@divmod) r R_numlimbs (@scmul) (@addT) (@addT') (@drop_high (S R_numlimbs)) N B k _ A_S. - Definition redc_loop_no_cps (B : T R_numlimbs) (k : Z) (count : nat) (A_S : T count * T (S R_numlimbs)) - : T 0 * T (S R_numlimbs) - := @redc_loop T (@divmod) r R_numlimbs (@scmul) (@addT) (@addT') (@drop_high (S R_numlimbs)) N B k count A_S. - Definition pre_redc_no_cps {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T (S R_numlimbs) - := @pre_redc T (@zero) (@divmod) r R_numlimbs (@scmul) (@addT) (@addT') (@drop_high (S R_numlimbs)) N _ A B k. - Definition redc_no_cps {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T R_numlimbs - := @redc T (@zero) (@divmod) r R_numlimbs (@scmul) (@addT) (@addT') (@drop_high (S R_numlimbs)) conditional_sub N _ A B k. - - Definition redc_body_cps {pred_A_numlimbs} (A : T (S pred_A_numlimbs)) (B : T R_numlimbs) (k : Z) (S' : T (S R_numlimbs)) - {cpsT} (rest : T pred_A_numlimbs * T (S R_numlimbs) -> cpsT) - : cpsT - := divmod_cps A (fun '(A, a) => - @scmul_cps r _ a B _ (fun aB => @add_cps r _ S' aB _ (fun S1 => - divmod_cps S1 (fun '(_, s) => - Z.mul_split_cps' r s k (fun mul_split_r_s_k => - dlet q := fst mul_split_r_s_k in - @scmul_cps r _ q N _ (fun qN => @add_S1_cps r _ S1 qN _ (fun S2 => - divmod_cps S2 (fun '(S3, _) => - @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4)))))))))). - - Section loop. - Context {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) {cpsT : Type}. - Fixpoint redc_loop_cps (count : nat) (rest : T 0 * T (S R_numlimbs) -> cpsT) : T count * T (S R_numlimbs) -> cpsT - := match count with - | O => rest - | S count' => fun '(A, S') => redc_body_cps A B k S' (redc_loop_cps count' rest) - end. - - Definition pre_redc_cps (rest : T (S R_numlimbs) -> cpsT) : cpsT - := redc_loop_cps A_numlimbs (fun '(A, S') => rest S') (A, zero). - - Definition redc_cps (rest : T R_numlimbs -> cpsT) : cpsT - := pre_redc_cps (fun v => conditional_sub_cps v rest). - End loop. - - Definition redc_body {pred_A_numlimbs} (A : T (S pred_A_numlimbs)) (B : T R_numlimbs) (k : Z) (S' : T (S R_numlimbs)) - : T pred_A_numlimbs * T (S R_numlimbs) - := redc_body_cps A B k S' id. - Definition redc_loop (B : T R_numlimbs) (k : Z) (count : nat) : T count * T (S R_numlimbs) -> T 0 * T (S R_numlimbs) - := redc_loop_cps B k count id. - Definition pre_redc {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T (S R_numlimbs) - := pre_redc_cps A B k id. - Definition redc {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T R_numlimbs - := redc_cps A B k id. - - Definition add_no_cps (A B : T R_numlimbs) : T R_numlimbs - := @add T R_numlimbs (@addT) (@conditional_sub) A B. - Definition sub_no_cps (A B : T R_numlimbs) : T R_numlimbs - := @sub T R_numlimbs (@sub_then_maybe_add) A B. - Definition opp_no_cps (A : T R_numlimbs) : T R_numlimbs - := @opp T (@zero) R_numlimbs (@sub_then_maybe_add) A. - - Definition add_cps (A B : T R_numlimbs) {cpsT} (rest : T R_numlimbs -> cpsT) : cpsT - := @add_cps r _ A B - _ (fun v => conditional_sub_cps v rest). - Definition add (A B : T R_numlimbs) : T R_numlimbs - := add_cps A B id. - Definition sub_cps (A B : T R_numlimbs) {cpsT} (rest : T R_numlimbs -> cpsT) : cpsT - := @sub_then_maybe_add_cps A B _ rest. - Definition sub (A B : T R_numlimbs) : T R_numlimbs - := sub_cps A B id. - Definition opp_cps (A : T R_numlimbs) {cpsT} (rest : T R_numlimbs -> cpsT) : cpsT - := sub_cps zero A rest. - Definition opp (A : T R_numlimbs) : T R_numlimbs - := opp_cps A id. - Definition nonzero_cps (A : T R_numlimbs) {cpsT} (f : Z -> cpsT) : cpsT - := @nonzero_cps R_numlimbs A cpsT f. - Definition nonzero (A : T R_numlimbs) : Z - := nonzero_cps A id. -End WordByWordMontgomery. - -Hint Opaque redc pre_redc redc_body redc_loop add sub opp nonzero : uncps. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v deleted file mode 100644 index 35c9e377b..000000000 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v +++ /dev/null @@ -1,329 +0,0 @@ -(*** Word-By-Word Montgomery Multiplication Proofs *) -Require Import Coq.ZArith.BinInt. -Require Import Coq.micromega.Lia. -Require Import Crypto.Arithmetic.Saturated.UniformWeight. -Require Import Crypto.Arithmetic.Saturated.MontgomeryAPI. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Proofs. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition. -Require Import Crypto.Util.Tactics.BreakMatch. - -Local Open Scope Z_scope. -Local Coercion Z.pos : positive >-> Z. -Section WordByWordMontgomery. - (** XXX TODO: pick better names for things like [R_numlimbs] *) - Context (r : positive) - (R_numlimbs : nat). - Local Notation small := (@small (Z.pos r)). - Local Notation eval := (@eval (Z.pos r)). - Local Notation addT' := (@MontgomeryAPI.add_S1 (Z.pos r)). - Local Notation addT := (@MontgomeryAPI.add (Z.pos r)). - Local Notation scmul := (@scmul (Z.pos r)). - Local Notation eval_zero := (@eval_zero (Z.pos r)). - Local Notation small_zero := (@small_zero r (Zorder.Zgt_pos_0 _)). - Local Notation small_scmul := (fun n a v _ _ _ => @small_scmul r (Zorder.Zgt_pos_0 _) n a v). - Local Notation eval_join0 := (@eval_zero (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation eval_div := (@eval_div (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation eval_mod := (@eval_mod (Z.pos r)). - Local Notation small_div := (@small_div (Z.pos r)). - Local Notation eval_scmul := (fun n a v smallv abound vbound => @eval_scmul (Z.pos r) (Zorder.Zgt_pos_0 _) n a v smallv abound). - Local Notation eval_addT := (@eval_add_same (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation eval_addT' := (@eval_add_S1 (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation drop_high := (@drop_high (S R_numlimbs)). - Local Notation small_drop_high := (@small_drop_high (Z.pos r) (S R_numlimbs)). - Context (A_numlimbs : nat) - (N : T R_numlimbs) - (A : T A_numlimbs) - (B : T R_numlimbs) - (k : Z). - Context ri - (r_big : r > 1) - (small_A : small A) - (ri_correct : r*ri mod (eval N) = 1 mod (eval N)) - (small_N : small N) - (small_B : small B) - (N_nonzero : eval N <> 0) - (N_mask : Tuple.map (Z.land (Z.pos r - 1)) N = N) - (k_correct : k * eval N mod r = (-1) mod r). - Let R : positive := match (Z.pos r ^ Z.of_nat R_numlimbs)%Z with - | Z.pos R => R - | _ => 1%positive - end. - Let Npos : positive := match eval N with - | Z.pos N => N - | _ => 1%positive - end. - Local Lemma R_correct : Z.pos R = Z.pos r ^ Z.of_nat R_numlimbs. - Proof. - assert (0 < r^Z.of_nat R_numlimbs) by (apply Z.pow_pos_nonneg; lia). - subst R; destruct (Z.pos r ^ Z.of_nat R_numlimbs) eqn:?; [ | reflexivity | ]; - lia. - Qed. - Local Lemma small_addT : forall n a b, small a -> small b -> small (@addT n a b). - Proof. - intros; apply MontgomeryAPI.small_add; auto; lia. - Qed. - Local Lemma small_addT' : forall n a b, small a -> small b -> small (@addT' n a b). - Proof. - intros; apply MontgomeryAPI.small_add_S1; auto; lia. - Qed. - - Local Notation conditional_sub_cps := (fun V : T (S R_numlimbs) => @conditional_sub_cps (Z.pos r) _ V N _). - Local Notation conditional_sub := (fun V : T (S R_numlimbs) => @conditional_sub (Z.pos r) _ V N). - Local Notation eval_conditional_sub' := (fun V small_V V_bound => @eval_conditional_sub (Z.pos r) (Zorder.Zgt_pos_0 _) _ V N small_V small_N V_bound). - - Local Lemma eval_conditional_sub - : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v) = eval v + if eval N <=? eval v then -eval N else 0. - Proof. rewrite R_correct; exact eval_conditional_sub'. Qed. - Local Notation small_conditional_sub' := (fun V small_V V_bound => @small_conditional_sub (Z.pos r) (Zorder.Zgt_pos_0 _) _ V N small_V small_N V_bound). - Local Lemma small_conditional_sub - : forall v : T (S R_numlimbs), small v -> 0 <= eval v < eval N + R -> small (conditional_sub v). - Proof. rewrite R_correct; exact small_conditional_sub'. Qed. - - Local Lemma A_bound : 0 <= eval A < Z.pos r ^ Z.of_nat A_numlimbs. - Proof. apply eval_small; auto; lia. Qed. - Local Lemma B_bound' : 0 <= eval B < r^Z.of_nat R_numlimbs. - Proof. apply eval_small; auto; lia. Qed. - Local Lemma N_bound' : 0 <= eval N < r^Z.of_nat R_numlimbs. - Proof. apply eval_small; auto; lia. Qed. - Local Lemma N_bound : 0 < eval N < r^Z.of_nat R_numlimbs. - Proof. pose proof N_bound'; lia. Qed. - Local Lemma Npos_correct: eval N = Z.pos Npos. - Proof. pose proof N_bound; subst Npos; destruct (eval N); [ | reflexivity | ]; lia. Qed. - Local Lemma N_lt_R : eval N < R. - Proof. rewrite R_correct; apply N_bound. Qed. - Local Lemma B_bound : 0 <= eval B < R. - Proof. rewrite R_correct; apply B_bound'. Qed. - Local Lemma eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs). - Proof. - intros; erewrite eval_drop_high by (eassumption || lia). - f_equal; unfold uweight. - rewrite Znat.Nat2Z.inj_succ, Z.pow_succ_r by lia; reflexivity. - Qed. - - Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N). - Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N). - Local Notation redc_body := (@redc_body r R_numlimbs N). - Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N B k). - Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N B k). - Local Notation redc_loop := (@redc_loop r R_numlimbs N B k). - Local Notation pre_redc_no_cps := (@pre_redc_no_cps r R_numlimbs N A_numlimbs A B k). - Local Notation pre_redc_cps := (@pre_redc_cps r R_numlimbs N A_numlimbs A B k). - Local Notation pre_redc := (@pre_redc r R_numlimbs N A_numlimbs A B k). - Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N A_numlimbs A B k). - Local Notation redc_cps := (@redc_cps r R_numlimbs N A_numlimbs A B k). - Local Notation redc := (@redc r R_numlimbs N A_numlimbs A B k). - - Definition redc_no_cps_bound : 0 <= eval redc_no_cps < R - := @redc_bound T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero small_zero eval_div eval_mod small_div (@scmul) eval_scmul small_scmul (@addT) eval_addT small_addT (@addT') eval_addT' small_addT' drop_high eval_drop_high small_drop_high N Npos Npos_correct small_N N_lt_R conditional_sub eval_conditional_sub B B_bound small_B ri k A_numlimbs A small_A A_bound. - Definition redc_no_cps_bound_N : eval B < eval N -> 0 <= eval redc_no_cps < eval N - := @redc_bound_N T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero small_zero eval_div eval_mod small_div (@scmul) eval_scmul small_scmul (@addT) eval_addT small_addT (@addT') eval_addT' small_addT' drop_high eval_drop_high small_drop_high N Npos Npos_correct small_N N_lt_R conditional_sub eval_conditional_sub B B_bound small_B ri k A_numlimbs A small_A. - Definition redc_no_cps_mod_N - : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N) - := @redc_mod_N T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero small_zero eval_div eval_mod small_div (@scmul) eval_scmul small_scmul (@addT) eval_addT small_addT (@addT') eval_addT' small_addT' drop_high eval_drop_high small_drop_high N Npos Npos_correct small_N N_lt_R conditional_sub eval_conditional_sub B B_bound small_B ri ri_correct k k_correct A_numlimbs A small_A A_bound. - Definition small_redc_no_cps - : small redc_no_cps - := @small_redc T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero small_zero eval_div eval_mod small_div (@scmul) eval_scmul small_scmul (@addT) eval_addT small_addT (@addT') eval_addT' small_addT' drop_high eval_drop_high small_drop_high N Npos Npos_correct small_N N_lt_R conditional_sub small_conditional_sub B B_bound small_B ri k A_numlimbs A small_A A_bound. - - Lemma redc_body_cps_id pred_A_numlimbs (A' : T (S pred_A_numlimbs)) (S' : T (S R_numlimbs)) {cpsT} f - : @redc_body_cps pred_A_numlimbs A' B k S' cpsT f = f (redc_body A' B k S'). - Proof. - unfold redc_body, redc_body_cps, LetIn.Let_In. - repeat first [ reflexivity - | break_innermost_match_step - | progress autorewrite with uncps ]. - Qed. - - Lemma redc_loop_cps_id (count : nat) (A_S : T count * T (S R_numlimbs)) {cpsT} f - : @redc_loop_cps cpsT count f A_S = f (redc_loop count A_S). - Proof. - unfold redc_loop. - revert A_S f. - induction count as [|count IHcount]. - { reflexivity. } - { intros [A' S']; simpl; intros. - etransitivity; rewrite @redc_body_cps_id; [ rewrite IHcount | ]; reflexivity. } - Qed. - Lemma pre_redc_cps_id {cpsT} f : @pre_redc_cps cpsT f = f pre_redc. - Proof. - unfold pre_redc, pre_redc_cps. - etransitivity; rewrite redc_loop_cps_id; [ | reflexivity ]; break_innermost_match; - reflexivity. - Qed. - Lemma redc_cps_id {cpsT} f : @redc_cps cpsT f = f redc. - Proof. - unfold redc, redc_cps. - etransitivity; rewrite pre_redc_cps_id; [ | reflexivity ]; - autorewrite with uncps; - reflexivity. - Qed. - - Lemma redc_body_id_no_cps pred_A_numlimbs A' S' - : @redc_body pred_A_numlimbs A' B k S' = redc_body_no_cps B k (A', S'). - Proof. - unfold redc_body, redc_body_cps, redc_body_no_cps, Abstract.Dependent.Definition.redc_body, LetIn.Let_In, id. - repeat autounfold with word_by_word_montgomery. - repeat first [ reflexivity - | progress cbn [fst snd id] - | progress autorewrite with uncps - | break_innermost_match_step - | f_equal; [] ]. - Qed. - Lemma redc_loop_cps_id_no_cps count A_S - : redc_loop count A_S = redc_loop_no_cps count A_S. - Proof. - unfold redc_loop_no_cps, id. - revert A_S. - induction count as [|count IHcount]; simpl; [ reflexivity | ]. - intros [A' S']; unfold redc_loop; simpl. - rewrite redc_body_cps_id, redc_loop_cps_id, IHcount, redc_body_id_no_cps. - reflexivity. - Qed. - Lemma pre_redc_cps_id_no_cps : pre_redc = pre_redc_no_cps. - Proof. - unfold pre_redc, pre_redc_cps, pre_redc_no_cps, Abstract.Dependent.Definition.pre_redc. - rewrite redc_loop_cps_id, (surjective_pairing (redc_loop _ _)). - rewrite redc_loop_cps_id_no_cps; reflexivity. - Qed. - Lemma redc_cps_id_no_cps : redc = redc_no_cps. - Proof. - unfold redc, redc_no_cps, redc_cps, Abstract.Dependent.Definition.redc. - rewrite pre_redc_cps_id, pre_redc_cps_id_no_cps. - autorewrite with uncps; reflexivity. - Qed. - - Lemma redc_bound : 0 <= eval redc < R. - Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_bound. Qed. - Lemma redc_bound_N : eval B < eval N -> 0 <= eval redc < eval N. - Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_bound_N. Qed. - Lemma redc_mod_N - : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). - Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_mod_N. Qed. - Lemma small_redc - : small redc. - Proof. rewrite redc_cps_id_no_cps; apply small_redc_no_cps. Qed. - - Section add_sub. - Context (Av Bv : T R_numlimbs) - (small_Av : small Av) - (small_Bv : small Bv) - (Av_bound : 0 <= eval Av < eval N) - (Bv_bound : 0 <= eval Bv < eval N). - Local Notation add_no_cps := (@add_no_cps r R_numlimbs N Av Bv). - Local Notation add_cps := (@add_cps r R_numlimbs N Av Bv). - Local Notation add := (@add r R_numlimbs N Av Bv). - Local Notation sub_no_cps := (@sub_no_cps r R_numlimbs N Av Bv). - Local Notation sub_cps := (@sub_cps r R_numlimbs N Av Bv). - Local Notation sub := (@sub r R_numlimbs N Av Bv). - Local Notation opp_no_cps := (@opp_no_cps r R_numlimbs N Av). - Local Notation opp_cps := (@opp_cps r R_numlimbs N Av). - Local Notation opp := (@opp r R_numlimbs N Av). - Local Notation sub_then_maybe_add_cps := - (fun p q => @sub_then_maybe_add_cps (Z.pos r) R_numlimbs (Z.pos r - 1) p q N). - Local Notation sub_then_maybe_add := - (fun p q => @sub_then_maybe_add (Z.pos r) R_numlimbs (Z.pos r - 1) p q N). - Local Notation eval_sub_then_maybe_add := - (fun p q smp smq => @eval_sub_then_maybe_add (Z.pos r) (Zorder.Zgt_pos_0 _) _ (Z.pos r - 1) p q N smp smq small_N N_mask). - Local Notation small_sub_then_maybe_add := - (fun p q => @small_sub_then_maybe_add (Z.pos r) (Zorder.Zgt_pos_0 _) _ (Z.pos r - 1) p q N). - - Definition add_no_cps_bound : 0 <= eval add_no_cps < eval N - := @add_bound T (@eval) r R R_numlimbs (@small) (@addT) (@eval_addT) (@small_addT) N N_lt_R (@conditional_sub) (@eval_conditional_sub) Av Bv small_Av small_Bv Av_bound Bv_bound. - Definition sub_no_cps_bound : 0 <= eval sub_no_cps < eval N - := @sub_bound T (@eval) r R R_numlimbs (@small) N (@sub_then_maybe_add) (@eval_sub_then_maybe_add) Av Bv small_Av small_Bv Av_bound Bv_bound. - Definition opp_no_cps_bound : 0 <= eval opp_no_cps < eval N - := @opp_bound T (@eval) (@zero) r R R_numlimbs (@small) (@eval_zero) (@small_zero) N (@sub_then_maybe_add) (@eval_sub_then_maybe_add) Av small_Av Av_bound. - - Definition small_add_no_cps : small add_no_cps - := @small_add T (@eval) r R R_numlimbs (@small) (@addT) (@eval_addT) (@small_addT) N N_lt_R (@conditional_sub) (@small_conditional_sub) Av Bv small_Av small_Bv Av_bound Bv_bound. - Definition small_sub_no_cps : small sub_no_cps - := @small_sub T R_numlimbs (@small) (@sub_then_maybe_add) (@small_sub_then_maybe_add) Av Bv. - Definition small_opp_no_cps : small opp_no_cps - := @small_opp T (@zero) R_numlimbs (@small) (@sub_then_maybe_add) (@small_sub_then_maybe_add) Av. - - Definition eval_add_no_cps : eval add_no_cps = eval Av + eval Bv + (if eval N <=? eval Av + eval Bv then - eval N else 0) - := @eval_add T (@eval) r R R_numlimbs (@small) (@addT) (@eval_addT) (@small_addT) N N_lt_R (@conditional_sub) (@eval_conditional_sub) Av Bv small_Av small_Bv Av_bound Bv_bound. - Definition eval_sub_no_cps : eval sub_no_cps = eval Av - eval Bv + (if eval Av - eval Bv @nonzero R_numlimbs Av = 0 <-> eval Av = 0. - Proof. apply eval_nonzero; lia. Qed. - End nonzero. -End WordByWordMontgomery. - -Hint Rewrite redc_body_cps_id redc_loop_cps_id pre_redc_cps_id redc_cps_id add_cps_id sub_cps_id opp_cps_id : uncps. diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v deleted file mode 100644 index d3ab6897f..000000000 --- a/src/Arithmetic/Saturated/AddSub.v +++ /dev/null @@ -1,285 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Local Open Scope Z_scope. - -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Arithmetic.Saturated.Core. -Require Import Crypto.Arithmetic.Saturated.UniformWeight. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.CPS. -Require Import Crypto.Util.ZUtil.AddGetCarry. -Require Import Crypto.Util.Tuple Crypto.Util.LetIn. -Require Import Crypto.Util.Tactics.BreakMatch. -Local Notation "A ^ n" := (tuple A n) : type_scope. - -Module B. - Module Positional. - Section Positional. - Context {s:Z} {s_pos : 0 < s}. (* s is bitwidth *) - Let small {n} := @small s n. - Section GenericOp. - Context {op : Z -> Z -> Z} - {op_get_carry_cps : forall {T}, Z -> Z -> (Z * Z -> T) -> T} (* no carry in, carry out *) - {op_with_carry_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T}. (* carry in, carry out *) - Let op_get_carry x y := op_get_carry_cps _ x y id. - Let op_with_carry x y z := op_with_carry_cps _ x y z id. - Context {op_get_carry_id : forall {T} x y f, - @op_get_carry_cps T x y f = f (op_get_carry x y)} - {op_with_carry_id : forall {T} x y z f, - @op_with_carry_cps T x y z f = f (op_with_carry x y z)}. - Hint Rewrite @op_get_carry_id @op_with_carry_id : uncps. - - Section chain_op'_cps. - Context (T : Type). - - Fixpoint chain_op'_cps {n} (c:option Z) (p q:Z^n) - : (Z*Z^n->T)->T := - match n return option Z -> Z^n -> Z^n -> (Z*Z^n -> T) -> T with - | O => fun c p _ f => - let carry := match c with | None => 0 | Some x => x end in - f (carry,p) - | S n' => - fun c p q f => - (* for the first call, use op_get_carry, then op_with_carry *) - let op'_cps := match c with - | None => op_get_carry_cps _ - | Some x => op_with_carry_cps _ x end in - op'_cps (hd p) (hd q) (fun carry_result => - dlet carry_result := carry_result in - chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) - (fun carry_pq => - f (fst carry_pq, - append (fst carry_result) (snd carry_pq)))) - end c p q. - End chain_op'_cps. - Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id. - Definition chain_op_cps {n} p q {T} f := @chain_op'_cps T n None p q f. - Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id. - - Lemma chain_op'_id {n} : forall c p q T f, - @chain_op'_cps T n c p q f = f (chain_op' c p q). - Proof. - cbv [chain_op']; induction n; intros; destruct c; - simpl chain_op'_cps; cbv [Let_In]; try reflexivity; - autorewrite with uncps. - { etransitivity; rewrite IHn; reflexivity. } - { etransitivity; rewrite IHn; reflexivity. } - Qed. - - Lemma chain_op_id {n} p q T f : - @chain_op_cps n p q T f = f (chain_op p q). - Proof. apply (@chain_op'_id n None). Qed. - End GenericOp. - Hint Opaque chain_op chain_op' : uncps. - Hint Rewrite @chain_op_id @chain_op'_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. - - Section AddSub. - Create HintDb divmod discriminated. - Hint Rewrite - Z.add_get_carry_full_mod - Z.add_get_carry_full_div - Z.add_with_get_carry_full_mod - Z.add_with_get_carry_full_div - Z.sub_get_borrow_full_mod - Z.sub_get_borrow_full_div - Z.sub_with_get_borrow_full_mod - Z.sub_with_get_borrow_full_div - : divmod. - Let eval {n} := B.Positional.eval (n:=n) (uweight s). - - Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry_cps := fun T => Z.add_get_carry_full_cps s) - (op_with_carry_cps := fun T => Z.add_with_get_carry_full_cps s) - p q f. - Definition sat_add {n} p q := @sat_add_cps n p q _ id. - - Lemma sat_add_id n p q T f : - @sat_add_cps n p q T f = f (sat_add p q). - Proof. cbv [sat_add sat_add_cps]. autorewrite with uncps. reflexivity. Qed. - - Lemma sat_add_mod_step n c d : - c mod s + s * ((d + c / s) mod (uweight s n)) - = (s * d + c) mod (s * uweight s n). - Proof. - assert (0 < uweight s n) as wt_pos - by auto using Z.lt_gt, Z.gt_lt, uweight_positive. - rewrite <-(Columns.compact_mod_step s (uweight s n) c d s_pos wt_pos). - repeat (ring_simplify; f_equal; ring_simplify; try omega). - Qed. - - Lemma sat_add_div_step n c d : - (d + c / s) / uweight s n = (s * d + c) / (s * uweight s n). - Proof. - assert (0 < uweight s n) as wt_pos - by auto using Z.lt_gt, Z.gt_lt, uweight_positive. - rewrite <-(Columns.compact_div_step s (uweight s n) c d s_pos wt_pos). - repeat (ring_simplify; f_equal; ring_simplify; try omega). - Qed. - - Lemma sat_add_divmod n p q : - eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n) - /\ fst (@sat_add n p q) = (eval p + eval q) / (uweight s n). - Proof. - cbv [sat_add sat_add_cps chain_op_cps]. - remember None as c. - replace (eval p + eval q) with - (eval p + eval q + match c with | None => 0 | Some x => x end) - by (subst; ring). - destruct Heqc. revert c. - induction n; [|destruct c]; intros; simpl chain_op'_cps; - repeat match goal with - | _ => progress cbv [eval Let_In] in * - | _ => progress autorewrite with uncps divmod push_id cancel_pair push_basesystem_eval - | _ => rewrite uweight_0, ?Z.mod_1_r, ?Z.div_1_r - | _ => rewrite uweight_succ - | _ => rewrite Z.sub_opp_r - | _ => rewrite sat_add_mod_step - | _ => rewrite sat_add_div_step - | p : Z ^ 0 |- _ => destruct p - | _ => rewrite uweight_eval_step, ?hd_append, ?tl_append - | |- context[B.Positional.eval _ (snd (chain_op' ?c ?p ?q))] - => specialize (IHn p q c); autorewrite with push_id uncps in IHn; - rewrite (proj1 IHn); rewrite (proj2 IHn) - | _ => split; ring - | _ => solve [split; repeat (f_equal; ring_simplify; try omega)] - end. - Qed. - - Lemma sat_add_mod n p q : - eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n). - Proof. exact (proj1 (sat_add_divmod n p q)). Qed. - - Lemma sat_add_div n p q : - fst (@sat_add n p q) = (eval p + eval q) / (uweight s n). - Proof. exact (proj2 (sat_add_divmod n p q)). Qed. - - Lemma small_sat_add n p q : small (snd (@sat_add n p q)). - Proof. - cbv [small UniformWeight.small sat_add sat_add_cps chain_op_cps]. - remember None as c. destruct Heqc. revert c. - induction n; intros; - repeat match goal with - | p: Z^0 |- _ => destruct p - | _ => progress (cbv [Let_In] in * ) - | _ => progress (simpl chain_op'_cps in * ) - | _ => progress autorewrite with uncps push_id cancel_pair in H - | H : _ |- _ => rewrite to_list_append in H; - simpl In in H - | H : _ \/ _ |- _ => destruct H - | _ => contradiction - | _ => break_innermost_match_hyps_step - | _ => progress subst - | [ H : In _ (to_list _ (snd _)) |- _ ] - => apply IHn in H; assumption - end; - try solve [ rewrite ?Z.add_with_get_carry_full_mod, - ?Z.add_get_carry_full_mod; - apply Z.mod_pos_bound; omega ]. - Qed. - - Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry_cps := fun T => Z.sub_get_borrow_full_cps s) - (op_with_carry_cps := fun T => Z.sub_with_get_borrow_full_cps s) - p q f. - Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. - - Lemma sat_sub_id n p q T f : - @sat_sub_cps n p q T f = f (sat_sub p q). - Proof. cbv [sat_sub sat_sub_cps]. autorewrite with uncps. reflexivity. Qed. - Lemma sat_sub_divmod n p q : - eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n) - /\ fst (@sat_sub n p q) = - ((eval p - eval q) / (uweight s n)). - Proof. - cbv [sat_sub sat_sub_cps chain_op_cps]. - remember None as c. - replace (eval p - eval q) with - (eval p - eval q - match c with | None => 0 | Some x => x end) - by (subst; ring). - destruct Heqc. revert c. - induction n; [|destruct c]; intros; simpl chain_op'_cps; - repeat match goal with - | _ => progress cbv [eval Let_In] in * - | _ => progress autorewrite with uncps divmod push_id cancel_pair push_basesystem_eval - | _ => rewrite uweight_0, ?Z.mod_1_r, ?Z.div_1_r - | _ => rewrite uweight_succ - | _ => rewrite Z.sub_opp_r - | _ => rewrite sat_add_mod_step - | _ => rewrite sat_add_div_step - | p : Z ^ 0 |- _ => destruct p - | _ => rewrite uweight_eval_step, ?hd_append, ?tl_append - | |- context[B.Positional.eval _ (snd (chain_op' ?c ?p ?q))] - => specialize (IHn p q c); autorewrite with push_id uncps in IHn; - rewrite (proj1 IHn); rewrite (proj2 IHn) - | _ => split; ring - | _ => solve [split; repeat (f_equal; ring_simplify; try omega)] - end. - Qed. - - Lemma sat_sub_mod n p q : - eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n). - Proof. exact (proj1 (sat_sub_divmod n p q)). Qed. - - Lemma sat_sub_div n p q : - fst (@sat_sub n p q) = - ((eval p - eval q) / uweight s n). - Proof. exact (proj2 (sat_sub_divmod n p q)). Qed. - - Lemma small_sat_sub n p q : small (snd (@sat_sub n p q)). - Proof. - cbv [small UniformWeight.small sat_sub sat_sub_cps chain_op_cps]. - remember None as c. destruct Heqc. revert c. - induction n; intros; - repeat match goal with - | p: Z^0 |- _ => destruct p - | _ => progress (cbv [Let_In] in * ) - | _ => progress (simpl chain_op'_cps in * ) - | _ => progress autorewrite with uncps push_id cancel_pair in H - | H : _ |- _ => rewrite to_list_append in H; - simpl In in H - | H : _ \/ _ |- _ => destruct H - | _ => contradiction - | _ => break_innermost_match_hyps_step - | _ => progress subst - | [ H : In _ (to_list _ (snd _)) |- _ ] - => apply IHn in H; assumption - end; - try solve [ rewrite ?Z.sub_with_get_borrow_full_mod, - ?Z.sub_get_borrow_full_mod; - apply Z.mod_pos_bound; omega ]. - Qed. - End AddSub. - End Positional. - End Positional. -End B. -Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps. -Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id : uncps. -Hint Rewrite @B.Positional.chain_op_id @B.Positional.chain_op' using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. -Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. - -Hint Unfold - B.Positional.chain_op'_cps - B.Positional.chain_op' - B.Positional.chain_op_cps - B.Positional.chain_op - B.Positional.sat_add_cps - B.Positional.sat_add - B.Positional.sat_sub_cps - B.Positional.sat_sub - : basesystem_partial_evaluation_unfolder. - -Ltac basesystem_partial_evaluation_unfolder t := - let t := (eval cbv delta [ - B.Positional.chain_op'_cps - B.Positional.chain_op' - B.Positional.chain_op_cps - B.Positional.chain_op - B.Positional.sat_add_cps - B.Positional.sat_add - B.Positional.sat_sub_cps - B.Positional.sat_sub - ] in t) in - let t := Arithmetic.Saturated.Core.basesystem_partial_evaluation_unfolder t in - let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in - t. - -Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::= - basesystem_partial_evaluation_unfolder t. diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v deleted file mode 100644 index a597b7bf2..000000000 --- a/src/Arithmetic/Saturated/Core.v +++ /dev/null @@ -1,485 +0,0 @@ -Require Import Coq.micromega.Lia. -Require Import Coq.Init.Nat. -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Local Open Scope Z_scope. - -Require Import Crypto.Algebra.Nsatz. -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. -Require Import Crypto.Util.Tuple Crypto.Util.ListUtil. -Require Import Crypto.Util.Tactics.BreakMatch. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Util.ZUtil.Modulo.PullPush. -Require Import Crypto.Util.ZUtil.Le. -Require Import Crypto.Util.ZUtil.Modulo. -Require Import Crypto.Util.ZUtil.Div. -Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. -Require Import Crypto.Util.NatUtil. -Require Import Crypto.Util.Tactics.SpecializeBy. -Local Notation "A ^ n" := (tuple A n) : type_scope. - -(*** - -Arithmetic on bignums that handles carry bits; this is useful for -saturated limbs. Compatible with mixed-radix bases. - -Uses "columns" representation: a bignum has type [tuple (list Z) n]. -Associated with a weight function w, the bignum B represents: - - \sum_{i=0}^{n}{w[i] * sum{B[i]}} - -Example: ([a21, a20],[],[a0]) with weight function (fun i => 10^i) -represents - - a0 + 10*0 + 100 * (a20 + a21) - -If you picture this representation with the weights on the bottom and -the terms in each list stacked above the corresponding weight, - - a20 - a0 a21 - --------------- - 1 10 100 - -it's easy to see how the lists can be called "columns". - -This is a particularly useful representation for adding partial -products after multiplication, particularly when we want to do this -using a carrying add. We want to add together the terms from each -column, accumulating the carries together along the way. Then we want -to add the carry accumulator to the next column, and repeat, producing -a [tuple Z n] as output. This operation is called "compact". - -As an example, let's compact the product of 571 and 645 in base 10. -At first, the partial products look like this: - - - 1*6 - 1*4 7*4 7*6 - 1*5 7*5 5*5 5*4 5*6 - ------------------------------------ - 1 10 100 1000 10000 - - 6 - 4 28 42 - 5 35 25 20 30 - ------------------------------------ - 1 10 100 1000 10000 - -Now, we process the first column: - - {carry_acc = 0; output =()} - STEP [5] - {carry_acc = 0; output=(5,)} - -Since we have only one term, there's no addition to do, and no carry -bit. We add a 0 to the next column and continue. - - STEP [0,4,35] (0 + 4 = 4) - {carry_acc = 0; output=(5,)} - STEP [4,35] (4 + 35 = 39) - {carry_acc = 3; output=(9,5)} - -This time, we have a carry. We add it to the third column and process -that: - - STEP [3,6,28,25] (3 + 6 = 9) - {carry_acc = 0; output=(9,5)} - STEP [9,28,25] (9 + 28 = 37) - {carry_acc = 3; output=(9,5)} - STEP [7,25] (7 + 25 = 32) - {carry_acc = 6; output=(2,9,5)} - -You're probably getting the idea, but here are the fourth and fifth -columns: - - STEP [6,42,20] (6 + 42 = 48) - {carry_acc = 4; output=(2,9,5)} - STEP [8,20] (8 + 20 = 28) - {carry_acc = 6; output=(8,2,9,5)} - - STEP [6,30] (6 + 30 = 36) - {carry_acc = 3; output=(6,8,2,9,5)} - -The final result is the output plus the final carry, so we produce -(6,8,2,9,5) and 3, representing the number 368295. A quick calculator -check confirms our result. - - ***) - -Module Columns. - Section Columns. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0} - (* add_get_carry takes in a number at which to split output *) - {add_get_carry_cps: forall {T}, Z ->Z -> Z -> (Z * Z -> T) -> T} - {div_cps modulo_cps : forall {T}, Z -> Z -> (Z -> T) -> T}. - Let add_get_carry s x y := add_get_carry_cps _ s x y id. - Let div x y := div_cps _ x y id. - Let modulo x y := modulo_cps _ x y id. - Context {add_get_carry_cps_id : forall {T} s x y f, - @add_get_carry_cps T s x y f = f (add_get_carry s x y)} - {add_get_carry_mod : forall s x y, - fst (add_get_carry s x y) = (x + y) mod s} - {add_get_carry_div : forall s x y, - snd (add_get_carry s x y) = (x + y) / s} - {div_cps_id : forall {T} x y f, - @div_cps T x y f = f (div x y)} - {modulo_cps_id : forall {T} x y f, - @modulo_cps T x y f = f (modulo x y)} - {div_correct : forall a b, div a b = a / b} - {modulo_correct : forall a b, modulo a b = a mod b} - . - Hint Rewrite div_correct modulo_correct add_get_carry_mod add_get_carry_div : div_mod. - Hint Rewrite add_get_carry_cps_id div_cps_id modulo_cps_id : uncps. - - Definition eval {n} (x : (list Z)^n) : Z := - B.Positional.eval weight (Tuple.map sum x). - - Lemma eval_unit (x:unit) : eval (n:=0) x = 0. - Proof. reflexivity. Qed. - Hint Rewrite eval_unit : push_basesystem_eval. - - Lemma eval_single (x:list Z) : eval (n:=1) x = sum x. - Proof. - cbv [eval]. simpl map. cbv - [Z.mul Z.add sum]. - rewrite weight_0; ring. - Qed. Hint Rewrite eval_single : push_basesystem_eval. - - Definition eval_from {n} (offset:nat) (x : (list Z)^n) : Z := - B.Positional.eval (fun i => weight (i+offset)) (Tuple.map sum x). - - Lemma eval_from_0 {n} x : @eval_from n 0 x = eval x. - Proof using Type. cbv [eval_from eval]. auto using B.Positional.eval_wt_equiv. Qed. - - Lemma eval_from_S {n}: forall i (inp : (list Z)^(S n)), - eval_from i inp = eval_from (S i) (tl inp) + weight i * sum (hd inp). - Proof using Type. - intros i inp; cbv [eval_from]. - replace inp with (append (hd inp) (tl inp)) - by (simpl in *; destruct n; destruct inp; reflexivity). - rewrite map_append, B.Positional.eval_step, hd_append, tl_append. - autorewrite with natsimplify; ring_simplify; rewrite Group.cancel_left. - apply B.Positional.eval_wt_equiv; intros; f_equal; omega. - Qed. - - (* Sums a list of integers using carry bits. - Output : carry, sum - *) - Section compact_digit_cps. - Context (n : nat) {T : Type}. - - Fixpoint compact_digit_cps (digit : list Z) (f:Z * Z->T) := - match digit with - | nil => f (0, 0) - | x :: nil => div_cps _ x (weight (S n) / weight n) (fun d => - modulo_cps _ x (weight (S n) / weight n) (fun m => - f (d, m))) - | x :: y :: nil => - add_get_carry_cps _ (weight (S n) / weight n) x y (fun sum_carry => - dlet sum_carry := sum_carry in - dlet carry := snd sum_carry in - f (carry, fst sum_carry)) - | x :: tl => - compact_digit_cps tl - (fun rec => - add_get_carry_cps _ (weight (S n) / weight n) x (snd rec) (fun sum_carry => - dlet sum_carry := sum_carry in - dlet carry' := (fst rec + snd sum_carry)%RT in - f (carry', fst sum_carry))) - end. - End compact_digit_cps. - - Definition compact_digit n digit := compact_digit_cps n digit id. - Lemma compact_digit_id n digit: forall {T} f, - @compact_digit_cps n T digit f = f (compact_digit n digit). - Proof using add_get_carry_cps_id div_cps_id modulo_cps_id. - induction digit; intros; cbv [compact_digit]; [reflexivity|]. - simpl compact_digit_cps; break_match; rewrite ?IHdigit; clear IHdigit; - cbv [Let_In]; autorewrite with uncps; reflexivity. - Qed. - Hint Opaque compact_digit : uncps. - Hint Rewrite compact_digit_id : uncps. - - Definition compact_step_cps (index:nat) (carry:Z) (digit: list Z) - {T} (f:Z * Z->T) := - compact_digit_cps index (carry::digit) f. - - Definition compact_step i c d := compact_step_cps i c d id. - Lemma compact_step_id i c d T f : - @compact_step_cps i c d T f = f (compact_step i c d). - Proof using add_get_carry_cps_id div_cps_id modulo_cps_id. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed. - Hint Opaque compact_step : uncps. - Hint Rewrite compact_step_id : uncps. - - Definition compact_cps {n} (xs : (list Z)^n) {T} (f:Z * Z^n->T) := - Tuple.mapi_with_cps compact_step_cps 0 xs f. - - Definition compact {n} xs := @compact_cps n xs _ id. - Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs). - Proof using add_get_carry_cps_id div_cps_id modulo_cps_id. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed. - - Lemma compact_digit_mod i (xs : list Z) : - snd (compact_digit i xs) = sum xs mod (weight (S i) / weight i). - Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct add_get_carry_cps_id div_cps_id modulo_cps_id. - induction xs; cbv [compact_digit]; simpl compact_digit_cps; - cbv [Let_In]; - repeat match goal with - | _ => progress autorewrite with div_mod - | _ => rewrite IHxs, <-Z.add_mod_r - | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) - | _ => progress (autorewrite with uncps push_id cancel_pair in * ) - | _ => progress break_match; try discriminate - | _ => reflexivity - | _ => f_equal; ring - end. - Qed. Hint Rewrite compact_digit_mod : div_mod. - - Lemma compact_digit_div i (xs : list Z) : - fst (compact_digit i xs) = sum xs / (weight (S i) / weight i). - Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides add_get_carry_cps_id div_cps_id modulo_cps_id. - induction xs; cbv [compact_digit]; simpl compact_digit_cps; - cbv [Let_In]; - repeat match goal with - | _ => progress autorewrite with div_mod - | _ => rewrite IHxs - | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) - | _ => progress (autorewrite with uncps push_id cancel_pair in * ) - | _ => progress break_match; try discriminate - | _ => reflexivity - | _ => f_equal; ring - end. - assert (weight (S i) / weight i <> 0) by auto using Z.positive_is_nonzero. - match goal with |- _ = (?a + ?X) / ?D => - transitivity ((a + X mod D + D * (X / D)) / D); - [| rewrite (Z.div_mod'' X D) at 3; f_equal; auto; ring] - end. - rewrite Z.div_add' by auto; nsatz. - Qed. - - Lemma small_mod_eq a b n: a mod n = b mod n -> 0 <= a < n -> a = b mod n. - Proof. intros; rewrite <-(Z.mod_small a n); auto. Qed. - - (* helper for some of the modular logic in compact *) - Lemma compact_mod_step a b c d: 0 < a -> 0 < b -> - a * ((c / a + d) mod b) + c mod a = (a * d + c) mod (a * b). - Proof. - clear. - intros Ha Hb. assert (a <= a * b) by (apply Z.le_mul_diag_r; omega). - pose proof (Z.mod_pos_bound c a Ha). - pose proof (Z.mod_pos_bound (c/a+d) b Hb). - apply small_mod_eq. - { rewrite <-(Z.mod_small (c mod a) (a * b)) by omega. - rewrite <-Z.mul_mod_distr_l with (c:=a) by omega. - rewrite Z.mul_add_distr_l, Z.mul_div_eq, <-Z.add_mod_full by omega. - f_equal; ring. } - { split; [Z.zero_bounds|]. - apply Z.lt_le_trans with (m:=a*(b-1)+a); [|ring_simplify; omega]. - apply Z.add_le_lt_mono; try apply Z.mul_le_mono_nonneg_l; omega. } - Qed. - - Lemma compact_div_step a b c d : 0 < a -> 0 < b -> - (c / a + d) / b = (a * d + c) / (a * b). - Proof. - clear. intros Ha Hb. - rewrite <-Z.div_div by omega. - rewrite Z.div_add_l' by omega. - f_equal; ring. - Qed. - - Lemma compact_div_mod {n} inp : - (B.Positional.eval weight (snd (compact inp)) - = (eval inp) mod (weight n)) - /\ (fst (compact inp) = eval (n:=n) inp / weight n). - Proof. - cbv [compact compact_cps compact_step compact_step_cps]; - autorewrite with uncps push_id. - change (fun i s a => compact_digit_cps i (s :: a) id) - with (fun i s a => compact_digit i (s :: a)). - - apply mapi_with'_linvariant; [|tauto]. - - clear n inp. intros n st x0 xs ys Hst Hys [Hmod Hdiv]. - pose proof (weight_positive n). pose proof (weight_divides n). - autorewrite with push_basesystem_eval. - destruct n; cbv [mapi_with] in *; simpl tuple in *; - [destruct xs, ys; subst; simpl| cbv [eval] in *]; - repeat match goal with - | _ => rewrite mapi_with'_left_step - | _ => rewrite compact_digit_div, sum_cons - | _ => rewrite compact_digit_mod, sum_cons - | _ => rewrite map_left_append - | _ => rewrite B.Positional.eval_left_append - | _ => rewrite weight_0, ?Z.div_1_r, ?Z.mod_1_r - | _ => rewrite Hdiv - | _ => rewrite Hmod - | _ => progress subst - | _ => progress autorewrite with natsimplify cancel_pair push_basesystem_eval - | _ => solve [split; ring_simplify; f_equal; ring] - end. - remember (weight (S (S n)) / weight (S n)) as bound. - replace (weight (S (S n))) with (weight (S n) * bound) - by (subst bound; rewrite Z.mul_div_eq by omega; - rewrite weight_multiples; ring). - split; [apply compact_mod_step | apply compact_div_step]; omega. - Qed. - - Lemma compact_mod {n} inp : - (B.Positional.eval weight (snd (compact inp)) - = (eval (n:=n) inp) mod (weight n)). - Proof. apply (proj1 (compact_div_mod inp)). Qed. - Hint Rewrite @compact_mod : push_basesystem_eval. - - Lemma compact_div {n} inp : - fst (compact inp) = eval (n:=n) inp / weight n. - Proof. apply (proj2 (compact_div_mod inp)). Qed. - Hint Rewrite @compact_div : push_basesystem_eval. - - (* TODO : move to tuple *) - Lemma hd_to_list {A n} a (t : A^(S n)) : List.hd a (to_list (S n) t) = hd t. - Proof. - rewrite (subst_append t), to_list_append, hd_append. reflexivity. - Qed. - - Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n) - {T} (f:(list Z)^n->T) := - @on_tuple_cps _ _ nil (update_nth_cps i (cons x)) n n t _ f. - - Definition cons_to_nth {n} i x t := @cons_to_nth_cps n i x t _ id. - Lemma cons_to_nth_id {n} i x t T f : - @cons_to_nth_cps n i x t T f = f (cons_to_nth i x t). - Proof using Type. - cbv [cons_to_nth_cps cons_to_nth]. - assert (forall xs : list (list Z), length xs = n -> - length (update_nth_cps i (cons x) xs id) = n) as Hlen. - { intros. autorewrite with uncps push_id distr_length. assumption. } - rewrite !on_tuple_cps_correct with (H:=Hlen) - by (intros; autorewrite with uncps push_id; reflexivity). reflexivity. - Qed. - Hint Opaque cons_to_nth : uncps. - Hint Rewrite @cons_to_nth_id : uncps. - - Lemma map_sum_update_nth l : forall i x, - List.map sum (update_nth i (cons x) l) = - update_nth i (Z.add x) (List.map sum l). - Proof using Type. - induction l as [|a l IHl]; intros i x; destruct i; simpl; rewrite ?IHl; reflexivity. - Qed. - - Lemma cons_to_nth_add_to_nth n i x t : - map sum (@cons_to_nth n i x t) = B.Positional.add_to_nth i x (map sum t). - Proof using weight. - cbv [B.Positional.add_to_nth B.Positional.add_to_nth_cps cons_to_nth cons_to_nth_cps on_tuple_cps]. - induction n; [simpl; rewrite !update_nth_cps_correct; reflexivity|]. - specialize (IHn (tl t)). autorewrite with uncps push_id in *. - apply to_list_ext. rewrite <-!map_to_list. - erewrite !from_list_default_eq, !to_list_from_list. - rewrite map_sum_update_nth. reflexivity. - Unshelve. - distr_length. - distr_length. - Qed. - - Lemma eval_cons_to_nth n i x t : (i < n)%nat -> - eval (@cons_to_nth n i x t) = weight i * x + eval t. - Proof using Type. - cbv [eval]; intros. rewrite cons_to_nth_add_to_nth. - auto using B.Positional.eval_add_to_nth. - Qed. - Hint Rewrite eval_cons_to_nth using omega : push_basesystem_eval. - - Definition nils n : (list Z)^n := Tuple.repeat nil n. - - Lemma map_sum_nils n : map sum (nils n) = B.Positional.zeros n. - Proof using Type. - cbv [nils B.Positional.zeros]; induction n as [|n]; [reflexivity|]. - change (List.repeat nil (S n)) with (@nil Z :: List.repeat nil n). - rewrite Tuple.map_repeat, sum_nil. reflexivity. - Qed. - - Lemma eval_nils n : eval (nils n) = 0. - Proof using Type. cbv [eval]. rewrite map_sum_nils, B.Positional.eval_zeros. reflexivity. Qed. Hint Rewrite eval_nils : push_basesystem_eval. - - Definition from_associational_cps n (p:list B.limb) - {T} (f:(list Z)^n -> T) := - fold_right_cps2 - (fun t st T' f' => - B.Positional.place_cps weight t (pred n) - (fun p=> cons_to_nth_cps (fst p) (snd p) st f')) - (nils n) p f. - - Definition from_associational n p := from_associational_cps n p id. - Lemma from_associational_id n p T f : - @from_associational_cps n p T f = f (from_associational n p). - Proof using Type. - cbv [from_associational_cps from_associational]. - autorewrite with uncps push_id; reflexivity. - Qed. - Hint Opaque from_associational : uncps. - Hint Rewrite from_associational_id : uncps. - - Lemma eval_from_associational n p (n_nonzero:n<>0%nat): - eval (from_associational n p) = B.Associational.eval p. - Proof using weight_0 weight_nonzero. - cbv [from_associational_cps from_associational]; induction p; - autorewrite with uncps push_id push_basesystem_eval; [reflexivity|]. - pose proof (B.Positional.weight_place_cps weight weight_0 weight_nonzero a (pred n)). - pose proof (B.Positional.place_cps_in_range weight a (pred n)). - rewrite Nat.succ_pred in * by assumption. simpl. - autorewrite with uncps push_id push_basesystem_eval in *. - rewrite eval_cons_to_nth by omega. nsatz. - Qed. - End Columns. -End Columns. -Hint Rewrite - @Columns.compact_digit_id - @Columns.compact_step_id - @Columns.compact_id - using (assumption || (intros; autorewrite with uncps; reflexivity)) - : uncps. -Hint Rewrite - @Columns.cons_to_nth_id - @Columns.from_associational_id - : uncps. -Hint Rewrite - @Columns.compact_mod - @Columns.compact_div - @Columns.eval_cons_to_nth - @Columns.eval_from_associational - @Columns.eval_nils - using (assumption || omega): push_basesystem_eval. - -Hint Unfold - Columns.eval Columns.eval_from - Columns.compact_digit_cps Columns.compact_digit - Columns.compact_step_cps Columns.compact_step - Columns.compact_cps Columns.compact - Columns.cons_to_nth_cps Columns.cons_to_nth - Columns.nils - Columns.from_associational_cps Columns.from_associational - : basesystem_partial_evaluation_unfolder. - -Ltac basesystem_partial_evaluation_unfolder t := - let t := - (eval - cbv - delta [ - (* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], [runtime_opp], [runtime_mul], [runtime_shr], or [runtime_and] *) - Columns.eval Columns.eval_from - Columns.compact_digit_cps Columns.compact_digit - Columns.compact_step_cps Columns.compact_step - Columns.compact_cps Columns.compact - Columns.cons_to_nth_cps Columns.cons_to_nth - Columns.nils - Columns.from_associational_cps Columns.from_associational - ] in t) in - let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in - t. - -Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::= - basesystem_partial_evaluation_unfolder t. diff --git a/src/Arithmetic/Saturated/CoreUnfolder.v b/src/Arithmetic/Saturated/CoreUnfolder.v deleted file mode 100644 index 9a0e0c06a..000000000 --- a/src/Arithmetic/Saturated/CoreUnfolder.v +++ /dev/null @@ -1,97 +0,0 @@ -Require Import Crypto.Arithmetic.CoreUnfolder. -Require Import Crypto.Arithmetic.Saturated.Core. - -Hint Unfold Core.Columns.compact_digit_cps Core.Columns.compact_step_cps Core.Columns.compact_cps : arithmetic_cps_unfolder. - -Module Columns. - (** -<< -#!/bin/bash -for i in eval eval_from compact_digit_cps compact_digit compact_step_cps compact_step compact_cps compact cons_to_nth_cps cons_to_nth nils from_associational_cps from_associational; do - echo " Definition ${i}_sig := parameterize_sig (@Core.Columns.${i})."; - echo " Definition ${i} := parameterize_from_sig ${i}_sig."; - echo " Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; - echo " Hint Unfold ${i} : basesystem_partial_evaluation_unfolder."; - echo " Hint Rewrite <- ${i}_eq : pattern_runtime."; echo ""; -done -echo "End Columns." ->> *) - Definition eval_sig := parameterize_sig (@Core.Columns.eval). - Definition eval := parameterize_from_sig eval_sig. - Definition eval_eq := parameterize_eq eval eval_sig. - Hint Unfold eval : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- eval_eq : pattern_runtime. - - Definition eval_from_sig := parameterize_sig (@Core.Columns.eval_from). - Definition eval_from := parameterize_from_sig eval_from_sig. - Definition eval_from_eq := parameterize_eq eval_from eval_from_sig. - Hint Unfold eval_from : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- eval_from_eq : pattern_runtime. - - Definition compact_digit_cps_sig := parameterize_sig (@Core.Columns.compact_digit_cps). - Definition compact_digit_cps := parameterize_from_sig compact_digit_cps_sig. - Definition compact_digit_cps_eq := parameterize_eq compact_digit_cps compact_digit_cps_sig. - Hint Unfold compact_digit_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- compact_digit_cps_eq : pattern_runtime. - - Definition compact_digit_sig := parameterize_sig (@Core.Columns.compact_digit). - Definition compact_digit := parameterize_from_sig compact_digit_sig. - Definition compact_digit_eq := parameterize_eq compact_digit compact_digit_sig. - Hint Unfold compact_digit : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- compact_digit_eq : pattern_runtime. - - Definition compact_step_cps_sig := parameterize_sig (@Core.Columns.compact_step_cps). - Definition compact_step_cps := parameterize_from_sig compact_step_cps_sig. - Definition compact_step_cps_eq := parameterize_eq compact_step_cps compact_step_cps_sig. - Hint Unfold compact_step_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- compact_step_cps_eq : pattern_runtime. - - Definition compact_step_sig := parameterize_sig (@Core.Columns.compact_step). - Definition compact_step := parameterize_from_sig compact_step_sig. - Definition compact_step_eq := parameterize_eq compact_step compact_step_sig. - Hint Unfold compact_step : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- compact_step_eq : pattern_runtime. - - Definition compact_cps_sig := parameterize_sig (@Core.Columns.compact_cps). - Definition compact_cps := parameterize_from_sig compact_cps_sig. - Definition compact_cps_eq := parameterize_eq compact_cps compact_cps_sig. - Hint Unfold compact_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- compact_cps_eq : pattern_runtime. - - Definition compact_sig := parameterize_sig (@Core.Columns.compact). - Definition compact := parameterize_from_sig compact_sig. - Definition compact_eq := parameterize_eq compact compact_sig. - Hint Unfold compact : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- compact_eq : pattern_runtime. - - Definition cons_to_nth_cps_sig := parameterize_sig (@Core.Columns.cons_to_nth_cps). - Definition cons_to_nth_cps := parameterize_from_sig cons_to_nth_cps_sig. - Definition cons_to_nth_cps_eq := parameterize_eq cons_to_nth_cps cons_to_nth_cps_sig. - Hint Unfold cons_to_nth_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- cons_to_nth_cps_eq : pattern_runtime. - - Definition cons_to_nth_sig := parameterize_sig (@Core.Columns.cons_to_nth). - Definition cons_to_nth := parameterize_from_sig cons_to_nth_sig. - Definition cons_to_nth_eq := parameterize_eq cons_to_nth cons_to_nth_sig. - Hint Unfold cons_to_nth : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- cons_to_nth_eq : pattern_runtime. - - Definition nils_sig := parameterize_sig (@Core.Columns.nils). - Definition nils := parameterize_from_sig nils_sig. - Definition nils_eq := parameterize_eq nils nils_sig. - Hint Unfold nils : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- nils_eq : pattern_runtime. - - Definition from_associational_cps_sig := parameterize_sig (@Core.Columns.from_associational_cps). - Definition from_associational_cps := parameterize_from_sig from_associational_cps_sig. - Definition from_associational_cps_eq := parameterize_eq from_associational_cps from_associational_cps_sig. - Hint Unfold from_associational_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- from_associational_cps_eq : pattern_runtime. - - Definition from_associational_sig := parameterize_sig (@Core.Columns.from_associational). - Definition from_associational := parameterize_from_sig from_associational_sig. - Definition from_associational_eq := parameterize_eq from_associational from_associational_sig. - Hint Unfold from_associational : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- from_associational_eq : pattern_runtime. - -End Columns. diff --git a/src/Arithmetic/Saturated/Freeze.v b/src/Arithmetic/Saturated/Freeze.v deleted file mode 100644 index d8e7f4b5e..000000000 --- a/src/Arithmetic/Saturated/Freeze.v +++ /dev/null @@ -1,145 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Local Open Scope Z_scope. - -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Arithmetic.Saturated.Core. -Require Import Crypto.Arithmetic.Saturated.Wrappers. -Require Import Crypto.Util.ZUtil.AddGetCarry. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.Modulo.PullPush. -Require Import Crypto.Util.ZUtil.Le. -Require Import Crypto.Util.ZUtil.CPS. -Require Import Crypto.Util.Tactics.BreakMatch. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Util.Tuple Crypto.Util.LetIn. -Local Notation "A ^ n" := (tuple A n) : type_scope. - -(* Canonicalize bignums by fully reducing them modulo p. - This works on unsaturated digits, but uses saturated add/subtract - loops.*) -Section Freeze. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0} - . - - - (* - The input to [freeze] should be less than 2*m (this can probably - be accomplished by a single carry_reduce step, for most moduli). - - [freeze] has the following steps: - (1) subtract modulus in a carrying loop (in our framework, this - consists of two steps; [Columns.unbalanced_sub_cps] combines the - input p and the modulus m such that the ith limb in the output is - the list [p[i];-m[i]]. We can then call [Columns.compact].) - (2) look at the final carry, which should be either 0 or -1. If - it's -1, then we add the modulus back in. Otherwise we add 0 for - constant-timeness. - (3) discard the carry after this last addition; it should be 1 if - the carry in step 3 was -1, so they cancel out. - *) - Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := - Columns.unbalanced_sub_cps (n3:=n) weight p m - (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m - (fun carry_r => f (snd carry_r))) - . - - Definition freeze {n} mask m p := - @freeze_cps n mask m p _ id. - Lemma freeze_id {n} mask m p T f: - @freeze_cps n mask m p T f = f (freeze mask m p). - Proof. - cbv [freeze_cps freeze]; repeat progress autounfold; - autorewrite with uncps push_id; reflexivity. - Qed. - Hint Opaque freeze : uncps. - Hint Rewrite @freeze_id : uncps. - - Lemma freezeZ m s c y y0 z z0 c0 a : - m = s - c -> - 0 < c < s -> - s <> 0 -> - 0 <= y < 2*m -> - y0 = y - m -> - z = y0 mod s -> - c0 = y0 / s -> - z0 = z + (if (dec (c0 = 0)) then 0 else m) -> - a = z0 mod s -> - a mod m = y0 mod m. - Proof. - clear. intros. subst. break_match. - { rewrite Z.add_0_r, Z.mod_mod by omega. - assert (-(s-c) <= y - (s-c) < s-c) by omega. - match goal with H : s <> 0 |- _ => - rewrite (proj2 (Z.mod_small_iff _ s H)) - by (apply Z.div_small_iff; assumption) - end. - reflexivity. } - { rewrite <-Z.add_mod_l, Z.sub_mod_full. - rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega. - rewrite Z.mod_small with (b := s) - by (pose proof (Z.div_small (y - (s-c)) s); omega). - f_equal. ring. } - Qed. - - Lemma eval_freeze {n} c mask m p - (n_nonzero:n<>0%nat) - (Hc : 0 < B.Associational.eval c < weight n) - (Hmask : Tuple.map (Z.land mask) m = m) - modulus (Hm : B.Positional.eval weight m = Z.pos modulus) - (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus)) - (Hsc : Z.pos modulus = weight n - B.Associational.eval c) - : - mod_eq modulus - (B.Positional.eval weight (@freeze n mask m p)) - (B.Positional.eval weight p). - Proof. - cbv [freeze_cps freeze]. - repeat progress autounfold. - pose proof Z.add_get_carry_full_mod. - pose proof Z.add_get_carry_full_div. - pose proof div_correct. pose proof modulo_correct. - pose proof @div_id. pose proof @modulo_id. - pose proof @Z.add_get_carry_full_cps_correct. - autorewrite with uncps push_id push_basesystem_eval. - - pose proof (weight_nonzero n). - - remember (B.Positional.eval weight p) as y. - remember (y + -B.Positional.eval weight m) as y0. - rewrite Hm in *. - - transitivity y0; cbv [mod_eq]. - { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0); - try assumption; reflexivity. } - { subst y0. - assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0. - rewrite Z.add_mod by assumption. - rewrite Z.mod_opp_l_z by auto using Z.mod_same. - rewrite Z.add_0_r, Z.mod_mod by assumption. - reflexivity. } - Qed. -End Freeze. -Hint Opaque freeze_cps : uncps. -Hint Rewrite @freeze_id : uncps. -Hint Rewrite @eval_freeze - using (assumption || reflexivity || auto || eassumption || omega) : push_basesystem_eval. - -Hint Unfold - freeze freeze_cps - : basesystem_partial_evaluation_unfolder. - -Ltac basesystem_partial_evaluation_unfolder t := - let t := (eval cbv delta [freeze freeze_cps] in t) in - let t := Saturated.Wrappers.basesystem_partial_evaluation_unfolder t in - let t := Saturated.Core.basesystem_partial_evaluation_unfolder t in - let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in - t. - -Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::= - basesystem_partial_evaluation_unfolder t. diff --git a/src/Arithmetic/Saturated/FreezeUnfolder.v b/src/Arithmetic/Saturated/FreezeUnfolder.v deleted file mode 100644 index bae1a87f3..000000000 --- a/src/Arithmetic/Saturated/FreezeUnfolder.v +++ /dev/null @@ -1,27 +0,0 @@ -Require Import Crypto.Arithmetic.CoreUnfolder. -Require Import Crypto.Arithmetic.Saturated.CoreUnfolder. -Require Import Crypto.Arithmetic.Saturated.WrappersUnfolder. -Require Import Crypto.Arithmetic.Saturated.Freeze. - -(** -<< -#!/bin/bash -for i in freeze freeze_cps; do - echo "Definition ${i}_sig := parameterize_sig (@Freeze.${i})."; - echo "Definition ${i} := parameterize_from_sig ${i}_sig."; - echo "Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; - echo "Hint Unfold ${i} : basesystem_partial_evaluation_unfolder."; - echo "Hint Rewrite <- ${i}_eq : pattern_runtime."; echo ""; -done ->> *) -Definition freeze_cps_sig := parameterize_sig (@Freeze.freeze_cps). -Definition freeze_cps := parameterize_from_sig freeze_cps_sig. -Definition freeze_cps_eq := parameterize_eq freeze_cps freeze_cps_sig. -Hint Unfold freeze_cps : basesystem_partial_evaluation_unfolder. -Hint Rewrite <- freeze_cps_eq : pattern_runtime. - -Definition freeze_sig := parameterize_sig (@Freeze.freeze). -Definition freeze := parameterize_from_sig freeze_sig. -Definition freeze_eq := parameterize_eq freeze freeze_sig. -Hint Unfold freeze : basesystem_partial_evaluation_unfolder. -Hint Rewrite <- freeze_eq : pattern_runtime. diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v deleted file mode 100644 index d08fe7a8b..000000000 --- a/src/Arithmetic/Saturated/MontgomeryAPI.v +++ /dev/null @@ -1,691 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.micromega.Lia. -Require Import Coq.Lists.List. -Local Open Scope Z_scope. - -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Arithmetic.Saturated.Core. -Require Import Crypto.Arithmetic.Saturated.UniformWeight. -Require Import Crypto.Arithmetic.Saturated.Wrappers. -Require Import Crypto.Arithmetic.Saturated.AddSub. -Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. -Require Import Crypto.Util.Tuple Crypto.Util.LetIn. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Util.ListUtil. -Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. -Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. -Require Import Crypto.Util.ZUtil.Modulo. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.CPS. -Require Import Crypto.Util.ZUtil.Zselect. -Require Import Crypto.Util.ZUtil.AddGetCarry. -Require Import Crypto.Util.ZUtil.MulSplit. -Require Import Crypto.Util.ZUtil.Div. -Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. -Require Import Crypto.Util.ZUtil.Opp. -Require Import Crypto.Util.Tactics.UniquePose. -Local Notation "A ^ n" := (tuple A n) : type_scope. - -Section API. - Context (bound : Z) {bound_pos : bound > 0}. - Definition T : nat -> Type := tuple Z. - - (* lowest limb is less than its bound; this is required for [divmod] - to simply separate the lowest limb from the rest and be equivalent - to normal div/mod with [bound]. *) - Local Notation small := (@small bound). - - Definition zero {n:nat} : T n := B.Positional.zeros n. - - (** Returns 0 iff all limbs are 0 *) - Definition nonzero_cps {n} (p : T n) {cpsT} (f : Z -> cpsT) : cpsT - := CPSUtil.to_list_cps _ p (fun p => CPSUtil.fold_right_cps runtime_lor 0%Z p f). - Definition nonzero {n} (p : T n) : Z - := nonzero_cps p id. - - Definition join0_cps {n:nat} (p : T n) {R} (f:T (S n) -> R) - := Tuple.left_append_cps 0 p f. - Definition join0 {n} p : T (S n) := @join0_cps n p _ id. - - Definition divmod_cps {n} (p : T (S n)) {R} (f:T n * Z->R) : R - := Tuple.tl_cps p (fun d => Tuple.hd_cps p (fun m => f (d, m))). - Definition divmod {n} p : T n * Z := @divmod_cps n p _ id. - - Definition drop_high_cps {n : nat} (p : T (S n)) {R} (f:T n->R) - := Tuple.left_tl_cps p f. - Definition drop_high {n} p : T n := @drop_high_cps n p _ id. - - Definition scmul_cps {n} (c : Z) (p : T n) {R} (f:T (S n)->R) := - Columns.mul_cps (n1:=1) (n3:=S n) (uweight bound) bound c p - (* The carry that comes out of Columns.mul_cps will be 0, since - (S n) limbs is enough to hold the result of the - multiplication, so we can safely discard it. *) - (fun carry_result =>f (snd carry_result)). - Definition scmul {n} c p : T (S n) := @scmul_cps n c p _ id. - - Definition add_cps {n} (p q: T n) {R} (f:T (S n)->R) := - B.Positional.sat_add_cps (s:=bound) p q _ - (* join the last carry *) - (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) f). - Definition add {n} p q : T (S n) := @add_cps n p q _ id. - - (* Wrappers for additions with slightly uneven limb counts *) - Definition add_S1_cps {n} (p: T (S n)) (q: T n) {R} (f:T (S (S n))->R) := - join0_cps q (fun Q => add_cps p Q f). - Definition add_S1 {n} p q := @add_S1_cps n p q _ id. - Definition add_S2_cps {n} (p: T n) (q: T (S n)) {R} (f:T (S (S n))->R) := - join0_cps p (fun P => add_cps P q f). - Definition add_S2 {n} p q := @add_S2_cps n p q _ id. - - Definition sub_then_maybe_add_cps {n} mask (p q r : T n) - {R} (f:T n -> R) := - B.Positional.sat_sub_cps (s:=bound) p q _ - (* the carry will be 0 unless we underflow--we do the addition only - in the underflow case *) - (fun carry_result => - B.Positional.select_cps mask (fst carry_result) r - (fun selected => join0_cps selected - (fun selected' => - B.Positional.sat_add_cps (s:=bound) (left_append (- (fst carry_result))%RT (snd carry_result)) selected' _ - (* We can now safely discard the carry and the highest digit. - This relies on the precondition that p - q + r < bound^n. *) - (fun carry_result' => drop_high_cps (snd carry_result') f)))). - Definition sub_then_maybe_add {n} mask (p q r : T n) := - sub_then_maybe_add_cps mask p q r id. - - (* Subtract q if and only if p >= q. We rely on the preconditions - that 0 <= p < 2*q and q < bound^n (this ensures the output is less - than bound^n). *) - Definition conditional_sub_cps {n} (p:Z^S n) (q:Z^n) R (f:Z^n->R) := - join0_cps q - (fun qq => B.Positional.sat_sub_cps (s:=bound) p qq _ - (* if carry is zero, we select the result of the subtraction, - otherwise the first input *) - (fun carry_result => - Tuple.map2_cps (Z.zselect (fst carry_result)) (snd carry_result) p - (* in either case, since our result must be < q and therefore < - bound^n, we can drop the high digit *) - (fun r => drop_high_cps r f))). - Definition conditional_sub {n} p q := @conditional_sub_cps n p q _ id. - - Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps. - - Section CPSProofs. - - Local Ltac prove_id := - repeat autounfold; - repeat (intros; autorewrite with uncps push_id); - reflexivity. - - Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p). - Proof. cbv [nonzero nonzero_cps]. prove_id. Qed. - - Lemma join0_id n p R f : - @join0_cps n p R f = f (join0 p). - Proof. cbv [join0_cps join0]. prove_id. Qed. - - Lemma divmod_id n p R f : - @divmod_cps n p R f = f (divmod p). - Proof. cbv [divmod_cps divmod]; prove_id. Qed. - - Lemma drop_high_id n p R f : - @drop_high_cps n p R f = f (drop_high p). - Proof. cbv [drop_high_cps drop_high]; prove_id. Qed. - Hint Rewrite drop_high_id : uncps. - - Lemma scmul_id n c p R f : - @scmul_cps n c p R f = f (scmul c p). - Proof. cbv [scmul_cps scmul]. prove_id. Qed. - - Lemma add_id n p q R f : - @add_cps n p q R f = f (add p q). - Proof. cbv [add_cps add Let_In]. prove_id. Qed. - Hint Rewrite add_id : uncps. - - Lemma add_S1_id n p q R f : - @add_S1_cps n p q R f = f (add_S1 p q). - Proof. cbv [add_S1_cps add_S1 join0_cps]. prove_id. Qed. - - Lemma add_S2_id n p q R f : - @add_S2_cps n p q R f = f (add_S2 p q). - Proof. cbv [add_S2_cps add_S2 join0_cps]. prove_id. Qed. - - Lemma sub_then_maybe_add_id n mask p q r R f : - @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r). - Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add join0_cps Let_In]. prove_id. Qed. - - Lemma conditional_sub_id n p q R f : - @conditional_sub_cps n p q R f = f (conditional_sub p q). - Proof. cbv [conditional_sub_cps conditional_sub join0_cps Let_In]. prove_id. Qed. - - End CPSProofs. - Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps. - - Section Proofs. - - Definition eval {n} (p : T n) : Z := - B.Positional.eval (uweight bound) p. - - Definition encode n (z : Z) : T n := - B.Positional.encode (uweight bound) (modulo_cps:=@modulo_cps) (div_cps:=@div_cps) z. - - Lemma eval_small n (p : T n) (Hsmall : small p) : - 0 <= eval p < uweight bound n. - Proof. - cbv [small eval] in *; intros. - induction n; cbv [T uweight] in *; [destruct p|rewrite (subst_left_append p)]; - repeat match goal with - | _ => progress autorewrite with push_basesystem_eval - | _ => rewrite Z.pow_0_r - | _ => specialize (IHn (left_tl p)) - | _ => - let H := fresh "H" in - match type of IHn with - ?P -> _ => assert P as H by auto using Tuple.In_to_list_left_tl; - specialize (IHn H) - end - | |- context [?b ^ Z.of_nat (S ?n)] => - replace (b ^ Z.of_nat (S n)) with (b ^ Z.of_nat n * b) by - (rewrite Nat2Z.inj_succ, <-Z.add_1_r, Z.pow_add_r, - Z.pow_1_r by (omega || auto using Nat2Z.is_nonneg); - reflexivity) - | _ => omega - end. - - specialize (Hsmall _ (Tuple.In_left_hd _ p)). - split; [Z.zero_bounds; omega |]. - apply Z.lt_le_trans with (m:=bound^Z.of_nat n * (left_hd p+1)). - { rewrite Z.mul_add_distr_l. - apply Z.add_le_lt_mono; omega. } - { apply Z.mul_le_mono_nonneg; omega. } - Qed. - - Lemma small_encode n (v : Z) (Hsmall : 0 <= v < uweight bound n) - : small (encode n v). - Proof. - Admitted. (* TODO(jadep): prove me *) - - Lemma eval_encode n (v : Z) (Hsmall : 0 <= v < uweight bound n) - : eval (encode n v) = v. - Proof. - destruct n as [|n]. - { cbv -[Z.le Z.lt Z.gt] in *; omega. } - { cbv [eval encode]. - pose proof (@uweight_divides _ bound_pos) as Hdiv. - apply B.Positional.eval_encode; try reflexivity; - eauto using modulo_id, div_id, div_mod, uweight_nonzero. - { intros i ?; specialize (Hdiv i); omega. } } - Qed. - - Lemma eval_zero n : eval (@zero n) = 0. - Proof. - cbv [eval zero]. - autorewrite with push_basesystem_eval. - reflexivity. - Qed. - - Lemma small_zero n : small (@zero n). - Proof. - cbv [zero small B.Positional.zeros]. destruct n; [simpl;tauto|]. - rewrite to_list_repeat. - intros x H; apply repeat_spec in H; subst x; omega. - Qed. - - Lemma small_hd n p : @small (S n) p -> 0 <= hd p < bound. - Proof. - cbv [small]. let H := fresh "H" in intro H; apply H. - rewrite (subst_append p). rewrite to_list_append, hd_append. - apply in_eq. - Qed. - - Lemma In_to_list_tl {A n} (p : A^(S n)) x : - In x (to_list n (tl p)) -> In x (to_list (S n) p). - Proof. - intros. rewrite (subst_append p). - rewrite to_list_append. simpl In. tauto. - Qed. - - Lemma small_tl n p : @small (S n) p -> small (tl p). - Proof. - cbv [small]. let H := fresh "H" in intros H ? ?; apply H. - auto using In_to_list_tl. - Qed. - - Lemma add_nonneg_zero_iff a b c : 0 <= a -> 0 <= b -> 0 < c -> - a = 0 /\ b = 0 <-> a + c * b = 0. - Proof. nia. Qed. - - Lemma eval_pair n (p : T (S (S n))) : small p -> - (snd p = 0 /\ eval (n:=S n) (fst p) = 0) <-> eval p = 0. - Proof. - intro Hsmall. cbv [eval]. - rewrite uweight_eval_step with (p:=p). - change (fst p) with (tl p). change (snd p) with (hd p). - apply add_nonneg_zero_iff; try omega. - { apply small_hd in Hsmall. omega. } - { apply small_tl, eval_small in Hsmall. - cbv [eval] in Hsmall; omega. } - Qed. - - Lemma eval_nonzero n p : small p -> @nonzero n p = 0 <-> eval p = 0. - Proof. - destruct n as [|n]. - { compute; split; trivial. } - induction n as [|n IHn]. - { simpl; rewrite Z.lor_0_r; unfold eval, id. - cbv -[Z.add iff]. - rewrite Z.add_0_r. - destruct p; omega. } - { destruct p as [ps p]; specialize (IHn ps). - unfold nonzero, nonzero_cps in *. - autorewrite with uncps in *. - unfold id in *. - setoid_rewrite to_list_S. - set (k := S n) in *; simpl in *. - intro Hsmall. - rewrite Z.lor_eq_0_iff, IHn - by (hnf in Hsmall |- *; simpl in *; eauto); - clear IHn. - exact (eval_pair n (ps, p) Hsmall). } - Qed. - - Lemma eval_join0 n p - : eval (@join0 n p) = eval p. - Proof. - cbv [join0 join0_cps eval]. autorewrite with uncps push_id. - rewrite B.Positional.eval_left_append. ring. - Qed. - - Local Ltac pose_uweight bound := - match goal with H : bound > 0 |- _ => - pose proof (uweight_0 bound); - pose proof (@uweight_positive bound H); - pose proof (@uweight_nonzero bound H); - pose proof (@uweight_multiples bound); - pose proof (@uweight_divides bound H) - end. - - Local Ltac pose_all := - pose_uweight bound; - pose proof Z.add_get_carry_full_div; - pose proof Z.add_get_carry_full_mod; - pose proof Z.mul_split_div; pose proof Z.mul_split_mod; - pose proof div_correct; pose proof modulo_correct; - pose proof @div_id; pose proof @modulo_id; - pose proof @Z.add_get_carry_full_cps_correct; - pose proof @Z.mul_split_cps_correct; - pose proof @Z.mul_split_cps'_correct. - - Lemma eval_add n p q : - eval (@add n p q) = eval p + eval q. - Proof. - intros. pose_all. cbv [add_cps add eval Let_In]. - autorewrite with uncps push_id cancel_pair push_basesystem_eval. - symmetry; auto using Z.div_mod. - Qed. - - Lemma eval_add_same n p q - : eval (@add n p q) = eval p + eval q. - Proof. apply eval_add; omega. Qed. - Lemma eval_add_S1 n p q - : eval (@add_S1 n p q) = eval p + eval q. - Proof. - cbv [add_S1 add_S1_cps]. autorewrite with uncps push_id. - rewrite eval_add; rewrite eval_join0; reflexivity. - Qed. - Lemma eval_add_S2 n p q - : eval (@add_S2 n p q) = eval p + eval q. - Proof. - cbv [add_S2 add_S2_cps]. autorewrite with uncps push_id. - rewrite eval_add; rewrite eval_join0; reflexivity. - Qed. - Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval. - - Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (uweight bound). - Local Definition compact_digit := Columns.compact_digit (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (uweight bound). - Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)). - Proof. - pose_all. - match goal with - |- ?G => assert (G /\ fst (compact p) = fst (compact p)); [|tauto] - end. (* assert a dummy second statement so that fst (compact x) is in context *) - cbv [compact Columns.compact Columns.compact_cps small - Columns.compact_step Columns.compact_step_cps]; - autorewrite with uncps push_id. - change (fun i s a => Columns.compact_digit_cps (uweight bound) i (s :: a) id) - with (fun i s a => compact_digit i (s :: a)). - remember (fun i s a => compact_digit i (s :: a)) as f. - - apply @mapi_with'_linvariant with (n:=n) (f:=f) (inp:=p); - intros; [|simpl; tauto]. split; [|reflexivity]. - let P := fresh "H" in - match goal with H : _ /\ _ |- _ => destruct H end. - destruct n0; subst f. - { cbv [compact_digit uweight to_list to_list' In]. - rewrite Columns.compact_digit_mod - by (assumption || (intros; autorewrite with uncps push_id; auto)). - rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?. - match goal with - H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end. - subst x. apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } - { rewrite Tuple.to_list_left_append. - let H := fresh "H" in - intros x H; apply in_app_or in H; destruct H; - [solve[auto]| cbv [In] in H; destruct H; - [|exfalso; assumption] ]. - subst x. cbv [compact_digit]. - rewrite Columns.compact_digit_mod - by (assumption || (intros; autorewrite with uncps push_id; auto)). - rewrite !uweight_succ, Z.div_mul by - (apply Z.neq_mul_0; split; auto; omega). - apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } - Qed. - - Lemma small_left_append {n} b x : - 0 <= x < bound -> small b -> small (@left_append _ n x b). - Proof. - intros. - cbv [small]. - setoid_rewrite Tuple.to_list_left_append. - setoid_rewrite in_app_iff. - intros y HIn; destruct HIn as [HIn|[]]; (contradiction||omega||eauto). - Qed. - - Lemma small_add n a b : - (2 <= bound) -> - small a -> small b -> small (@add n a b). - Proof. - intros. - cbv [add add_cps]; autorewrite with uncps push_id in *. - pose proof @B.Positional.small_sat_add bound ltac:(omega) _ a b. - eapply small_left_append; eauto. - rewrite @B.Positional.sat_add_div by omega. - repeat match goal with H:_|-_=> unique pose proof (eval_small _ _ H) end. - cbv [eval] in *; Z.div_mod_to_quot_rem_in_goal; nia. - Qed. - - Lemma small_join0 {n} b : small b -> small (@join0 n b). - Proof. - cbv [join0 join0_cps]; autorewrite with uncps push_id in *. - eapply small_left_append; omega. - Qed. - - Lemma small_add_S1 n a b : - (2 <= bound) -> - small a -> small b -> small (@add_S1 n a b). - Proof. - intros. - cbv [add_S1 add_S1_cps Let_In]; autorewrite with uncps push_id in *. - eauto using small_add, small_join0. - Qed. - - Lemma small_left_tl n (v:T (S n)) : small v -> small (left_tl v). - Proof. cbv [small]. auto using Tuple.In_to_list_left_tl. Qed. - - Lemma eval_drop_high n v : - small v -> eval (@drop_high n v) = eval v mod (uweight bound n). - Proof. - cbv [drop_high drop_high_cps eval]. - rewrite Tuple.left_tl_cps_correct, push_id. (* TODO : for some reason autorewrite with uncps doesn't work here *) - intro H. apply small_left_tl in H. - rewrite (subst_left_append v) at 2. - autorewrite with push_basesystem_eval. - apply eval_small in H. - rewrite Z.mod_add_l' by (pose_uweight bound; auto). - rewrite Z.mod_small; auto. - Qed. - - Lemma small_drop_high n v : small v -> small (@drop_high n v). - Proof. - cbv [drop_high drop_high_cps]. - rewrite Tuple.left_tl_cps_correct, push_id. - apply small_left_tl. - Qed. - - Lemma div_nonzero_neg_iff x y : x < y -> 0 < y -> - - (x / y) = 0 <-> x progress intros - | _ => rewrite Z.ltb_ge - | _ => rewrite Z.opp_eq_0_iff - | _ => rewrite Z.div_small_iff by omega - | _ => split - | _ => omega - end. - Qed. - - Lemma eval_sub_then_maybe_add n mask p q r: - small p -> small q -> small r -> - (map (Z.land mask) r = r) -> - (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> - eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q progress (intros; cbv [eval runtime_opp sub_then_maybe_add sub_then_maybe_add_cps] in * ) - | _ => progress autorewrite with uncps push_id push_basesystem_eval - | _ => rewrite eval_drop_high by (apply @B.Positional.small_sat_add; omega) - | _ => rewrite B.Positional.sat_sub_mod by omega - | _ => rewrite B.Positional.sat_sub_div by omega - | _ => rewrite B.Positional.sat_add_mod by omega - | _ => rewrite B.Positional.eval_left_append - | _ => rewrite eval_join0 - | H : small _ |- _ => apply eval_small in H - end. - let H := fresh "H" in - match goal with |- context [- (?X / ?Y) = 0] => - assert ((- (X / Y) = 0) <-> X _ |- _ - => specialize (H (eq_refl x)) end; - try congruence; - match goal with - | H : _ |- _ => rewrite Z.ltb_ge in H - | H : _ |- _ => rewrite Z.ltb_lt in H - end. - { repeat (rewrite Z.mod_small; try omega). } - { rewrite !Z.mul_opp_r, Z.opp_involutive. - rewrite Z.mul_div_eq_full by (subst; auto). - match goal with |- context [?a - ?b + ?b] => - replace (a - b + b) with a by ring end. - repeat (rewrite Z.mod_small; try omega). } - Qed. - - Lemma small_sub_then_maybe_add n mask (p q r : T n) : - small (sub_then_maybe_add mask p q r). - Proof. - cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros. - repeat progress autounfold. autorewrite with uncps push_id. - apply small_drop_high, @B.Positional.small_sat_add; omega. - Qed. - - Lemma map2_zselect n cond x y : - Tuple.map2 (n:=n) (Z.zselect cond) x y = if dec (cond = 0) then x else y. - Proof. - unfold Z.zselect. - break_innermost_match; Z.ltb_to_lt; subst; try omega; - [ rewrite Tuple.map2_fst, Tuple.map_id - | rewrite Tuple.map2_snd, Tuple.map_id ]; - reflexivity. - Qed. - - Lemma eval_conditional_sub_nz n (p:T (S n)) (q:T n) - (n_nonzero: (n <> 0)%nat) (psmall : small p) (qsmall : small q): - 0 <= eval p < eval q + uweight bound n -> - eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0). - Proof. - pose_all. - pose proof (@uweight_le_mono _ bound_pos n (S n) (Nat.le_succ_diag_r _)). - intros. - repeat match goal with - | _ => progress (intros; cbv [eval conditional_sub conditional_sub_cps] in * ) - | _ => progress autorewrite with uncps push_id push_basesystem_eval - | _ => rewrite eval_drop_high - by (break_match; try assumption; apply @B.Positional.small_sat_sub; omega) - | _ => rewrite map2_zselect - | _ => rewrite B.Positional.sat_sub_mod by omega - | _ => rewrite B.Positional.sat_sub_div by omega - | _ => rewrite B.Positional.sat_add_mod by omega - | _ => rewrite B.Positional.eval_left_append - | _ => rewrite eval_join0 - | H : small _ |- _ => apply eval_small in H - end. - let H := fresh "H" in - match goal with |- context [- (?X / ?Y) = 0] => - assert ((- (X / Y) = 0) <-> X _ |- _ - => specialize (H (eq_refl x)) end; - repeat match goal with - | H : _ |- _ => rewrite Z.leb_gt in H - | H : _ |- _ => rewrite Z.leb_le in H - | H : _ |- _ => rewrite Z.ltb_lt in H - | H : _ |- _ => rewrite Z.ltb_ge in H - end; try omega. - { rewrite @B.Positional.sat_sub_mod by omega. - rewrite eval_join0; cbv [eval]. - repeat (rewrite Z.mod_small; try omega). } - { repeat (rewrite Z.mod_small; try omega). } - Qed. - - Lemma eval_conditional_sub n (p:T (S n)) (q:T n) - (psmall : small p) (qsmall : small q) : - 0 <= eval p < eval q + uweight bound n -> - eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0). - Proof. - destruct n; [|solve[auto using eval_conditional_sub_nz]]. - repeat match goal with - | _ => progress (intros; cbv [T tuple tuple'] in p, q) - | q : unit |- _ => destruct q - | _ => progress (cbv [conditional_sub conditional_sub_cps eval] in * ) - | _ => progress autounfold - | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * ) - | _ => (rewrite uweight_0 in * ) - | _ => assert (p = 0) by omega; subst p; break_match; ring - end. - Qed. - - Lemma small_conditional_sub n (p:T (S n)) (q:T n) - (psmall : small p) (qsmall : small q) : - 0 <= eval p < eval q + uweight bound n -> - small (conditional_sub p q). - Proof. - intros. - cbv [conditional_sub conditional_sub_cps]; autorewrite with uncps push_id. - eapply small_drop_high. - rewrite map2_zselect; break_match; [|assumption]. - eauto using @B.Positional.small_sat_sub with omega. - Qed. - - Lemma eval_scmul n a v : small v -> 0 <= a < bound -> - eval (@scmul n a v) = a * eval v. - Proof. - intro Hsmall. pose_all. apply eval_small in Hsmall. - intros. cbv [scmul scmul_cps eval] in *. repeat autounfold. - autorewrite with uncps. - autorewrite with push_basesystem_eval. - autorewrite with uncps push_id push_basesystem_eval. - rewrite uweight_0, Z.mul_1_l. apply Z.mod_small. - split; [solve[Z.zero_bounds]|]. cbv [uweight] in *. - rewrite !Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg. - apply Z.mul_lt_mono_nonneg; omega. - Qed. - - Lemma small_scmul n a v : small (@scmul n a v). - Proof. - cbv [scmul scmul_cps eval] in *. repeat autounfold. - autorewrite with uncps push_id push_basesystem_eval. - apply small_compact. - Qed. - - (* TODO : move to tuple *) - Lemma from_list_tl {A n} (ls : list A) H H': - from_list n (List.tl ls) H = tl (from_list (S n) ls H'). - Proof. - induction ls; distr_length. simpl List.tl. - rewrite from_list_cons, tl_append, <-!(from_list_default_eq a ls). - reflexivity. - Qed. - - Lemma eval_div n p : small p -> eval (fst (@divmod n p)) = eval p / bound. - Proof. - cbv [divmod divmod_cps eval]. intros. - autorewrite with uncps push_id cancel_pair. - rewrite (subst_append p) at 2. - rewrite uweight_eval_step. rewrite hd_append, tl_append. - rewrite Z.div_add' by omega. rewrite Z.div_small by auto using small_hd. - ring. - Qed. - - Lemma eval_mod n p : small p -> snd (@divmod n p) = eval p mod bound. - Proof. - cbv [divmod divmod_cps eval]. intros. - autorewrite with uncps push_id cancel_pair. - rewrite (subst_append p) at 2. - rewrite uweight_eval_step, Z.mod_add'_full, hd_append. - rewrite Z.mod_small by auto using small_hd. reflexivity. - Qed. - - Lemma small_div n v : small v -> small (fst (@divmod n v)). - Proof. - cbv [divmod divmod_cps]. intros. - autorewrite with uncps push_id cancel_pair. - auto using small_tl. - Qed. - End Proofs. -End API. -Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id add_S1_id add_S2_id sub_then_maybe_add_id conditional_sub_id : uncps. - -Hint Unfold - nonzero_cps - nonzero - scmul_cps - scmul - add_cps - add - add_S1_cps - add_S1 - add_S2_cps - add_S2 - sub_then_maybe_add_cps - sub_then_maybe_add - conditional_sub_cps - conditional_sub - eval - encode - : basesystem_partial_evaluation_unfolder. - -Ltac basesystem_partial_evaluation_unfolder t := - let t := (eval cbv delta [ - nonzero_cps - nonzero - scmul_cps - scmul - add_cps - add - add_S1_cps - add_S1 - add_S2_cps - add_S2 - sub_then_maybe_add_cps - sub_then_maybe_add - conditional_sub_cps - conditional_sub - eval - encode - ] in t) in - let t := Saturated.AddSub.basesystem_partial_evaluation_unfolder t in - let t := Saturated.Wrappers.basesystem_partial_evaluation_unfolder t in - let t := Saturated.Core.basesystem_partial_evaluation_unfolder t in - let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in - t. - -Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::= - basesystem_partial_evaluation_unfolder t. diff --git a/src/Arithmetic/Saturated/MulSplit.v b/src/Arithmetic/Saturated/MulSplit.v deleted file mode 100644 index cd86a8b48..000000000 --- a/src/Arithmetic/Saturated/MulSplit.v +++ /dev/null @@ -1,100 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Local Open Scope Z_scope. - -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. - -(* Defines bignum multiplication using a two-output multiply operation. *) -Module B. - Module Associational. - Section Associational. - Context {mul_split_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) - {mul_split_cps_id : forall {T} s x y f, - @mul_split_cps T s x y f = f (@mul_split_cps _ s x y id)} - {mul_split_mod : forall s x y, - fst (mul_split_cps s x y id) = (x * y) mod s} - {mul_split_div : forall s x y, - snd (mul_split_cps s x y id) = (x * y) / s} - . - - Local Lemma mul_split_cps_correct {T} s x y f - : @mul_split_cps T s x y f = f ((x * y) mod s, (x * y) / s). - Proof. - now rewrite mul_split_cps_id, <- mul_split_mod, <- mul_split_div, <- surjective_pairing. - Qed. - Hint Rewrite @mul_split_cps_correct : uncps. - - Definition sat_multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) := - mul_split_cps _ s (snd t) (snd t') (fun xy => - dlet xy := xy in - f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil)). - - Definition sat_multerm s t t' := sat_multerm_cps s t t' id. - Lemma sat_multerm_id s t t' T f : - @sat_multerm_cps s t t' T f = f (sat_multerm s t t'). - Proof. - unfold sat_multerm, sat_multerm_cps; - etransitivity; rewrite mul_split_cps_id; reflexivity. - Qed. - Hint Opaque sat_multerm : uncps. - Hint Rewrite sat_multerm_id : uncps. - - Definition sat_mul_cps s (p q : list B.limb) {T} (f : list B.limb -> T) := - flat_map_cps (fun t => @flat_map_cps _ _ (sat_multerm_cps s t) q) p f. - - Definition sat_mul s p q := sat_mul_cps s p q id. - Lemma sat_mul_id s p q T f : @sat_mul_cps s p q T f = f (sat_mul s p q). - Proof. cbv [sat_mul sat_mul_cps]. autorewrite with uncps. reflexivity. Qed. - Hint Opaque sat_mul : uncps. - Hint Rewrite sat_mul_id : uncps. - - Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0): - B.Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * B.Associational.eval q. - Proof. - cbv [sat_multerm sat_multerm_cps Let_In]; induction q; - repeat match goal with - | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval in * - | _ => progress simpl flat_map - | _ => progress unfold id in * - | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div - | _ => rewrite Z.mod_eq by assumption - | _ => rewrite B.Associational.eval_nil - | _ => progress change (Z * Z)%type with B.limb - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_map_sat_multerm using (omega || assumption) - : push_basesystem_eval. - - Lemma eval_sat_mul s p q (s_nonzero:s<>0): - B.Associational.eval (sat_mul s p q) = B.Associational.eval p * B.Associational.eval q. - Proof. - cbv [sat_mul sat_mul_cps]; induction p; [reflexivity|]. - repeat match goal with - | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * ) - | _ => progress simpl flat_map - | _ => rewrite IHp - | _ => progress change (fun x => sat_multerm_cps s a x id) with (sat_multerm s a) - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_sat_mul : push_basesystem_eval. - End Associational. - End Associational. -End B. -Hint Opaque B.Associational.sat_mul B.Associational.sat_multerm : uncps. -Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. -Hint Rewrite @B.Associational.eval_sat_mul @B.Associational.eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval. - -Hint Unfold - B.Associational.sat_multerm_cps B.Associational.sat_multerm B.Associational.sat_mul_cps B.Associational.sat_mul - : basesystem_partial_evaluation_unfolder. - -Ltac basesystem_partial_evaluation_unfolder t := - let t := (eval cbv delta [B.Associational.sat_multerm_cps B.Associational.sat_multerm B.Associational.sat_mul_cps B.Associational.sat_mul] in t) in - let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in - t. - -Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::= - basesystem_partial_evaluation_unfolder t. diff --git a/src/Arithmetic/Saturated/MulSplitUnfolder.v b/src/Arithmetic/Saturated/MulSplitUnfolder.v deleted file mode 100644 index e9747eb79..000000000 --- a/src/Arithmetic/Saturated/MulSplitUnfolder.v +++ /dev/null @@ -1,45 +0,0 @@ -Require Import Crypto.Arithmetic.CoreUnfolder. -Require Import Crypto.Arithmetic.Saturated.CoreUnfolder. -Require Import Crypto.Arithmetic.Saturated.MulSplit. - -Module B. - Module Associational. -(** -<< -#!/bin/bash -for i in sat_multerm_cps sat_multerm sat_mul_cps sat_mul; do - echo " Definition ${i}_sig := parameterize_sig (@MulSplit.B.Associational.${i})."; - echo " Definition ${i} := parameterize_from_sig ${i}_sig."; - echo " Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; - echo " Hint Unfold ${i} : basesystem_partial_evaluation_unfolder."; - echo " Hint Rewrite <- ${i}_eq : pattern_runtime."; echo ""; -done -echo " End Associational." -echo "End B." ->> *) - Definition sat_multerm_cps_sig := parameterize_sig (@MulSplit.B.Associational.sat_multerm_cps). - Definition sat_multerm_cps := parameterize_from_sig sat_multerm_cps_sig. - Definition sat_multerm_cps_eq := parameterize_eq sat_multerm_cps sat_multerm_cps_sig. - Hint Unfold sat_multerm_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- sat_multerm_cps_eq : pattern_runtime. - - Definition sat_multerm_sig := parameterize_sig (@MulSplit.B.Associational.sat_multerm). - Definition sat_multerm := parameterize_from_sig sat_multerm_sig. - Definition sat_multerm_eq := parameterize_eq sat_multerm sat_multerm_sig. - Hint Unfold sat_multerm : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- sat_multerm_eq : pattern_runtime. - - Definition sat_mul_cps_sig := parameterize_sig (@MulSplit.B.Associational.sat_mul_cps). - Definition sat_mul_cps := parameterize_from_sig sat_mul_cps_sig. - Definition sat_mul_cps_eq := parameterize_eq sat_mul_cps sat_mul_cps_sig. - Hint Unfold sat_mul_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- sat_mul_cps_eq : pattern_runtime. - - Definition sat_mul_sig := parameterize_sig (@MulSplit.B.Associational.sat_mul). - Definition sat_mul := parameterize_from_sig sat_mul_sig. - Definition sat_mul_eq := parameterize_eq sat_mul sat_mul_sig. - Hint Unfold sat_mul : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- sat_mul_eq : pattern_runtime. - - End Associational. -End B. diff --git a/src/Arithmetic/Saturated/UniformWeight.v b/src/Arithmetic/Saturated/UniformWeight.v deleted file mode 100644 index bf069f2d6..000000000 --- a/src/Arithmetic/Saturated/UniformWeight.v +++ /dev/null @@ -1,93 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Local Open Scope Z_scope. - -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Arithmetic.Saturated.Core. -Require Import Crypto.Util.ZUtil.Le. -Require Import Crypto.Util.ZUtil.Modulo. -Require Import Crypto.Util.ZUtil.Tactics.PeelLe. -Require Import Crypto.Util.LetIn Crypto.Util.Tuple. -Local Notation "A ^ n" := (tuple A n) : type_scope. - -Section UniformWeight. - Context (bound : Z) {bound_pos : bound > 0}. - - Definition uweight : nat -> Z := fun i => bound ^ Z.of_nat i. - Lemma uweight_0 : uweight 0%nat = 1. Proof. reflexivity. Qed. - Lemma uweight_positive i : uweight i > 0. - Proof. apply Z.lt_gt, Z.pow_pos_nonneg; omega. Qed. - Lemma uweight_nonzero i : uweight i <> 0. - Proof. auto using Z.positive_is_nonzero, uweight_positive. Qed. - Lemma uweight_multiples i : uweight (S i) mod uweight i = 0. - Proof. apply Z.mod_same_pow; rewrite Nat2Z.inj_succ; omega. Qed. - Lemma uweight_divides i : uweight (S i) / uweight i > 0. - Proof. - cbv [uweight]. rewrite <-Z.pow_sub_r by (rewrite ?Nat2Z.inj_succ; omega). - apply Z.lt_gt, Z.pow_pos_nonneg; rewrite ?Nat2Z.inj_succ; omega. - Qed. - - (* TODO : move to Positional *) - Lemma eval_from_eq {n} (p:Z^n) wt offset : - (forall i, wt i = uweight (i + offset)) -> - B.Positional.eval wt p = B.Positional.eval_from uweight offset p. - Proof. cbv [B.Positional.eval_from]. auto using B.Positional.eval_wt_equiv. Qed. - - Lemma uweight_eval_from {n} (p:Z^n): forall offset, - B.Positional.eval_from uweight offset p = uweight offset * B.Positional.eval uweight p. - Proof. - induction n; intros; cbv [B.Positional.eval_from]; - [|rewrite (subst_append p)]; - repeat match goal with - | _ => destruct p - | _ => rewrite B.Positional.eval_unit; [ ] - | _ => rewrite B.Positional.eval_step; [ ] - | _ => rewrite IHn; [ ] - | _ => rewrite eval_from_eq with (offset0:=S offset) - by (intros; f_equal; omega) - | _ => rewrite eval_from_eq with - (wt:=fun i => uweight (S i)) (offset0:=1%nat) - by (intros; f_equal; omega) - | _ => ring - end. - repeat match goal with - | _ => cbv [uweight]; progress autorewrite with natsimplify - | _ => progress (rewrite ?Nat2Z.inj_succ, ?Nat2Z.inj_0, ?Z.pow_0_r) - | _ => rewrite !Z.pow_succ_r by (try apply Nat2Z.is_nonneg; omega) - | _ => ring - end. - Qed. - - Lemma uweight_eval_step {n} (p:Z^S n): - B.Positional.eval uweight p = hd p + bound * B.Positional.eval uweight (tl p). - Proof. - rewrite (subst_append p) at 1; rewrite B.Positional.eval_step. - rewrite eval_from_eq with (offset := 1%nat) by (intros; f_equal; omega). - rewrite uweight_eval_from. cbv [uweight]; rewrite Z.pow_0_r, Z.pow_1_r. - ring. - Qed. - - Lemma uweight_le_mono n m : (n <= m)%nat -> - uweight n <= uweight m. - Proof. - unfold uweight; intro; Z.peel_le; omega. - Qed. - - Lemma uweight_lt_mono (bound_gt_1 : bound > 1) n m : (n < m)%nat -> - uweight n < uweight m. - Proof. - clear bound_pos. - unfold uweight; intro; apply Z.pow_lt_mono_r; omega. - Qed. - - Lemma uweight_succ n : uweight (S n) = bound * uweight n. - Proof. - unfold uweight. - rewrite Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg; reflexivity. - Qed. - - - Definition small {n} (p : Z^n) : Prop := - forall x, In x (to_list _ p) -> 0 <= x < bound. - -End UniformWeight. diff --git a/src/Arithmetic/Saturated/UniformWeightInstances.v b/src/Arithmetic/Saturated/UniformWeightInstances.v deleted file mode 100644 index 7ca7b1f3e..000000000 --- a/src/Arithmetic/Saturated/UniformWeightInstances.v +++ /dev/null @@ -1,34 +0,0 @@ -Require Import Coq.ZArith.BinInt. -Require Import Crypto.Arithmetic.Saturated.UniformWeight. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Util.Tactics.DestructHead. - -Fixpoint small_Decidable' {bound n} : forall (p : Tuple.tuple' Z n), Decidable (small (n:=S n) bound p). -Proof. - refine match n as n return forall p : Tuple.tuple' Z n, id (Decidable (small (n:=S n) bound p)) with - | 0 - => fun p : Z - => if dec (0 <= p < bound)%Z then left _ else right _ - | S n' - => fun p : Tuple.tuple' Z n' * Z - => if dec (0 <= snd p < bound)%Z - then if dec (small (n:=S n') bound (fst p))%Z - then left _ - else right _ - else right _ - end; - unfold id, small in *; simpl in *; - [ clear small_Decidable' n - | clear small_Decidable' n - | clear small_Decidable'; simpl in *.. ]; - [ abstract (simpl in *; intros; destruct_head'_or; subst; auto; exfalso; assumption) - | abstract (simpl in *; intros; destruct_head'_or; subst; auto; exfalso; assumption) - | abstract (destruct p; simpl in *; intros; destruct_head'_or; subst; auto).. ]. -Defined. - -Global Instance small_Decidable {bound n} : forall (p : Tuple.tuple Z n), Decidable (small bound p). -Proof. - destruct n; simpl; [ left | apply small_Decidable' ]. - intros ??; exfalso; assumption. -Defined. diff --git a/src/Arithmetic/Saturated/Wrappers.v b/src/Arithmetic/Saturated/Wrappers.v deleted file mode 100644 index cbd4c42b5..000000000 --- a/src/Arithmetic/Saturated/Wrappers.v +++ /dev/null @@ -1,68 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Local Open Scope Z_scope. - -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Arithmetic.Saturated.Core. -Require Import Crypto.Arithmetic.Saturated.MulSplit. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.MulSplit. -Require Import Crypto.Util.ZUtil.CPS. -Require Import Crypto.Util.Tuple. -Local Notation "A ^ n" := (tuple A n) : type_scope. - -(* Define wrapper definitions that use Columns representation -internally but with input and output in Positonal representation.*) -Module Columns. - Section Wrappers. - Context (weight : nat->Z). - - Definition add_cps {n1 n2 n3} (p : Z^n1) (q : Z^n2) - {T} (f : (Z*Z^n3)->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.to_associational_cps weight q - (fun Q => Columns.from_associational_cps weight n3 (P++Q) - (fun R => Columns.compact_cps (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f))). - - Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2) - {T} (f : (Z*Z^n3)->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.negate_snd_cps weight q - (fun nq => B.Positional.to_associational_cps weight nq - (fun Q => Columns.from_associational_cps weight n3 (P++Q) - (fun R => Columns.compact_cps (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))). - - Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2) - {T} (f : (Z*Z^n3)->T) := - B.Positional.to_associational_cps weight p - (fun P => B.Positional.to_associational_cps weight q - (fun Q => B.Associational.sat_mul_cps (mul_split_cps := @Z.mul_split_cps') s P Q - (fun PQ => Columns.from_associational_cps weight n3 PQ - (fun R => Columns.compact_cps (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))). - - Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2) - {T} (f:_->T) := - B.Positional.select_cps mask cond q - (fun qq => add_cps (n3:=n3) p qq f). - - End Wrappers. -End Columns. -Hint Unfold - Columns.conditional_add_cps - Columns.add_cps - Columns.unbalanced_sub_cps - Columns.mul_cps. - -Hint Unfold - Columns.add_cps Columns.unbalanced_sub_cps Columns.mul_cps Columns.conditional_add_cps - : basesystem_partial_evaluation_unfolder. - -Ltac basesystem_partial_evaluation_unfolder t := - let t := (eval cbv delta [Columns.add_cps Columns.unbalanced_sub_cps Columns.mul_cps Columns.conditional_add_cps] in t) in - let t := Saturated.MulSplit.basesystem_partial_evaluation_unfolder t in - let t := Saturated.Core.basesystem_partial_evaluation_unfolder t in - let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in - t. - -Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::= - basesystem_partial_evaluation_unfolder t. diff --git a/src/Arithmetic/Saturated/WrappersUnfolder.v b/src/Arithmetic/Saturated/WrappersUnfolder.v deleted file mode 100644 index b8fe6afc6..000000000 --- a/src/Arithmetic/Saturated/WrappersUnfolder.v +++ /dev/null @@ -1,45 +0,0 @@ -Require Import Crypto.Arithmetic.CoreUnfolder. -Require Import Crypto.Arithmetic.Saturated.CoreUnfolder. -Require Import Crypto.Arithmetic.Saturated.MulSplitUnfolder. -Require Import Crypto.Arithmetic.Saturated.Wrappers. - -Hint Unfold Wrappers.Columns.add_cps Wrappers.Columns.unbalanced_sub_cps Wrappers.Columns.mul_cps : arithmetic_cps_unfolder. - -Module Columns. - (** -<< -#!/bin/bash -for i in add_cps unbalanced_sub_cps mul_cps conditional_add_cps; do - echo " Definition ${i}_sig := parameterize_sig (@Wrappers.Columns.${i})."; - echo " Definition ${i} := parameterize_from_sig ${i}_sig."; - echo " Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; - echo " Hint Unfold ${i} : basesystem_partial_evaluation_unfolder."; - echo " Hint Rewrite <- ${i}_eq : pattern_runtime."; echo ""; -done -echo "End Columns." ->> *) - Definition add_cps_sig := parameterize_sig (@Wrappers.Columns.add_cps). - Definition add_cps := parameterize_from_sig add_cps_sig. - Definition add_cps_eq := parameterize_eq add_cps add_cps_sig. - Hint Unfold add_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- add_cps_eq : pattern_runtime. - - Definition unbalanced_sub_cps_sig := parameterize_sig (@Wrappers.Columns.unbalanced_sub_cps). - Definition unbalanced_sub_cps := parameterize_from_sig unbalanced_sub_cps_sig. - Definition unbalanced_sub_cps_eq := parameterize_eq unbalanced_sub_cps unbalanced_sub_cps_sig. - Hint Unfold unbalanced_sub_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- unbalanced_sub_cps_eq : pattern_runtime. - - Definition mul_cps_sig := parameterize_sig (@Wrappers.Columns.mul_cps). - Definition mul_cps := parameterize_from_sig mul_cps_sig. - Definition mul_cps_eq := parameterize_eq mul_cps mul_cps_sig. - Hint Unfold mul_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- mul_cps_eq : pattern_runtime. - - Definition conditional_add_cps_sig := parameterize_sig (@Wrappers.Columns.conditional_add_cps). - Definition conditional_add_cps := parameterize_from_sig conditional_add_cps_sig. - Definition conditional_add_cps_eq := parameterize_eq conditional_add_cps conditional_add_cps_sig. - Hint Unfold conditional_add_cps : basesystem_partial_evaluation_unfolder. - Hint Rewrite <- conditional_add_cps_eq : pattern_runtime. - -End Columns. -- cgit v1.2.3