diff options
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 142 |
1 files changed, 80 insertions, 62 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 1e748892d..ed8f80659 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -4,8 +4,12 @@ Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.BaseSystem Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemList. +Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. Require Import Coq.Lists.List. +Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil. Import ListNotations. Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. @@ -14,8 +18,7 @@ Require Import Crypto.Tactics.VerdiTactics. Local Open Scope Z. Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := { - coeff : BaseSystem.digits; - coeff_length : (length coeff = length (Pow2Base.base_from_limb_widths limb_widths))%nat; + coeff : tuple Z (length limb_widths); coeff_mod: decode coeff = 0%F }. @@ -37,9 +40,9 @@ Definition map_opt {A B} := Eval compute in @map A B. Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain. Definition length_opt := Eval compute in length. Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths. -Definition max_ones_opt := Eval compute in @max_ones. -Definition max_bound_opt := Eval compute in @max_bound. Definition minus_opt := Eval compute in minus. +Definition max_ones_opt := Eval compute in @max_ones. +Definition from_list_default_opt {A} := Eval compute in (@from_list_default A). Definition Let_In {A P} (x : A) (f : forall y : A, P y) := let y := x in f y. @@ -110,14 +113,16 @@ Section Carries. (* allows caller to precompute k and c *) (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_). Local Notation base := (Pow2Base.base_from_limb_widths limb_widths). + Local Notation digits := (tuple Z (length limb_widths)). Definition carry_opt_sig (i : nat) (b : digits) : { d : digits | (i < length limb_widths)%nat -> d = carry i b }. Proof. eexists ; intros. - cbv [carry]. - rewrite <- pull_app_if_sumbool. + cbv [carry ModularBaseSystemList.carry]. + rewrite <-from_list_default_eq with (d := 0%Z). + rewrite <-pull_app_if_sumbool. cbv beta delta [carry carry_and_reduce Pow2Base.carry_gen Pow2Base.carry_and_reduce_single Pow2Base.carry_simple Z.pow2_mod Z.ones Z.pred @@ -127,12 +132,13 @@ Section Carries. change @nth_default with @nth_default_opt in *. change @set_nth with @set_nth_opt in *. lazymatch goal with - | [ |- _ = (if ?br then ?c else ?d) ] + | [ |- _ = _ (if ?br then ?c else ?d) ] => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y end. 2:cbv zeta. 2:break_if; reflexivity. + change @from_list_default with @from_list_default_opt. change @nth_default with @nth_default_opt. rewrite c_subst. change @set_nth with @set_nth_opt. @@ -141,10 +147,12 @@ Section Carries. reflexivity. Defined. - Definition carry_opt i b - := Eval cbv beta iota delta [proj1_sig carry_opt_sig] in proj1_sig (carry_opt_sig i b). + Definition carry_opt is us := Eval cbv [proj1_sig carry_opt_sig] in + proj1_sig (carry_opt_sig is us). - Definition carry_opt_correct i b : (i < length limb_widths)%nat -> carry_opt i b = carry i b := proj2_sig (carry_opt_sig i b). + Definition carry_opt_correct i us + : (i < length limb_widths)%nat -> carry_opt i us = carry i us + := proj2_sig (carry_opt_sig i us). Definition carry_sequence_opt_sig (is : list nat) (us : digits) : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }. @@ -183,11 +191,10 @@ Section Carries. cbv beta iota delta [carry_opt]. let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (nth_default_opt 0%Z b i) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (nth_default_opt 0%Z b i) RHSf). + let RHSf := match (eval pattern (nth_default_opt 0%Z (to_list _ b) i) in RHS) with ?RHSf _ => RHSf end in + change (LHS = Let_In (nth_default_opt 0%Z (to_list _ b) i) RHSf). change Z.shiftl with Z_shiftl_opt. - change (-1) with (Z_opp_opt 1). - change Z.add with Z_add_opt at 5 9 17 21. + match goal with |- appcontext[ ?x + -1] => change (x + -1) with (Z_add_opt x (Z_opp_opt 1)) end. reflexivity. Defined. @@ -265,13 +272,13 @@ Section Carries. Lemma carry_sequence_opt_cps_rep - : forall (is : list nat) (us : list Z) (x : F modulus), + : forall (is : list nat) (us : digits) (x : F modulus), (forall i : nat, In i is -> i < length base)%nat -> rep us x -> rep (carry_sequence_opt_cps is us) x. Proof. intros. rewrite carry_sequence_opt_cps_correct by assumption. - apply carry_sequence_rep; eauto using rep_length. + auto using carry_sequence_rep. Qed. Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> (i < length base)%nat. @@ -320,6 +327,7 @@ End Carries. Section Addition. Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}. + Local Notation digits := (tuple Z (length limb_widths)). Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }. Proof. @@ -338,6 +346,7 @@ End Addition. Section Subtraction. Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}. + Local Notation digits := (tuple Z (length limb_widths)). Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff us vs }. Proof. @@ -358,9 +367,11 @@ Section Multiplication. Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm} (* allows caller to precompute k and c *) (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_). + Local Notation digits := (tuple Z (length limb_widths)). + Definition mul_bi'_step - (mul_bi' : nat -> digits -> list Z -> list Z) - (i : nat) (vsr : digits) (bs : list Z) + (mul_bi' : nat -> list Z -> list Z -> list Z) + (i : nat) (vsr : list Z) (bs : list Z) : list Z := match vsr with | [] => [] @@ -368,8 +379,8 @@ Section Multiplication. end. Definition mul_bi'_opt_step_sig - (mul_bi' : nat -> digits -> list Z -> list Z) - (i : nat) (vsr : digits) (bs : list Z) + (mul_bi' : nat -> list Z -> list Z -> list Z) + (i : nat) (vsr : list Z) (bs : list Z) : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }. Proof. eexists. @@ -384,19 +395,19 @@ Section Multiplication. Defined. Definition mul_bi'_opt_step - (mul_bi' : nat -> digits -> list Z -> list Z) - (i : nat) (vsr : digits) (bs : list Z) + (mul_bi' : nat -> list Z -> list Z -> list Z) + (i : nat) (vsr : list Z) (bs : list Z) : list Z := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs). Fixpoint mul_bi'_opt - (i : nat) (vsr : digits) (bs : list Z) {struct vsr} + (i : nat) (vsr : list Z) (bs : list Z) {struct vsr} : list Z := mul_bi'_opt_step mul_bi'_opt i vsr bs. Definition mul_bi'_opt_correct - (i : nat) (vsr : digits) (bs : list Z) + (i : nat) (vsr : list Z) (bs : list Z) : mul_bi'_opt i vsr bs = mul_bi' bs i vsr. Proof. revert i; induction vsr as [|vsr vsrs IHvsr]; intros. @@ -413,12 +424,12 @@ Section Multiplication. Qed. Definition mul'_step - (mul' : digits -> digits -> list Z -> digits) - (usr vs : digits) (bs : list Z) - : digits + (mul' : list Z -> list Z -> list Z -> list Z) + (usr vs : list Z) (bs : list Z) + : list Z := match usr with | [] => [] - | u :: usr' => add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs) + | u :: usr' => BaseSystem.add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs) end. Lemma map_zeros : forall a n l, @@ -428,9 +439,9 @@ Section Multiplication. Qed. Definition mul'_opt_step_sig - (mul' : digits -> digits -> list Z -> digits) - (usr vs : digits) (bs : list Z) - : { d : digits | d = mul'_step mul' usr vs bs }. + (mul' : list Z -> list Z -> list Z -> list Z) + (usr vs : list Z) (bs : list Z) + : { d : list Z | d = mul'_step mul' usr vs bs }. Proof. eexists. cbv [mul'_step]. @@ -449,18 +460,18 @@ Section Multiplication. Defined. Definition mul'_opt_step - (mul' : digits -> digits -> list Z -> digits) - (usr vs : digits) (bs : list Z) - : digits + (mul' : list Z -> list Z -> list Z -> list Z) + (usr vs : list Z) (bs : list Z) + : list Z := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs). Fixpoint mul'_opt - (usr vs : digits) (bs : list Z) - : digits + (usr vs : list Z) (bs : list Z) + : list Z := mul'_opt_step mul'_opt usr vs bs. Definition mul'_opt_correct - (usr vs : digits) (bs : list Z) + (usr vs : list Z) (bs : list Z) : mul'_opt usr vs bs = mul' bs usr vs. Proof. revert vs; induction usr as [|usr usrs IHusr]; intros. @@ -471,13 +482,17 @@ Section Multiplication. cbv [mul_each mul_bi]. rewrite map_zeros. rewrite <- mul_bi'_opt_correct. + cbv [zeros]. reflexivity. } Qed. Definition mul_opt_sig (us vs : digits) : { b : digits | b = mul us vs }. Proof. eexists. - cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros reduce]. + cbv [mul ModularBaseSystemList.mul BaseSystem.mul mul_each mul_bi mul_bi' zeros reduce]. + rewrite <- from_list_default_eq with (d := 0%Z). + change (@from_list_default Z) with (@from_list_default_opt Z). + apply f_equal. rewrite ext_base_alt by auto using limb_widths_pos with zarith. rewrite <- mul'_opt_correct. change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. @@ -512,6 +527,13 @@ Section Multiplication. Definition carry_mul_opt_cps_correct {T} (f:digits -> T) (us vs : digits) : carry_mul_opt_cps f us vs = f (carry_mul us vs) := proj2_sig (carry_mul_opt_sig f us vs). + + Definition carry_mul_opt := carry_mul_opt_cps id. + + Definition carry_mul_opt_correct (us vs : digits) + : carry_mul_opt us vs = carry_mul us vs := + carry_mul_opt_cps_correct id us vs. + End Multiplication. Section with_base. @@ -525,8 +547,8 @@ Section with_base. int_width_pos : 0 < int_width; int_width_compat : forall w, In w limb_widths -> w <= int_width; c_pos : 0 < c; - c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < max_bound 0 + 1; - c_reduce2 : c <= max_bound 0 - c; + c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < 2 ^ log_cap 0; + c_reduce2 : c < 2 ^ log_cap 0 - c; two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus }. End with_base. @@ -538,30 +560,26 @@ Section Canonicalization. (* allows caller to precompute k and c *) (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) {int_width} (preconditions : freezePreconditions prm int_width). + Local Notation digits := (tuple Z (length limb_widths)). + + Definition encodeZ_opt := Eval compute in Pow2Base.encodeZ. Definition modulus_digits_opt_sig : - { b : digits | b = modulus_digits }. + { b : list Z | b = modulus_digits }. Proof. eexists. - cbv beta iota delta [modulus_digits modulus_digits' app]. - change @max_bound with max_bound_opt. - rewrite c_subst. - change length with length_opt. - change minus with minus_opt. - change Z.add with Z_add_opt. - change Z.sub with Z_sub_opt. - change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. + cbv beta iota delta [modulus_digits]. + change Pow2Base.encodeZ with encodeZ_opt. reflexivity. Defined. - Definition modulus_digits_opt : digits + Definition modulus_digits_opt : list Z := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig). Definition modulus_digits_opt_correct : modulus_digits_opt = modulus_digits := proj2_sig (modulus_digits_opt_sig). - Definition carry_full_3_opt_cps_sig {T} (f : digits -> T) (us : digits) @@ -587,20 +605,20 @@ Section Canonicalization. { b : digits | b = freeze us }. Proof. eexists. - cbv [freeze]. - cbv [and_term]. + cbv [freeze conditional_subtract_modulus]. + rewrite <-from_list_default_eq with (d := 0%Z). + change (@from_list_default Z) with (@from_list_default_opt Z). let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (isFull (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (isFull(carry_full (carry_full (carry_full us)))) RHSf). + let RHSf := match (eval pattern (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in + change (LHS = Let_In (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) RHSf). let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in rewrite <-carry_full_3_opt_cps_correct with (f := RHSf). - cbv beta iota delta [and_term isFull isFull']. + cbv beta iota delta [ge_modulus ge_modulus']. change length with length_opt. - change @max_bound with max_bound_opt. - rewrite c_subst. + change (nth_default 0 modulus_digits) with (nth_default_opt 0 modulus_digits_opt). change @max_ones with max_ones_opt. change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. change minus with minus_opt. @@ -616,7 +634,7 @@ Section Canonicalization. Definition freeze_opt_correct us : freeze_opt us = freeze us := proj2_sig (freeze_opt_sig us). - +(* Lemma freeze_opt_canonical: forall us vs x, @pre_carry_bounds _ _ int_width us -> rep us x -> @pre_carry_bounds _ _ int_width vs -> rep vs x -> @@ -643,5 +661,5 @@ Section Canonicalization. split; eauto using freeze_opt_canonical. auto using freeze_opt_preserves_rep. Qed. - -End Canonicalization. +*) +End Canonicalization.
\ No newline at end of file |