diff options
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 246 |
1 files changed, 143 insertions, 103 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 7c171faf7..436d309c7 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -1,13 +1,14 @@ Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. +Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Import Crypto.BaseSystem. Require Import Crypto.ModularArithmetic.ModularBaseSystemList. Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. Require Import Coq.Lists.List. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil. @@ -30,6 +31,7 @@ Definition Z_mul_opt := Eval compute in Z.mul. Definition Z_div_opt := Eval compute in Z.div. Definition Z_pow_opt := Eval compute in Z.pow. Definition Z_opp_opt := Eval compute in Z.opp. +Definition Z_ones_opt := Eval compute in Z.ones. Definition Z_shiftl_opt := Eval compute in Z.shiftl. Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by. @@ -115,35 +117,53 @@ Section Carries. Local Notation base := (Pow2Base.base_from_limb_widths limb_widths). Local Notation digits := (tuple Z (length limb_widths)). + Definition carry_gen_opt_sig fc fi i us + : { d : list Z | (0 <= fi (S (fi i)) < length us)%nat -> + d = carry_gen limb_widths fc fi i us}. + Proof. + eexists; intros. + cbv beta iota delta [carry_gen carry_single Z.pow2_mod]. + rewrite add_to_nth_set_nth. + change @nth_default with @nth_default_opt in *. + change @set_nth with @set_nth_opt in *. + change Z.ones with Z_ones_opt. + rewrite set_nth_nth_default by assumption. + rewrite <- @beq_nat_eq_nat_dec. + reflexivity. + Defined. + + Definition carry_gen_opt fc fi i us := Eval cbv [proj1_sig carry_gen_opt_sig] in + proj1_sig (carry_gen_opt_sig fc fi i us). + + Definition carry_gen_opt_correct fc fi i us + : (0 <= fi (S (fi i)) < length us)%nat -> + carry_gen_opt fc fi i us = carry_gen limb_widths fc fi i us + := proj2_sig (carry_gen_opt_sig fc fi i us). + Definition carry_opt_sig - (i : nat) (b : digits) - : { d : digits | (i < length limb_widths)%nat -> d = carry i b }. + (i : nat) (b : list Z) + : { d : list Z | (length b = length limb_widths) + -> (i < length limb_widths)%nat + -> d = carry i b }. Proof. eexists ; intros. - cbv [carry ModularBaseSystemList.carry]. - rewrite <-from_list_default_eq with (d := 0%Z). + cbv [carry]. rewrite <-pull_app_if_sumbool. cbv beta delta - [carry carry_and_reduce Pow2Base.carry_gen Pow2Base.carry_single Pow2Base.carry_simple - Z.pow2_mod Z.ones Z.pred - PseudoMersenneBaseParams.limb_widths]. - rewrite !add_to_nth_set_nth. - change @Pow2Base.base_from_limb_widths with @base_from_limb_widths_opt. - change @nth_default with @nth_default_opt in *. - change @set_nth with @set_nth_opt in *. + [carry carry_and_reduce carry_simple]. lazymatch goal with - | [ |- _ = ?f (if ?br then ?c else ?d) ] - => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y + | [ |- _ = (if ?br then ?c else ?d) ] + => let x := fresh "x" in let y := fresh "y" in evar (x:list Z); evar (y:list Z); transitivity (if br then x else y); subst x; subst y end. - 2:cbv zeta. - 2:break_if; reflexivity. - - change @from_list_default with @from_list_default_opt. - change @nth_default with @nth_default_opt. + Focus 2. { + cbv zeta. + break_if; rewrite <-carry_gen_opt_correct by (omega || + (replace (length b) with (length limb_widths) by congruence; + apply Nat.mod_bound_pos; omega)); reflexivity. + } Unfocus. rewrite c_subst. - change @set_nth with @set_nth_opt. - change @map with @map_opt. rewrite <- @beq_nat_eq_nat_dec. + cbv [carry_gen_opt]. reflexivity. Defined. @@ -151,11 +171,15 @@ Section Carries. proj1_sig (carry_opt_sig is us). Definition carry_opt_correct i us - : (i < length limb_widths)%nat -> carry_opt i us = carry i us + : length us = length limb_widths + -> (i < length limb_widths)%nat + -> carry_opt i us = carry i us := proj2_sig (carry_opt_sig i us). - Definition carry_sequence_opt_sig (is : list nat) (us : digits) - : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }. + Definition carry_sequence_opt_sig (is : list nat) (us : list Z) + : { b : list Z | (length us = length limb_widths) + -> (forall i, In i is -> i < length limb_widths)%nat + -> b = carry_sequence is us }. Proof. eexists. intros H. cbv [carry_sequence]. @@ -164,9 +188,9 @@ Section Carries. { induction is; [ reflexivity | ]. simpl; rewrite IHis, carry_opt_correct. - reflexivity. - - rewrite base_length in H. - apply H; apply in_eq. - - intros. apply H. right. auto. + - fold (carry_sequence is us). auto using length_carry_sequence. + - auto using in_eq. + - intros. auto using in_cons. } Unfocus. reflexivity. @@ -176,123 +200,128 @@ Section Carries. proj1_sig (carry_sequence_opt_sig is us). Definition carry_sequence_opt_correct is us - : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us + : (length us = length limb_widths) + -> (forall i, In i is -> i < length limb_widths)%nat + -> carry_sequence_opt is us = carry_sequence is us := proj2_sig (carry_sequence_opt_sig is us). - Definition carry_opt_cps_sig - {T} + Definition carry_gen_opt_cps_sig + {T} fc fi (i : nat) - (f : digits -> T) - (b : digits) - : { d : T | (i < length base)%nat -> d = f (carry i b) }. + (f : list Z -> T) + (b : list Z) + : { d : T | (0 <= fi (S (fi i)) < length b)%nat -> d = f (carry_gen limb_widths fc fi i b) }. Proof. eexists. intros H. - rewrite <- carry_opt_correct by (rewrite base_length in H; assumption). - cbv beta iota delta [carry_opt]. + rewrite <-carry_gen_opt_correct by assumption. + cbv beta iota delta [carry_gen_opt]. + match goal with |- appcontext[?a & Z_ones_opt _] => let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (nth_default_opt 0%Z (to_list _ b) i) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (nth_default_opt 0%Z (to_list _ b) i) RHSf). - change Z.shiftl with Z_shiftl_opt. - match goal with |- appcontext[ ?x + -1] => change (x + -1) with (Z_add_opt x (Z_opp_opt 1)) end. + let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in + change (LHS = Let_In (a) RHSf) end. reflexivity. Defined. - Definition carry_opt_cps {T} i f b - := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b). + Definition carry_gen_opt_cps {T} fc fi i f b + := Eval cbv beta iota delta [proj1_sig carry_gen_opt_cps_sig] in + proj1_sig (@carry_gen_opt_cps_sig T fc fi i f b). - Definition carry_opt_cps_correct {T} i f b : - (i < length base)%nat -> - @carry_opt_cps T i f b = f (carry i b) - := proj2_sig (carry_opt_cps_sig i f b). + Definition carry_gen_opt_cps_correct {T} fc fi i f b : + (0 <= fi (S (fi i)) < length b)%nat -> + @carry_gen_opt_cps T fc fi i f b = f (carry_gen limb_widths fc fi i b) + := proj2_sig (carry_gen_opt_cps_sig fc fi i f b). - Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits) - (f : digits -> T) - : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }. + Definition carry_opt_cps_sig + {T} + (i : nat) + (f : list Z -> T) + (b : list Z) + : { d : T | (length b = length limb_widths) + -> (i < length limb_widths)%nat + -> d = f (carry i b) }. Proof. - eexists. - cbv [carry_sequence]. - transitivity (fold_right carry_opt_cps f (List.rev is) us). - Focus 2. - { - assert (forall i, In i (rev is) -> i < length base)%nat as Hr. { - subst. intros. rewrite <- in_rev in *. auto. } - remember (rev is) as ris eqn:Heq. - rewrite <- (rev_involutive is), <- Heq. - clear H Heq is. - rewrite fold_left_rev_right. - revert us; induction ris; [ reflexivity | ]; intros. - { simpl. - rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption]. - rewrite carry_opt_cps_correct; [reflexivity|]. - apply Hr; left; reflexivity. - } } - Unfocus. + eexists. intros. + cbv beta delta + [carry carry_and_reduce carry_simple]. + rewrite <-pull_app_if_sumbool. + lazymatch goal with + | [ |- _ = ?f (if ?br then ?c else ?d) ] + => let x := fresh "x" in let y := fresh "y" in evar (x:T); evar (y:T); transitivity (if br then x else y); subst x; subst y + end. + Focus 2. { + cbv zeta. + break_if; rewrite <-carry_gen_opt_cps_correct by (omega || + (replace (length b) with (length limb_widths) by congruence; + apply Nat.mod_bound_pos; omega)); reflexivity. + } Unfocus. + rewrite c_subst. + rewrite <- @beq_nat_eq_nat_dec. reflexivity. Defined. - Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) := - Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in - proj1_sig (carry_sequence_opt_cps2_sig is us f). + Definition carry_opt_cps {T} i f b + := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b). - Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T) - : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us) - := proj2_sig (carry_sequence_opt_cps2_sig is us f). + Definition carry_opt_cps_correct {T} i f b : + (length b = length limb_widths) + -> (i < length limb_widths)%nat + -> @carry_opt_cps T i f b = f (carry i b) + := proj2_sig (carry_opt_cps_sig i f b). - Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits) - : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }. + Definition carry_sequence_opt_cps_sig {T} (is : list nat) (us : list Z) + (f : list Z -> T) + : { b : T | (length us = length limb_widths) + -> (forall i, In i is -> i < length limb_widths)%nat + -> b = f (carry_sequence is us) }. Proof. eexists. cbv [carry_sequence]. - transitivity (fold_right carry_opt_cps id (List.rev is) us). + transitivity (fold_right carry_opt_cps f (List.rev is) us). Focus 2. { - assert (forall i, In i (rev is) -> i < length base)%nat as Hr. { + assert (forall i, In i (rev is) -> i < length limb_widths)%nat as Hr. { subst. intros. rewrite <- in_rev in *. auto. } remember (rev is) as ris eqn:Heq. - rewrite <- (rev_involutive is), <- Heq. - clear H Heq is. + rewrite <- (rev_involutive is), <- Heq in H0 |- *. + clear H0 Heq is. rewrite fold_left_rev_right. - revert us; induction ris; [ reflexivity | ]; intros. + revert H. revert us; induction ris; [ reflexivity | ]; intros. { simpl. - rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption]. - rewrite carry_opt_cps_correct; [reflexivity|]. + rewrite <- IHris; clear IHris; + [|intros; apply Hr; right; assumption|auto using length_carry]. + rewrite carry_opt_cps_correct; [reflexivity|congruence|]. apply Hr; left; reflexivity. } } Unfocus. + cbv [carry_opt_cps]. reflexivity. Defined. - Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in - proj1_sig (carry_sequence_opt_cps_sig is us). - - Definition carry_sequence_opt_cps_correct is us - : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us - := proj2_sig (carry_sequence_opt_cps_sig is us). - + Definition carry_sequence_opt_cps {T} is us (f : list Z -> T) := + Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in + proj1_sig (carry_sequence_opt_cps_sig is us f). - Lemma carry_sequence_opt_cps_rep - : forall (is : list nat) (us : digits) (x : F modulus), - (forall i : nat, In i is -> i < length base)%nat -> - rep us x -> rep (carry_sequence_opt_cps is us) x. - Proof. - intros. - rewrite carry_sequence_opt_cps_correct by assumption. - auto using carry_sequence_rep. - Qed. + Definition carry_sequence_opt_cps_correct {T} is us (f : list Z -> T) + : (length us = length limb_widths) + -> (forall i, In i is -> i < length limb_widths)%nat + -> carry_sequence_opt_cps is us f = f (carry_sequence is us) + := proj2_sig (carry_sequence_opt_cps_sig is us f). - Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> (i < length base)%nat. + Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> + (i < length limb_widths)%nat. Proof. - unfold Pow2Base.full_carry_chain; rewrite <-base_length; intros. + unfold Pow2Base.full_carry_chain; intros. apply Pow2BaseProofs.make_chain_lt; auto. Qed. Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }. Proof. eexists. - cbv [carry_full]. + cbv [carry_full ModularBaseSystemList.carry_full]. + rewrite <-from_list_default_eq with (d := 0). + rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds). change @Pow2Base.full_carry_chain with full_carry_chain_opt. - rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds). reflexivity. Defined. @@ -311,8 +340,10 @@ Section Carries. eexists. rewrite <- carry_full_opt_correct. cbv beta iota delta [carry_full_opt]. - rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds. - rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds. + rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds). + match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) => + change (LHS = (fun x => f (g x)) (carry_sequence is us)) end. + rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds). reflexivity. Defined. @@ -518,7 +549,16 @@ Section Multiplication. cbv [carry_mul]. erewrite <-carry_full_opt_cps_correct by eauto. erewrite <-mul_opt_correct. + cbv [carry_full_opt_cps mul_opt]. + erewrite from_list_default_eq. + rewrite to_list_from_list. reflexivity. + Grab Existential Variables. + rewrite mul'_opt_correct. + distr_length. + assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; destruct limb_widths; congruence || simpl; omega). + rewrite Min.min_l; rewrite !length_to_list; break_match; try omega. + rewrite Max.max_l; omega. Defined. Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T |