aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-07-25 21:06:07 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-07-25 21:06:07 -0400
commit39a6c95de8a900c859726d875cc40ea96298d31b (patch)
tree750571dc101f477c34340716db87a3697cca41eb /src
parentea9397e3da37f35d088488be141cb18cc38ea11b (diff)
Put ModularBaseSystem carries in terms of [carry_gen], and pushed this change through the pipeline. Also began the process of redoing canonicalization proofs, attempting to put the messy case analysis in theorem statements rather than separate lemmas.
Diffstat (limited to 'src')
-rw-r--r--src/ModularArithmetic/ExtendedBaseVector.v1
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v8
-rw-r--r--src/ModularArithmetic/ModularBaseSystemList.v17
-rw-r--r--src/ModularArithmetic/ModularBaseSystemListProofs.v14
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v246
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v1831
-rw-r--r--src/ModularArithmetic/Pow2Base.v11
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v135
8 files changed, 494 insertions, 1769 deletions
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v
index fcd871aae..ef8c9716a 100644
--- a/src/ModularArithmetic/ExtendedBaseVector.v
+++ b/src/ModularArithmetic/ExtendedBaseVector.v
@@ -122,6 +122,7 @@ Section ExtendedBaseVector.
rewrite (map_nth_default _ _ _ _ 0) by omega.
apply base_matches_modulus; auto using limb_widths_nonnegative, limb_widths_match_modulus;
distr_length.
+ assumption.
} { (* i < length base, j >= length base, i + j >= length base *)
do 2 rewrite map_nth_default_base_high by omega.
remember (j - length base)%nat as j'.
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 70c8138da..4fd881e1e 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -50,16 +50,12 @@ Section ModularBaseSystem.
(* Placeholder *)
Definition div (x y : digits) : digits := encode (ModularArithmetic.div (decode x) (decode y)).
- Definition carry i (us : digits) : digits := from_list (carry i [[us]])
- (length_carry length_to_list).
-
Definition rep (us : digits) (x : F modulus) := decode us = x.
Local Notation "u ~= x" := (rep u x).
Local Hint Unfold rep.
- Definition carry_sequence is (us : digits) : digits := fold_right carry us is.
-
- Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths).
+ Definition carry_full (us : digits) : digits := from_list (carry_full [[us]])
+ (length_carry_full length_to_list).
Definition carry_mul (us vs : digits) : digits := carry_full (mul us vs).
diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v
index 07b2c2bac..cbab03d6a 100644
--- a/src/ModularArithmetic/ModularBaseSystemList.v
+++ b/src/ModularArithmetic/ModularBaseSystemList.v
@@ -31,20 +31,21 @@ Section Defs.
(* In order to subtract without underflowing, we add a multiple of the modulus first. *)
Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs.
- (*
+ (* [carry_and_reduce] multiplies the carried value by c, and, if carrying
+ from index [i] in a list [us], adds the value to the digit with index
+ [(S i) mod (length us)] *)
Definition carry_and_reduce :=
- carry_gen limb_widths (fun ci => c * ci).
- *)
- Definition carry_and_reduce i := fun us =>
- let di := us [i] in
- let us' := set_nth i (Z.pow2_mod di (limb_widths [i])) us in
- add_to_nth 0 (c * (Z.shiftr di (limb_widths [i]))) us'.
+ carry_gen limb_widths (fun ci => c * ci) (fun Si => (Si mod (length limb_widths))%nat).
Definition carry i : digits -> digits :=
- if eq_nat_dec i (pred (length base))
+ if eq_nat_dec i (pred (length limb_widths))
then carry_and_reduce i
else carry_simple limb_widths i.
+ Definition carry_sequence is (us : digits) : digits := fold_right carry us is.
+
+ Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths).
+
Definition modulus_digits := encodeZ limb_widths modulus.
(* compute at compile time *)
diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v
index 35de02cde..a49c26a53 100644
--- a/src/ModularArithmetic/ModularBaseSystemListProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v
@@ -76,6 +76,20 @@ Section LengthProofs.
Proof. intros; unfold carry; break_if; autorewrite with distr_length; omega. Qed.
Hint Rewrite @length_carry : distr_length.
+ Lemma length_carry_sequence {u i} :
+ length u = length limb_widths
+ -> length (carry_sequence i u) = length limb_widths.
+ Proof.
+ induction i; intros; unfold carry_sequence;
+ simpl; autorewrite with distr_length; auto. Qed.
+ Hint Rewrite @length_carry_sequence : distr_length.
+
+ Lemma length_carry_full {u} :
+ length u = length limb_widths
+ -> length (carry_full u) = length limb_widths.
+ Proof. intros; unfold carry_full; autorewrite with distr_length; congruence. Qed.
+ Hint Rewrite @length_carry_full : distr_length.
+
Lemma length_modulus_digits : length modulus_digits = length limb_widths.
Proof.
intros; unfold modulus_digits, encodeZ.
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 7c171faf7..436d309c7 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -1,13 +1,14 @@
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
-Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
+Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Coq.Lists.List.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
@@ -30,6 +31,7 @@ Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
+Definition Z_ones_opt := Eval compute in Z.ones.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.
@@ -115,35 +117,53 @@ Section Carries.
Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
Local Notation digits := (tuple Z (length limb_widths)).
+ Definition carry_gen_opt_sig fc fi i us
+ : { d : list Z | (0 <= fi (S (fi i)) < length us)%nat ->
+ d = carry_gen limb_widths fc fi i us}.
+ Proof.
+ eexists; intros.
+ cbv beta iota delta [carry_gen carry_single Z.pow2_mod].
+ rewrite add_to_nth_set_nth.
+ change @nth_default with @nth_default_opt in *.
+ change @set_nth with @set_nth_opt in *.
+ change Z.ones with Z_ones_opt.
+ rewrite set_nth_nth_default by assumption.
+ rewrite <- @beq_nat_eq_nat_dec.
+ reflexivity.
+ Defined.
+
+ Definition carry_gen_opt fc fi i us := Eval cbv [proj1_sig carry_gen_opt_sig] in
+ proj1_sig (carry_gen_opt_sig fc fi i us).
+
+ Definition carry_gen_opt_correct fc fi i us
+ : (0 <= fi (S (fi i)) < length us)%nat ->
+ carry_gen_opt fc fi i us = carry_gen limb_widths fc fi i us
+ := proj2_sig (carry_gen_opt_sig fc fi i us).
+
Definition carry_opt_sig
- (i : nat) (b : digits)
- : { d : digits | (i < length limb_widths)%nat -> d = carry i b }.
+ (i : nat) (b : list Z)
+ : { d : list Z | (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> d = carry i b }.
Proof.
eexists ; intros.
- cbv [carry ModularBaseSystemList.carry].
- rewrite <-from_list_default_eq with (d := 0%Z).
+ cbv [carry].
rewrite <-pull_app_if_sumbool.
cbv beta delta
- [carry carry_and_reduce Pow2Base.carry_gen Pow2Base.carry_single Pow2Base.carry_simple
- Z.pow2_mod Z.ones Z.pred
- PseudoMersenneBaseParams.limb_widths].
- rewrite !add_to_nth_set_nth.
- change @Pow2Base.base_from_limb_widths with @base_from_limb_widths_opt.
- change @nth_default with @nth_default_opt in *.
- change @set_nth with @set_nth_opt in *.
+ [carry carry_and_reduce carry_simple].
lazymatch goal with
- | [ |- _ = ?f (if ?br then ?c else ?d) ]
- => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y
+ | [ |- _ = (if ?br then ?c else ?d) ]
+ => let x := fresh "x" in let y := fresh "y" in evar (x:list Z); evar (y:list Z); transitivity (if br then x else y); subst x; subst y
end.
- 2:cbv zeta.
- 2:break_if; reflexivity.
-
- change @from_list_default with @from_list_default_opt.
- change @nth_default with @nth_default_opt.
+ Focus 2. {
+ cbv zeta.
+ break_if; rewrite <-carry_gen_opt_correct by (omega ||
+ (replace (length b) with (length limb_widths) by congruence;
+ apply Nat.mod_bound_pos; omega)); reflexivity.
+ } Unfocus.
rewrite c_subst.
- change @set_nth with @set_nth_opt.
- change @map with @map_opt.
rewrite <- @beq_nat_eq_nat_dec.
+ cbv [carry_gen_opt].
reflexivity.
Defined.
@@ -151,11 +171,15 @@ Section Carries.
proj1_sig (carry_opt_sig is us).
Definition carry_opt_correct i us
- : (i < length limb_widths)%nat -> carry_opt i us = carry i us
+ : length us = length limb_widths
+ -> (i < length limb_widths)%nat
+ -> carry_opt i us = carry i us
:= proj2_sig (carry_opt_sig i us).
- Definition carry_sequence_opt_sig (is : list nat) (us : digits)
- : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Definition carry_sequence_opt_sig (is : list nat) (us : list Z)
+ : { b : list Z | (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> b = carry_sequence is us }.
Proof.
eexists. intros H.
cbv [carry_sequence].
@@ -164,9 +188,9 @@ Section Carries.
{ induction is; [ reflexivity | ].
simpl; rewrite IHis, carry_opt_correct.
- reflexivity.
- - rewrite base_length in H.
- apply H; apply in_eq.
- - intros. apply H. right. auto.
+ - fold (carry_sequence is us). auto using length_carry_sequence.
+ - auto using in_eq.
+ - intros. auto using in_cons.
}
Unfocus.
reflexivity.
@@ -176,123 +200,128 @@ Section Carries.
proj1_sig (carry_sequence_opt_sig is us).
Definition carry_sequence_opt_correct is us
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us
+ : (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> carry_sequence_opt is us = carry_sequence is us
:= proj2_sig (carry_sequence_opt_sig is us).
- Definition carry_opt_cps_sig
- {T}
+ Definition carry_gen_opt_cps_sig
+ {T} fc fi
(i : nat)
- (f : digits -> T)
- (b : digits)
- : { d : T | (i < length base)%nat -> d = f (carry i b) }.
+ (f : list Z -> T)
+ (b : list Z)
+ : { d : T | (0 <= fi (S (fi i)) < length b)%nat -> d = f (carry_gen limb_widths fc fi i b) }.
Proof.
eexists. intros H.
- rewrite <- carry_opt_correct by (rewrite base_length in H; assumption).
- cbv beta iota delta [carry_opt].
+ rewrite <-carry_gen_opt_correct by assumption.
+ cbv beta iota delta [carry_gen_opt].
+ match goal with |- appcontext[?a & Z_ones_opt _] =>
let LHS := match goal with |- ?LHS = ?RHS => LHS end in
let RHS := match goal with |- ?LHS = ?RHS => RHS end in
- let RHSf := match (eval pattern (nth_default_opt 0%Z (to_list _ b) i) in RHS) with ?RHSf _ => RHSf end in
- change (LHS = Let_In (nth_default_opt 0%Z (to_list _ b) i) RHSf).
- change Z.shiftl with Z_shiftl_opt.
- match goal with |- appcontext[ ?x + -1] => change (x + -1) with (Z_add_opt x (Z_opp_opt 1)) end.
+ let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in
+ change (LHS = Let_In (a) RHSf) end.
reflexivity.
Defined.
- Definition carry_opt_cps {T} i f b
- := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
+ Definition carry_gen_opt_cps {T} fc fi i f b
+ := Eval cbv beta iota delta [proj1_sig carry_gen_opt_cps_sig] in
+ proj1_sig (@carry_gen_opt_cps_sig T fc fi i f b).
- Definition carry_opt_cps_correct {T} i f b :
- (i < length base)%nat ->
- @carry_opt_cps T i f b = f (carry i b)
- := proj2_sig (carry_opt_cps_sig i f b).
+ Definition carry_gen_opt_cps_correct {T} fc fi i f b :
+ (0 <= fi (S (fi i)) < length b)%nat ->
+ @carry_gen_opt_cps T fc fi i f b = f (carry_gen limb_widths fc fi i b)
+ := proj2_sig (carry_gen_opt_cps_sig fc fi i f b).
- Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits)
- (f : digits -> T)
- : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }.
+ Definition carry_opt_cps_sig
+ {T}
+ (i : nat)
+ (f : list Z -> T)
+ (b : list Z)
+ : { d : T | (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> d = f (carry i b) }.
Proof.
- eexists.
- cbv [carry_sequence].
- transitivity (fold_right carry_opt_cps f (List.rev is) us).
- Focus 2.
- {
- assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
- subst. intros. rewrite <- in_rev in *. auto. }
- remember (rev is) as ris eqn:Heq.
- rewrite <- (rev_involutive is), <- Heq.
- clear H Heq is.
- rewrite fold_left_rev_right.
- revert us; induction ris; [ reflexivity | ]; intros.
- { simpl.
- rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
- rewrite carry_opt_cps_correct; [reflexivity|].
- apply Hr; left; reflexivity.
- } }
- Unfocus.
+ eexists. intros.
+ cbv beta delta
+ [carry carry_and_reduce carry_simple].
+ rewrite <-pull_app_if_sumbool.
+ lazymatch goal with
+ | [ |- _ = ?f (if ?br then ?c else ?d) ]
+ => let x := fresh "x" in let y := fresh "y" in evar (x:T); evar (y:T); transitivity (if br then x else y); subst x; subst y
+ end.
+ Focus 2. {
+ cbv zeta.
+ break_if; rewrite <-carry_gen_opt_cps_correct by (omega ||
+ (replace (length b) with (length limb_widths) by congruence;
+ apply Nat.mod_bound_pos; omega)); reflexivity.
+ } Unfocus.
+ rewrite c_subst.
+ rewrite <- @beq_nat_eq_nat_dec.
reflexivity.
Defined.
- Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) :=
- Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in
- proj1_sig (carry_sequence_opt_cps2_sig is us f).
+ Definition carry_opt_cps {T} i f b
+ := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
- Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T)
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us)
- := proj2_sig (carry_sequence_opt_cps2_sig is us f).
+ Definition carry_opt_cps_correct {T} i f b :
+ (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> @carry_opt_cps T i f b = f (carry i b)
+ := proj2_sig (carry_opt_cps_sig i f b).
- Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits)
- : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Definition carry_sequence_opt_cps_sig {T} (is : list nat) (us : list Z)
+ (f : list Z -> T)
+ : { b : T | (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> b = f (carry_sequence is us) }.
Proof.
eexists.
cbv [carry_sequence].
- transitivity (fold_right carry_opt_cps id (List.rev is) us).
+ transitivity (fold_right carry_opt_cps f (List.rev is) us).
Focus 2.
{
- assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
+ assert (forall i, In i (rev is) -> i < length limb_widths)%nat as Hr. {
subst. intros. rewrite <- in_rev in *. auto. }
remember (rev is) as ris eqn:Heq.
- rewrite <- (rev_involutive is), <- Heq.
- clear H Heq is.
+ rewrite <- (rev_involutive is), <- Heq in H0 |- *.
+ clear H0 Heq is.
rewrite fold_left_rev_right.
- revert us; induction ris; [ reflexivity | ]; intros.
+ revert H. revert us; induction ris; [ reflexivity | ]; intros.
{ simpl.
- rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
- rewrite carry_opt_cps_correct; [reflexivity|].
+ rewrite <- IHris; clear IHris;
+ [|intros; apply Hr; right; assumption|auto using length_carry].
+ rewrite carry_opt_cps_correct; [reflexivity|congruence|].
apply Hr; left; reflexivity.
} }
Unfocus.
+ cbv [carry_opt_cps].
reflexivity.
Defined.
- Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
- proj1_sig (carry_sequence_opt_cps_sig is us).
-
- Definition carry_sequence_opt_cps_correct is us
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us
- := proj2_sig (carry_sequence_opt_cps_sig is us).
-
+ Definition carry_sequence_opt_cps {T} is us (f : list Z -> T) :=
+ Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
+ proj1_sig (carry_sequence_opt_cps_sig is us f).
- Lemma carry_sequence_opt_cps_rep
- : forall (is : list nat) (us : digits) (x : F modulus),
- (forall i : nat, In i is -> i < length base)%nat ->
- rep us x -> rep (carry_sequence_opt_cps is us) x.
- Proof.
- intros.
- rewrite carry_sequence_opt_cps_correct by assumption.
- auto using carry_sequence_rep.
- Qed.
+ Definition carry_sequence_opt_cps_correct {T} is us (f : list Z -> T)
+ : (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> carry_sequence_opt_cps is us f = f (carry_sequence is us)
+ := proj2_sig (carry_sequence_opt_cps_sig is us f).
- Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> (i < length base)%nat.
+ Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) ->
+ (i < length limb_widths)%nat.
Proof.
- unfold Pow2Base.full_carry_chain; rewrite <-base_length; intros.
+ unfold Pow2Base.full_carry_chain; intros.
apply Pow2BaseProofs.make_chain_lt; auto.
Qed.
Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }.
Proof.
eexists.
- cbv [carry_full].
+ cbv [carry_full ModularBaseSystemList.carry_full].
+ rewrite <-from_list_default_eq with (d := 0).
+ rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds).
change @Pow2Base.full_carry_chain with full_carry_chain_opt.
- rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds).
reflexivity.
Defined.
@@ -311,8 +340,10 @@ Section Carries.
eexists.
rewrite <- carry_full_opt_correct.
cbv beta iota delta [carry_full_opt].
- rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds.
- rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds.
+ rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
+ match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) =>
+ change (LHS = (fun x => f (g x)) (carry_sequence is us)) end.
+ rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
reflexivity.
Defined.
@@ -518,7 +549,16 @@ Section Multiplication.
cbv [carry_mul].
erewrite <-carry_full_opt_cps_correct by eauto.
erewrite <-mul_opt_correct.
+ cbv [carry_full_opt_cps mul_opt].
+ erewrite from_list_default_eq.
+ rewrite to_list_from_list.
reflexivity.
+ Grab Existential Variables.
+ rewrite mul'_opt_correct.
+ distr_length.
+ assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; destruct limb_widths; congruence || simpl; omega).
+ rewrite Min.min_l; rewrite !length_to_list; break_match; try omega.
+ rewrite Max.max_l; omega.
Defined.
Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index e6351dc17..01c073f06 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -184,7 +184,7 @@ Section PseudoMersenneProofs.
Proof.
intros.
apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
- apply base_succ; eauto.
+ apply base_succ; distr_length; eauto.
Qed.
Lemma Fdecode_decode_mod : forall us x,
@@ -253,40 +253,41 @@ Section CarryProofs.
Lemma carry_decode_eq_reduce : forall us,
(length us = length limb_widths) ->
- BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
+ BaseSystem.decode base (carry_and_reduce (pred (length limb_widths)) us) mod modulus
= BaseSystem.decode base us mod modulus.
Proof.
- unfold carry_and_reduce; intros ? length_eq.
- pose proof base_length_nonzero.
- rewrite add_to_nth_sum by (rewrite length_set_nth, base_length; omega).
- rewrite set_nth_sum by (rewrite base_length; omega).
- rewrite Zplus_comm, <- Z.mul_assoc, <- pseudomersenne_add, BaseSystem.b0_1.
- rewrite (Z.mul_comm (2 ^ k)), <- Zred_factor0.
- f_equal.
- rewrite <- (Z.add_comm (BaseSystem.decode base us)), <- Z.add_assoc, <- Z.add_0_r.
- f_equal.
- destruct (NPeano.Nat.eq_dec (length base) 0) as [length_zero | length_nonzero].
- + pose proof (base_length) as limbs_length.
- destruct us; rewrite ?(@nil_length0 Z), ?(@length_cons Z) in *;
- pose proof base_length_nonzero; omega.
- + rewrite nth_default_base by (omega || eauto).
- rewrite <- Z.add_opp_l, <- Z.opp_sub_distr.
- unfold Z.pow2_mod.
- rewrite Z.land_ones by eauto using log_cap_nonneg.
- rewrite <- Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega || eauto using log_cap_nonneg).
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- unfold k.
- replace (length limb_widths) with (S (pred (length base))) by
- (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega).
- rewrite sum_firstn_succ with (x:= log_cap (pred (length base))) by
- (apply nth_error_Some_nth_default; rewrite <- base_length; omega).
- rewrite Z.pow_add_r by eauto using log_cap_nonneg.
- ring.
+ cbv [carry_and_reduce]; intros.
+ rewrite carry_gen_decode_eq; auto.
+ distr_length.
+ assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil;
+ destruct limb_widths; distr_length; congruence).
+ repeat break_if; repeat rewrite ?pred_mod, ?Nat.succ_pred,?Nat.mod_same in * by omega;
+ try omega.
+ rewrite !nth_default_base by (omega || auto).
+ rewrite sum_firstn_0.
+ autorewrite with zsimplify.
+ match goal with |- appcontext [2 ^ ?a * ?b * 2 ^ ?c] =>
+ replace (2 ^ a * b * 2 ^ c) with (2 ^ (a + c) * b) end.
+ { rewrite <-sum_firstn_succ by (apply nth_error_Some_nth_default; omega).
+ rewrite Nat.succ_pred by omega.
+ remember (Init.Nat.pred (length limb_widths)) as pred_len.
+ fold k.
+ rewrite <-Z.mul_sub_distr_r.
+ replace (c - 2 ^ k) with (modulus * -1) by (cbv [c]; ring).
+ rewrite <-Z.mul_assoc.
+ apply Z.mod_add_l'.
+ pose proof prime_modulus. Z.prime_bound. }
+ { rewrite Z.pow_add_r; auto using log_cap_nonneg, sum_firstn_limb_widths_nonneg.
+ rewrite <-!Z.mul_assoc.
+ apply Z.mul_cancel_l; try ring.
+ apply Z.pow_nonzero; (omega || auto using log_cap_nonneg). }
Qed.
Lemma carry_rep : forall i us x,
- (i < length base)%nat ->
- us ~= x -> carry i us ~= x.
+ (length us = length limb_widths)%nat ->
+ (i < length limb_widths)%nat ->
+ forall pf1 pf2,
+ from_list _ us pf1 ~= x -> from_list _ (carry i us) pf2 ~= x.
Proof.
cbv [carry rep decode]; intros.
rewrite to_list_from_list.
@@ -295,18 +296,24 @@ Section CarryProofs.
specialize_by eauto.
cbv [ModularBaseSystemList.carry].
break_if; subst; eauto.
- apply F_eq; apply carry_decode_eq_reduce; apply length_to_list.
+ apply F_eq.
+ rewrite to_list_from_list.
+ apply carry_decode_eq_reduce. auto.
cbv [ModularBaseSystemList.decode].
- f_equal.
- apply carry_simple_decode_eq; try omega; rewrite ?base_length; auto using length_to_list.
+ apply ZToField_eqmod.
+ rewrite to_list_from_list, carry_simple_decode_eq; try omega; distr_length; auto.
Qed.
Hint Resolve carry_rep.
Lemma carry_sequence_rep : forall is us x,
- (forall i, In i is -> (i < length base)%nat) ->
- us ~= x -> carry_sequence is us ~= x.
+ (forall i, In i is -> (i < length limb_widths)%nat) ->
+ us ~= x -> forall pf, from_list _ (carry_sequence is (to_list _ us)) pf ~= x.
Proof.
- induction is; boring.
+ induction is; intros.
+ + cbv [carry_sequence fold_right]. rewrite from_list_to_list. assumption.
+ + simpl. apply carry_rep with (pf1 := length_carry_sequence (length_to_list us));
+ auto using length_carry_sequence, length_to_list, in_eq.
+ apply IHis; auto using in_cons.
Qed.
Lemma carry_full_preserves_rep : forall us x,
@@ -314,7 +321,7 @@ Section CarryProofs.
Proof.
unfold carry_full; intros.
apply carry_sequence_rep; auto.
- unfold full_carry_chain; rewrite base_length; apply make_chain_lt.
+ unfold full_carry_chain; apply make_chain_lt.
Qed.
Opaque carry_full.
@@ -334,1561 +341,227 @@ Section CanonicalizationProofs.
Context `{prm : PseudoMersenneBaseParams}.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
- Context (lt_1_length_base : (1 < length base)%nat)
+ Context (lt_1_length_base : (1 < length limb_widths)%nat)
{B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B)
(* on the first reduce step, we add at most one bit of width to the first digit *)
- (c_reduce1 : c * (Z.ones (B - log_cap (pred (length base)))) < 2 ^ log_cap 0)
+ (c_reduce1 : c * ((2 ^ B) >> log_cap (pred (length limb_widths))) <= 2 ^ log_cap 0)
(* on the second reduce step, we add at most one bit of width to the first digit,
and leave room to carry c one more time after the highest bit is carried *)
(c_reduce2 : c <= nth_default 0 modulus_digits 0)
(* this condition is probably implied by c_reduce2, but is more straighforward to compute than to prove *)
(two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus).
-(*
- (* BEGIN groundwork proofs *)
- Local Hint Resolve (@log_cap_nonneg limb_widths) limb_widths_nonneg.
-
- Lemma pow_2_log_cap_pos : forall i, 0 < 2 ^ log_cap i.
- Proof.
- intros; apply Z.pow_pos_nonneg; eauto using log_cap_nonneg; omega.
- Qed.
- Local Hint Resolve pow_2_log_cap_pos.
-
- Lemma max_value_log_cap : forall i, Z.succ (max_value i) = 2 ^ log_cap i.
- Proof.
- intros.
- unfold max_value, Z.ones.
- rewrite Z.shiftl_1_l.
- omega.
- Qed.
-
- Lemma pow2_mod_log_cap_range : forall a i, 0 <= Z.pow2_mod a (log_cap i) <= max_value i.
- Proof.
- intros.
- unfold Z.pow2_mod.
- rewrite Z.land_ones by eauto using log_cap_nonneg.
- unfold max_value, Z.ones.
- rewrite Z.shiftl_1_l, <-Z.lt_le_pred.
- apply Z_mod_lt.
- pose proof (pow_2_log_cap_pos i).
- omega.
- Qed.
-
- Lemma pow2_mod_log_cap_bounds_lower : forall a i, 0 <= Z.pow2_mod a (log_cap i).
- Proof.
- intros.
- pose proof (pow2_mod_log_cap_range a i); omega.
- Qed.
-
- Lemma pow2_mod_log_cap_bounds_upper : forall a i, Z.pow2_mod a (log_cap i) <= max_value i.
- Proof.
- intros.
- pose proof (pow2_mod_log_cap_range a i); omega.
- Qed.
-
- Lemma pow2_mod_log_cap_small : forall a i, 0 <= a <= max_value i ->
- Z.pow2_mod a (log_cap i) = a.
- Proof.
- intros.
- unfold Z.pow2_mod.
- rewrite Z.land_ones by eauto using log_cap_nonneg.
- apply Z.mod_small.
- split; try omega.
- rewrite <- max_value_log_cap.
- omega.
- Qed.
-
- Lemma max_value_pos : forall i, (i < length base)%nat -> 0 < max_value i.
- Proof.
- unfold max_value; intros; apply Z.ones_pos_pos.
- apply limb_widths_pos.
- rewrite nth_default_eq.
- apply nth_In.
- rewrite <-base_length; assumption.
- Qed.
- Local Hint Resolve max_value_pos.
-
- Lemma max_value_nonneg : forall i, 0 <= max_value i.
- Proof.
- unfold max_value; intros; eauto using Z.ones_nonneg.
- Qed.
- Local Hint Resolve max_value_nonneg.
-
- Lemma shiftr_eq_0_max_value : forall i a, Z.shiftr a (log_cap i) = 0 ->
- a <= max_value i.
- Proof.
- intros ? ? shiftr_0.
- apply Z.shiftr_eq_0_iff in shiftr_0.
- intuition; subst; try apply max_value_nonneg.
- match goal with H : Z.log2 _ < log_cap _ |- _ => apply Z.log2_lt_pow2 in H;
- replace (2 ^ log_cap i) with (Z.succ (max_value i)) in H by
- (unfold max_value, Z.ones; rewrite Z.shiftl_1_l; omega)
- end; auto.
- omega.
- Qed.
-
- Lemma B_compat_log_cap : forall i, 0 <= B - log_cap i.
- Proof.
- intros.
- destruct (lt_dec i (length limb_widths)).
- + apply Z.le_0_sub.
- apply B_compat.
- rewrite nth_default_eq.
- apply nth_In; assumption.
- + replace (nth_default 0 limb_widths i) with 0; try omega.
- symmetry; apply nth_default_out_of_bounds.
- omega.
- Qed.
- Local Hint Resolve B_compat_log_cap.
-
- Lemma max_value_shiftr_eq_0 : forall i a, 0 <= a -> a <= max_value i ->
- Z.shiftr a (log_cap i) = 0.
- Proof.
- intros ? ? ? le_max_value.
- apply Z.shiftr_eq_0_iff.
- destruct (Z_eq_dec a 0); auto.
- right.
- split; try omega.
- apply Z.log2_lt_pow2; try omega.
- rewrite <-max_value_log_cap.
- omega.
- Qed.
-
- Lemma log_cap_eq : forall i, log_cap i = nth_default 0 limb_widths i.
- Proof.
- reflexivity.
- Qed.
-
- (* END groundwork proofs *)
- Opaque Z.pow2_mod max_value.
-
- (* automation *)
- Ltac carry_length_conditions' := unfold carry_full;
- rewrite ?length_set_nth, ?length_add_to_nth, ?length_carry, ?carry_sequence_length;
- try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ].
- Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'.
-
- Ltac add_set_nth :=
- rewrite ?add_to_nth_nth_default by carry_length_conditions; break_if; try omega;
- rewrite ?set_nth_nth_default by carry_length_conditions; break_if; try omega.
-
- (* BEGIN defs *)
-
- Definition pre_carry_bounds us := forall i, 0 <= nth_default 0 us i <
- if (eq_nat_dec i 0) then 2 ^ B else 2 ^ B - 2 ^ (B - log_cap (pred i)).
-
- Lemma pre_carry_bounds_nonzero : forall us, pre_carry_bounds us ->
- (forall i, 0 <= nth_default 0 us i).
- Proof.
- unfold pre_carry_bounds.
- intros ? PCB i.
- specialize (PCB i).
- omega.
- Qed.
- Local Hint Resolve pre_carry_bounds_nonzero.
-
- (* END defs *)
-
- (* BEGIN proofs about first carry loop *)
-
- Lemma nth_default_carry_bound_upper : forall i us, (length us = length base) ->
- nth_default 0 (carry i us) i <= max_value i.
- Proof.
- unfold carry; intros.
- break_if.
- + unfold carry_and_reduce.
- add_set_nth.
- apply pow2_mod_log_cap_bounds_upper.
- + autorewrite with push_nth_default natsimplify.
- destruct (lt_dec i (length us)); auto using pow2_mod_log_cap_bounds_upper.
- Qed.
- Local Hint Resolve nth_default_carry_bound_upper.
-
- Lemma nth_default_carry_bound_lower : forall i us, (length us = length base) ->
- 0 <= nth_default 0 (carry i us) i.
- Proof.
- unfold carry; intros.
- break_if.
- + unfold carry_and_reduce.
- add_set_nth.
- apply pow2_mod_log_cap_bounds_lower.
- + autorewrite with push_nth_default natsimplify.
- break_if; auto using pow2_mod_log_cap_bounds_lower, Z.le_refl.
- Qed.
- Local Hint Resolve nth_default_carry_bound_lower.
-
- Lemma nth_default_carry_bound_succ_lower : forall i us, (forall i, 0 <= nth_default 0 us i) ->
- (length us = length base) ->
- 0 <= nth_default 0 (carry i us) (S i).
- Proof.
- unfold carry; intros ? ? PCB ? .
- break_if.
- + subst. replace (S (pred (length base))) with (length base) by omega.
- rewrite nth_default_out_of_bounds; carry_length_conditions.
- unfold carry_and_reduce.
- carry_length_conditions.
- + autorewrite with push_nth_default natsimplify.
- break_if; zero_bounds.
- Qed.
-
- Lemma carry_unaffected_low : forall i j us, ((0 < i < j)%nat \/ (i = 0 /\ j <> 0 /\ j <> pred (length base))%nat)->
- (length us = length base) ->
- nth_default 0 (carry j us) i = nth_default 0 us i.
- Proof.
- intros.
- unfold carry.
- break_if.
- + unfold carry_and_reduce.
- add_set_nth.
- + autorewrite with push_nth_default simpl_nth_default natsimplify.
- repeat break_if; autorewrite with simpl_nth_default natsimplify; omega.
- Qed.
-
- Lemma carry_unaffected_high : forall i j us, (S j < i)%nat -> (length us = length base) ->
- nth_default 0 (carry j us) i = nth_default 0 us i.
- Proof.
- intros.
- destruct (lt_dec i (length us));
- [ | rewrite !nth_default_out_of_bounds by carry_length_conditions; reflexivity].
- unfold carry.
- break_if; [omega | autorewrite with push_nth_default natsimplify; repeat break_if; omega ].
- Qed.
-
- Hint Rewrite max_bound_shiftr_eq_0 using omega : core.
- Hint Rewrite pow2_mod_log_cap_small using assumption : core.
-
- Lemma carry_nothing : forall i j us, (i < length base)%nat ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 us j <= max_value j ->
- nth_default 0 (carry j us) i = nth_default 0 us i.
- Proof.
- unfold carry, carry_and_reduce; intros.
- repeat (break_if
- || subst
- || (autorewrite with push_nth_default natsimplify core)
- || omega).
- Qed.
-
- Hint Rewrite pow2_mod_log_cap_small using (intuition; auto using shiftr_eq_0_max_bound) : core.
-
- Lemma carry_carry_done_done : forall i us,
- (length us = length base)%nat ->
- (i < length base)%nat ->
- carry_done us -> carry_done (carry i us).
- Proof.
- unfold carry_done; intros i ? ? i_bound Hcarry_done x x_bound.
- destruct (Hcarry_done x x_bound) as [lower_bound_x shiftr_0_x].
- destruct (Hcarry_done i i_bound) as [lower_bound_i shiftr_0_i].
- split.
- + rewrite carry_nothing; auto.
- split; [ apply Hcarry_done; auto | ].
- apply shiftr_eq_0_max_value.
- apply Hcarry_done; auto.
- + unfold carry, carry_and_reduce; break_if; subst.
- - add_set_nth; subst.
- * rewrite shiftr_0_i, Z.mul_0_r, Z.add_0_l.
- assumption.
- * rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_value).
- assumption.
- - repeat (carry_length_conditions
- || (autorewrite with push_nth_default natsimplify core zsimplify)
- || break_if
- || subst
- || rewrite shiftr_0_i by omega).
- Qed.
-
- Lemma carry_sequence_chain_step : forall i us,
- carry_sequence (make_chain (S i)) us = carry i (carry_sequence (make_chain i) us).
- Proof.
- reflexivity.
- Qed.
-
- Lemma carry_bounds_0_upper : forall us j, (length us = length base) ->
- (0 < j < length base)%nat ->
- nth_default 0 (carry_sequence (make_chain j) us) 0 <= max_value 0.
- Proof.
- induction j as [ | [ | j ] IHj ]; [simpl; intros; omega | | ]; intros.
- + subst; simpl; auto.
- + rewrite carry_sequence_chain_step, carry_unaffected_low; carry_length_conditions.
- Qed.
-
- Lemma carry_bounds_upper : forall i us j, (0 < i < j)%nat -> (length us = length base) ->
- nth_default 0 (carry_sequence (make_chain j) us) i <= max_value i.
- Proof.
- unfold carry_sequence;
- induction j; [simpl; intros; omega | ].
- intros.
- simpl in *.
- assert (i = j \/ i < j)%nat as cases by omega.
- destruct cases as [eq_j_i | lt_i_j]; subst.
- + apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain j) us); carry_length_conditions.
- + rewrite carry_unaffected_low; try omega.
- fold (carry_sequence (make_chain j) us); carry_length_conditions.
- Qed.
-
- Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat ->
- nth_default 0 (carry_sequence (make_chain j) us) i = nth_default 0 us i.
- Proof.
- induction j; [simpl; intros; omega | ].
- intros.
- simpl in *.
- rewrite carry_unaffected_high by carry_length_conditions.
- apply IHj; omega.
- Qed.
-
- (* makes omega run faster *)
- Ltac clear_obvious :=
- match goal with
- | [H : ?a <= ?a |- _] => clear H
- | [H : ?a <= S ?a |- _] => clear H
- | [H : ?a < S ?a |- _] => clear H
- | [H : ?a = ?a |- _] => clear H
- end.
-
- Lemma carry_sequence_bounds_lower : forall j i us, (length us = length base) ->
- (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain j) us) i.
- Proof.
- induction j; intros; simpl; auto.
- destruct (lt_dec (S j) i).
- + rewrite carry_unaffected_high by carry_length_conditions.
- apply IHj; auto; omega.
- + assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega.
- destruct cases as [? | [? | ?]].
- - subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions.
- intros; eapply IHj; auto; omega.
- - subst. apply nth_default_carry_bound_lower; carry_length_conditions.
- - destruct (eq_nat_dec j (pred (length base)));
- [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ].
- subst.
- do 2 match goal with H : appcontext[S (pred (length base))] |- _ =>
- erewrite <-(S_pred (length base)) in H by eauto end.
- unfold carry; break_if; [ unfold carry_and_reduce | omega ].
- clear_obvious. pose proof c_pos.
- add_set_nth; [ zero_bounds | ]; apply IHj; auto; omega.
- Qed.
-
- Ltac carry_seq_lower_bound :=
- repeat (intros; eapply carry_sequence_bounds_lower; eauto; carry_length_conditions).
-
- Lemma carry_bounds_lower : forall i us j, (0 < i <= j)%nat -> (length us = length base) ->
- (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain j) us) i.
- Proof.
- unfold carry_sequence;
- induction j; [simpl; intros; omega | ].
- intros.
- simpl in *.
- destruct (eq_nat_dec i (S j)).
- + subst. apply nth_default_carry_bound_succ_lower; auto;
- fold (carry_sequence (make_chain j) us); carry_length_conditions.
- carry_seq_lower_bound.
- + assert (i = j \/ i < j)%nat as cases by omega.
- destruct cases as [eq_j_i | lt_i_j]; subst;
- [apply nth_default_carry_bound_lower| rewrite carry_unaffected_low]; try omega;
- fold (carry_sequence (make_chain j) us); carry_length_conditions.
- carry_seq_lower_bound.
- Qed.
-
- Lemma carry_full_bounds : forall us i, (i <> 0)%nat -> (forall i, 0 <= nth_default 0 us i) ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full us) i <= max_value i.
- Proof.
- unfold carry_full, full_carry_chain; intros.
- split; (destruct (lt_dec i (length limb_widths));
- [ | rewrite nth_default_out_of_bounds by carry_length_conditions; omega || auto ]).
- + apply carry_bounds_lower; carry_length_conditions.
- + apply carry_bounds_upper; carry_length_conditions.
- Qed.
-
- Lemma carry_simple_no_overflow : forall us i, (i < pred (length base))%nat ->
- length us = length base ->
- 0 <= nth_default 0 us i < 2 ^ B ->
- 0 <= nth_default 0 us (S i) < 2 ^ B - 2 ^ (B - log_cap i) ->
- 0 <= nth_default 0 (carry i us) (S i) < 2 ^ B.
- Proof.
- intros.
- unfold carry; break_if; try omega.
- autorewrite with push_nth_default natsimplify.
- break_if; try omega.
- rewrite Z.add_comm.
- replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega.
- split; [ zero_bounds | ].
- apply Z.add_lt_mono; try omega.
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos.
- rewrite <-Z.pow_add_r by (eauto using log_cap_nonneg || apply B_compat_log_cap).
- replace (log_cap i + (B - log_cap i)) with B by ring.
- omega.
- Qed.
-
- Lemma carry_sequence_no_overflow : forall i us, pre_carry_bounds us ->
- (length us = length base) ->
- nth_default 0 (carry_sequence (make_chain i) us) i < 2 ^ B.
- Proof.
- unfold pre_carry_bounds.
- intros ? ? PCB ?.
- induction i.
- + simpl. specialize (PCB 0%nat).
- intuition.
- + simpl.
- destruct (lt_eq_lt_dec i (pred (length base))) as [[? | ? ] | ? ].
- - apply carry_simple_no_overflow; carry_length_conditions; carry_seq_lower_bound.
- rewrite carry_sequence_unaffected; try omega.
- specialize (PCB (S i)); rewrite Nat.pred_succ in PCB.
- break_if; intuition.
- - unfold carry; break_if; try omega.
- rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ].
- subst; unfold carry_and_reduce.
- carry_length_conditions.
- - rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ].
- carry_length_conditions.
- Qed.
-
- Lemma carry_full_bounds_0 : forall us, pre_carry_bounds us ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full us) 0 <= max_value 0 + c * (Z.ones (B - log_cap (pred (length base)))).
- Proof.
- unfold carry_full, full_carry_chain; intros.
- rewrite <- base_length.
- replace (length base) with (S (pred (length base))) by omega.
- simpl.
- unfold carry, carry_and_reduce; break_if; try omega.
- clear_obvious; add_set_nth.
- split; [pose proof c_pos; zero_bounds; carry_seq_lower_bound | ].
- rewrite Z.add_comm.
- apply Z.add_le_mono.
- + apply carry_bounds_0_upper; auto; omega.
- + apply Z.mul_le_mono_pos_l; auto using c_pos.
- apply Z.shiftr_ones; eauto;
- [ | pose proof (B_compat_log_cap (pred (length base))); omega ].
- split.
- - apply carry_bounds_lower; auto; omega.
- - apply carry_sequence_no_overflow; auto.
- Qed.
-
- Lemma carry_full_bounds_lower : forall i us, pre_carry_bounds us ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full us) i.
- Proof.
- destruct i; intros.
- + apply carry_full_bounds_0; auto.
- + destruct (lt_dec (S i) (length base)).
- - apply carry_bounds_lower; carry_length_conditions.
- - rewrite nth_default_out_of_bounds; carry_length_conditions.
- Qed.
-
- (* END proofs about first carry loop *)
-
- (* BEGIN proofs about second carry loop *)
-
- Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full us)) i <= 2 ^ log_cap i.
- Proof.
- induction i; intros; try omega.
- simpl.
- unfold carry; break_if; try omega.
- autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ].
- rewrite Z.add_comm.
- split.
- + zero_bounds; [destruct (eq_nat_dec i 0); subst | ].
- - simpl; apply carry_full_bounds_0; auto.
- - apply IHi; auto; omega.
- - rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; auto; omega.
- + rewrite <-max_value_log_cap, <-Z.add_1_l.
- apply Z.add_le_mono.
- - rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.div_floor; auto.
- destruct i.
- * simpl.
- eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ].
- replace (2 ^ log_cap 0 * 2) with (2 ^ log_cap 0 + 2 ^ log_cap 0) by ring.
- rewrite <-max_value_log_cap, <-Z.add_1_l.
- apply Z.add_lt_le_mono; omega.
- * eapply Z.le_lt_trans; [ apply IHi; auto; omega | ].
- apply Z.lt_mul_diag_r; auto; omega.
- - rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; auto; omega.
- Qed.
-
- Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us ->
- (length us = length base)%nat -> (1 < length base)%nat ->
- 0 <= nth_default 0 (carry_full (carry_full us)) 0 <= max_value 0 + c.
- Proof.
- intros.
- unfold carry_full at 1 3, full_carry_chain.
- rewrite <-base_length.
- replace (length base) with (S (pred (length base))) by (pose proof base_length_nonzero; omega).
- simpl.
- unfold carry, carry_and_reduce; break_if; try omega.
- clear_obvious; add_set_nth.
- split.
- + pose proof c_pos; zero_bounds; [ | carry_seq_lower_bound].
- apply carry_sequence_carry_full_bounds_same; auto; omega.
- + rewrite Z.add_comm.
- apply Z.add_le_mono.
- - apply carry_bounds_0_upper; carry_length_conditions.
- - etransitivity; [ | replace c with (c * 1) by ring; reflexivity ].
- apply Z.mul_le_mono_pos_l; try (pose proof c_pos; omega).
- rewrite Z.shiftr_div_pow2 by eauto.
- apply Z.div_le_upper_bound; auto.
- ring_simplify.
- apply carry_sequence_carry_full_bounds_same; auto.
- omega.
- Qed.
-
- Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < pred (length base))%nat ->
- ((0 < i < length base)%nat ->
- 0 <= nth_default 0
- (carry_sequence (make_chain i) (carry_full (carry_full us))) i <=
- 2 ^ log_cap i) ->
- 0 <= nth_default 0 (carry_simple limb_widths i
- (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i).
- Proof.
- intros ? ? PCB length_eq ? IH.
- autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ].
- rewrite Z.add_comm.
- split.
- + zero_bounds. destruct i;
- [ simpl; pose proof (carry_full_2_bounds_0 us PCB length_eq); omega | ].
- rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; carry_length_conditions.
- carry_seq_lower_bound.
- + rewrite <-max_value_log_cap, <-Z.add_1_l.
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.add_le_mono.
- - apply Z.div_le_upper_bound; auto.
- ring_simplify. apply IH. omega.
- - rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; carry_length_conditions.
- carry_seq_lower_bound.
- Qed.
-
- Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i.
- Proof.
- intros; induction i; try omega.
- simpl; unfold carry.
- break_if; try omega.
- split; (destruct (eq_nat_dec i 0); subst;
- [ cbv [make_chain carry_sequence fold_right];
- autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ];
- rewrite Z.add_comm
- | eapply carry_full_2_bounds_succ; eauto; omega]).
- + zero_bounds.
- - eapply carry_full_2_bounds_0; eauto.
- - eapply carry_full_bounds; eauto; carry_length_conditions.
- carry_seq_lower_bound.
- + rewrite <-max_value_log_cap, <-Z.add_1_l.
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.add_le_mono.
- - apply Z.div_floor; auto.
- eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ].
- replace (Z.succ 1) with (2 ^ 1) by ring.
- rewrite <-max_value_log_cap.
- ring_simplify. pose proof c_pos; omega.
- - apply carry_full_bounds; carry_length_conditions; carry_seq_lower_bound.
- Qed.
-
- Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat -> (i + j < length base)%nat -> (j <> 0)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain (i + j)) (carry_full (carry_full us))) i <= max_value i.
- Proof.
- induction j; intros; try omega.
- split; (destruct j; [ rewrite Nat.add_1_r; simpl
- | rewrite <-plus_n_Sm; simpl; rewrite carry_unaffected_low by carry_length_conditions; eapply IHj; eauto; omega ]).
- + apply nth_default_carry_bound_lower; carry_length_conditions.
- + apply nth_default_carry_bound_upper; carry_length_conditions.
- Qed.
-
- Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat -> (i < j < length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain j) (carry_full (carry_full us))) i <= max_value i.
- Proof.
- intros.
- replace j with (i + (j - i))%nat by omega.
- eapply carry_full_2_bounds'; eauto; omega.
- Qed.
-
- Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- (0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0).
- Proof.
- induction i; try omega.
- intros ? length_eq ?; simpl.
- destruct i.
- + unfold carry.
- break_if;
- [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
- simpl.
- autorewrite with push_nth_default natsimplify.
- apply pow2_mod_log_cap_bounds_lower.
- + rewrite carry_unaffected_low by carry_length_conditions.
- assert (0 < S i < length base)%nat by omega.
- intuition.
- Qed.
-
- Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full (carry_full us)) i.
- Proof.
- intros; destruct i.
- + apply carry_full_2_bounds_0; auto.
- + apply carry_full_bounds; try solve [carry_length_conditions].
- intro j; destruct j.
- - apply carry_full_bounds_0; auto.
- - apply carry_full_bounds; carry_length_conditions.
- Qed.
-
- Local Hint Resolve carry_full_length.
-
- Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- (nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0 <= max_value 0 - c)
- \/ carry_done (carry_sequence (make_chain i) (carry_full (carry_full us))).
- Proof.
- induction i; try omega.
- intros ? length_eq ?; simpl.
- destruct i.
- + destruct (Z_le_dec (nth_default 0 (carry_full (carry_full us)) 0) (max_value 0)).
- - right.
- apply carry_carry_done_done; try solve [carry_length_conditions].
- apply carry_done_bounds; try solve [carry_length_conditions].
- intros.
- simpl.
- split; [ auto using carry_full_2_bounds_lower | ].
- destruct i; rewrite <-max_value_log_cap, Z.lt_succ_r; auto.
- apply carry_full_bounds; auto using carry_full_bounds_lower.
- - left; unfold carry.
- break_if;
- [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
- autorewrite with push_nth_default natsimplify.
- simpl.
- remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x.
- apply Z.le_trans with (m := (max_value 0 + c) - (1 + max_value 0)); try omega.
- replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring.
- rewrite Z.pow2_mod_spec by eauto.
- cbv [make_chain carry_sequence fold_right].
- rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega).
- rewrite <-max_value_log_cap, <-Z.add_1_l, Z.mod_small;
- [ apply Z.sub_le_mono_r; subst; apply carry_full_2_bounds_0; auto | ].
- split; try omega.
- pose proof carry_full_2_bounds_0.
- apply Z.le_lt_trans with (m := (max_value 0 + c) - (1 + max_value 0));
- [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto;
- ring_simplify | ]; pose proof c_pos; omega.
- + rewrite carry_unaffected_low by carry_length_conditions.
- assert (0 < S i < length base)%nat by omega.
- intuition; right.
- apply carry_carry_done_done; try solve [carry_length_conditions].
- assumption.
- Qed.
-
-
- (* END proofs about second carry loop *)
-
- (* BEGIN proofs about third carry loop *)
-
- Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat ->(i < length base)%nat ->
- 0 <= nth_default 0 (carry_full (carry_full (carry_full us))) i <= max_value i.
- Proof.
- intros.
- destruct i; [ | apply carry_full_bounds; carry_length_conditions;
- carry_seq_lower_bound ].
- unfold carry_full at 1 4, full_carry_chain.
- case_eq limb_widths; [intros; pose proof limb_widths_nonnil; congruence | ].
- simpl.
- intros ? ? limb_widths_eq.
- replace (length l) with (pred (length limb_widths)) by (rewrite limb_widths_eq; auto).
- rewrite <- base_length.
- unfold carry, carry_and_reduce; break_if; try omega; intros.
- add_set_nth. pose proof c_pos.
- split.
- + zero_bounds.
- - eapply carry_full_2_bounds_same; eauto; omega.
- - eapply carry_carry_full_2_bounds_0_lower; eauto; omega.
- + pose proof (carry_carry_full_2_bounds_0_upper us (pred (length base))).
- assert (0 < pred (length base) < length base)%nat by omega.
- intuition.
- - replace (max_value 0) with (c + (max_value 0 - c)) by ring.
- apply Z.add_le_mono; try assumption.
- etransitivity; [ | replace c with (c * 1) by ring; reflexivity ].
- apply Z.mul_le_mono_pos_l; try omega.
- rewrite Z.shiftr_div_pow2 by eauto.
- apply Z.div_le_upper_bound; auto.
- ring_simplify.
- apply carry_full_2_bounds_same; auto.
- - match goal with H0 : (pred (length base) < length base)%nat,
- H : carry_done _ |- _ =>
- destruct (H (pred (length base)) H0) as [Hcd1 Hcd2]; rewrite Hcd2 by omega end.
- ring_simplify.
- apply shiftr_eq_0_max_value; auto.
- assert (0 < length base)%nat as zero_lt_length by omega.
- match goal with H : carry_done _ |- _ =>
- destruct (H 0%nat zero_lt_length) end.
- assumption.
- Qed.
-
- Lemma carry_full_3_done : forall us, pre_carry_bounds us ->
- (length us = length base)%nat ->
- carry_done (carry_full (carry_full (carry_full us))).
- Proof.
- intros.
- apply carry_done_bounds; [ carry_length_conditions | intros ].
- destruct (lt_dec i (length base)).
- + rewrite <-max_value_log_cap, Z.lt_succ_r.
- auto using carry_full_3_bounds.
- + rewrite nth_default_out_of_bounds; carry_length_conditions.
- Qed.
-
- (* END proofs about third carry loop *)
-
- Lemma isFull'_false : forall us n, isFull' us false n = false.
- Proof.
- unfold isFull'; induction n; intros; rewrite Bool.andb_false_r; auto.
- Qed.
-
- Lemma isFull'_last : forall us b j, (j <> 0)%nat -> isFull' us b j = true ->
- max_value j = nth_default 0 us j.
- Proof.
- induction j; simpl; intros; try omega.
- match goal with
- | [H : isFull' _ ((?comp ?a ?b) && _) _ = true |- _ ] =>
- case_eq (comp a b); rewrite ?Z.eqb_eq; intro comp_eq; try assumption;
- rewrite comp_eq, Bool.andb_false_l, isFull'_false in H; congruence
- end.
- Qed.
-
- Lemma isFull'_lower_bound_0 : forall j us b, isFull' us b j = true ->
- max_value 0 - c < nth_default 0 us 0.
- Proof.
- induction j; intros.
- + match goal with H : isFull' _ _ 0 = _ |- _ => cbv [isFull'] in H;
- apply Bool.andb_true_iff in H; destruct H end.
- apply Z.ltb_lt; assumption.
- + eauto.
- Qed.
-
- Lemma isFull'_true_full : forall us i j b, (i <> 0)%nat -> (i <= j)%nat -> isFull' us b j = true ->
- max_value i = nth_default 0 us i.
- Proof.
- induction j; intros; try omega.
- assert (i = S j \/ i <= j)%nat as cases by omega.
- destruct cases.
- + subst. eapply isFull'_last; eauto.
- + eapply IHj; eauto.
- Qed.
-
- Lemma max_ones_nonneg : 0 <= max_ones.
- Proof.
- unfold max_ones.
- apply Z.ones_nonneg.
- clear.
- pose proof limb_widths_nonneg.
- induction limb_widths as [|?? IHl].
- cbv; congruence.
- simpl.
- apply Z.max_le_iff.
- right.
- apply IHl; eauto using in_cons.
- Qed.
-
- Lemma land_max_ones_noop : forall x i, 0 <= x < 2 ^ log_cap i -> Z.land max_ones x = x.
- Proof.
- unfold max_ones.
- intros ? ? x_range.
- rewrite Z.land_comm.
- rewrite Z.land_ones by apply Z.le_fold_right_max_initial.
- apply Z.mod_small.
- split; try omega.
- eapply Z.lt_le_trans; try eapply x_range.
- apply Z.pow_le_mono_r; try omega.
- destruct (lt_dec i (length limb_widths)).
- + apply Z.le_fold_right_max.
- - apply limb_widths_nonneg.
- - rewrite nth_default_eq.
- auto using nth_In.
- + rewrite nth_default_out_of_bounds by omega.
- apply Z.le_fold_right_max_initial.
- Qed.
-
- Lemma full_isFull'_true : forall j us, (length us = length base) ->
- ( max_value 0 - c < nth_default 0 us 0
- /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_value i)) ->
- isFull' us true j = true.
- Proof.
- induction j; intros.
- + cbv [isFull']; apply Bool.andb_true_iff.
- rewrite Z.ltb_lt; intuition.
- + intuition.
- simpl.
- match goal with H : forall j, _ -> ?b j = ?a j |- appcontext[?a ?i =? ?b ?i] =>
- replace (a i =? b i) with true by (symmetry; apply Z.eqb_eq; symmetry; apply H; omega) end.
- apply IHj; auto; intuition.
- Qed.
-
- Lemma isFull'_true_iff : forall j us, (length us = length base) -> (isFull' us true j = true <->
- max_value 0 - c < nth_default 0 us 0
- /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_value i)).
- Proof.
- intros; split; intros; auto using full_isFull'_true.
- split; eauto using isFull'_lower_bound_0.
- intros.
- symmetry; eapply isFull'_true_full; [ omega | | eauto].
- omega.
- Qed.
-
- Lemma isFull'_true_step : forall us j, isFull' us true (S j) = true ->
- isFull' us true j = true.
- Proof.
- simpl; intros ? ? succ_true.
- destruct (max_value (S j) =? nth_default 0 us (S j)); auto.
- rewrite isFull'_false in succ_true.
- congruence.
- Qed.
-
- Opaque isFull' max_ones.
-
- Lemma carry_full_3_length : forall us, (length us = length base) ->
- length (carry_full (carry_full (carry_full us))) = length us.
- Proof.
- intros.
- repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); auto.
- Qed.
- Local Hint Resolve carry_full_3_length.
-
- Lemma modulus_digits'_length : forall i, length (modulus_digits' i) = S i.
- Proof.
- induction i; intros; [ cbv; congruence | ].
- unfold modulus_digits'; fold modulus_digits'.
- rewrite app_length, IHi.
- cbv [length]; omega.
- Qed.
-
- Lemma modulus_digits_length : length modulus_digits = length base.
- Proof.
- unfold modulus_digits.
- rewrite modulus_digits'_length; omega.
- Qed.
-
- (* Helps with solving goals of the form [x = y -> min x y = x] or [x = y -> min x y = y] *)
- Local Hint Resolve Nat.eq_le_incl eq_le_incl_rev.
-
- Hint Rewrite app_length cons_length map2_length modulus_digits_length length_zeros
- map_length combine_length firstn_length map_app : lengths.
- Ltac simpl_lengths := autorewrite with lengths;
- repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto);
- auto using Min.min_l; auto using Min.min_r.
-
- Lemma freeze_length : forall us, (length us = length base) ->
- length (freeze us) = length us.
- Proof.
- unfold freeze; intros; simpl_lengths.
- rewrite Min.min_l by omega. congruence.
- Qed.
-
- Lemma decode_firstn_succ : forall n us, (length us = length base) ->
- (n < length base)%nat ->
- BaseSystem.decode' (firstn (S n) base) (firstn (S n) us) =
- BaseSystem.decode' (firstn n base) (firstn n us) +
- nth_default 0 base n * nth_default 0 us n.
- Proof.
- intros.
- rewrite !firstn_succ with (d := 0) by omega.
- rewrite base_app, firstn_app.
- autorewrite with lengths; rewrite !Min.min_l by omega.
- rewrite Nat.sub_diag, firstn_firstn, firstn0, app_nil_r by omega.
- rewrite skipn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega).
- rewrite decode'_cons, decode_nil, Z.add_0_r.
- reflexivity.
- Qed.
-
- Local Hint Resolve sum_firstn_limb_widths_nonneg.
- Local Hint Resolve limb_widths_nonneg.
- Local Hint Resolve nth_error_value_In.
-
- Lemma decode_carry_done_upper_bound' : forall n us, carry_done us ->
- (length us = length base) ->
- BaseSystem.decode (firstn n base) (firstn n us) < 2 ^ (sum_firstn limb_widths n).
- Proof.
- induction n; intros; [ cbv; congruence | ].
- destruct (lt_dec n (length base)) as [ n_lt_length | ? ].
- + rewrite decode_firstn_succ; auto.
- rewrite base_length in n_lt_length.
- destruct (nth_error_length_exists_value _ _ n_lt_length).
- erewrite sum_firstn_succ; eauto.
- rewrite Z.pow_add_r; eauto.
- rewrite nth_default_base by
- (try rewrite base_from_limb_widths_length; omega || eauto).
- rewrite Z.lt_add_lt_sub_r.
- eapply Z.lt_le_trans; eauto.
- rewrite Z.mul_comm at 1.
- rewrite <-Z.mul_sub_distr_l.
- rewrite <-Z.mul_1_r at 1.
- apply Z.mul_le_mono_nonneg_l; [ apply Z.pow_nonneg; omega | ].
- replace 1 with (Z.succ 0) by reflexivity.
- rewrite Z.le_succ_l, Z.lt_0_sub.
- match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end.
- replace x with (log_cap n); try intuition.
- apply nth_error_value_eq_nth_default; auto.
- + repeat erewrite firstn_all_strong by omega.
- rewrite sum_firstn_all_succ by (rewrite <-base_length; omega).
- eapply Z.le_lt_trans; [ | eauto].
- repeat erewrite firstn_all_strong by omega.
- omega.
- Qed.
-
- Lemma decode_carry_done_upper_bound : forall us, carry_done us ->
- (length us = length base) -> BaseSystem.decode base us < 2 ^ k.
- Proof.
- unfold k; intros.
- rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto).
- rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto).
- auto using decode_carry_done_upper_bound'.
- Qed.
-
- Lemma decode_carry_done_lower_bound' : forall n us, carry_done us ->
- (length us = length base) ->
- 0 <= BaseSystem.decode (firstn n base) (firstn n us).
- Proof.
- induction n; intros; [ cbv; congruence | ].
- destruct (lt_dec n (length base)) as [ n_lt_length | ? ].
- + rewrite decode_firstn_succ by auto.
- zero_bounds.
- - rewrite nth_default_base by (omega || eauto).
- apply Z.pow_nonneg; omega.
- - match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end.
- intuition.
- + eapply Z.le_trans; [ apply IHn; eauto | ].
- repeat rewrite firstn_all_strong by omega.
- omega.
- Qed.
-
- Lemma decode_carry_done_lower_bound : forall us, carry_done us ->
- (length us = length base) -> 0 <= BaseSystem.decode base us.
- Proof.
- intros.
- rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto).
- rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto).
- auto using decode_carry_done_lower_bound'.
- Qed.
-
-
- Lemma nth_default_modulus_digits' : forall d j i,
- nth_default d (modulus_digits' j) i =
- if lt_dec i (S j)
- then (if (eq_nat_dec i 0) then max_value i - c + 1 else max_value i)
- else d.
- Proof.
- induction j; intros; (break_if; [| apply nth_default_out_of_bounds; rewrite modulus_digits'_length; omega]).
- + replace i with 0%nat by omega.
- apply nth_default_cons.
- + simpl. rewrite nth_default_app.
- rewrite modulus_digits'_length.
- break_if.
- - rewrite IHj; break_if; try omega; reflexivity.
- - replace i with (S j) by omega.
- rewrite Nat.sub_diag, nth_default_cons.
- reflexivity.
- Qed.
-
- Lemma nth_default_modulus_digits : forall d i,
- nth_default d modulus_digits i =
- if lt_dec i (length base)
- then (if (eq_nat_dec i 0) then max_value i - c + 1 else max_value i)
- else d.
- Proof.
- unfold modulus_digits; intros.
- rewrite nth_default_modulus_digits'.
- replace (S (length base - 1)) with (length base) by omega.
- reflexivity.
- Qed.
-
- Lemma carry_done_modulus_digits : carry_done modulus_digits.
- Proof.
- apply carry_done_bounds; [apply modulus_digits_length | ].
- intros.
- rewrite nth_default_modulus_digits.
- break_if; [ | split; auto; omega].
- break_if; subst; split; auto; try rewrite <- max_value_log_cap; pose proof c_pos; omega.
- Qed.
- Local Hint Resolve carry_done_modulus_digits.
-
- Lemma decode_mod : forall us vs x, (length us = length base) -> (length vs = length base) ->
- decode us = x ->
- BaseSystem.decode base us mod modulus = BaseSystem.decode base vs mod modulus ->
- decode vs = x.
- Proof.
- unfold decode; intros until 2; intros decode_us_x BSdecode_eq.
- rewrite ZToField_mod in decode_us_x |- *.
- rewrite <-BSdecode_eq.
- assumption.
- Qed.
-
- Lemma decode_map2_sub : forall us vs,
- (length us = length vs) ->
- BaseSystem.decode' base (map2 (fun x y => x - y) us vs)
- = BaseSystem.decode' base us - BaseSystem.decode' base vs.
- Proof.
- induction us using rev_ind; induction vs using rev_ind;
- intros; autorewrite with lengths in *; simpl_list_lengths;
- rewrite ?decode_nil; try omega.
- rewrite map2_app by omega.
- rewrite map2_cons, map2_nil_l.
- rewrite !set_higher.
- autorewrite with lengths.
- rewrite Min.min_l by omega.
- rewrite IHus by omega.
- replace (length vs) with (length us) by omega.
- ring.
- Qed.
-
- Lemma decode_modulus_digits' : forall i, (i <= length base)%nat ->
- BaseSystem.decode' base (modulus_digits' i) = 2 ^ (sum_firstn limb_widths (S i)) - c.
- Proof.
- induction i; intros; unfold modulus_digits'; fold modulus_digits'.
- + let base := constr:(base) in
- case_eq base;
- [ intro base_eq; rewrite base_eq, (@nil_length0 Z) in lt_1_length_base; omega | ].
- intros z ? base_eq.
- rewrite decode'_cons, decode_nil, Z.add_0_r.
- replace z with (nth_default 0 base 0) by (rewrite base_eq; auto).
- rewrite nth_default_base by (omega || eauto).
- replace (max_value 0 - c + 1) with (Z.succ (max_value 0) - c) by ring.
- rewrite max_value_log_cap.
- rewrite sum_firstn_succ with (x := log_cap 0) by (
- apply nth_error_Some_nth_default; rewrite <-base_length; omega).
- rewrite Z.pow_add_r by eauto.
- cbv [sum_firstn fold_right firstn].
- ring.
- + assert (S i < length base \/ S i = length base)%nat as cases by omega.
- destruct cases.
- - rewrite sum_firstn_succ with (x := log_cap (S i)) by
- (apply nth_error_Some_nth_default;
- rewrite <-base_length; omega).
- rewrite Z.pow_add_r, <-max_value_log_cap, set_higher by eauto.
- rewrite IHi, modulus_digits'_length by omega.
- rewrite nth_default_base by (omega || eauto).
- ring.
- - rewrite sum_firstn_all_succ by (rewrite <-base_length; omega).
- rewrite decode'_splice, modulus_digits'_length, firstn_all by auto.
- rewrite skipn_all, decode_base_nil, Z.add_0_r by omega.
- apply IHi.
- omega.
- Qed.
-
- Lemma decode_modulus_digits : BaseSystem.decode' base modulus_digits = modulus.
- Proof.
- unfold modulus_digits; rewrite decode_modulus_digits' by omega.
- replace (S (length base - 1)) with (length base) by omega.
- rewrite base_length.
- fold k. unfold c.
- ring.
- Qed.
-
- Lemma map_land_max_ones_modulus_digits' : forall i,
- map (Z.land max_ones) (modulus_digits' i) = (modulus_digits' i).
- Proof.
- induction i; intros.
- + cbv [modulus_digits' map].
- f_equal.
- apply land_max_ones_noop with (i := 0%nat).
- rewrite <-max_value_log_cap.
- pose proof c_pos; omega.
- + unfold modulus_digits'; fold modulus_digits'.
- rewrite map_app.
- f_equal; [ apply IHi; omega | ].
- cbv [map]; f_equal.
- apply land_max_ones_noop with (i := S i).
- rewrite <-max_value_log_cap.
- split; auto; omega.
- Qed.
-
- Lemma map_land_max_ones_modulus_digits : map (Z.land max_ones) modulus_digits = modulus_digits.
- Proof.
- apply map_land_max_ones_modulus_digits'.
- Qed.
-
- Opaque modulus_digits.
-
- Lemma map_land_zero : forall ls, map (Z.land 0) ls = BaseSystem.zeros (length ls).
- Proof.
- induction ls; boring.
- Qed.
-
- Lemma carry_full_preserves_Fdecode : forall us x, (length us = length base) ->
- decode us = x -> decode (carry_full us) = x.
- Proof.
- intros.
- apply carry_full_preserves_rep; auto.
- unfold rep; auto.
- Qed.
-
- Lemma freeze_preserves_rep : forall us x, rep us x -> rep (freeze us) x.
- Proof.
- unfold rep; intros.
- intuition; rewrite ?freeze_length; auto.
- unfold freeze, and_term.
- break_if.
- + apply decode_mod with (us := carry_full (carry_full (carry_full us))).
- - rewrite carry_full_3_length; auto.
- - autorewrite with lengths.
- apply Min.min_r.
- simpl_lengths; omega.
- - repeat apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto.
- unfold rep; intuition.
- - rewrite decode_map2_sub by (simpl_lengths; omega).
- rewrite map_land_max_ones_modulus_digits.
- rewrite decode_modulus_digits.
- destruct (Z_eq_dec modulus 0); [ subst; rewrite !Zmod_0_r; reflexivity | ].
- rewrite <-Z.add_opp_r.
- replace (-modulus) with (-1 * modulus) by ring.
- symmetry; auto using Z.mod_add.
- + eapply decode_mod; eauto.
- simpl_lengths.
- rewrite map_land_zero, decode_map2_sub, zeros_rep, Z.sub_0_r by simpl_lengths.
- match goal with H : decode ?us = ?x |- _ => erewrite Fdecode_decode_mod; eauto;
- do 3 apply carry_full_preserves_Fdecode in H; simpl_lengths
- end.
- erewrite Fdecode_decode_mod; eauto; simpl_lengths.
- Qed.
- Hint Resolve freeze_preserves_rep.
-
- Lemma isFull_true_iff : forall us, (length us = length base) -> (isFull us = true <->
- max_value 0 - c < nth_default 0 us 0
- /\ (forall i, (0 < i <= length base - 1)%nat -> nth_default 0 us i = max_value i)).
- Proof.
- unfold isFull; intros; auto using isFull'_true_iff.
- Qed.
-
- Definition minimal_rep us := BaseSystem.decode base us = (BaseSystem.decode base us) mod modulus.
-
- Fixpoint compare' us vs i :=
- match i with
- | O => Eq
- | S i' => if Z_eq_dec (nth_default 0 us i') (nth_default 0 vs i')
- then compare' us vs i'
- else Z.compare (nth_default 0 us i') (nth_default 0 vs i')
- end.
-
- (* Lexicographically compare two vectors of equal length, starting from the END of the list
- (in our context, this is the most significant end). NOT constant time. *)
- Definition compare us vs := compare' us vs (length us).
-
- Lemma compare'_Eq : forall us vs i, (length us = length vs) ->
- compare' us vs i = Eq -> firstn i us = firstn i vs.
- Proof.
- induction i; intros; [ cbv; congruence | ].
- destruct (lt_dec i (length us)).
- + repeat rewrite firstn_succ with (d := 0) by omega.
- match goal with H : compare' _ _ (S _) = Eq |- _ =>
- inversion H end.
- break_if; f_equal; auto.
- - f_equal; auto.
- - rewrite Z.compare_eq_iff in *. congruence.
- - rewrite Z.compare_eq_iff in *. congruence.
- + rewrite !firstn_all_strong in IHi by omega.
- match goal with H : compare' _ _ (S _) = Eq |- _ =>
- inversion H end.
- rewrite (nth_default_out_of_bounds i us) in * by omega.
- rewrite (nth_default_out_of_bounds i vs) in * by omega.
- break_if; try congruence.
- f_equal; auto.
- Qed.
-
- Lemma compare_Eq : forall us vs, (length us = length vs) ->
- compare us vs = Eq -> us = vs.
- Proof.
- intros.
- erewrite <-(firstn_all _ us); eauto.
- erewrite <-(firstn_all _ vs); eauto.
- apply compare'_Eq; auto.
- Qed.
-
- Lemma decode_lt_next_digit : forall us n, (length us = length base) ->
- (n < length base)%nat -> (n < length us)%nat ->
- carry_done us ->
- BaseSystem.decode' (firstn n base) (firstn n us) <
- (nth_default 0 base n).
- Proof.
- induction n; intros ? ? ? bounded.
- + cbv [firstn].
- rewrite decode_base_nil.
- apply Z.gt_lt; auto using nth_default_base_positive.
- + rewrite decode_firstn_succ by (auto || omega).
- rewrite nth_default_base_succ by (eauto || omega).
- eapply Z.lt_le_trans.
- - apply Z.add_lt_mono_r.
- apply IHn; auto; omega.
- - rewrite <-(Z.mul_1_r (nth_default 0 base n)) at 1.
- rewrite <-Z.mul_add_distr_l, Z.mul_comm.
- apply Z.mul_le_mono_pos_r.
- * apply Z.gt_lt. apply nth_default_base_positive; omega.
- * rewrite Z.add_1_l.
- apply Z.le_succ_l.
- rewrite carry_done_bounds in bounded by assumption.
- apply bounded.
- Qed.
-
- Lemma highest_digit_determines : forall us vs n x, (x < 0) ->
- (length us = length base) ->
- (length vs = length base) ->
- (n < length us)%nat -> carry_done us ->
- (n < length vs)%nat -> carry_done vs ->
- BaseSystem.decode (firstn n base) (firstn n us) +
- nth_default 0 base n * x -
- BaseSystem.decode (firstn n base) (firstn n vs) < 0.
- Proof.
- intros.
- eapply Z.le_lt_trans.
- + apply Z.le_sub_nonneg.
- apply decode_carry_done_lower_bound'; auto.
- + eapply Z.le_lt_trans.
- - eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ].
- apply Z.mul_le_mono_nonneg_l; try omega.
- rewrite nth_default_base by (omega || eauto).
- zero_bounds.
- - ring_simplify.
- apply Z.lt_sub_0.
- apply decode_lt_next_digit; auto.
- omega.
- Qed.
-
- Lemma Z_compare_decode_step_eq : forall n us vs,
- (length us = length base) ->
- (length us = length vs) ->
- (S n <= length base)%nat ->
- (nth_default 0 us n = nth_default 0 vs n) ->
- (BaseSystem.decode (firstn (S n) base) us ?=
- BaseSystem.decode (firstn (S n) base) vs) =
- (BaseSystem.decode (firstn n base) us ?=
- BaseSystem.decode (firstn n base) vs).
- Proof.
- intros until 3; intro nth_default_eq.
- destruct (lt_dec n (length us)); try omega.
- rewrite firstn_succ with (d := 0), !base_app by omega.
- autorewrite with lengths; rewrite Min.min_l by omega.
- do 2 (rewrite skipn_nth_default with (d := 0) by omega;
- rewrite decode'_cons, decode_base_nil, Z.add_0_r).
- rewrite Z.compare_sub, nth_default_eq, Z.add_add_simpl_r_r.
- rewrite BaseSystem.decode'_truncate with (us := us).
- rewrite BaseSystem.decode'_truncate with (us := vs).
- rewrite firstn_length, Min.min_l, <-Z.compare_sub by omega.
- reflexivity.
- Qed.
-
- Lemma Z_compare_decode_step_lt : forall n us vs,
- (length us = length base) ->
- (length us = length vs) ->
- (S n <= length base)%nat ->
- carry_done us -> carry_done vs ->
- (nth_default 0 us n < nth_default 0 vs n) ->
- (BaseSystem.decode (firstn (S n) base) us ?=
- BaseSystem.decode (firstn (S n) base) vs) = Lt.
- Proof.
- intros until 5; intro nth_default_lt.
- destruct (lt_dec n (length us)).
- + rewrite firstn_succ with (d := 0) by omega.
- rewrite !base_app.
- autorewrite with lengths; rewrite Min.min_l by omega.
- do 2 (rewrite skipn_nth_default with (d := 0) by omega;
- rewrite decode'_cons, decode_base_nil, Z.add_0_r).
- rewrite Z.compare_sub.
- apply Z.compare_lt_iff.
- ring_simplify.
- rewrite <-Z.add_sub_assoc.
- rewrite <-Z.mul_sub_distr_l.
- apply highest_digit_determines; auto; omega.
- + rewrite !nth_default_out_of_bounds in nth_default_lt; omega.
- Qed.
-
- Lemma Z_compare_decode_step_neq : forall n us vs,
- (length us = length base) -> (length us = length vs) ->
- (S n <= length base)%nat ->
- carry_done us -> carry_done vs ->
- (nth_default 0 us n <> nth_default 0 vs n) ->
- (BaseSystem.decode (firstn (S n) base) us ?=
- BaseSystem.decode (firstn (S n) base) vs) =
- (nth_default 0 us n ?= nth_default 0 vs n).
- Proof.
- intros.
- destruct (Z_dec (nth_default 0 us n) (nth_default 0 vs n)) as [[?|Hgt]|?]; try congruence.
- + etransitivity; try apply Z_compare_decode_step_lt; auto.
- + match goal with |- (?a ?= ?b) = (?c ?= ?d) =>
- rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end.
- apply CompOpp_inj; rewrite !CompOpp_involutive.
- apply Z.gt_lt_iff in Hgt.
- etransitivity; try apply Z_compare_decode_step_lt; auto; omega.
- Qed.
-
- Lemma decode_compare' : forall n us vs,
- (length us = length base) ->
- (length us = length vs) ->
- (n <= length base)%nat ->
- carry_done us -> carry_done vs ->
- (BaseSystem.decode (firstn n base) us ?= BaseSystem.decode (firstn n base) vs)
- = compare' us vs n.
- Proof.
- induction n; intros.
- + cbv [firstn compare']; rewrite !decode_base_nil; auto.
- + unfold compare'; fold compare'.
- break_if.
- - rewrite Z_compare_decode_step_eq by (auto || omega).
- apply IHn; auto; omega.
- - rewrite Z_compare_decode_step_neq; (auto || omega).
- Qed.
-
- Lemma decode_compare : forall us vs,
- (length us = length base) -> carry_done us ->
- (length vs = length base) -> carry_done vs ->
- Z.compare (BaseSystem.decode base us) (BaseSystem.decode base vs) = compare us vs.
- Proof.
- unfold compare; intros.
- erewrite <-(firstn_all _ base).
- + apply decode_compare'; auto; omega.
- + assumption.
- Qed.
+ Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg.
+ Local Hint Resolve log_cap_nonneg.
- Lemma compare'_succ : forall us j vs, compare' us vs (S j) =
- if Z.eq_dec (nth_default 0 us j) (nth_default 0 vs j)
- then compare' us vs j
- else nth_default 0 us j ?= nth_default 0 vs j.
- Proof.
+ Lemma nth_default_carry_and_reduce_full : forall n i us,
+ nth_default 0 (carry_and_reduce i us) n
+ = if lt_dec n (length us)
+ then
+ (if eq_nat_dec n (i mod length limb_widths)
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n) +
+ if PeanoNat.Nat.eq_dec n (S (i mod length limb_widths) mod length limb_widths)
+ then c * nth_default 0 us (i mod length limb_widths) >> log_cap (i mod length limb_widths)
+ else 0
+ else 0.
+ Proof.
+ cbv [carry_and_reduce]; intros.
+ autorewrite with push_nth_default.
reflexivity.
Qed.
-
- Lemma compare'_firstn_r_small_index : forall us j vs, (j <= length vs)%nat ->
- compare' us vs j = compare' us (firstn j vs) j.
- Proof.
- induction j; intros; auto.
- rewrite !compare'_succ by omega.
- rewrite firstn_succ with (d := 0) by omega.
- rewrite nth_default_app.
- simpl_lengths.
- rewrite Min.min_l by omega.
- destruct (lt_dec j j); try omega.
- rewrite Nat.sub_diag.
- rewrite nth_default_cons.
- break_if; try reflexivity.
- rewrite IHj with (vs := firstn j vs ++ nth_default 0 vs j :: nil) by
- (autorewrite with lengths; rewrite Min.min_l; omega).
- rewrite firstn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega).
- apply IHj; omega.
- Qed.
-
- Lemma compare'_firstn_r : forall us j vs,
- compare' us vs j = compare' us (firstn j vs) j.
- Proof.
- intros.
- destruct (le_dec j (length vs)).
- + auto using compare'_firstn_r_small_index.
- + f_equal. symmetry.
- apply firstn_all_strong.
- omega.
- Qed.
-
- Lemma compare'_not_Lt : forall us vs j, j <> 0%nat ->
- (forall i, (0 < i < j)%nat -> 0 <= nth_default 0 us i <= nth_default 0 vs i) ->
- compare' us vs j <> Lt ->
- nth_default 0 vs 0 <= nth_default 0 us 0 /\
- (forall i : nat, (0 < i < j)%nat -> nth_default 0 us i = nth_default 0 vs i).
- Proof.
- induction j; try congruence.
- rewrite compare'_succ.
- intros; destruct (eq_nat_dec j 0).
- + break_if; subst; split; intros; try omega.
- rewrite Z.compare_ge_iff in *; omega.
- + break_if.
- - split; intros; [ | destruct (eq_nat_dec i j); subst; auto ];
- apply IHj; auto; intros; try omega;
- match goal with H : forall i, _ -> 0 <= ?f i <= ?g i |- 0 <= ?f _ <= ?g _ =>
- apply H; omega end.
- - exfalso. rewrite Z.compare_ge_iff in *.
- match goal with H : forall i, ?P -> 0 <= ?f i <= ?g i |- _ =>
- specialize (H j) end; omega.
- Qed.
-
- Lemma isFull'_compare' : forall us j, j <> 0%nat -> (length us = length base) ->
- (j <= length base)%nat -> carry_done us ->
- (isFull' us true (j - 1) = true <-> compare' us modulus_digits j <> Lt).
- Proof.
- unfold compare; induction j; intros; try congruence.
- replace (S j - 1)%nat with j by omega.
- split; intros.
- + simpl.
- break_if; [destruct (eq_nat_dec j 0) | ].
- - subst. cbv; congruence.
- - apply IHj; auto; try omega.
- apply isFull'_true_step.
- replace (S (j - 1)) with j by omega; auto.
- - rewrite nth_default_modulus_digits in *.
- repeat (break_if; try omega).
- * subst.
- match goal with H : isFull' _ _ _ = true |- _ =>
- apply isFull'_lower_bound_0 in H end.
- apply Z.compare_ge_iff.
- omega.
- * match goal with H : isFull' _ _ _ = true |- _ =>
- apply isFull'_true_iff in H; try assumption; destruct H as [? eq_max_value] end.
- specialize (eq_max_value j).
- omega.
- + apply isFull'_true_iff; try assumption.
- match goal with H : compare' _ _ _ <> Lt |- _ => apply compare'_not_Lt in H; [ destruct H as [Hdigit0 Hnonzero] | | ] end.
- - split; [ | intros i i_range; assert (0 < i < S j)%nat as i_range' by omega;
- specialize (Hnonzero i i_range')];
- rewrite nth_default_modulus_digits in *;
- repeat (break_if; try omega).
- - congruence.
- - intros.
- rewrite nth_default_modulus_digits.
- repeat (break_if; try omega).
- rewrite <-Z.lt_succ_r with (m := max_value i).
- rewrite max_value_log_cap; apply carry_done_bounds; assumption.
- Qed.
-
- Lemma isFull_compare : forall us, (length us = length base) -> carry_done us ->
- (isFull us = true <-> compare us modulus_digits <> Lt).
- Proof.
- unfold compare, isFull; intros ? lengths_eq. intros.
- rewrite lengths_eq.
- apply isFull'_compare'; try omega.
- assumption.
- Qed.
-
- Lemma isFull_decode : forall us, (length us = length base) -> carry_done us ->
- (isFull us = true <->
- (BaseSystem.decode base us ?= BaseSystem.decode base modulus_digits <> Lt)).
- Proof.
- intros.
- rewrite decode_compare; autorewrite with lengths; auto.
- apply isFull_compare; auto.
- Qed.
-
- Lemma isFull_false_upper_bound : forall us, (length us = length base) ->
- carry_done us -> isFull us = false ->
- BaseSystem.decode base us < modulus.
- Proof.
- intros.
- destruct (Z_lt_dec (BaseSystem.decode base us) modulus) as [? | nlt_modulus];
- [assumption | exfalso].
- apply Z.compare_nlt_iff in nlt_modulus.
- rewrite <-decode_modulus_digits in nlt_modulus at 2.
- apply isFull_decode in nlt_modulus; try assumption; congruence.
- Qed.
-
- Lemma isFull_true_lower_bound : forall us, (length us = length base) ->
- carry_done us -> isFull us = true ->
- modulus <= BaseSystem.decode base us.
- Proof.
- intros.
- rewrite <-decode_modulus_digits at 1.
- apply Z.compare_ge_iff.
- apply isFull_decode; auto.
- Qed.
-
- Lemma freeze_in_bounds : forall us,
- pre_carry_bounds us -> (length us = length base) ->
- carry_done (freeze us).
- Proof.
- unfold freeze, and_term; intros ? PCB lengths_eq.
- rewrite carry_done_bounds by simpl_lengths; intro i.
- rewrite nth_default_map2 with (d1 := 0) (d2 := 0).
- simpl_lengths.
- break_if; [ | split; (omega || auto)].
+ Hint Rewrite @nth_default_carry_and_reduce_full : push_nth_default.
+
+ Lemma nth_default_carry_full : forall n i us,
+ length us = length limb_widths ->
+ nth_default 0 (carry i us) n
+ = if lt_dec n (length us)
+ then
+ if eq_nat_dec i (pred (length limb_widths))
+ then (if eq_nat_dec n i
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n) +
+ if eq_nat_dec n 0
+ then c * (nth_default 0 us i >> log_cap i)
+ else 0
+ else if eq_nat_dec n i
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n +
+ if eq_nat_dec n (S i)
+ then nth_default 0 us i >> log_cap i
+ else 0
+ else 0.
+ Proof.
+ intros.
+ cbv [carry].
break_if.
- + rewrite map_land_max_ones_modulus_digits.
- apply isFull_true_iff in Heqb; [ | simpl_lengths].
- destruct Heqb as [first_digit high_digits].
- destruct (eq_nat_dec i 0).
- - subst.
- clear high_digits.
- rewrite nth_default_modulus_digits.
- repeat (break_if; try omega).
- pose proof (carry_full_3_done us PCB lengths_eq) as cf3_done.
- rewrite carry_done_bounds in cf3_done by simpl_lengths.
- specialize (cf3_done 0%nat).
- pose proof c_pos; omega.
- - assert ((0 < i <= length base - 1)%nat) as i_range by
- (simpl_lengths; apply lt_min_l in l; omega).
- specialize (high_digits i i_range).
- clear first_digit i_range.
- rewrite high_digits.
- rewrite <-max_value_log_cap.
- rewrite nth_default_modulus_digits.
- repeat (break_if; try omega).
- * rewrite Z.sub_diag.
- split; try omega.
- apply Z.lt_succ_r; auto.
- * rewrite Z.lt_succ_r, Z.sub_0_r. split; (omega || auto).
- + rewrite map_land_zero, nth_default_zeros.
- rewrite Z.sub_0_r.
- apply carry_done_bounds; [ simpl_lengths | ].
- auto using carry_full_3_done.
- Qed.
- Local Hint Resolve freeze_in_bounds.
-
- Local Hint Resolve carry_full_3_done.
-
- Lemma freeze_minimal_rep : forall us, pre_carry_bounds us -> (length us = length base) ->
- minimal_rep (freeze us).
- Proof.
- unfold minimal_rep, freeze, and_term.
- intros.
- symmetry. apply Z.mod_small.
- split; break_if; rewrite decode_map2_sub; simpl_lengths.
- + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits.
- apply Z.le_0_sub.
- apply isFull_true_lower_bound; simpl_lengths.
- + rewrite map_land_zero, zeros_rep, Z.sub_0_r.
- apply decode_carry_done_lower_bound; simpl_lengths.
- + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits.
- rewrite Z.lt_sub_lt_add_r.
- apply Z.lt_le_trans with (m := 2 * modulus); try omega.
- eapply Z.lt_le_trans; [ | apply two_pow_k_le_2modulus ].
- apply decode_carry_done_upper_bound; simpl_lengths.
- + rewrite map_land_zero, zeros_rep, Z.sub_0_r.
- apply isFull_false_upper_bound; simpl_lengths.
+ + subst i. autorewrite with push_nth_default natsimplify.
+ destruct (eq_nat_dec (length limb_widths) (length us)); congruence.
+ + autorewrite with push_nth_default; reflexivity.
+ Qed.
+ Hint Rewrite @nth_default_carry_full : push_nth_default.
+
+ Lemma nth_default_carry_sequence_make_chain_full : forall i n us,
+ length us = length limb_widths ->
+ (i <= length limb_widths)%nat ->
+ nth_default 0 (carry_sequence (make_chain i) us) n
+ = if lt_dec n (length limb_widths)
+ then
+ if eq_nat_dec i 0
+ then nth_default 0 us n
+ else
+ if lt_dec i (length limb_widths)
+ then
+ if lt_dec n i
+ then
+ if eq_nat_dec n (pred i)
+ then Z.pow2_mod
+ (nth_default 0 (carry_sequence (make_chain (pred i)) us) n)
+ (log_cap n)
+ else nth_default 0 (carry_sequence (make_chain (pred i)) us) n
+ else nth_default 0 (carry_sequence (make_chain (pred i)) us) n +
+ (if eq_nat_dec n i
+ then (nth_default 0 (carry_sequence (make_chain (pred i)) us) (pred i))
+ >> log_cap (pred i)
+ else 0)
+ else
+ if lt_dec n (pred i)
+ then nth_default 0 (carry_sequence (make_chain (pred i)) us) n +
+ (if eq_nat_dec n 0
+ then c * (nth_default 0 (carry_sequence (make_chain (pred i)) us) (pred i))
+ >> log_cap (pred i)
+ else 0)
+ else Z.pow2_mod
+ (nth_default 0 (carry_sequence (make_chain (pred i)) us) n)
+ (log_cap n)
+ else 0.
+ Proof.
+ induction i; intros; cbv [ModularBaseSystemList.carry_sequence].
+ + cbv [pred make_chain fold_right].
+ repeat break_if; subst; omega || reflexivity || auto using Z.add_0_r.
+ apply nth_default_out_of_bounds. omega.
+ + replace (make_chain (S i)) with (i :: make_chain i) by reflexivity.
+ rewrite fold_right_cons.
+ autorewrite with push_nth_default natsimplify;
+ rewrite ?Nat.pred_succ; fold (carry_sequence (make_chain i) us);
+ rewrite length_carry_sequence; auto.
+ repeat break_if; try omega; rewrite ?IHi by (omega || auto);
+ rewrite ?Z.add_0_r; try reflexivity.
+ Qed.
+
+ Lemma nth_default_carry_full_full : forall n us,
+ length us = length limb_widths ->
+ nth_default 0 (ModularBaseSystemList.carry_full us) n
+ = if lt_dec n (length limb_widths)
+ then
+ if eq_nat_dec n (pred (length limb_widths))
+ then Z.pow2_mod
+ (nth_default 0 (carry_sequence (make_chain (pred (length limb_widths))) us) n)
+ (log_cap n)
+ else nth_default 0 (carry_sequence (make_chain (pred (length limb_widths))) us) n +
+ (if eq_nat_dec n 0
+ then c * (nth_default 0 (carry_sequence (make_chain (pred (length limb_widths))) us) (pred (length limb_widths)))
+ >> log_cap (pred (length limb_widths))
+ else 0)
+ else 0.
+ Proof.
+ intros.
+ cbv [ModularBaseSystemList.carry_full full_carry_chain].
+ rewrite (nth_default_carry_sequence_make_chain_full (length limb_widths)) by omega.
+ repeat break_if; try omega; reflexivity.
+ Qed.
+ Hint Rewrite @nth_default_carry_full_full : push_nth_default.
+
+ Lemma nth_default_carry : forall i us,
+ length us = length limb_widths ->
+ (i < length us)%nat ->
+ nth_default 0 (ModularBaseSystemList.carry i us) i
+ = Z.pow2_mod (nth_default 0 us i) (log_cap i).
+ Proof.
+ intros; autorewrite with push_nth_default natsimplify; break_match; omega.
+ Qed.
+ Hint Rewrite @nth_default_carry using (omega || distr_length; omega) : push_nth_default.
+
+ Local Notation pred := Init.Nat.pred.
+ Local Notation "u '[' i ']' " := (nth_default 0 u i) (at level 30).
+ Local Notation "u '{{' i '}}' " := (carry_sequence (make_chain i) u) (at level 30).
+
+ Lemma bound_during_first_loop : forall i n us,
+ length us = length limb_widths ->
+ (i <= length limb_widths)%nat ->
+ (forall n, 0 <= nth_default 0 us n < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> log_cap (pred n))) ->
+ 0 <= us{{i}}[n] < if eq_nat_dec i 0 then us[n] + 1 else
+ if lt_dec i (length limb_widths)
+ then
+ if lt_dec n i
+ then 2 ^ (log_cap n)
+ else if eq_nat_dec n i
+ then 2 ^ B
+ else us[n] + 1
+ else
+ if eq_nat_dec n 0
+ then 2 * 2 ^ limb_widths [n]
+ else 2 ^ limb_widths [n].
+ Proof.
+ induction i; intros; cbv [ModularBaseSystemList.carry_sequence].
+ + break_if; try omega.
+ cbv [make_chain fold_right]. split; try omega. apply H1.
+ + replace (make_chain (S i)) with (i :: make_chain i) by reflexivity.
+ rewrite fold_right_cons.
+ autorewrite with push_nth_default natsimplify; rewrite ?Nat.pred_succ;
+ fold (carry_sequence (make_chain i) us); rewrite length_carry_sequence; auto.
+ repeat (break_if; try omega);
+ try solve [rewrite Z.pow2_mod_spec by auto; autorewrite with zsimplify; apply Z.mod_pos_bound; zero_bounds];
+ pose proof (IHi i us); pose proof (IHi n us); specialize_by assumption; specialize_by auto;
+ repeat break_if; try omega; pose proof c_pos; (split; try solve [zero_bounds]).
+ (* TODO (jadep) : clean up/automate these leftover cases. *)
+ - replace (2 * 2 ^ limb_widths [n]) with (2 ^ limb_widths [n] + 2 ^ limb_widths [n]) by ring.
+ apply Z.add_lt_le_mono; subst n. omega.
+ eapply Z.le_trans; eauto.
+ apply Z.mul_le_mono_nonneg_l; try omega. subst i.
+ apply Z.shiftr_le; auto. apply Z.lt_le_incl. apply H2.
+ - replace (2 ^ B) with ((2 ^ B - ((2 ^ B) >> log_cap i)) + ((2 ^ B) >> log_cap i)) by ring.
+ apply Z.add_lt_le_mono.
+ * eapply Z.le_lt_trans with (m := us [n]); try omega.
+ replace i with (pred n) by omega.
+ eapply Z.lt_le_trans; [ apply H1 | ].
+ break_if; omega.
+ * apply Z.shiftr_le. auto.
+ apply Z.le_trans with (m := us [i]); [ omega | ].
+ eapply Z.le_trans. apply Z.lt_le_incl. apply H1.
+ break_if; omega.
+ - replace (2 ^ B) with ((2 ^ B - ((2 ^ B) >> log_cap i)) + ((2 ^ B) >> log_cap i)) by ring.
+ apply Z.add_lt_le_mono.
+ * eapply Z.le_lt_trans with (m := us [n]); try omega.
+ replace i with (pred n) by omega.
+ eapply Z.lt_le_trans; [ apply H1 | ].
+ break_if; omega.
+ * apply Z.shiftr_le. auto. omega.
+ Qed.
+
+ Lemma bound_after_first_loop : forall n us,
+ length us = length limb_widths ->
+ (forall n, 0 <= nth_default 0 us n < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> log_cap (pred n))) ->
+ 0 <= (ModularBaseSystemList.carry_full us)[n] <
+ if eq_nat_dec n 0
+ then 2 * 2 ^ limb_widths [n]
+ else 2 ^ limb_widths [n].
+ Proof.
+ cbv [ModularBaseSystemList.carry_full full_carry_chain]; intros.
+ pose proof (bound_during_first_loop (length limb_widths) n us).
+ specialize_by eauto.
+ repeat (break_if; try omega).
Qed.
- Local Hint Resolve freeze_minimal_rep.
- Lemma rep_decode_mod : forall us vs x, rep us x -> rep vs x ->
- (BaseSystem.decode base us) mod modulus = (BaseSystem.decode base vs) mod modulus.
- Proof.
- unfold rep, decode; intros.
- intuition.
- repeat rewrite <-FieldToZ_ZToField.
- congruence.
- Qed.
+ (* TODO(jadep):
+ - Proof of bound after 3 loops
+ - Proof of correctness for [ge_modulus] and [cond_subtract_modulus]
+ - Proof of correctness for [freeze]
+ * freeze us = encode (decode us)
+ * decode us = x ->
+ canonicalized_BSToWord (freeze us)) = FToWord x
- Lemma minimal_rep_unique : forall us vs x,
- rep us x -> minimal_rep us -> carry_done us ->
- rep vs x -> minimal_rep vs -> carry_done vs ->
- us = vs.
- Proof.
- intros.
- match goal with Hrep1 : rep _ ?x, Hrep2 : rep _ ?x |- _ =>
- pose proof (rep_decode_mod _ _ _ Hrep1 Hrep2) as eqmod end.
- repeat match goal with Hmin : minimal_rep ?us |- _ => unfold minimal_rep in Hmin;
- rewrite <- Hmin in eqmod; clear Hmin end.
- apply Z.compare_eq_iff in eqmod.
- rewrite decode_compare in eqmod; unfold rep in *; auto; intuition; try congruence.
- apply compare_Eq; auto.
- congruence.
- Qed.
+ (where [canonicalized_BSToWord] uses bitwise operations to concatenate digits
+ in BaseSystem in canonical form, splitting along word capacities)
+ *)
- Lemma freeze_canonical : forall us vs x,
- pre_carry_bounds us -> rep us x ->
- pre_carry_bounds vs -> rep vs x ->
- freeze us = freeze vs.
- Proof.
- intros.
- assert (length us = length base) by (unfold rep in *; intuition).
- assert (length vs = length base) by (unfold rep in *; intuition).
- eapply minimal_rep_unique; eauto; rewrite freeze_length; assumption.
- Qed.
-*)
End CanonicalizationProofs.
diff --git a/src/ModularArithmetic/Pow2Base.v b/src/ModularArithmetic/Pow2Base.v
index a2c76016d..9d6cc2410 100644
--- a/src/ModularArithmetic/Pow2Base.v
+++ b/src/ModularArithmetic/Pow2Base.v
@@ -53,14 +53,19 @@ Section Pow2Base.
(Z.pow2_mod di (log_cap i),
Z.shiftr di (log_cap i)).
+ (* [fi] is fed [length us] and [S i] and produces the index of
+ the digit to which value should be added;
+ [fc] modifies the carried value before adding it to that digit *)
Definition carry_gen fc fi i := fun us =>
- let i := fi (length us) i in
+ let i := fi i in
let di := nth_default 0 us i in
let '(di', ci) := carry_single i di in
let us' := set_nth i di' us in
- add_to_nth (fi (length us) (S i)) (fc ci) us'.
+ add_to_nth (fi (S i)) (fc ci) us'.
- Definition carry_simple := carry_gen (fun ci => ci) (fun _ i => i).
+ (* carry_simple does not modify the carried value, and always adds it
+ to the digit with index [S i] *)
+ Definition carry_simple := carry_gen (fun ci => ci) (fun i => i).
Definition carry_simple_sequence is us := fold_right carry_simple us is.
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index 9255f033f..4b616c288 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -34,7 +34,7 @@ Section Pow2BaseProofs.
Lemma two_sum_firstn_limb_widths_nonzero n : 2^sum_firstn limb_widths n <> 0.
Proof. pose proof (two_sum_firstn_limb_widths_pos n); omega. Qed.
- Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length limb_widths)%nat ->
nth_error base i = Some b ->
nth_error limb_widths i = Some w ->
nth_error base (S i) = Some (two_p w * b).
@@ -42,7 +42,7 @@ Section Pow2BaseProofs.
induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
[rewrite (@nil_length0 Z) in *; omega | ].
- simpl in *; rewrite map_length in *.
+ simpl in *.
case_eq i; intros; subst.
+ subst; apply nth_error_first in nth_err_w.
apply nth_error_first in nth_err_b; subst.
@@ -60,15 +60,14 @@ Section Pow2BaseProofs.
Qed.
- Lemma nth_error_base : forall i, (i < length base)%nat ->
+ Lemma nth_error_base : forall i, (i < length limb_widths)%nat ->
nth_error base i = Some (two_p (sum_firstn limb_widths i)).
Proof.
induction i; intros.
+ unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
- + assert (i < length base)%nat as lt_i_length by omega.
+ + assert (i < length limb_widths)%nat as lt_i_length by omega.
specialize (IHi lt_i_length).
- rewrite base_from_limb_widths_length in lt_i_length.
destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
erewrite base_from_limb_widths_step; eauto.
f_equal.
@@ -86,19 +85,16 @@ Section Pow2BaseProofs.
eapply nth_error_value_In; eauto.
Qed.
- Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ Lemma nth_default_base : forall d i, (i < length limb_widths)%nat ->
nth_default d base i = 2 ^ (sum_firstn limb_widths i).
Proof.
intros ? ? i_lt_length.
- destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
- unfold nth_default.
- rewrite nth_err_x.
- rewrite nth_error_base in nth_err_x by assumption.
- rewrite two_p_correct in nth_err_x.
- congruence.
+ apply nth_error_value_eq_nth_default.
+ rewrite nth_error_base, two_p_correct by assumption.
+ reflexivity.
Qed.
- Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ Lemma base_succ : forall i, ((S i) < length limb_widths)%nat ->
nth_default 0 base (S i) mod nth_default 0 base i = 0.
Proof.
intros.
@@ -111,8 +107,7 @@ Section Pow2BaseProofs.
apply limb_widths_nonneg.
rewrite lw_eq.
apply in_eq.
- + assert (i < length base)%nat as i_lt_length by omega.
- rewrite base_from_limb_widths_length in *.
+ + assert (i < length limb_widths)%nat as i_lt_length by omega.
apply nth_error_length_exists_value in i_lt_length.
destruct i_lt_length as [x nth_err_x].
erewrite sum_firstn_succ; eauto.
@@ -126,6 +121,7 @@ Section Pow2BaseProofs.
Proof.
intros i b nth_err_b.
pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite base_from_limb_widths_length in *.
rewrite nth_error_base in nth_err_b by assumption.
rewrite two_p_correct in nth_err_b.
congruence.
@@ -168,19 +164,19 @@ Section Pow2BaseProofs.
Section make_base_vector.
Local Notation k := (sum_firstn limb_widths (length limb_widths)).
Context (limb_widths_match_modulus : forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i + j >= length limb_widths)%nat ->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i + j >= length base)%nat ->
let w_sum := sum_firstn limb_widths in
- k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j)
+ k + w_sum (i + j - length base)%nat <= w_sum i + w_sum j)
(limb_widths_good : forall i j, (i + j < length limb_widths)%nat ->
sum_firstn limb_widths (i + j) <=
sum_firstn limb_widths i + sum_firstn limb_widths j).
Lemma base_matches_modulus: forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i+j >= length limb_widths)%nat->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
let b := nth_default 0 base in
let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
b i * b j = r * (2^k * b (i+j-length base)%nat).
@@ -188,20 +184,20 @@ Section Pow2BaseProofs.
intros.
rewrite (Z.mul_comm r).
subst r.
+ rewrite base_from_limb_widths_length in *;
assert (i + j - length limb_widths < length limb_widths)%nat by omega.
- rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.mul_pos_pos;
- subst b; rewrite ?nth_default_base; zero_bounds; rewrite ?base_from_limb_widths_length;
- auto using sum_firstn_limb_widths_nonneg, limb_widths_nonneg).
+ rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; subst b; rewrite ?nth_default_base; zero_bounds;
+ assumption).
rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
f_equal.
subst b.
- repeat rewrite nth_default_base by (rewrite ?base_from_limb_widths_length; auto).
+ repeat rewrite nth_default_base by auto.
do 2 rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg.
symmetry.
apply Z.mod_same_pow.
split.
+ apply Z.add_nonneg_nonneg; auto using sum_firstn_limb_widths_nonneg.
- + rewrite base_from_limb_widths_length; auto using limb_widths_nonneg, limb_widths_match_modulus.
+ + auto using limb_widths_match_modulus.
Qed.
Lemma base_good : forall i j : nat,
@@ -211,7 +207,9 @@ Section Pow2BaseProofs.
b i * b j = r * b (i + j)%nat.
Proof.
intros; subst b r.
- repeat rewrite nth_default_base by (omega || auto).
+ clear limb_widths_match_modulus.
+ rewrite base_from_limb_widths_length in *.
+ repeat rewrite nth_default_base by omega.
rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; zero_bounds;
auto using sum_firstn_limb_widths_nonneg).
@@ -219,10 +217,11 @@ Section Pow2BaseProofs.
rewrite Z.mod_same_pow; try ring.
split; [ auto using sum_firstn_limb_widths_nonneg | ].
apply limb_widths_good.
- rewrite <-base_from_limb_widths_length; auto using limb_widths_nonneg.
+ assumption.
Qed.
End make_base_vector.
End Pow2BaseProofs.
+Hint Rewrite @base_from_limb_widths_length : distr_length.
Section BitwiseDecodeEncode.
Context {limb_widths} (bv : BaseSystem.BaseVector (base_from_limb_widths limb_widths))
@@ -232,7 +231,7 @@ Section BitwiseDecodeEncode.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation upper_bound := (upper_bound limb_widths).
- Lemma encode'_spec : forall x i, (i <= length base)%nat ->
+ Lemma encode'_spec : forall x i, (i <= length limb_widths)%nat ->
encode' limb_widths x i = BaseSystem.encode' base x upper_bound i.
Proof.
induction i; intros.
@@ -240,13 +239,12 @@ Section BitwiseDecodeEncode.
+ rewrite encode'_succ, <-IHi by omega.
simpl; do 2 f_equal.
rewrite Z.land_ones, Z.shiftr_div_pow2 by auto using sum_firstn_limb_widths_nonneg.
- match goal with H : (S _ <= length base)%nat |- _ =>
+ match goal with H : (S _ <= length limb_widths)%nat |- _ =>
apply le_lt_or_eq in H; destruct H end.
- repeat f_equal; rewrite nth_default_base by (omega || auto); reflexivity.
- repeat f_equal; try solve [rewrite nth_default_base by (omega || auto); reflexivity].
- rewrite nth_default_out_of_bounds by omega.
+ rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- rewrite <-base_from_limb_widths_length by auto.
congruence.
Qed.
@@ -258,20 +256,17 @@ Section BitwiseDecodeEncode.
Lemma base_upper_bound_compatible : @base_max_succ_divide base upper_bound.
Proof.
unfold base_max_succ_divide; intros i lt_Si_length.
+ rewrite base_from_limb_widths_length in lt_Si_length.
rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length;
rewrite !nth_default_base by (omega || auto).
- + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
- + rewrite nth_default_out_of_bounds by omega.
+ + rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- replace (length limb_widths) with (S (pred (length limb_widths))) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- replace i with (pred (length limb_widths)) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ replace (length limb_widths) with (S (pred (length limb_widths))) by omega.
+ replace i with (pred (length limb_widths)) by omega.
+ erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
Qed.
@@ -281,7 +276,7 @@ Section BitwiseDecodeEncode.
BaseSystem.decode base (encodeZ limb_widths x) = x mod upper_bound.
Proof.
intros.
- assert (length base = length limb_widths) by auto using base_from_limb_widths_length.
+ assert (length base = length limb_widths) by distr_length.
unfold encodeZ; rewrite encode'_spec by omega.
rewrite BaseSystemProofs.encode'_spec; unfold Pow2Base.upper_bound; try zero_bounds;
auto using sum_firstn_limb_widths_nonneg.
@@ -521,7 +516,7 @@ Section carrying_helper.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
- Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length base)%nat ->
+ Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (update_nth n f us) =
(let v := nth_default 0 us n in f v - v) * nth_default 0 base n + BaseSystem.decode base us.
Proof.
@@ -540,17 +535,17 @@ Section carrying_helper.
erewrite (nth_error_value_eq_nth_default _ _ us) by eassumption.
rewrite firstn_length in Heqn0.
rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length base) n). {
- rewrite (@nth_default_out_of_bounds _ _ base) by auto.
- rewrite skipn_all by omega.
+ destruct (le_lt_dec (length limb_widths) n). {
+ rewrite (@nth_default_out_of_bounds _ _ base) by (distr_length; auto).
+ rewrite skipn_all by (rewrite base_from_limb_widths_length; omega).
do 2 rewrite decode_base_nil.
ring_simplify; auto.
} {
- rewrite (skipn_nth_default n base 0) by omega.
+ rewrite (skipn_nth_default n base 0) by (distr_length; omega).
do 2 rewrite decode'_cons.
ring_simplify; ring.
} }
- { rewrite (nth_default_out_of_bounds _ base) by omega; ring_simplify.
+ { rewrite (nth_default_out_of_bounds _ base) by (distr_length; omega); ring_simplify.
etransitivity; rewrite BaseSystem.decode'_truncate; [ reflexivity | ].
apply f_equal.
autorewrite with push_firstn simpl_update_nth.
@@ -639,12 +634,12 @@ Section carrying_helper.
Hint Rewrite @length_add_to_nth : distr_length.
- Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (set_nth n x us) =
(x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; unfold set_nth; rewrite update_nth_sum by assumption; reflexivity. Qed.
- Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (add_to_nth n x us) =
x * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; rewrite add_to_nth_set_nth, set_nth_sum; try ring_simplify; auto. Qed.
@@ -696,7 +691,7 @@ Section carrying.
Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed.
Hint Rewrite @length_carry_simple : distr_length.
- Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ Lemma nth_default_base_succ : forall i, (S i < length limb_widths)%nat ->
nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
Proof.
intros.
@@ -705,13 +700,13 @@ Section carrying.
Qed.
Lemma carry_gen_decode_eq : forall fc fi i' us
- (i := fi (length base) i')
- (Si := fi (length base) (S i)),
- (length us = length base) ->
+ (i := fi i')
+ (Si := fi (S i)),
+ (length us = length limb_widths) ->
BaseSystem.decode base (carry_gen limb_widths fc fi i' us)
= (fc (nth_default 0 us i / 2 ^ log_cap i) *
(if eq_nat_dec Si (S i)
- then if lt_dec (S i) (length base)
+ then if lt_dec (S i) (length limb_widths)
then 2 ^ log_cap i * nth_default 0 base i
else 0
else nth_default 0 base Si)
@@ -719,29 +714,29 @@ Section carrying.
+ BaseSystem.decode base us.
Proof.
intros fc fi i' us i Si H; intros.
- destruct (eq_nat_dec 0 (length base));
+ destruct (eq_nat_dec 0 (length limb_widths));
[ destruct limb_widths, us, i; simpl in *; try congruence;
break_match;
unfold carry_gen, carry_single, add_to_nth;
autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length;
reflexivity
| ].
- (*assert (0 <= i < length base)%nat by (subst i; auto with arith).*)
+ (*assert (0 <= i < length limb_widths)%nat by (subst i; auto with arith).*)
assert (0 <= log_cap i) by auto using log_cap_nonneg.
assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia).
unfold carry_gen, carry_single.
- rewrite H; change (i' mod length base)%nat with i.
+ change (i' mod length limb_widths)%nat with i.
rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
rewrite set_nth_sum by omega.
unfold Z.pow2_mod.
rewrite Z.land_ones by auto using log_cap_nonneg.
rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg.
- change (fi (length base) i') with i.
+ change (fi i') with i.
subst Si.
repeat first [ ring
| match goal with H : _ = _ |- _ => rewrite !H in * end
| rewrite nth_default_base_succ by omega
- | rewrite !(nth_default_out_of_bounds _ base) by omega
+ | rewrite !(nth_default_out_of_bounds _ base) by (distr_length; omega)
| rewrite !(nth_default_out_of_bounds _ us) by omega
| rewrite Z.mod_eq by assumption
| progress distr_length
@@ -750,8 +745,8 @@ Section carrying.
Qed.
Lemma carry_simple_decode_eq : forall i us,
- (length us = length base) ->
- (i < (pred (length base)))%nat ->
+ (length us = length limb_widths) ->
+ (i < (pred (length limb_widths)))%nat ->
BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us.
Proof.
unfold carry_simple; intros; rewrite carry_gen_decode_eq by assumption.
@@ -790,11 +785,11 @@ Section carrying.
Lemma nth_default_carry_gen_full fc fi d i n us
: nth_default d (carry_gen limb_widths fc fi i us) n
= if lt_dec n (length us)
- then (if eq_nat_dec n (fi (length us) i)
+ then (if eq_nat_dec n (fi i)
then Z.pow2_mod (nth_default 0 us n) (log_cap n)
else nth_default 0 us n) +
- if eq_nat_dec n (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec n (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0
else d.
Proof.
@@ -826,11 +821,11 @@ Section carrying.
: forall fc fi i us,
(0 <= i < length us)%nat
-> nth_default 0 (carry_gen limb_widths fc fi i us) i
- = (if eq_nat_dec i (fi (length us) i)
+ = (if eq_nat_dec i (fi i)
then Z.pow2_mod (nth_default 0 us i) (log_cap i)
else nth_default 0 us i) +
- if eq_nat_dec i (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec i (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0.
Proof.
intros; autorewrite with push_nth_default natsimplify; break_match; omega.
@@ -848,7 +843,7 @@ Section carrying.
Hint Rewrite @nth_default_carry_simple using (omega || distr_length; omega) : push_nth_default.
End carrying.
-Hint Rewrite @length_carry_gen @base_from_limb_widths_length : distr_length.
+Hint Rewrite @length_carry_gen : distr_length.
Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length.
Hint Rewrite @nth_default_carry_simple_full @nth_default_carry_gen_full : push_nth_default.
Hint Rewrite @nth_default_carry_simple @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.