aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v246
1 files changed, 143 insertions, 103 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 7c171faf7..436d309c7 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -1,13 +1,14 @@
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
-Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
+Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Coq.Lists.List.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
@@ -30,6 +31,7 @@ Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
+Definition Z_ones_opt := Eval compute in Z.ones.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.
@@ -115,35 +117,53 @@ Section Carries.
Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
Local Notation digits := (tuple Z (length limb_widths)).
+ Definition carry_gen_opt_sig fc fi i us
+ : { d : list Z | (0 <= fi (S (fi i)) < length us)%nat ->
+ d = carry_gen limb_widths fc fi i us}.
+ Proof.
+ eexists; intros.
+ cbv beta iota delta [carry_gen carry_single Z.pow2_mod].
+ rewrite add_to_nth_set_nth.
+ change @nth_default with @nth_default_opt in *.
+ change @set_nth with @set_nth_opt in *.
+ change Z.ones with Z_ones_opt.
+ rewrite set_nth_nth_default by assumption.
+ rewrite <- @beq_nat_eq_nat_dec.
+ reflexivity.
+ Defined.
+
+ Definition carry_gen_opt fc fi i us := Eval cbv [proj1_sig carry_gen_opt_sig] in
+ proj1_sig (carry_gen_opt_sig fc fi i us).
+
+ Definition carry_gen_opt_correct fc fi i us
+ : (0 <= fi (S (fi i)) < length us)%nat ->
+ carry_gen_opt fc fi i us = carry_gen limb_widths fc fi i us
+ := proj2_sig (carry_gen_opt_sig fc fi i us).
+
Definition carry_opt_sig
- (i : nat) (b : digits)
- : { d : digits | (i < length limb_widths)%nat -> d = carry i b }.
+ (i : nat) (b : list Z)
+ : { d : list Z | (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> d = carry i b }.
Proof.
eexists ; intros.
- cbv [carry ModularBaseSystemList.carry].
- rewrite <-from_list_default_eq with (d := 0%Z).
+ cbv [carry].
rewrite <-pull_app_if_sumbool.
cbv beta delta
- [carry carry_and_reduce Pow2Base.carry_gen Pow2Base.carry_single Pow2Base.carry_simple
- Z.pow2_mod Z.ones Z.pred
- PseudoMersenneBaseParams.limb_widths].
- rewrite !add_to_nth_set_nth.
- change @Pow2Base.base_from_limb_widths with @base_from_limb_widths_opt.
- change @nth_default with @nth_default_opt in *.
- change @set_nth with @set_nth_opt in *.
+ [carry carry_and_reduce carry_simple].
lazymatch goal with
- | [ |- _ = ?f (if ?br then ?c else ?d) ]
- => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y
+ | [ |- _ = (if ?br then ?c else ?d) ]
+ => let x := fresh "x" in let y := fresh "y" in evar (x:list Z); evar (y:list Z); transitivity (if br then x else y); subst x; subst y
end.
- 2:cbv zeta.
- 2:break_if; reflexivity.
-
- change @from_list_default with @from_list_default_opt.
- change @nth_default with @nth_default_opt.
+ Focus 2. {
+ cbv zeta.
+ break_if; rewrite <-carry_gen_opt_correct by (omega ||
+ (replace (length b) with (length limb_widths) by congruence;
+ apply Nat.mod_bound_pos; omega)); reflexivity.
+ } Unfocus.
rewrite c_subst.
- change @set_nth with @set_nth_opt.
- change @map with @map_opt.
rewrite <- @beq_nat_eq_nat_dec.
+ cbv [carry_gen_opt].
reflexivity.
Defined.
@@ -151,11 +171,15 @@ Section Carries.
proj1_sig (carry_opt_sig is us).
Definition carry_opt_correct i us
- : (i < length limb_widths)%nat -> carry_opt i us = carry i us
+ : length us = length limb_widths
+ -> (i < length limb_widths)%nat
+ -> carry_opt i us = carry i us
:= proj2_sig (carry_opt_sig i us).
- Definition carry_sequence_opt_sig (is : list nat) (us : digits)
- : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Definition carry_sequence_opt_sig (is : list nat) (us : list Z)
+ : { b : list Z | (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> b = carry_sequence is us }.
Proof.
eexists. intros H.
cbv [carry_sequence].
@@ -164,9 +188,9 @@ Section Carries.
{ induction is; [ reflexivity | ].
simpl; rewrite IHis, carry_opt_correct.
- reflexivity.
- - rewrite base_length in H.
- apply H; apply in_eq.
- - intros. apply H. right. auto.
+ - fold (carry_sequence is us). auto using length_carry_sequence.
+ - auto using in_eq.
+ - intros. auto using in_cons.
}
Unfocus.
reflexivity.
@@ -176,123 +200,128 @@ Section Carries.
proj1_sig (carry_sequence_opt_sig is us).
Definition carry_sequence_opt_correct is us
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us
+ : (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> carry_sequence_opt is us = carry_sequence is us
:= proj2_sig (carry_sequence_opt_sig is us).
- Definition carry_opt_cps_sig
- {T}
+ Definition carry_gen_opt_cps_sig
+ {T} fc fi
(i : nat)
- (f : digits -> T)
- (b : digits)
- : { d : T | (i < length base)%nat -> d = f (carry i b) }.
+ (f : list Z -> T)
+ (b : list Z)
+ : { d : T | (0 <= fi (S (fi i)) < length b)%nat -> d = f (carry_gen limb_widths fc fi i b) }.
Proof.
eexists. intros H.
- rewrite <- carry_opt_correct by (rewrite base_length in H; assumption).
- cbv beta iota delta [carry_opt].
+ rewrite <-carry_gen_opt_correct by assumption.
+ cbv beta iota delta [carry_gen_opt].
+ match goal with |- appcontext[?a & Z_ones_opt _] =>
let LHS := match goal with |- ?LHS = ?RHS => LHS end in
let RHS := match goal with |- ?LHS = ?RHS => RHS end in
- let RHSf := match (eval pattern (nth_default_opt 0%Z (to_list _ b) i) in RHS) with ?RHSf _ => RHSf end in
- change (LHS = Let_In (nth_default_opt 0%Z (to_list _ b) i) RHSf).
- change Z.shiftl with Z_shiftl_opt.
- match goal with |- appcontext[ ?x + -1] => change (x + -1) with (Z_add_opt x (Z_opp_opt 1)) end.
+ let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in
+ change (LHS = Let_In (a) RHSf) end.
reflexivity.
Defined.
- Definition carry_opt_cps {T} i f b
- := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
+ Definition carry_gen_opt_cps {T} fc fi i f b
+ := Eval cbv beta iota delta [proj1_sig carry_gen_opt_cps_sig] in
+ proj1_sig (@carry_gen_opt_cps_sig T fc fi i f b).
- Definition carry_opt_cps_correct {T} i f b :
- (i < length base)%nat ->
- @carry_opt_cps T i f b = f (carry i b)
- := proj2_sig (carry_opt_cps_sig i f b).
+ Definition carry_gen_opt_cps_correct {T} fc fi i f b :
+ (0 <= fi (S (fi i)) < length b)%nat ->
+ @carry_gen_opt_cps T fc fi i f b = f (carry_gen limb_widths fc fi i b)
+ := proj2_sig (carry_gen_opt_cps_sig fc fi i f b).
- Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits)
- (f : digits -> T)
- : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }.
+ Definition carry_opt_cps_sig
+ {T}
+ (i : nat)
+ (f : list Z -> T)
+ (b : list Z)
+ : { d : T | (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> d = f (carry i b) }.
Proof.
- eexists.
- cbv [carry_sequence].
- transitivity (fold_right carry_opt_cps f (List.rev is) us).
- Focus 2.
- {
- assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
- subst. intros. rewrite <- in_rev in *. auto. }
- remember (rev is) as ris eqn:Heq.
- rewrite <- (rev_involutive is), <- Heq.
- clear H Heq is.
- rewrite fold_left_rev_right.
- revert us; induction ris; [ reflexivity | ]; intros.
- { simpl.
- rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
- rewrite carry_opt_cps_correct; [reflexivity|].
- apply Hr; left; reflexivity.
- } }
- Unfocus.
+ eexists. intros.
+ cbv beta delta
+ [carry carry_and_reduce carry_simple].
+ rewrite <-pull_app_if_sumbool.
+ lazymatch goal with
+ | [ |- _ = ?f (if ?br then ?c else ?d) ]
+ => let x := fresh "x" in let y := fresh "y" in evar (x:T); evar (y:T); transitivity (if br then x else y); subst x; subst y
+ end.
+ Focus 2. {
+ cbv zeta.
+ break_if; rewrite <-carry_gen_opt_cps_correct by (omega ||
+ (replace (length b) with (length limb_widths) by congruence;
+ apply Nat.mod_bound_pos; omega)); reflexivity.
+ } Unfocus.
+ rewrite c_subst.
+ rewrite <- @beq_nat_eq_nat_dec.
reflexivity.
Defined.
- Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) :=
- Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in
- proj1_sig (carry_sequence_opt_cps2_sig is us f).
+ Definition carry_opt_cps {T} i f b
+ := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
- Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T)
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us)
- := proj2_sig (carry_sequence_opt_cps2_sig is us f).
+ Definition carry_opt_cps_correct {T} i f b :
+ (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> @carry_opt_cps T i f b = f (carry i b)
+ := proj2_sig (carry_opt_cps_sig i f b).
- Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits)
- : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Definition carry_sequence_opt_cps_sig {T} (is : list nat) (us : list Z)
+ (f : list Z -> T)
+ : { b : T | (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> b = f (carry_sequence is us) }.
Proof.
eexists.
cbv [carry_sequence].
- transitivity (fold_right carry_opt_cps id (List.rev is) us).
+ transitivity (fold_right carry_opt_cps f (List.rev is) us).
Focus 2.
{
- assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
+ assert (forall i, In i (rev is) -> i < length limb_widths)%nat as Hr. {
subst. intros. rewrite <- in_rev in *. auto. }
remember (rev is) as ris eqn:Heq.
- rewrite <- (rev_involutive is), <- Heq.
- clear H Heq is.
+ rewrite <- (rev_involutive is), <- Heq in H0 |- *.
+ clear H0 Heq is.
rewrite fold_left_rev_right.
- revert us; induction ris; [ reflexivity | ]; intros.
+ revert H. revert us; induction ris; [ reflexivity | ]; intros.
{ simpl.
- rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
- rewrite carry_opt_cps_correct; [reflexivity|].
+ rewrite <- IHris; clear IHris;
+ [|intros; apply Hr; right; assumption|auto using length_carry].
+ rewrite carry_opt_cps_correct; [reflexivity|congruence|].
apply Hr; left; reflexivity.
} }
Unfocus.
+ cbv [carry_opt_cps].
reflexivity.
Defined.
- Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
- proj1_sig (carry_sequence_opt_cps_sig is us).
-
- Definition carry_sequence_opt_cps_correct is us
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us
- := proj2_sig (carry_sequence_opt_cps_sig is us).
-
+ Definition carry_sequence_opt_cps {T} is us (f : list Z -> T) :=
+ Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
+ proj1_sig (carry_sequence_opt_cps_sig is us f).
- Lemma carry_sequence_opt_cps_rep
- : forall (is : list nat) (us : digits) (x : F modulus),
- (forall i : nat, In i is -> i < length base)%nat ->
- rep us x -> rep (carry_sequence_opt_cps is us) x.
- Proof.
- intros.
- rewrite carry_sequence_opt_cps_correct by assumption.
- auto using carry_sequence_rep.
- Qed.
+ Definition carry_sequence_opt_cps_correct {T} is us (f : list Z -> T)
+ : (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> carry_sequence_opt_cps is us f = f (carry_sequence is us)
+ := proj2_sig (carry_sequence_opt_cps_sig is us f).
- Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> (i < length base)%nat.
+ Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) ->
+ (i < length limb_widths)%nat.
Proof.
- unfold Pow2Base.full_carry_chain; rewrite <-base_length; intros.
+ unfold Pow2Base.full_carry_chain; intros.
apply Pow2BaseProofs.make_chain_lt; auto.
Qed.
Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }.
Proof.
eexists.
- cbv [carry_full].
+ cbv [carry_full ModularBaseSystemList.carry_full].
+ rewrite <-from_list_default_eq with (d := 0).
+ rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds).
change @Pow2Base.full_carry_chain with full_carry_chain_opt.
- rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds).
reflexivity.
Defined.
@@ -311,8 +340,10 @@ Section Carries.
eexists.
rewrite <- carry_full_opt_correct.
cbv beta iota delta [carry_full_opt].
- rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds.
- rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds.
+ rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
+ match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) =>
+ change (LHS = (fun x => f (g x)) (carry_sequence is us)) end.
+ rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
reflexivity.
Defined.
@@ -518,7 +549,16 @@ Section Multiplication.
cbv [carry_mul].
erewrite <-carry_full_opt_cps_correct by eauto.
erewrite <-mul_opt_correct.
+ cbv [carry_full_opt_cps mul_opt].
+ erewrite from_list_default_eq.
+ rewrite to_list_from_list.
reflexivity.
+ Grab Existential Variables.
+ rewrite mul'_opt_correct.
+ distr_length.
+ assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; destruct limb_widths; congruence || simpl; omega).
+ rewrite Min.min_l; rewrite !length_to_list; break_match; try omega.
+ rewrite Max.max_l; omega.
Defined.
Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T