aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-07-25 21:06:07 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-07-25 21:06:07 -0400
commit39a6c95de8a900c859726d875cc40ea96298d31b (patch)
tree750571dc101f477c34340716db87a3697cca41eb
parentea9397e3da37f35d088488be141cb18cc38ea11b (diff)
Put ModularBaseSystem carries in terms of [carry_gen], and pushed this change through the pipeline. Also began the process of redoing canonicalization proofs, attempting to put the messy case analysis in theorem statements rather than separate lemmas.
-rw-r--r--src/ModularArithmetic/ExtendedBaseVector.v1
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v8
-rw-r--r--src/ModularArithmetic/ModularBaseSystemList.v17
-rw-r--r--src/ModularArithmetic/ModularBaseSystemListProofs.v14
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v246
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v1831
-rw-r--r--src/ModularArithmetic/Pow2Base.v11
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v135
8 files changed, 494 insertions, 1769 deletions
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v
index fcd871aae..ef8c9716a 100644
--- a/src/ModularArithmetic/ExtendedBaseVector.v
+++ b/src/ModularArithmetic/ExtendedBaseVector.v
@@ -122,6 +122,7 @@ Section ExtendedBaseVector.
rewrite (map_nth_default _ _ _ _ 0) by omega.
apply base_matches_modulus; auto using limb_widths_nonnegative, limb_widths_match_modulus;
distr_length.
+ assumption.
} { (* i < length base, j >= length base, i + j >= length base *)
do 2 rewrite map_nth_default_base_high by omega.
remember (j - length base)%nat as j'.
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 70c8138da..4fd881e1e 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -50,16 +50,12 @@ Section ModularBaseSystem.
(* Placeholder *)
Definition div (x y : digits) : digits := encode (ModularArithmetic.div (decode x) (decode y)).
- Definition carry i (us : digits) : digits := from_list (carry i [[us]])
- (length_carry length_to_list).
-
Definition rep (us : digits) (x : F modulus) := decode us = x.
Local Notation "u ~= x" := (rep u x).
Local Hint Unfold rep.
- Definition carry_sequence is (us : digits) : digits := fold_right carry us is.
-
- Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths).
+ Definition carry_full (us : digits) : digits := from_list (carry_full [[us]])
+ (length_carry_full length_to_list).
Definition carry_mul (us vs : digits) : digits := carry_full (mul us vs).
diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v
index 07b2c2bac..cbab03d6a 100644
--- a/src/ModularArithmetic/ModularBaseSystemList.v
+++ b/src/ModularArithmetic/ModularBaseSystemList.v
@@ -31,20 +31,21 @@ Section Defs.
(* In order to subtract without underflowing, we add a multiple of the modulus first. *)
Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs.
- (*
+ (* [carry_and_reduce] multiplies the carried value by c, and, if carrying
+ from index [i] in a list [us], adds the value to the digit with index
+ [(S i) mod (length us)] *)
Definition carry_and_reduce :=
- carry_gen limb_widths (fun ci => c * ci).
- *)
- Definition carry_and_reduce i := fun us =>
- let di := us [i] in
- let us' := set_nth i (Z.pow2_mod di (limb_widths [i])) us in
- add_to_nth 0 (c * (Z.shiftr di (limb_widths [i]))) us'.
+ carry_gen limb_widths (fun ci => c * ci) (fun Si => (Si mod (length limb_widths))%nat).
Definition carry i : digits -> digits :=
- if eq_nat_dec i (pred (length base))
+ if eq_nat_dec i (pred (length limb_widths))
then carry_and_reduce i
else carry_simple limb_widths i.
+ Definition carry_sequence is (us : digits) : digits := fold_right carry us is.
+
+ Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths).
+
Definition modulus_digits := encodeZ limb_widths modulus.
(* compute at compile time *)
diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v
index 35de02cde..a49c26a53 100644
--- a/src/ModularArithmetic/ModularBaseSystemListProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v
@@ -76,6 +76,20 @@ Section LengthProofs.
Proof. intros; unfold carry; break_if; autorewrite with distr_length; omega. Qed.
Hint Rewrite @length_carry : distr_length.
+ Lemma length_carry_sequence {u i} :
+ length u = length limb_widths
+ -> length (carry_sequence i u) = length limb_widths.
+ Proof.
+ induction i; intros; unfold carry_sequence;
+ simpl; autorewrite with distr_length; auto. Qed.
+ Hint Rewrite @length_carry_sequence : distr_length.
+
+ Lemma length_carry_full {u} :
+ length u = length limb_widths
+ -> length (carry_full u) = length limb_widths.
+ Proof. intros; unfold carry_full; autorewrite with distr_length; congruence. Qed.
+ Hint Rewrite @length_carry_full : distr_length.
+
Lemma length_modulus_digits : length modulus_digits = length limb_widths.
Proof.
intros; unfold modulus_digits, encodeZ.
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 7c171faf7..436d309c7 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -1,13 +1,14 @@
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
-Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
+Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Coq.Lists.List.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
@@ -30,6 +31,7 @@ Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
+Definition Z_ones_opt := Eval compute in Z.ones.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.
@@ -115,35 +117,53 @@ Section Carries.
Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
Local Notation digits := (tuple Z (length limb_widths)).
+ Definition carry_gen_opt_sig fc fi i us
+ : { d : list Z | (0 <= fi (S (fi i)) < length us)%nat ->
+ d = carry_gen limb_widths fc fi i us}.
+ Proof.
+ eexists; intros.
+ cbv beta iota delta [carry_gen carry_single Z.pow2_mod].
+ rewrite add_to_nth_set_nth.
+ change @nth_default with @nth_default_opt in *.
+ change @set_nth with @set_nth_opt in *.
+ change Z.ones with Z_ones_opt.
+ rewrite set_nth_nth_default by assumption.
+ rewrite <- @beq_nat_eq_nat_dec.
+ reflexivity.
+ Defined.
+
+ Definition carry_gen_opt fc fi i us := Eval cbv [proj1_sig carry_gen_opt_sig] in
+ proj1_sig (carry_gen_opt_sig fc fi i us).
+
+ Definition carry_gen_opt_correct fc fi i us
+ : (0 <= fi (S (fi i)) < length us)%nat ->
+ carry_gen_opt fc fi i us = carry_gen limb_widths fc fi i us
+ := proj2_sig (carry_gen_opt_sig fc fi i us).
+
Definition carry_opt_sig
- (i : nat) (b : digits)
- : { d : digits | (i < length limb_widths)%nat -> d = carry i b }.
+ (i : nat) (b : list Z)
+ : { d : list Z | (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> d = carry i b }.
Proof.
eexists ; intros.
- cbv [carry ModularBaseSystemList.carry].
- rewrite <-from_list_default_eq with (d := 0%Z).
+ cbv [carry].
rewrite <-pull_app_if_sumbool.
cbv beta delta
- [carry carry_and_reduce Pow2Base.carry_gen Pow2Base.carry_single Pow2Base.carry_simple
- Z.pow2_mod Z.ones Z.pred
- PseudoMersenneBaseParams.limb_widths].
- rewrite !add_to_nth_set_nth.
- change @Pow2Base.base_from_limb_widths with @base_from_limb_widths_opt.
- change @nth_default with @nth_default_opt in *.
- change @set_nth with @set_nth_opt in *.
+ [carry carry_and_reduce carry_simple].
lazymatch goal with
- | [ |- _ = ?f (if ?br then ?c else ?d) ]
- => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y
+ | [ |- _ = (if ?br then ?c else ?d) ]
+ => let x := fresh "x" in let y := fresh "y" in evar (x:list Z); evar (y:list Z); transitivity (if br then x else y); subst x; subst y
end.
- 2:cbv zeta.
- 2:break_if; reflexivity.
-
- change @from_list_default with @from_list_default_opt.
- change @nth_default with @nth_default_opt.
+ Focus 2. {
+ cbv zeta.
+ break_if; rewrite <-carry_gen_opt_correct by (omega ||
+ (replace (length b) with (length limb_widths) by congruence;
+ apply Nat.mod_bound_pos; omega)); reflexivity.
+ } Unfocus.
rewrite c_subst.
- change @set_nth with @set_nth_opt.
- change @map with @map_opt.
rewrite <- @beq_nat_eq_nat_dec.
+ cbv [carry_gen_opt].
reflexivity.
Defined.
@@ -151,11 +171,15 @@ Section Carries.
proj1_sig (carry_opt_sig is us).
Definition carry_opt_correct i us
- : (i < length limb_widths)%nat -> carry_opt i us = carry i us
+ : length us = length limb_widths
+ -> (i < length limb_widths)%nat
+ -> carry_opt i us = carry i us
:= proj2_sig (carry_opt_sig i us).
- Definition carry_sequence_opt_sig (is : list nat) (us : digits)
- : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Definition carry_sequence_opt_sig (is : list nat) (us : list Z)
+ : { b : list Z | (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> b = carry_sequence is us }.
Proof.
eexists. intros H.
cbv [carry_sequence].
@@ -164,9 +188,9 @@ Section Carries.
{ induction is; [ reflexivity | ].
simpl; rewrite IHis, carry_opt_correct.
- reflexivity.
- - rewrite base_length in H.
- apply H; apply in_eq.
- - intros. apply H. right. auto.
+ - fold (carry_sequence is us). auto using length_carry_sequence.
+ - auto using in_eq.
+ - intros. auto using in_cons.
}
Unfocus.
reflexivity.
@@ -176,123 +200,128 @@ Section Carries.
proj1_sig (carry_sequence_opt_sig is us).
Definition carry_sequence_opt_correct is us
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us
+ : (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> carry_sequence_opt is us = carry_sequence is us
:= proj2_sig (carry_sequence_opt_sig is us).
- Definition carry_opt_cps_sig
- {T}
+ Definition carry_gen_opt_cps_sig
+ {T} fc fi
(i : nat)
- (f : digits -> T)
- (b : digits)
- : { d : T | (i < length base)%nat -> d = f (carry i b) }.
+ (f : list Z -> T)
+ (b : list Z)
+ : { d : T | (0 <= fi (S (fi i)) < length b)%nat -> d = f (carry_gen limb_widths fc fi i b) }.
Proof.
eexists. intros H.
- rewrite <- carry_opt_correct by (rewrite base_length in H; assumption).
- cbv beta iota delta [carry_opt].
+ rewrite <-carry_gen_opt_correct by assumption.
+ cbv beta iota delta [carry_gen_opt].
+ match goal with |- appcontext[?a & Z_ones_opt _] =>
let LHS := match goal with |- ?LHS = ?RHS => LHS end in
let RHS := match goal with |- ?LHS = ?RHS => RHS end in
- let RHSf := match (eval pattern (nth_default_opt 0%Z (to_list _ b) i) in RHS) with ?RHSf _ => RHSf end in
- change (LHS = Let_In (nth_default_opt 0%Z (to_list _ b) i) RHSf).
- change Z.shiftl with Z_shiftl_opt.
- match goal with |- appcontext[ ?x + -1] => change (x + -1) with (Z_add_opt x (Z_opp_opt 1)) end.
+ let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in
+ change (LHS = Let_In (a) RHSf) end.
reflexivity.
Defined.
- Definition carry_opt_cps {T} i f b
- := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
+ Definition carry_gen_opt_cps {T} fc fi i f b
+ := Eval cbv beta iota delta [proj1_sig carry_gen_opt_cps_sig] in
+ proj1_sig (@carry_gen_opt_cps_sig T fc fi i f b).
- Definition carry_opt_cps_correct {T} i f b :
- (i < length base)%nat ->
- @carry_opt_cps T i f b = f (carry i b)
- := proj2_sig (carry_opt_cps_sig i f b).
+ Definition carry_gen_opt_cps_correct {T} fc fi i f b :
+ (0 <= fi (S (fi i)) < length b)%nat ->
+ @carry_gen_opt_cps T fc fi i f b = f (carry_gen limb_widths fc fi i b)
+ := proj2_sig (carry_gen_opt_cps_sig fc fi i f b).
- Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits)
- (f : digits -> T)
- : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }.
+ Definition carry_opt_cps_sig
+ {T}
+ (i : nat)
+ (f : list Z -> T)
+ (b : list Z)
+ : { d : T | (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> d = f (carry i b) }.
Proof.
- eexists.
- cbv [carry_sequence].
- transitivity (fold_right carry_opt_cps f (List.rev is) us).
- Focus 2.
- {
- assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
- subst. intros. rewrite <- in_rev in *. auto. }
- remember (rev is) as ris eqn:Heq.
- rewrite <- (rev_involutive is), <- Heq.
- clear H Heq is.
- rewrite fold_left_rev_right.
- revert us; induction ris; [ reflexivity | ]; intros.
- { simpl.
- rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
- rewrite carry_opt_cps_correct; [reflexivity|].
- apply Hr; left; reflexivity.
- } }
- Unfocus.
+ eexists. intros.
+ cbv beta delta
+ [carry carry_and_reduce carry_simple].
+ rewrite <-pull_app_if_sumbool.
+ lazymatch goal with
+ | [ |- _ = ?f (if ?br then ?c else ?d) ]
+ => let x := fresh "x" in let y := fresh "y" in evar (x:T); evar (y:T); transitivity (if br then x else y); subst x; subst y
+ end.
+ Focus 2. {
+ cbv zeta.
+ break_if; rewrite <-carry_gen_opt_cps_correct by (omega ||
+ (replace (length b) with (length limb_widths) by congruence;
+ apply Nat.mod_bound_pos; omega)); reflexivity.
+ } Unfocus.
+ rewrite c_subst.
+ rewrite <- @beq_nat_eq_nat_dec.
reflexivity.
Defined.
- Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) :=
- Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in
- proj1_sig (carry_sequence_opt_cps2_sig is us f).
+ Definition carry_opt_cps {T} i f b
+ := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
- Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T)
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us)
- := proj2_sig (carry_sequence_opt_cps2_sig is us f).
+ Definition carry_opt_cps_correct {T} i f b :
+ (length b = length limb_widths)
+ -> (i < length limb_widths)%nat
+ -> @carry_opt_cps T i f b = f (carry i b)
+ := proj2_sig (carry_opt_cps_sig i f b).
- Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits)
- : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Definition carry_sequence_opt_cps_sig {T} (is : list nat) (us : list Z)
+ (f : list Z -> T)
+ : { b : T | (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> b = f (carry_sequence is us) }.
Proof.
eexists.
cbv [carry_sequence].
- transitivity (fold_right carry_opt_cps id (List.rev is) us).
+ transitivity (fold_right carry_opt_cps f (List.rev is) us).
Focus 2.
{
- assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
+ assert (forall i, In i (rev is) -> i < length limb_widths)%nat as Hr. {
subst. intros. rewrite <- in_rev in *. auto. }
remember (rev is) as ris eqn:Heq.
- rewrite <- (rev_involutive is), <- Heq.
- clear H Heq is.
+ rewrite <- (rev_involutive is), <- Heq in H0 |- *.
+ clear H0 Heq is.
rewrite fold_left_rev_right.
- revert us; induction ris; [ reflexivity | ]; intros.
+ revert H. revert us; induction ris; [ reflexivity | ]; intros.
{ simpl.
- rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
- rewrite carry_opt_cps_correct; [reflexivity|].
+ rewrite <- IHris; clear IHris;
+ [|intros; apply Hr; right; assumption|auto using length_carry].
+ rewrite carry_opt_cps_correct; [reflexivity|congruence|].
apply Hr; left; reflexivity.
} }
Unfocus.
+ cbv [carry_opt_cps].
reflexivity.
Defined.
- Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
- proj1_sig (carry_sequence_opt_cps_sig is us).
-
- Definition carry_sequence_opt_cps_correct is us
- : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us
- := proj2_sig (carry_sequence_opt_cps_sig is us).
-
+ Definition carry_sequence_opt_cps {T} is us (f : list Z -> T) :=
+ Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
+ proj1_sig (carry_sequence_opt_cps_sig is us f).
- Lemma carry_sequence_opt_cps_rep
- : forall (is : list nat) (us : digits) (x : F modulus),
- (forall i : nat, In i is -> i < length base)%nat ->
- rep us x -> rep (carry_sequence_opt_cps is us) x.
- Proof.
- intros.
- rewrite carry_sequence_opt_cps_correct by assumption.
- auto using carry_sequence_rep.
- Qed.
+ Definition carry_sequence_opt_cps_correct {T} is us (f : list Z -> T)
+ : (length us = length limb_widths)
+ -> (forall i, In i is -> i < length limb_widths)%nat
+ -> carry_sequence_opt_cps is us f = f (carry_sequence is us)
+ := proj2_sig (carry_sequence_opt_cps_sig is us f).
- Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> (i < length base)%nat.
+ Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) ->
+ (i < length limb_widths)%nat.
Proof.
- unfold Pow2Base.full_carry_chain; rewrite <-base_length; intros.
+ unfold Pow2Base.full_carry_chain; intros.
apply Pow2BaseProofs.make_chain_lt; auto.
Qed.
Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }.
Proof.
eexists.
- cbv [carry_full].
+ cbv [carry_full ModularBaseSystemList.carry_full].
+ rewrite <-from_list_default_eq with (d := 0).
+ rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds).
change @Pow2Base.full_carry_chain with full_carry_chain_opt.
- rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds).
reflexivity.
Defined.
@@ -311,8 +340,10 @@ Section Carries.
eexists.
rewrite <- carry_full_opt_correct.
cbv beta iota delta [carry_full_opt].
- rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds.
- rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds.
+ rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
+ match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) =>
+ change (LHS = (fun x => f (g x)) (carry_sequence is us)) end.
+ rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
reflexivity.
Defined.
@@ -518,7 +549,16 @@ Section Multiplication.
cbv [carry_mul].
erewrite <-carry_full_opt_cps_correct by eauto.
erewrite <-mul_opt_correct.
+ cbv [carry_full_opt_cps mul_opt].
+ erewrite from_list_default_eq.
+ rewrite to_list_from_list.
reflexivity.
+ Grab Existential Variables.
+ rewrite mul'_opt_correct.
+ distr_length.
+ assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; destruct limb_widths; congruence || simpl; omega).
+ rewrite Min.min_l; rewrite !length_to_list; break_match; try omega.
+ rewrite Max.max_l; omega.
Defined.
Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index e6351dc17..01c073f06 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -184,7 +184,7 @@ Section PseudoMersenneProofs.
Proof.
intros.
apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
- apply base_succ; eauto.
+ apply base_succ; distr_length; eauto.
Qed.
Lemma Fdecode_decode_mod : forall us x,
@@ -253,40 +253,41 @@ Section CarryProofs.
Lemma carry_decode_eq_reduce : forall us,
(length us = length limb_widths) ->
- BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
+ BaseSystem.decode base (carry_and_reduce (pred (length limb_widths)) us) mod modulus
= BaseSystem.decode base us mod modulus.
Proof.
- unfold carry_and_reduce; intros ? length_eq.
- pose proof base_length_nonzero.
- rewrite add_to_nth_sum by (rewrite length_set_nth, base_length; omega).
- rewrite set_nth_sum by (rewrite base_length; omega).
- rewrite Zplus_comm, <- Z.mul_assoc, <- pseudomersenne_add, BaseSystem.b0_1.
- rewrite (Z.mul_comm (2 ^ k)), <- Zred_factor0.
- f_equal.
- rewrite <- (Z.add_comm (BaseSystem.decode base us)), <- Z.add_assoc, <- Z.add_0_r.
- f_equal.
- destruct (NPeano.Nat.eq_dec (length base) 0) as [length_zero | length_nonzero].
- + pose proof (base_length) as limbs_length.
- destruct us; rewrite ?(@nil_length0 Z), ?(@length_cons Z) in *;
- pose proof base_length_nonzero; omega.
- + rewrite nth_default_base by (omega || eauto).
- rewrite <- Z.add_opp_l, <- Z.opp_sub_distr.
- unfold Z.pow2_mod.
- rewrite Z.land_ones by eauto using log_cap_nonneg.
- rewrite <- Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega || eauto using log_cap_nonneg).
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- unfold k.
- replace (length limb_widths) with (S (pred (length base))) by
- (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega).
- rewrite sum_firstn_succ with (x:= log_cap (pred (length base))) by
- (apply nth_error_Some_nth_default; rewrite <- base_length; omega).
- rewrite Z.pow_add_r by eauto using log_cap_nonneg.
- ring.
+ cbv [carry_and_reduce]; intros.
+ rewrite carry_gen_decode_eq; auto.
+ distr_length.
+ assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil;
+ destruct limb_widths; distr_length; congruence).
+ repeat break_if; repeat rewrite ?pred_mod, ?Nat.succ_pred,?Nat.mod_same in * by omega;
+ try omega.
+ rewrite !nth_default_base by (omega || auto).
+ rewrite sum_firstn_0.
+ autorewrite with zsimplify.
+ match goal with |- appcontext [2 ^ ?a * ?b * 2 ^ ?c] =>
+ replace (2 ^ a * b * 2 ^ c) with (2 ^ (a + c) * b) end.
+ { rewrite <-sum_firstn_succ by (apply nth_error_Some_nth_default; omega).
+ rewrite Nat.succ_pred by omega.
+ remember (Init.Nat.pred (length limb_widths)) as pred_len.
+ fold k.
+ rewrite <-Z.mul_sub_distr_r.
+ replace (c - 2 ^ k) with (modulus * -1) by (cbv [c]; ring).
+ rewrite <-Z.mul_assoc.
+ apply Z.mod_add_l'.
+ pose proof prime_modulus. Z.prime_bound. }
+ { rewrite Z.pow_add_r; auto using log_cap_nonneg, sum_firstn_limb_widths_nonneg.
+ rewrite <-!Z.mul_assoc.
+ apply Z.mul_cancel_l; try ring.
+ apply Z.pow_nonzero; (omega || auto using log_cap_nonneg). }
Qed.
Lemma carry_rep : forall i us x,
- (i < length base)%nat ->
- us ~= x -> carry i us ~= x.
+ (length us = length limb_widths)%nat ->
+ (i < length limb_widths)%nat ->
+ forall pf1 pf2,
+ from_list _ us pf1 ~= x -> from_list _ (carry i us) pf2 ~= x.
Proof.
cbv [carry rep decode]; intros.
rewrite to_list_from_list.
@@ -295,18 +296,24 @@ Section CarryProofs.
specialize_by eauto.
cbv [ModularBaseSystemList.carry].
break_if; subst; eauto.
- apply F_eq; apply carry_decode_eq_reduce; apply length_to_list.
+ apply F_eq.
+ rewrite to_list_from_list.
+ apply carry_decode_eq_reduce. auto.
cbv [ModularBaseSystemList.decode].
- f_equal.
- apply carry_simple_decode_eq; try omega; rewrite ?base_length; auto using length_to_list.
+ apply ZToField_eqmod.
+ rewrite to_list_from_list, carry_simple_decode_eq; try omega; distr_length; auto.
Qed.
Hint Resolve carry_rep.
Lemma carry_sequence_rep : forall is us x,
- (forall i, In i is -> (i < length base)%nat) ->
- us ~= x -> carry_sequence is us ~= x.
+ (forall i, In i is -> (i < length limb_widths)%nat) ->
+ us ~= x -> forall pf, from_list _ (carry_sequence is (to_list _ us)) pf ~= x.
Proof.
- induction is; boring.
+ induction is; intros.
+ + cbv [carry_sequence fold_right]. rewrite from_list_to_list. assumption.
+ + simpl. apply carry_rep with (pf1 := length_carry_sequence (length_to_list us));
+ auto using length_carry_sequence, length_to_list, in_eq.
+ apply IHis; auto using in_cons.
Qed.
Lemma carry_full_preserves_rep : forall us x,
@@ -314,7 +321,7 @@ Section CarryProofs.
Proof.
unfold carry_full; intros.
apply carry_sequence_rep; auto.
- unfold full_carry_chain; rewrite base_length; apply make_chain_lt.
+ unfold full_carry_chain; apply make_chain_lt.
Qed.
Opaque carry_full.
@@ -334,1561 +341,227 @@ Section CanonicalizationProofs.
Context `{prm : PseudoMersenneBaseParams}.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
- Context (lt_1_length_base : (1 < length base)%nat)
+ Context (lt_1_length_base : (1 < length limb_widths)%nat)
{B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B)
(* on the first reduce step, we add at most one bit of width to the first digit *)
- (c_reduce1 : c * (Z.ones (B - log_cap (pred (length base)))) < 2 ^ log_cap 0)
+ (c_reduce1 : c * ((2 ^ B) >> log_cap (pred (length limb_widths))) <= 2 ^ log_cap 0)
(* on the second reduce step, we add at most one bit of width to the first digit,
and leave room to carry c one more time after the highest bit is carried *)
(c_reduce2 : c <= nth_default 0 modulus_digits 0)
(* this condition is probably implied by c_reduce2, but is more straighforward to compute than to prove *)
(two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus).
-(*
- (* BEGIN groundwork proofs *)
- Local Hint Resolve (@log_cap_nonneg limb_widths) limb_widths_nonneg.
-
- Lemma pow_2_log_cap_pos : forall i, 0 < 2 ^ log_cap i.
- Proof.
- intros; apply Z.pow_pos_nonneg; eauto using log_cap_nonneg; omega.
- Qed.
- Local Hint Resolve pow_2_log_cap_pos.
-
- Lemma max_value_log_cap : forall i, Z.succ (max_value i) = 2 ^ log_cap i.
- Proof.
- intros.
- unfold max_value, Z.ones.
- rewrite Z.shiftl_1_l.
- omega.
- Qed.
-
- Lemma pow2_mod_log_cap_range : forall a i, 0 <= Z.pow2_mod a (log_cap i) <= max_value i.
- Proof.
- intros.
- unfold Z.pow2_mod.
- rewrite Z.land_ones by eauto using log_cap_nonneg.
- unfold max_value, Z.ones.
- rewrite Z.shiftl_1_l, <-Z.lt_le_pred.
- apply Z_mod_lt.
- pose proof (pow_2_log_cap_pos i).
- omega.
- Qed.
-
- Lemma pow2_mod_log_cap_bounds_lower : forall a i, 0 <= Z.pow2_mod a (log_cap i).
- Proof.
- intros.
- pose proof (pow2_mod_log_cap_range a i); omega.
- Qed.
-
- Lemma pow2_mod_log_cap_bounds_upper : forall a i, Z.pow2_mod a (log_cap i) <= max_value i.
- Proof.
- intros.
- pose proof (pow2_mod_log_cap_range a i); omega.
- Qed.
-
- Lemma pow2_mod_log_cap_small : forall a i, 0 <= a <= max_value i ->
- Z.pow2_mod a (log_cap i) = a.
- Proof.
- intros.
- unfold Z.pow2_mod.
- rewrite Z.land_ones by eauto using log_cap_nonneg.
- apply Z.mod_small.
- split; try omega.
- rewrite <- max_value_log_cap.
- omega.
- Qed.
-
- Lemma max_value_pos : forall i, (i < length base)%nat -> 0 < max_value i.
- Proof.
- unfold max_value; intros; apply Z.ones_pos_pos.
- apply limb_widths_pos.
- rewrite nth_default_eq.
- apply nth_In.
- rewrite <-base_length; assumption.
- Qed.
- Local Hint Resolve max_value_pos.
-
- Lemma max_value_nonneg : forall i, 0 <= max_value i.
- Proof.
- unfold max_value; intros; eauto using Z.ones_nonneg.
- Qed.
- Local Hint Resolve max_value_nonneg.
-
- Lemma shiftr_eq_0_max_value : forall i a, Z.shiftr a (log_cap i) = 0 ->
- a <= max_value i.
- Proof.
- intros ? ? shiftr_0.
- apply Z.shiftr_eq_0_iff in shiftr_0.
- intuition; subst; try apply max_value_nonneg.
- match goal with H : Z.log2 _ < log_cap _ |- _ => apply Z.log2_lt_pow2 in H;
- replace (2 ^ log_cap i) with (Z.succ (max_value i)) in H by
- (unfold max_value, Z.ones; rewrite Z.shiftl_1_l; omega)
- end; auto.
- omega.
- Qed.
-
- Lemma B_compat_log_cap : forall i, 0 <= B - log_cap i.
- Proof.
- intros.
- destruct (lt_dec i (length limb_widths)).
- + apply Z.le_0_sub.
- apply B_compat.
- rewrite nth_default_eq.
- apply nth_In; assumption.
- + replace (nth_default 0 limb_widths i) with 0; try omega.
- symmetry; apply nth_default_out_of_bounds.
- omega.
- Qed.
- Local Hint Resolve B_compat_log_cap.
-
- Lemma max_value_shiftr_eq_0 : forall i a, 0 <= a -> a <= max_value i ->
- Z.shiftr a (log_cap i) = 0.
- Proof.
- intros ? ? ? le_max_value.
- apply Z.shiftr_eq_0_iff.
- destruct (Z_eq_dec a 0); auto.
- right.
- split; try omega.
- apply Z.log2_lt_pow2; try omega.
- rewrite <-max_value_log_cap.
- omega.
- Qed.
-
- Lemma log_cap_eq : forall i, log_cap i = nth_default 0 limb_widths i.
- Proof.
- reflexivity.
- Qed.
-
- (* END groundwork proofs *)
- Opaque Z.pow2_mod max_value.
-
- (* automation *)
- Ltac carry_length_conditions' := unfold carry_full;
- rewrite ?length_set_nth, ?length_add_to_nth, ?length_carry, ?carry_sequence_length;
- try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ].
- Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'.
-
- Ltac add_set_nth :=
- rewrite ?add_to_nth_nth_default by carry_length_conditions; break_if; try omega;
- rewrite ?set_nth_nth_default by carry_length_conditions; break_if; try omega.
-
- (* BEGIN defs *)
-
- Definition pre_carry_bounds us := forall i, 0 <= nth_default 0 us i <
- if (eq_nat_dec i 0) then 2 ^ B else 2 ^ B - 2 ^ (B - log_cap (pred i)).
-
- Lemma pre_carry_bounds_nonzero : forall us, pre_carry_bounds us ->
- (forall i, 0 <= nth_default 0 us i).
- Proof.
- unfold pre_carry_bounds.
- intros ? PCB i.
- specialize (PCB i).
- omega.
- Qed.
- Local Hint Resolve pre_carry_bounds_nonzero.
-
- (* END defs *)
-
- (* BEGIN proofs about first carry loop *)
-
- Lemma nth_default_carry_bound_upper : forall i us, (length us = length base) ->
- nth_default 0 (carry i us) i <= max_value i.
- Proof.
- unfold carry; intros.
- break_if.
- + unfold carry_and_reduce.
- add_set_nth.
- apply pow2_mod_log_cap_bounds_upper.
- + autorewrite with push_nth_default natsimplify.
- destruct (lt_dec i (length us)); auto using pow2_mod_log_cap_bounds_upper.
- Qed.
- Local Hint Resolve nth_default_carry_bound_upper.
-
- Lemma nth_default_carry_bound_lower : forall i us, (length us = length base) ->
- 0 <= nth_default 0 (carry i us) i.
- Proof.
- unfold carry; intros.
- break_if.
- + unfold carry_and_reduce.
- add_set_nth.
- apply pow2_mod_log_cap_bounds_lower.
- + autorewrite with push_nth_default natsimplify.
- break_if; auto using pow2_mod_log_cap_bounds_lower, Z.le_refl.
- Qed.
- Local Hint Resolve nth_default_carry_bound_lower.
-
- Lemma nth_default_carry_bound_succ_lower : forall i us, (forall i, 0 <= nth_default 0 us i) ->
- (length us = length base) ->
- 0 <= nth_default 0 (carry i us) (S i).
- Proof.
- unfold carry; intros ? ? PCB ? .
- break_if.
- + subst. replace (S (pred (length base))) with (length base) by omega.
- rewrite nth_default_out_of_bounds; carry_length_conditions.
- unfold carry_and_reduce.
- carry_length_conditions.
- + autorewrite with push_nth_default natsimplify.
- break_if; zero_bounds.
- Qed.
-
- Lemma carry_unaffected_low : forall i j us, ((0 < i < j)%nat \/ (i = 0 /\ j <> 0 /\ j <> pred (length base))%nat)->
- (length us = length base) ->
- nth_default 0 (carry j us) i = nth_default 0 us i.
- Proof.
- intros.
- unfold carry.
- break_if.
- + unfold carry_and_reduce.
- add_set_nth.
- + autorewrite with push_nth_default simpl_nth_default natsimplify.
- repeat break_if; autorewrite with simpl_nth_default natsimplify; omega.
- Qed.
-
- Lemma carry_unaffected_high : forall i j us, (S j < i)%nat -> (length us = length base) ->
- nth_default 0 (carry j us) i = nth_default 0 us i.
- Proof.
- intros.
- destruct (lt_dec i (length us));
- [ | rewrite !nth_default_out_of_bounds by carry_length_conditions; reflexivity].
- unfold carry.
- break_if; [omega | autorewrite with push_nth_default natsimplify; repeat break_if; omega ].
- Qed.
-
- Hint Rewrite max_bound_shiftr_eq_0 using omega : core.
- Hint Rewrite pow2_mod_log_cap_small using assumption : core.
-
- Lemma carry_nothing : forall i j us, (i < length base)%nat ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 us j <= max_value j ->
- nth_default 0 (carry j us) i = nth_default 0 us i.
- Proof.
- unfold carry, carry_and_reduce; intros.
- repeat (break_if
- || subst
- || (autorewrite with push_nth_default natsimplify core)
- || omega).
- Qed.
-
- Hint Rewrite pow2_mod_log_cap_small using (intuition; auto using shiftr_eq_0_max_bound) : core.
-
- Lemma carry_carry_done_done : forall i us,
- (length us = length base)%nat ->
- (i < length base)%nat ->
- carry_done us -> carry_done (carry i us).
- Proof.
- unfold carry_done; intros i ? ? i_bound Hcarry_done x x_bound.
- destruct (Hcarry_done x x_bound) as [lower_bound_x shiftr_0_x].
- destruct (Hcarry_done i i_bound) as [lower_bound_i shiftr_0_i].
- split.
- + rewrite carry_nothing; auto.
- split; [ apply Hcarry_done; auto | ].
- apply shiftr_eq_0_max_value.
- apply Hcarry_done; auto.
- + unfold carry, carry_and_reduce; break_if; subst.
- - add_set_nth; subst.
- * rewrite shiftr_0_i, Z.mul_0_r, Z.add_0_l.
- assumption.
- * rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_value).
- assumption.
- - repeat (carry_length_conditions
- || (autorewrite with push_nth_default natsimplify core zsimplify)
- || break_if
- || subst
- || rewrite shiftr_0_i by omega).
- Qed.
-
- Lemma carry_sequence_chain_step : forall i us,
- carry_sequence (make_chain (S i)) us = carry i (carry_sequence (make_chain i) us).
- Proof.
- reflexivity.
- Qed.
-
- Lemma carry_bounds_0_upper : forall us j, (length us = length base) ->
- (0 < j < length base)%nat ->
- nth_default 0 (carry_sequence (make_chain j) us) 0 <= max_value 0.
- Proof.
- induction j as [ | [ | j ] IHj ]; [simpl; intros; omega | | ]; intros.
- + subst; simpl; auto.
- + rewrite carry_sequence_chain_step, carry_unaffected_low; carry_length_conditions.
- Qed.
-
- Lemma carry_bounds_upper : forall i us j, (0 < i < j)%nat -> (length us = length base) ->
- nth_default 0 (carry_sequence (make_chain j) us) i <= max_value i.
- Proof.
- unfold carry_sequence;
- induction j; [simpl; intros; omega | ].
- intros.
- simpl in *.
- assert (i = j \/ i < j)%nat as cases by omega.
- destruct cases as [eq_j_i | lt_i_j]; subst.
- + apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain j) us); carry_length_conditions.
- + rewrite carry_unaffected_low; try omega.
- fold (carry_sequence (make_chain j) us); carry_length_conditions.
- Qed.
-
- Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat ->
- nth_default 0 (carry_sequence (make_chain j) us) i = nth_default 0 us i.
- Proof.
- induction j; [simpl; intros; omega | ].
- intros.
- simpl in *.
- rewrite carry_unaffected_high by carry_length_conditions.
- apply IHj; omega.
- Qed.
-
- (* makes omega run faster *)
- Ltac clear_obvious :=
- match goal with
- | [H : ?a <= ?a |- _] => clear H
- | [H : ?a <= S ?a |- _] => clear H
- | [H : ?a < S ?a |- _] => clear H
- | [H : ?a = ?a |- _] => clear H
- end.
-
- Lemma carry_sequence_bounds_lower : forall j i us, (length us = length base) ->
- (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain j) us) i.
- Proof.
- induction j; intros; simpl; auto.
- destruct (lt_dec (S j) i).
- + rewrite carry_unaffected_high by carry_length_conditions.
- apply IHj; auto; omega.
- + assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega.
- destruct cases as [? | [? | ?]].
- - subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions.
- intros; eapply IHj; auto; omega.
- - subst. apply nth_default_carry_bound_lower; carry_length_conditions.
- - destruct (eq_nat_dec j (pred (length base)));
- [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ].
- subst.
- do 2 match goal with H : appcontext[S (pred (length base))] |- _ =>
- erewrite <-(S_pred (length base)) in H by eauto end.
- unfold carry; break_if; [ unfold carry_and_reduce | omega ].
- clear_obvious. pose proof c_pos.
- add_set_nth; [ zero_bounds | ]; apply IHj; auto; omega.
- Qed.
-
- Ltac carry_seq_lower_bound :=
- repeat (intros; eapply carry_sequence_bounds_lower; eauto; carry_length_conditions).
-
- Lemma carry_bounds_lower : forall i us j, (0 < i <= j)%nat -> (length us = length base) ->
- (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain j) us) i.
- Proof.
- unfold carry_sequence;
- induction j; [simpl; intros; omega | ].
- intros.
- simpl in *.
- destruct (eq_nat_dec i (S j)).
- + subst. apply nth_default_carry_bound_succ_lower; auto;
- fold (carry_sequence (make_chain j) us); carry_length_conditions.
- carry_seq_lower_bound.
- + assert (i = j \/ i < j)%nat as cases by omega.
- destruct cases as [eq_j_i | lt_i_j]; subst;
- [apply nth_default_carry_bound_lower| rewrite carry_unaffected_low]; try omega;
- fold (carry_sequence (make_chain j) us); carry_length_conditions.
- carry_seq_lower_bound.
- Qed.
-
- Lemma carry_full_bounds : forall us i, (i <> 0)%nat -> (forall i, 0 <= nth_default 0 us i) ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full us) i <= max_value i.
- Proof.
- unfold carry_full, full_carry_chain; intros.
- split; (destruct (lt_dec i (length limb_widths));
- [ | rewrite nth_default_out_of_bounds by carry_length_conditions; omega || auto ]).
- + apply carry_bounds_lower; carry_length_conditions.
- + apply carry_bounds_upper; carry_length_conditions.
- Qed.
-
- Lemma carry_simple_no_overflow : forall us i, (i < pred (length base))%nat ->
- length us = length base ->
- 0 <= nth_default 0 us i < 2 ^ B ->
- 0 <= nth_default 0 us (S i) < 2 ^ B - 2 ^ (B - log_cap i) ->
- 0 <= nth_default 0 (carry i us) (S i) < 2 ^ B.
- Proof.
- intros.
- unfold carry; break_if; try omega.
- autorewrite with push_nth_default natsimplify.
- break_if; try omega.
- rewrite Z.add_comm.
- replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega.
- split; [ zero_bounds | ].
- apply Z.add_lt_mono; try omega.
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos.
- rewrite <-Z.pow_add_r by (eauto using log_cap_nonneg || apply B_compat_log_cap).
- replace (log_cap i + (B - log_cap i)) with B by ring.
- omega.
- Qed.
-
- Lemma carry_sequence_no_overflow : forall i us, pre_carry_bounds us ->
- (length us = length base) ->
- nth_default 0 (carry_sequence (make_chain i) us) i < 2 ^ B.
- Proof.
- unfold pre_carry_bounds.
- intros ? ? PCB ?.
- induction i.
- + simpl. specialize (PCB 0%nat).
- intuition.
- + simpl.
- destruct (lt_eq_lt_dec i (pred (length base))) as [[? | ? ] | ? ].
- - apply carry_simple_no_overflow; carry_length_conditions; carry_seq_lower_bound.
- rewrite carry_sequence_unaffected; try omega.
- specialize (PCB (S i)); rewrite Nat.pred_succ in PCB.
- break_if; intuition.
- - unfold carry; break_if; try omega.
- rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ].
- subst; unfold carry_and_reduce.
- carry_length_conditions.
- - rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ].
- carry_length_conditions.
- Qed.
-
- Lemma carry_full_bounds_0 : forall us, pre_carry_bounds us ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full us) 0 <= max_value 0 + c * (Z.ones (B - log_cap (pred (length base)))).
- Proof.
- unfold carry_full, full_carry_chain; intros.
- rewrite <- base_length.
- replace (length base) with (S (pred (length base))) by omega.
- simpl.
- unfold carry, carry_and_reduce; break_if; try omega.
- clear_obvious; add_set_nth.
- split; [pose proof c_pos; zero_bounds; carry_seq_lower_bound | ].
- rewrite Z.add_comm.
- apply Z.add_le_mono.
- + apply carry_bounds_0_upper; auto; omega.
- + apply Z.mul_le_mono_pos_l; auto using c_pos.
- apply Z.shiftr_ones; eauto;
- [ | pose proof (B_compat_log_cap (pred (length base))); omega ].
- split.
- - apply carry_bounds_lower; auto; omega.
- - apply carry_sequence_no_overflow; auto.
- Qed.
-
- Lemma carry_full_bounds_lower : forall i us, pre_carry_bounds us ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full us) i.
- Proof.
- destruct i; intros.
- + apply carry_full_bounds_0; auto.
- + destruct (lt_dec (S i) (length base)).
- - apply carry_bounds_lower; carry_length_conditions.
- - rewrite nth_default_out_of_bounds; carry_length_conditions.
- Qed.
-
- (* END proofs about first carry loop *)
-
- (* BEGIN proofs about second carry loop *)
-
- Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full us)) i <= 2 ^ log_cap i.
- Proof.
- induction i; intros; try omega.
- simpl.
- unfold carry; break_if; try omega.
- autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ].
- rewrite Z.add_comm.
- split.
- + zero_bounds; [destruct (eq_nat_dec i 0); subst | ].
- - simpl; apply carry_full_bounds_0; auto.
- - apply IHi; auto; omega.
- - rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; auto; omega.
- + rewrite <-max_value_log_cap, <-Z.add_1_l.
- apply Z.add_le_mono.
- - rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.div_floor; auto.
- destruct i.
- * simpl.
- eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ].
- replace (2 ^ log_cap 0 * 2) with (2 ^ log_cap 0 + 2 ^ log_cap 0) by ring.
- rewrite <-max_value_log_cap, <-Z.add_1_l.
- apply Z.add_lt_le_mono; omega.
- * eapply Z.le_lt_trans; [ apply IHi; auto; omega | ].
- apply Z.lt_mul_diag_r; auto; omega.
- - rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; auto; omega.
- Qed.
-
- Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us ->
- (length us = length base)%nat -> (1 < length base)%nat ->
- 0 <= nth_default 0 (carry_full (carry_full us)) 0 <= max_value 0 + c.
- Proof.
- intros.
- unfold carry_full at 1 3, full_carry_chain.
- rewrite <-base_length.
- replace (length base) with (S (pred (length base))) by (pose proof base_length_nonzero; omega).
- simpl.
- unfold carry, carry_and_reduce; break_if; try omega.
- clear_obvious; add_set_nth.
- split.
- + pose proof c_pos; zero_bounds; [ | carry_seq_lower_bound].
- apply carry_sequence_carry_full_bounds_same; auto; omega.
- + rewrite Z.add_comm.
- apply Z.add_le_mono.
- - apply carry_bounds_0_upper; carry_length_conditions.
- - etransitivity; [ | replace c with (c * 1) by ring; reflexivity ].
- apply Z.mul_le_mono_pos_l; try (pose proof c_pos; omega).
- rewrite Z.shiftr_div_pow2 by eauto.
- apply Z.div_le_upper_bound; auto.
- ring_simplify.
- apply carry_sequence_carry_full_bounds_same; auto.
- omega.
- Qed.
-
- Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < pred (length base))%nat ->
- ((0 < i < length base)%nat ->
- 0 <= nth_default 0
- (carry_sequence (make_chain i) (carry_full (carry_full us))) i <=
- 2 ^ log_cap i) ->
- 0 <= nth_default 0 (carry_simple limb_widths i
- (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i).
- Proof.
- intros ? ? PCB length_eq ? IH.
- autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ].
- rewrite Z.add_comm.
- split.
- + zero_bounds. destruct i;
- [ simpl; pose proof (carry_full_2_bounds_0 us PCB length_eq); omega | ].
- rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; carry_length_conditions.
- carry_seq_lower_bound.
- + rewrite <-max_value_log_cap, <-Z.add_1_l.
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.add_le_mono.
- - apply Z.div_le_upper_bound; auto.
- ring_simplify. apply IH. omega.
- - rewrite carry_sequence_unaffected by carry_length_conditions.
- apply carry_full_bounds; carry_length_conditions.
- carry_seq_lower_bound.
- Qed.
-
- Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i.
- Proof.
- intros; induction i; try omega.
- simpl; unfold carry.
- break_if; try omega.
- split; (destruct (eq_nat_dec i 0); subst;
- [ cbv [make_chain carry_sequence fold_right];
- autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ];
- rewrite Z.add_comm
- | eapply carry_full_2_bounds_succ; eauto; omega]).
- + zero_bounds.
- - eapply carry_full_2_bounds_0; eauto.
- - eapply carry_full_bounds; eauto; carry_length_conditions.
- carry_seq_lower_bound.
- + rewrite <-max_value_log_cap, <-Z.add_1_l.
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- apply Z.add_le_mono.
- - apply Z.div_floor; auto.
- eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ].
- replace (Z.succ 1) with (2 ^ 1) by ring.
- rewrite <-max_value_log_cap.
- ring_simplify. pose proof c_pos; omega.
- - apply carry_full_bounds; carry_length_conditions; carry_seq_lower_bound.
- Qed.
-
- Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat -> (i + j < length base)%nat -> (j <> 0)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain (i + j)) (carry_full (carry_full us))) i <= max_value i.
- Proof.
- induction j; intros; try omega.
- split; (destruct j; [ rewrite Nat.add_1_r; simpl
- | rewrite <-plus_n_Sm; simpl; rewrite carry_unaffected_low by carry_length_conditions; eapply IHj; eauto; omega ]).
- + apply nth_default_carry_bound_lower; carry_length_conditions.
- + apply nth_default_carry_bound_upper; carry_length_conditions.
- Qed.
-
- Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat -> (i < j < length base)%nat ->
- 0 <= nth_default 0 (carry_sequence (make_chain j) (carry_full (carry_full us))) i <= max_value i.
- Proof.
- intros.
- replace j with (i + (j - i))%nat by omega.
- eapply carry_full_2_bounds'; eauto; omega.
- Qed.
-
- Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- (0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0).
- Proof.
- induction i; try omega.
- intros ? length_eq ?; simpl.
- destruct i.
- + unfold carry.
- break_if;
- [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
- simpl.
- autorewrite with push_nth_default natsimplify.
- apply pow2_mod_log_cap_bounds_lower.
- + rewrite carry_unaffected_low by carry_length_conditions.
- assert (0 < S i < length base)%nat by omega.
- intuition.
- Qed.
-
- Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us ->
- (length us = length base)%nat ->
- 0 <= nth_default 0 (carry_full (carry_full us)) i.
- Proof.
- intros; destruct i.
- + apply carry_full_2_bounds_0; auto.
- + apply carry_full_bounds; try solve [carry_length_conditions].
- intro j; destruct j.
- - apply carry_full_bounds_0; auto.
- - apply carry_full_bounds; carry_length_conditions.
- Qed.
-
- Local Hint Resolve carry_full_length.
-
- Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat -> (0 < i < length base)%nat ->
- (nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0 <= max_value 0 - c)
- \/ carry_done (carry_sequence (make_chain i) (carry_full (carry_full us))).
- Proof.
- induction i; try omega.
- intros ? length_eq ?; simpl.
- destruct i.
- + destruct (Z_le_dec (nth_default 0 (carry_full (carry_full us)) 0) (max_value 0)).
- - right.
- apply carry_carry_done_done; try solve [carry_length_conditions].
- apply carry_done_bounds; try solve [carry_length_conditions].
- intros.
- simpl.
- split; [ auto using carry_full_2_bounds_lower | ].
- destruct i; rewrite <-max_value_log_cap, Z.lt_succ_r; auto.
- apply carry_full_bounds; auto using carry_full_bounds_lower.
- - left; unfold carry.
- break_if;
- [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
- autorewrite with push_nth_default natsimplify.
- simpl.
- remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x.
- apply Z.le_trans with (m := (max_value 0 + c) - (1 + max_value 0)); try omega.
- replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring.
- rewrite Z.pow2_mod_spec by eauto.
- cbv [make_chain carry_sequence fold_right].
- rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega).
- rewrite <-max_value_log_cap, <-Z.add_1_l, Z.mod_small;
- [ apply Z.sub_le_mono_r; subst; apply carry_full_2_bounds_0; auto | ].
- split; try omega.
- pose proof carry_full_2_bounds_0.
- apply Z.le_lt_trans with (m := (max_value 0 + c) - (1 + max_value 0));
- [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto;
- ring_simplify | ]; pose proof c_pos; omega.
- + rewrite carry_unaffected_low by carry_length_conditions.
- assert (0 < S i < length base)%nat by omega.
- intuition; right.
- apply carry_carry_done_done; try solve [carry_length_conditions].
- assumption.
- Qed.
-
-
- (* END proofs about second carry loop *)
-
- (* BEGIN proofs about third carry loop *)
-
- Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us ->
- (length us = length base)%nat ->(i < length base)%nat ->
- 0 <= nth_default 0 (carry_full (carry_full (carry_full us))) i <= max_value i.
- Proof.
- intros.
- destruct i; [ | apply carry_full_bounds; carry_length_conditions;
- carry_seq_lower_bound ].
- unfold carry_full at 1 4, full_carry_chain.
- case_eq limb_widths; [intros; pose proof limb_widths_nonnil; congruence | ].
- simpl.
- intros ? ? limb_widths_eq.
- replace (length l) with (pred (length limb_widths)) by (rewrite limb_widths_eq; auto).
- rewrite <- base_length.
- unfold carry, carry_and_reduce; break_if; try omega; intros.
- add_set_nth. pose proof c_pos.
- split.
- + zero_bounds.
- - eapply carry_full_2_bounds_same; eauto; omega.
- - eapply carry_carry_full_2_bounds_0_lower; eauto; omega.
- + pose proof (carry_carry_full_2_bounds_0_upper us (pred (length base))).
- assert (0 < pred (length base) < length base)%nat by omega.
- intuition.
- - replace (max_value 0) with (c + (max_value 0 - c)) by ring.
- apply Z.add_le_mono; try assumption.
- etransitivity; [ | replace c with (c * 1) by ring; reflexivity ].
- apply Z.mul_le_mono_pos_l; try omega.
- rewrite Z.shiftr_div_pow2 by eauto.
- apply Z.div_le_upper_bound; auto.
- ring_simplify.
- apply carry_full_2_bounds_same; auto.
- - match goal with H0 : (pred (length base) < length base)%nat,
- H : carry_done _ |- _ =>
- destruct (H (pred (length base)) H0) as [Hcd1 Hcd2]; rewrite Hcd2 by omega end.
- ring_simplify.
- apply shiftr_eq_0_max_value; auto.
- assert (0 < length base)%nat as zero_lt_length by omega.
- match goal with H : carry_done _ |- _ =>
- destruct (H 0%nat zero_lt_length) end.
- assumption.
- Qed.
-
- Lemma carry_full_3_done : forall us, pre_carry_bounds us ->
- (length us = length base)%nat ->
- carry_done (carry_full (carry_full (carry_full us))).
- Proof.
- intros.
- apply carry_done_bounds; [ carry_length_conditions | intros ].
- destruct (lt_dec i (length base)).
- + rewrite <-max_value_log_cap, Z.lt_succ_r.
- auto using carry_full_3_bounds.
- + rewrite nth_default_out_of_bounds; carry_length_conditions.
- Qed.
-
- (* END proofs about third carry loop *)
-
- Lemma isFull'_false : forall us n, isFull' us false n = false.
- Proof.
- unfold isFull'; induction n; intros; rewrite Bool.andb_false_r; auto.
- Qed.
-
- Lemma isFull'_last : forall us b j, (j <> 0)%nat -> isFull' us b j = true ->
- max_value j = nth_default 0 us j.
- Proof.
- induction j; simpl; intros; try omega.
- match goal with
- | [H : isFull' _ ((?comp ?a ?b) && _) _ = true |- _ ] =>
- case_eq (comp a b); rewrite ?Z.eqb_eq; intro comp_eq; try assumption;
- rewrite comp_eq, Bool.andb_false_l, isFull'_false in H; congruence
- end.
- Qed.
-
- Lemma isFull'_lower_bound_0 : forall j us b, isFull' us b j = true ->
- max_value 0 - c < nth_default 0 us 0.
- Proof.
- induction j; intros.
- + match goal with H : isFull' _ _ 0 = _ |- _ => cbv [isFull'] in H;
- apply Bool.andb_true_iff in H; destruct H end.
- apply Z.ltb_lt; assumption.
- + eauto.
- Qed.
-
- Lemma isFull'_true_full : forall us i j b, (i <> 0)%nat -> (i <= j)%nat -> isFull' us b j = true ->
- max_value i = nth_default 0 us i.
- Proof.
- induction j; intros; try omega.
- assert (i = S j \/ i <= j)%nat as cases by omega.
- destruct cases.
- + subst. eapply isFull'_last; eauto.
- + eapply IHj; eauto.
- Qed.
-
- Lemma max_ones_nonneg : 0 <= max_ones.
- Proof.
- unfold max_ones.
- apply Z.ones_nonneg.
- clear.
- pose proof limb_widths_nonneg.
- induction limb_widths as [|?? IHl].
- cbv; congruence.
- simpl.
- apply Z.max_le_iff.
- right.
- apply IHl; eauto using in_cons.
- Qed.
-
- Lemma land_max_ones_noop : forall x i, 0 <= x < 2 ^ log_cap i -> Z.land max_ones x = x.
- Proof.
- unfold max_ones.
- intros ? ? x_range.
- rewrite Z.land_comm.
- rewrite Z.land_ones by apply Z.le_fold_right_max_initial.
- apply Z.mod_small.
- split; try omega.
- eapply Z.lt_le_trans; try eapply x_range.
- apply Z.pow_le_mono_r; try omega.
- destruct (lt_dec i (length limb_widths)).
- + apply Z.le_fold_right_max.
- - apply limb_widths_nonneg.
- - rewrite nth_default_eq.
- auto using nth_In.
- + rewrite nth_default_out_of_bounds by omega.
- apply Z.le_fold_right_max_initial.
- Qed.
-
- Lemma full_isFull'_true : forall j us, (length us = length base) ->
- ( max_value 0 - c < nth_default 0 us 0
- /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_value i)) ->
- isFull' us true j = true.
- Proof.
- induction j; intros.
- + cbv [isFull']; apply Bool.andb_true_iff.
- rewrite Z.ltb_lt; intuition.
- + intuition.
- simpl.
- match goal with H : forall j, _ -> ?b j = ?a j |- appcontext[?a ?i =? ?b ?i] =>
- replace (a i =? b i) with true by (symmetry; apply Z.eqb_eq; symmetry; apply H; omega) end.
- apply IHj; auto; intuition.
- Qed.
-
- Lemma isFull'_true_iff : forall j us, (length us = length base) -> (isFull' us true j = true <->
- max_value 0 - c < nth_default 0 us 0
- /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_value i)).
- Proof.
- intros; split; intros; auto using full_isFull'_true.
- split; eauto using isFull'_lower_bound_0.
- intros.
- symmetry; eapply isFull'_true_full; [ omega | | eauto].
- omega.
- Qed.
-
- Lemma isFull'_true_step : forall us j, isFull' us true (S j) = true ->
- isFull' us true j = true.
- Proof.
- simpl; intros ? ? succ_true.
- destruct (max_value (S j) =? nth_default 0 us (S j)); auto.
- rewrite isFull'_false in succ_true.
- congruence.
- Qed.
-
- Opaque isFull' max_ones.
-
- Lemma carry_full_3_length : forall us, (length us = length base) ->
- length (carry_full (carry_full (carry_full us))) = length us.
- Proof.
- intros.
- repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); auto.
- Qed.
- Local Hint Resolve carry_full_3_length.
-
- Lemma modulus_digits'_length : forall i, length (modulus_digits' i) = S i.
- Proof.
- induction i; intros; [ cbv; congruence | ].
- unfold modulus_digits'; fold modulus_digits'.
- rewrite app_length, IHi.
- cbv [length]; omega.
- Qed.
-
- Lemma modulus_digits_length : length modulus_digits = length base.
- Proof.
- unfold modulus_digits.
- rewrite modulus_digits'_length; omega.
- Qed.
-
- (* Helps with solving goals of the form [x = y -> min x y = x] or [x = y -> min x y = y] *)
- Local Hint Resolve Nat.eq_le_incl eq_le_incl_rev.
-
- Hint Rewrite app_length cons_length map2_length modulus_digits_length length_zeros
- map_length combine_length firstn_length map_app : lengths.
- Ltac simpl_lengths := autorewrite with lengths;
- repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto);
- auto using Min.min_l; auto using Min.min_r.
-
- Lemma freeze_length : forall us, (length us = length base) ->
- length (freeze us) = length us.
- Proof.
- unfold freeze; intros; simpl_lengths.
- rewrite Min.min_l by omega. congruence.
- Qed.
-
- Lemma decode_firstn_succ : forall n us, (length us = length base) ->
- (n < length base)%nat ->
- BaseSystem.decode' (firstn (S n) base) (firstn (S n) us) =
- BaseSystem.decode' (firstn n base) (firstn n us) +
- nth_default 0 base n * nth_default 0 us n.
- Proof.
- intros.
- rewrite !firstn_succ with (d := 0) by omega.
- rewrite base_app, firstn_app.
- autorewrite with lengths; rewrite !Min.min_l by omega.
- rewrite Nat.sub_diag, firstn_firstn, firstn0, app_nil_r by omega.
- rewrite skipn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega).
- rewrite decode'_cons, decode_nil, Z.add_0_r.
- reflexivity.
- Qed.
-
- Local Hint Resolve sum_firstn_limb_widths_nonneg.
- Local Hint Resolve limb_widths_nonneg.
- Local Hint Resolve nth_error_value_In.
-
- Lemma decode_carry_done_upper_bound' : forall n us, carry_done us ->
- (length us = length base) ->
- BaseSystem.decode (firstn n base) (firstn n us) < 2 ^ (sum_firstn limb_widths n).
- Proof.
- induction n; intros; [ cbv; congruence | ].
- destruct (lt_dec n (length base)) as [ n_lt_length | ? ].
- + rewrite decode_firstn_succ; auto.
- rewrite base_length in n_lt_length.
- destruct (nth_error_length_exists_value _ _ n_lt_length).
- erewrite sum_firstn_succ; eauto.
- rewrite Z.pow_add_r; eauto.
- rewrite nth_default_base by
- (try rewrite base_from_limb_widths_length; omega || eauto).
- rewrite Z.lt_add_lt_sub_r.
- eapply Z.lt_le_trans; eauto.
- rewrite Z.mul_comm at 1.
- rewrite <-Z.mul_sub_distr_l.
- rewrite <-Z.mul_1_r at 1.
- apply Z.mul_le_mono_nonneg_l; [ apply Z.pow_nonneg; omega | ].
- replace 1 with (Z.succ 0) by reflexivity.
- rewrite Z.le_succ_l, Z.lt_0_sub.
- match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end.
- replace x with (log_cap n); try intuition.
- apply nth_error_value_eq_nth_default; auto.
- + repeat erewrite firstn_all_strong by omega.
- rewrite sum_firstn_all_succ by (rewrite <-base_length; omega).
- eapply Z.le_lt_trans; [ | eauto].
- repeat erewrite firstn_all_strong by omega.
- omega.
- Qed.
-
- Lemma decode_carry_done_upper_bound : forall us, carry_done us ->
- (length us = length base) -> BaseSystem.decode base us < 2 ^ k.
- Proof.
- unfold k; intros.
- rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto).
- rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto).
- auto using decode_carry_done_upper_bound'.
- Qed.
-
- Lemma decode_carry_done_lower_bound' : forall n us, carry_done us ->
- (length us = length base) ->
- 0 <= BaseSystem.decode (firstn n base) (firstn n us).
- Proof.
- induction n; intros; [ cbv; congruence | ].
- destruct (lt_dec n (length base)) as [ n_lt_length | ? ].
- + rewrite decode_firstn_succ by auto.
- zero_bounds.
- - rewrite nth_default_base by (omega || eauto).
- apply Z.pow_nonneg; omega.
- - match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end.
- intuition.
- + eapply Z.le_trans; [ apply IHn; eauto | ].
- repeat rewrite firstn_all_strong by omega.
- omega.
- Qed.
-
- Lemma decode_carry_done_lower_bound : forall us, carry_done us ->
- (length us = length base) -> 0 <= BaseSystem.decode base us.
- Proof.
- intros.
- rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto).
- rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto).
- auto using decode_carry_done_lower_bound'.
- Qed.
-
-
- Lemma nth_default_modulus_digits' : forall d j i,
- nth_default d (modulus_digits' j) i =
- if lt_dec i (S j)
- then (if (eq_nat_dec i 0) then max_value i - c + 1 else max_value i)
- else d.
- Proof.
- induction j; intros; (break_if; [| apply nth_default_out_of_bounds; rewrite modulus_digits'_length; omega]).
- + replace i with 0%nat by omega.
- apply nth_default_cons.
- + simpl. rewrite nth_default_app.
- rewrite modulus_digits'_length.
- break_if.
- - rewrite IHj; break_if; try omega; reflexivity.
- - replace i with (S j) by omega.
- rewrite Nat.sub_diag, nth_default_cons.
- reflexivity.
- Qed.
-
- Lemma nth_default_modulus_digits : forall d i,
- nth_default d modulus_digits i =
- if lt_dec i (length base)
- then (if (eq_nat_dec i 0) then max_value i - c + 1 else max_value i)
- else d.
- Proof.
- unfold modulus_digits; intros.
- rewrite nth_default_modulus_digits'.
- replace (S (length base - 1)) with (length base) by omega.
- reflexivity.
- Qed.
-
- Lemma carry_done_modulus_digits : carry_done modulus_digits.
- Proof.
- apply carry_done_bounds; [apply modulus_digits_length | ].
- intros.
- rewrite nth_default_modulus_digits.
- break_if; [ | split; auto; omega].
- break_if; subst; split; auto; try rewrite <- max_value_log_cap; pose proof c_pos; omega.
- Qed.
- Local Hint Resolve carry_done_modulus_digits.
-
- Lemma decode_mod : forall us vs x, (length us = length base) -> (length vs = length base) ->
- decode us = x ->
- BaseSystem.decode base us mod modulus = BaseSystem.decode base vs mod modulus ->
- decode vs = x.
- Proof.
- unfold decode; intros until 2; intros decode_us_x BSdecode_eq.
- rewrite ZToField_mod in decode_us_x |- *.
- rewrite <-BSdecode_eq.
- assumption.
- Qed.
-
- Lemma decode_map2_sub : forall us vs,
- (length us = length vs) ->
- BaseSystem.decode' base (map2 (fun x y => x - y) us vs)
- = BaseSystem.decode' base us - BaseSystem.decode' base vs.
- Proof.
- induction us using rev_ind; induction vs using rev_ind;
- intros; autorewrite with lengths in *; simpl_list_lengths;
- rewrite ?decode_nil; try omega.
- rewrite map2_app by omega.
- rewrite map2_cons, map2_nil_l.
- rewrite !set_higher.
- autorewrite with lengths.
- rewrite Min.min_l by omega.
- rewrite IHus by omega.
- replace (length vs) with (length us) by omega.
- ring.
- Qed.
-
- Lemma decode_modulus_digits' : forall i, (i <= length base)%nat ->
- BaseSystem.decode' base (modulus_digits' i) = 2 ^ (sum_firstn limb_widths (S i)) - c.
- Proof.
- induction i; intros; unfold modulus_digits'; fold modulus_digits'.
- + let base := constr:(base) in
- case_eq base;
- [ intro base_eq; rewrite base_eq, (@nil_length0 Z) in lt_1_length_base; omega | ].
- intros z ? base_eq.
- rewrite decode'_cons, decode_nil, Z.add_0_r.
- replace z with (nth_default 0 base 0) by (rewrite base_eq; auto).
- rewrite nth_default_base by (omega || eauto).
- replace (max_value 0 - c + 1) with (Z.succ (max_value 0) - c) by ring.
- rewrite max_value_log_cap.
- rewrite sum_firstn_succ with (x := log_cap 0) by (
- apply nth_error_Some_nth_default; rewrite <-base_length; omega).
- rewrite Z.pow_add_r by eauto.
- cbv [sum_firstn fold_right firstn].
- ring.
- + assert (S i < length base \/ S i = length base)%nat as cases by omega.
- destruct cases.
- - rewrite sum_firstn_succ with (x := log_cap (S i)) by
- (apply nth_error_Some_nth_default;
- rewrite <-base_length; omega).
- rewrite Z.pow_add_r, <-max_value_log_cap, set_higher by eauto.
- rewrite IHi, modulus_digits'_length by omega.
- rewrite nth_default_base by (omega || eauto).
- ring.
- - rewrite sum_firstn_all_succ by (rewrite <-base_length; omega).
- rewrite decode'_splice, modulus_digits'_length, firstn_all by auto.
- rewrite skipn_all, decode_base_nil, Z.add_0_r by omega.
- apply IHi.
- omega.
- Qed.
-
- Lemma decode_modulus_digits : BaseSystem.decode' base modulus_digits = modulus.
- Proof.
- unfold modulus_digits; rewrite decode_modulus_digits' by omega.
- replace (S (length base - 1)) with (length base) by omega.
- rewrite base_length.
- fold k. unfold c.
- ring.
- Qed.
-
- Lemma map_land_max_ones_modulus_digits' : forall i,
- map (Z.land max_ones) (modulus_digits' i) = (modulus_digits' i).
- Proof.
- induction i; intros.
- + cbv [modulus_digits' map].
- f_equal.
- apply land_max_ones_noop with (i := 0%nat).
- rewrite <-max_value_log_cap.
- pose proof c_pos; omega.
- + unfold modulus_digits'; fold modulus_digits'.
- rewrite map_app.
- f_equal; [ apply IHi; omega | ].
- cbv [map]; f_equal.
- apply land_max_ones_noop with (i := S i).
- rewrite <-max_value_log_cap.
- split; auto; omega.
- Qed.
-
- Lemma map_land_max_ones_modulus_digits : map (Z.land max_ones) modulus_digits = modulus_digits.
- Proof.
- apply map_land_max_ones_modulus_digits'.
- Qed.
-
- Opaque modulus_digits.
-
- Lemma map_land_zero : forall ls, map (Z.land 0) ls = BaseSystem.zeros (length ls).
- Proof.
- induction ls; boring.
- Qed.
-
- Lemma carry_full_preserves_Fdecode : forall us x, (length us = length base) ->
- decode us = x -> decode (carry_full us) = x.
- Proof.
- intros.
- apply carry_full_preserves_rep; auto.
- unfold rep; auto.
- Qed.
-
- Lemma freeze_preserves_rep : forall us x, rep us x -> rep (freeze us) x.
- Proof.
- unfold rep; intros.
- intuition; rewrite ?freeze_length; auto.
- unfold freeze, and_term.
- break_if.
- + apply decode_mod with (us := carry_full (carry_full (carry_full us))).
- - rewrite carry_full_3_length; auto.
- - autorewrite with lengths.
- apply Min.min_r.
- simpl_lengths; omega.
- - repeat apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto.
- unfold rep; intuition.
- - rewrite decode_map2_sub by (simpl_lengths; omega).
- rewrite map_land_max_ones_modulus_digits.
- rewrite decode_modulus_digits.
- destruct (Z_eq_dec modulus 0); [ subst; rewrite !Zmod_0_r; reflexivity | ].
- rewrite <-Z.add_opp_r.
- replace (-modulus) with (-1 * modulus) by ring.
- symmetry; auto using Z.mod_add.
- + eapply decode_mod; eauto.
- simpl_lengths.
- rewrite map_land_zero, decode_map2_sub, zeros_rep, Z.sub_0_r by simpl_lengths.
- match goal with H : decode ?us = ?x |- _ => erewrite Fdecode_decode_mod; eauto;
- do 3 apply carry_full_preserves_Fdecode in H; simpl_lengths
- end.
- erewrite Fdecode_decode_mod; eauto; simpl_lengths.
- Qed.
- Hint Resolve freeze_preserves_rep.
-
- Lemma isFull_true_iff : forall us, (length us = length base) -> (isFull us = true <->
- max_value 0 - c < nth_default 0 us 0
- /\ (forall i, (0 < i <= length base - 1)%nat -> nth_default 0 us i = max_value i)).
- Proof.
- unfold isFull; intros; auto using isFull'_true_iff.
- Qed.
-
- Definition minimal_rep us := BaseSystem.decode base us = (BaseSystem.decode base us) mod modulus.
-
- Fixpoint compare' us vs i :=
- match i with
- | O => Eq
- | S i' => if Z_eq_dec (nth_default 0 us i') (nth_default 0 vs i')
- then compare' us vs i'
- else Z.compare (nth_default 0 us i') (nth_default 0 vs i')
- end.
-
- (* Lexicographically compare two vectors of equal length, starting from the END of the list
- (in our context, this is the most significant end). NOT constant time. *)
- Definition compare us vs := compare' us vs (length us).
-
- Lemma compare'_Eq : forall us vs i, (length us = length vs) ->
- compare' us vs i = Eq -> firstn i us = firstn i vs.
- Proof.
- induction i; intros; [ cbv; congruence | ].
- destruct (lt_dec i (length us)).
- + repeat rewrite firstn_succ with (d := 0) by omega.
- match goal with H : compare' _ _ (S _) = Eq |- _ =>
- inversion H end.
- break_if; f_equal; auto.
- - f_equal; auto.
- - rewrite Z.compare_eq_iff in *. congruence.
- - rewrite Z.compare_eq_iff in *. congruence.
- + rewrite !firstn_all_strong in IHi by omega.
- match goal with H : compare' _ _ (S _) = Eq |- _ =>
- inversion H end.
- rewrite (nth_default_out_of_bounds i us) in * by omega.
- rewrite (nth_default_out_of_bounds i vs) in * by omega.
- break_if; try congruence.
- f_equal; auto.
- Qed.
-
- Lemma compare_Eq : forall us vs, (length us = length vs) ->
- compare us vs = Eq -> us = vs.
- Proof.
- intros.
- erewrite <-(firstn_all _ us); eauto.
- erewrite <-(firstn_all _ vs); eauto.
- apply compare'_Eq; auto.
- Qed.
-
- Lemma decode_lt_next_digit : forall us n, (length us = length base) ->
- (n < length base)%nat -> (n < length us)%nat ->
- carry_done us ->
- BaseSystem.decode' (firstn n base) (firstn n us) <
- (nth_default 0 base n).
- Proof.
- induction n; intros ? ? ? bounded.
- + cbv [firstn].
- rewrite decode_base_nil.
- apply Z.gt_lt; auto using nth_default_base_positive.
- + rewrite decode_firstn_succ by (auto || omega).
- rewrite nth_default_base_succ by (eauto || omega).
- eapply Z.lt_le_trans.
- - apply Z.add_lt_mono_r.
- apply IHn; auto; omega.
- - rewrite <-(Z.mul_1_r (nth_default 0 base n)) at 1.
- rewrite <-Z.mul_add_distr_l, Z.mul_comm.
- apply Z.mul_le_mono_pos_r.
- * apply Z.gt_lt. apply nth_default_base_positive; omega.
- * rewrite Z.add_1_l.
- apply Z.le_succ_l.
- rewrite carry_done_bounds in bounded by assumption.
- apply bounded.
- Qed.
-
- Lemma highest_digit_determines : forall us vs n x, (x < 0) ->
- (length us = length base) ->
- (length vs = length base) ->
- (n < length us)%nat -> carry_done us ->
- (n < length vs)%nat -> carry_done vs ->
- BaseSystem.decode (firstn n base) (firstn n us) +
- nth_default 0 base n * x -
- BaseSystem.decode (firstn n base) (firstn n vs) < 0.
- Proof.
- intros.
- eapply Z.le_lt_trans.
- + apply Z.le_sub_nonneg.
- apply decode_carry_done_lower_bound'; auto.
- + eapply Z.le_lt_trans.
- - eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ].
- apply Z.mul_le_mono_nonneg_l; try omega.
- rewrite nth_default_base by (omega || eauto).
- zero_bounds.
- - ring_simplify.
- apply Z.lt_sub_0.
- apply decode_lt_next_digit; auto.
- omega.
- Qed.
-
- Lemma Z_compare_decode_step_eq : forall n us vs,
- (length us = length base) ->
- (length us = length vs) ->
- (S n <= length base)%nat ->
- (nth_default 0 us n = nth_default 0 vs n) ->
- (BaseSystem.decode (firstn (S n) base) us ?=
- BaseSystem.decode (firstn (S n) base) vs) =
- (BaseSystem.decode (firstn n base) us ?=
- BaseSystem.decode (firstn n base) vs).
- Proof.
- intros until 3; intro nth_default_eq.
- destruct (lt_dec n (length us)); try omega.
- rewrite firstn_succ with (d := 0), !base_app by omega.
- autorewrite with lengths; rewrite Min.min_l by omega.
- do 2 (rewrite skipn_nth_default with (d := 0) by omega;
- rewrite decode'_cons, decode_base_nil, Z.add_0_r).
- rewrite Z.compare_sub, nth_default_eq, Z.add_add_simpl_r_r.
- rewrite BaseSystem.decode'_truncate with (us := us).
- rewrite BaseSystem.decode'_truncate with (us := vs).
- rewrite firstn_length, Min.min_l, <-Z.compare_sub by omega.
- reflexivity.
- Qed.
-
- Lemma Z_compare_decode_step_lt : forall n us vs,
- (length us = length base) ->
- (length us = length vs) ->
- (S n <= length base)%nat ->
- carry_done us -> carry_done vs ->
- (nth_default 0 us n < nth_default 0 vs n) ->
- (BaseSystem.decode (firstn (S n) base) us ?=
- BaseSystem.decode (firstn (S n) base) vs) = Lt.
- Proof.
- intros until 5; intro nth_default_lt.
- destruct (lt_dec n (length us)).
- + rewrite firstn_succ with (d := 0) by omega.
- rewrite !base_app.
- autorewrite with lengths; rewrite Min.min_l by omega.
- do 2 (rewrite skipn_nth_default with (d := 0) by omega;
- rewrite decode'_cons, decode_base_nil, Z.add_0_r).
- rewrite Z.compare_sub.
- apply Z.compare_lt_iff.
- ring_simplify.
- rewrite <-Z.add_sub_assoc.
- rewrite <-Z.mul_sub_distr_l.
- apply highest_digit_determines; auto; omega.
- + rewrite !nth_default_out_of_bounds in nth_default_lt; omega.
- Qed.
-
- Lemma Z_compare_decode_step_neq : forall n us vs,
- (length us = length base) -> (length us = length vs) ->
- (S n <= length base)%nat ->
- carry_done us -> carry_done vs ->
- (nth_default 0 us n <> nth_default 0 vs n) ->
- (BaseSystem.decode (firstn (S n) base) us ?=
- BaseSystem.decode (firstn (S n) base) vs) =
- (nth_default 0 us n ?= nth_default 0 vs n).
- Proof.
- intros.
- destruct (Z_dec (nth_default 0 us n) (nth_default 0 vs n)) as [[?|Hgt]|?]; try congruence.
- + etransitivity; try apply Z_compare_decode_step_lt; auto.
- + match goal with |- (?a ?= ?b) = (?c ?= ?d) =>
- rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end.
- apply CompOpp_inj; rewrite !CompOpp_involutive.
- apply Z.gt_lt_iff in Hgt.
- etransitivity; try apply Z_compare_decode_step_lt; auto; omega.
- Qed.
-
- Lemma decode_compare' : forall n us vs,
- (length us = length base) ->
- (length us = length vs) ->
- (n <= length base)%nat ->
- carry_done us -> carry_done vs ->
- (BaseSystem.decode (firstn n base) us ?= BaseSystem.decode (firstn n base) vs)
- = compare' us vs n.
- Proof.
- induction n; intros.
- + cbv [firstn compare']; rewrite !decode_base_nil; auto.
- + unfold compare'; fold compare'.
- break_if.
- - rewrite Z_compare_decode_step_eq by (auto || omega).
- apply IHn; auto; omega.
- - rewrite Z_compare_decode_step_neq; (auto || omega).
- Qed.
-
- Lemma decode_compare : forall us vs,
- (length us = length base) -> carry_done us ->
- (length vs = length base) -> carry_done vs ->
- Z.compare (BaseSystem.decode base us) (BaseSystem.decode base vs) = compare us vs.
- Proof.
- unfold compare; intros.
- erewrite <-(firstn_all _ base).
- + apply decode_compare'; auto; omega.
- + assumption.
- Qed.
+ Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg.
+ Local Hint Resolve log_cap_nonneg.
- Lemma compare'_succ : forall us j vs, compare' us vs (S j) =
- if Z.eq_dec (nth_default 0 us j) (nth_default 0 vs j)
- then compare' us vs j
- else nth_default 0 us j ?= nth_default 0 vs j.
- Proof.
+ Lemma nth_default_carry_and_reduce_full : forall n i us,
+ nth_default 0 (carry_and_reduce i us) n
+ = if lt_dec n (length us)
+ then
+ (if eq_nat_dec n (i mod length limb_widths)
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n) +
+ if PeanoNat.Nat.eq_dec n (S (i mod length limb_widths) mod length limb_widths)
+ then c * nth_default 0 us (i mod length limb_widths) >> log_cap (i mod length limb_widths)
+ else 0
+ else 0.
+ Proof.
+ cbv [carry_and_reduce]; intros.
+ autorewrite with push_nth_default.
reflexivity.
Qed.
-
- Lemma compare'_firstn_r_small_index : forall us j vs, (j <= length vs)%nat ->
- compare' us vs j = compare' us (firstn j vs) j.
- Proof.
- induction j; intros; auto.
- rewrite !compare'_succ by omega.
- rewrite firstn_succ with (d := 0) by omega.
- rewrite nth_default_app.
- simpl_lengths.
- rewrite Min.min_l by omega.
- destruct (lt_dec j j); try omega.
- rewrite Nat.sub_diag.
- rewrite nth_default_cons.
- break_if; try reflexivity.
- rewrite IHj with (vs := firstn j vs ++ nth_default 0 vs j :: nil) by
- (autorewrite with lengths; rewrite Min.min_l; omega).
- rewrite firstn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega).
- apply IHj; omega.
- Qed.
-
- Lemma compare'_firstn_r : forall us j vs,
- compare' us vs j = compare' us (firstn j vs) j.
- Proof.
- intros.
- destruct (le_dec j (length vs)).
- + auto using compare'_firstn_r_small_index.
- + f_equal. symmetry.
- apply firstn_all_strong.
- omega.
- Qed.
-
- Lemma compare'_not_Lt : forall us vs j, j <> 0%nat ->
- (forall i, (0 < i < j)%nat -> 0 <= nth_default 0 us i <= nth_default 0 vs i) ->
- compare' us vs j <> Lt ->
- nth_default 0 vs 0 <= nth_default 0 us 0 /\
- (forall i : nat, (0 < i < j)%nat -> nth_default 0 us i = nth_default 0 vs i).
- Proof.
- induction j; try congruence.
- rewrite compare'_succ.
- intros; destruct (eq_nat_dec j 0).
- + break_if; subst; split; intros; try omega.
- rewrite Z.compare_ge_iff in *; omega.
- + break_if.
- - split; intros; [ | destruct (eq_nat_dec i j); subst; auto ];
- apply IHj; auto; intros; try omega;
- match goal with H : forall i, _ -> 0 <= ?f i <= ?g i |- 0 <= ?f _ <= ?g _ =>
- apply H; omega end.
- - exfalso. rewrite Z.compare_ge_iff in *.
- match goal with H : forall i, ?P -> 0 <= ?f i <= ?g i |- _ =>
- specialize (H j) end; omega.
- Qed.
-
- Lemma isFull'_compare' : forall us j, j <> 0%nat -> (length us = length base) ->
- (j <= length base)%nat -> carry_done us ->
- (isFull' us true (j - 1) = true <-> compare' us modulus_digits j <> Lt).
- Proof.
- unfold compare; induction j; intros; try congruence.
- replace (S j - 1)%nat with j by omega.
- split; intros.
- + simpl.
- break_if; [destruct (eq_nat_dec j 0) | ].
- - subst. cbv; congruence.
- - apply IHj; auto; try omega.
- apply isFull'_true_step.
- replace (S (j - 1)) with j by omega; auto.
- - rewrite nth_default_modulus_digits in *.
- repeat (break_if; try omega).
- * subst.
- match goal with H : isFull' _ _ _ = true |- _ =>
- apply isFull'_lower_bound_0 in H end.
- apply Z.compare_ge_iff.
- omega.
- * match goal with H : isFull' _ _ _ = true |- _ =>
- apply isFull'_true_iff in H; try assumption; destruct H as [? eq_max_value] end.
- specialize (eq_max_value j).
- omega.
- + apply isFull'_true_iff; try assumption.
- match goal with H : compare' _ _ _ <> Lt |- _ => apply compare'_not_Lt in H; [ destruct H as [Hdigit0 Hnonzero] | | ] end.
- - split; [ | intros i i_range; assert (0 < i < S j)%nat as i_range' by omega;
- specialize (Hnonzero i i_range')];
- rewrite nth_default_modulus_digits in *;
- repeat (break_if; try omega).
- - congruence.
- - intros.
- rewrite nth_default_modulus_digits.
- repeat (break_if; try omega).
- rewrite <-Z.lt_succ_r with (m := max_value i).
- rewrite max_value_log_cap; apply carry_done_bounds; assumption.
- Qed.
-
- Lemma isFull_compare : forall us, (length us = length base) -> carry_done us ->
- (isFull us = true <-> compare us modulus_digits <> Lt).
- Proof.
- unfold compare, isFull; intros ? lengths_eq. intros.
- rewrite lengths_eq.
- apply isFull'_compare'; try omega.
- assumption.
- Qed.
-
- Lemma isFull_decode : forall us, (length us = length base) -> carry_done us ->
- (isFull us = true <->
- (BaseSystem.decode base us ?= BaseSystem.decode base modulus_digits <> Lt)).
- Proof.
- intros.
- rewrite decode_compare; autorewrite with lengths; auto.
- apply isFull_compare; auto.
- Qed.
-
- Lemma isFull_false_upper_bound : forall us, (length us = length base) ->
- carry_done us -> isFull us = false ->
- BaseSystem.decode base us < modulus.
- Proof.
- intros.
- destruct (Z_lt_dec (BaseSystem.decode base us) modulus) as [? | nlt_modulus];
- [assumption | exfalso].
- apply Z.compare_nlt_iff in nlt_modulus.
- rewrite <-decode_modulus_digits in nlt_modulus at 2.
- apply isFull_decode in nlt_modulus; try assumption; congruence.
- Qed.
-
- Lemma isFull_true_lower_bound : forall us, (length us = length base) ->
- carry_done us -> isFull us = true ->
- modulus <= BaseSystem.decode base us.
- Proof.
- intros.
- rewrite <-decode_modulus_digits at 1.
- apply Z.compare_ge_iff.
- apply isFull_decode; auto.
- Qed.
-
- Lemma freeze_in_bounds : forall us,
- pre_carry_bounds us -> (length us = length base) ->
- carry_done (freeze us).
- Proof.
- unfold freeze, and_term; intros ? PCB lengths_eq.
- rewrite carry_done_bounds by simpl_lengths; intro i.
- rewrite nth_default_map2 with (d1 := 0) (d2 := 0).
- simpl_lengths.
- break_if; [ | split; (omega || auto)].
+ Hint Rewrite @nth_default_carry_and_reduce_full : push_nth_default.
+
+ Lemma nth_default_carry_full : forall n i us,
+ length us = length limb_widths ->
+ nth_default 0 (carry i us) n
+ = if lt_dec n (length us)
+ then
+ if eq_nat_dec i (pred (length limb_widths))
+ then (if eq_nat_dec n i
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n) +
+ if eq_nat_dec n 0
+ then c * (nth_default 0 us i >> log_cap i)
+ else 0
+ else if eq_nat_dec n i
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n +
+ if eq_nat_dec n (S i)
+ then nth_default 0 us i >> log_cap i
+ else 0
+ else 0.
+ Proof.
+ intros.
+ cbv [carry].
break_if.
- + rewrite map_land_max_ones_modulus_digits.
- apply isFull_true_iff in Heqb; [ | simpl_lengths].
- destruct Heqb as [first_digit high_digits].
- destruct (eq_nat_dec i 0).
- - subst.
- clear high_digits.
- rewrite nth_default_modulus_digits.
- repeat (break_if; try omega).
- pose proof (carry_full_3_done us PCB lengths_eq) as cf3_done.
- rewrite carry_done_bounds in cf3_done by simpl_lengths.
- specialize (cf3_done 0%nat).
- pose proof c_pos; omega.
- - assert ((0 < i <= length base - 1)%nat) as i_range by
- (simpl_lengths; apply lt_min_l in l; omega).
- specialize (high_digits i i_range).
- clear first_digit i_range.
- rewrite high_digits.
- rewrite <-max_value_log_cap.
- rewrite nth_default_modulus_digits.
- repeat (break_if; try omega).
- * rewrite Z.sub_diag.
- split; try omega.
- apply Z.lt_succ_r; auto.
- * rewrite Z.lt_succ_r, Z.sub_0_r. split; (omega || auto).
- + rewrite map_land_zero, nth_default_zeros.
- rewrite Z.sub_0_r.
- apply carry_done_bounds; [ simpl_lengths | ].
- auto using carry_full_3_done.
- Qed.
- Local Hint Resolve freeze_in_bounds.
-
- Local Hint Resolve carry_full_3_done.
-
- Lemma freeze_minimal_rep : forall us, pre_carry_bounds us -> (length us = length base) ->
- minimal_rep (freeze us).
- Proof.
- unfold minimal_rep, freeze, and_term.
- intros.
- symmetry. apply Z.mod_small.
- split; break_if; rewrite decode_map2_sub; simpl_lengths.
- + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits.
- apply Z.le_0_sub.
- apply isFull_true_lower_bound; simpl_lengths.
- + rewrite map_land_zero, zeros_rep, Z.sub_0_r.
- apply decode_carry_done_lower_bound; simpl_lengths.
- + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits.
- rewrite Z.lt_sub_lt_add_r.
- apply Z.lt_le_trans with (m := 2 * modulus); try omega.
- eapply Z.lt_le_trans; [ | apply two_pow_k_le_2modulus ].
- apply decode_carry_done_upper_bound; simpl_lengths.
- + rewrite map_land_zero, zeros_rep, Z.sub_0_r.
- apply isFull_false_upper_bound; simpl_lengths.
+ + subst i. autorewrite with push_nth_default natsimplify.
+ destruct (eq_nat_dec (length limb_widths) (length us)); congruence.
+ + autorewrite with push_nth_default; reflexivity.
+ Qed.
+ Hint Rewrite @nth_default_carry_full : push_nth_default.
+
+ Lemma nth_default_carry_sequence_make_chain_full : forall i n us,
+ length us = length limb_widths ->
+ (i <= length limb_widths)%nat ->
+ nth_default 0 (carry_sequence (make_chain i) us) n
+ = if lt_dec n (length limb_widths)
+ then
+ if eq_nat_dec i 0
+ then nth_default 0 us n
+ else
+ if lt_dec i (length limb_widths)
+ then
+ if lt_dec n i
+ then
+ if eq_nat_dec n (pred i)
+ then Z.pow2_mod
+ (nth_default 0 (carry_sequence (make_chain (pred i)) us) n)
+ (log_cap n)
+ else nth_default 0 (carry_sequence (make_chain (pred i)) us) n
+ else nth_default 0 (carry_sequence (make_chain (pred i)) us) n +
+ (if eq_nat_dec n i
+ then (nth_default 0 (carry_sequence (make_chain (pred i)) us) (pred i))
+ >> log_cap (pred i)
+ else 0)
+ else
+ if lt_dec n (pred i)
+ then nth_default 0 (carry_sequence (make_chain (pred i)) us) n +
+ (if eq_nat_dec n 0
+ then c * (nth_default 0 (carry_sequence (make_chain (pred i)) us) (pred i))
+ >> log_cap (pred i)
+ else 0)
+ else Z.pow2_mod
+ (nth_default 0 (carry_sequence (make_chain (pred i)) us) n)
+ (log_cap n)
+ else 0.
+ Proof.
+ induction i; intros; cbv [ModularBaseSystemList.carry_sequence].
+ + cbv [pred make_chain fold_right].
+ repeat break_if; subst; omega || reflexivity || auto using Z.add_0_r.
+ apply nth_default_out_of_bounds. omega.
+ + replace (make_chain (S i)) with (i :: make_chain i) by reflexivity.
+ rewrite fold_right_cons.
+ autorewrite with push_nth_default natsimplify;
+ rewrite ?Nat.pred_succ; fold (carry_sequence (make_chain i) us);
+ rewrite length_carry_sequence; auto.
+ repeat break_if; try omega; rewrite ?IHi by (omega || auto);
+ rewrite ?Z.add_0_r; try reflexivity.
+ Qed.
+
+ Lemma nth_default_carry_full_full : forall n us,
+ length us = length limb_widths ->
+ nth_default 0 (ModularBaseSystemList.carry_full us) n
+ = if lt_dec n (length limb_widths)
+ then
+ if eq_nat_dec n (pred (length limb_widths))
+ then Z.pow2_mod
+ (nth_default 0 (carry_sequence (make_chain (pred (length limb_widths))) us) n)
+ (log_cap n)
+ else nth_default 0 (carry_sequence (make_chain (pred (length limb_widths))) us) n +
+ (if eq_nat_dec n 0
+ then c * (nth_default 0 (carry_sequence (make_chain (pred (length limb_widths))) us) (pred (length limb_widths)))
+ >> log_cap (pred (length limb_widths))
+ else 0)
+ else 0.
+ Proof.
+ intros.
+ cbv [ModularBaseSystemList.carry_full full_carry_chain].
+ rewrite (nth_default_carry_sequence_make_chain_full (length limb_widths)) by omega.
+ repeat break_if; try omega; reflexivity.
+ Qed.
+ Hint Rewrite @nth_default_carry_full_full : push_nth_default.
+
+ Lemma nth_default_carry : forall i us,
+ length us = length limb_widths ->
+ (i < length us)%nat ->
+ nth_default 0 (ModularBaseSystemList.carry i us) i
+ = Z.pow2_mod (nth_default 0 us i) (log_cap i).
+ Proof.
+ intros; autorewrite with push_nth_default natsimplify; break_match; omega.
+ Qed.
+ Hint Rewrite @nth_default_carry using (omega || distr_length; omega) : push_nth_default.
+
+ Local Notation pred := Init.Nat.pred.
+ Local Notation "u '[' i ']' " := (nth_default 0 u i) (at level 30).
+ Local Notation "u '{{' i '}}' " := (carry_sequence (make_chain i) u) (at level 30).
+
+ Lemma bound_during_first_loop : forall i n us,
+ length us = length limb_widths ->
+ (i <= length limb_widths)%nat ->
+ (forall n, 0 <= nth_default 0 us n < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> log_cap (pred n))) ->
+ 0 <= us{{i}}[n] < if eq_nat_dec i 0 then us[n] + 1 else
+ if lt_dec i (length limb_widths)
+ then
+ if lt_dec n i
+ then 2 ^ (log_cap n)
+ else if eq_nat_dec n i
+ then 2 ^ B
+ else us[n] + 1
+ else
+ if eq_nat_dec n 0
+ then 2 * 2 ^ limb_widths [n]
+ else 2 ^ limb_widths [n].
+ Proof.
+ induction i; intros; cbv [ModularBaseSystemList.carry_sequence].
+ + break_if; try omega.
+ cbv [make_chain fold_right]. split; try omega. apply H1.
+ + replace (make_chain (S i)) with (i :: make_chain i) by reflexivity.
+ rewrite fold_right_cons.
+ autorewrite with push_nth_default natsimplify; rewrite ?Nat.pred_succ;
+ fold (carry_sequence (make_chain i) us); rewrite length_carry_sequence; auto.
+ repeat (break_if; try omega);
+ try solve [rewrite Z.pow2_mod_spec by auto; autorewrite with zsimplify; apply Z.mod_pos_bound; zero_bounds];
+ pose proof (IHi i us); pose proof (IHi n us); specialize_by assumption; specialize_by auto;
+ repeat break_if; try omega; pose proof c_pos; (split; try solve [zero_bounds]).
+ (* TODO (jadep) : clean up/automate these leftover cases. *)
+ - replace (2 * 2 ^ limb_widths [n]) with (2 ^ limb_widths [n] + 2 ^ limb_widths [n]) by ring.
+ apply Z.add_lt_le_mono; subst n. omega.
+ eapply Z.le_trans; eauto.
+ apply Z.mul_le_mono_nonneg_l; try omega. subst i.
+ apply Z.shiftr_le; auto. apply Z.lt_le_incl. apply H2.
+ - replace (2 ^ B) with ((2 ^ B - ((2 ^ B) >> log_cap i)) + ((2 ^ B) >> log_cap i)) by ring.
+ apply Z.add_lt_le_mono.
+ * eapply Z.le_lt_trans with (m := us [n]); try omega.
+ replace i with (pred n) by omega.
+ eapply Z.lt_le_trans; [ apply H1 | ].
+ break_if; omega.
+ * apply Z.shiftr_le. auto.
+ apply Z.le_trans with (m := us [i]); [ omega | ].
+ eapply Z.le_trans. apply Z.lt_le_incl. apply H1.
+ break_if; omega.
+ - replace (2 ^ B) with ((2 ^ B - ((2 ^ B) >> log_cap i)) + ((2 ^ B) >> log_cap i)) by ring.
+ apply Z.add_lt_le_mono.
+ * eapply Z.le_lt_trans with (m := us [n]); try omega.
+ replace i with (pred n) by omega.
+ eapply Z.lt_le_trans; [ apply H1 | ].
+ break_if; omega.
+ * apply Z.shiftr_le. auto. omega.
+ Qed.
+
+ Lemma bound_after_first_loop : forall n us,
+ length us = length limb_widths ->
+ (forall n, 0 <= nth_default 0 us n < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> log_cap (pred n))) ->
+ 0 <= (ModularBaseSystemList.carry_full us)[n] <
+ if eq_nat_dec n 0
+ then 2 * 2 ^ limb_widths [n]
+ else 2 ^ limb_widths [n].
+ Proof.
+ cbv [ModularBaseSystemList.carry_full full_carry_chain]; intros.
+ pose proof (bound_during_first_loop (length limb_widths) n us).
+ specialize_by eauto.
+ repeat (break_if; try omega).
Qed.
- Local Hint Resolve freeze_minimal_rep.
- Lemma rep_decode_mod : forall us vs x, rep us x -> rep vs x ->
- (BaseSystem.decode base us) mod modulus = (BaseSystem.decode base vs) mod modulus.
- Proof.
- unfold rep, decode; intros.
- intuition.
- repeat rewrite <-FieldToZ_ZToField.
- congruence.
- Qed.
+ (* TODO(jadep):
+ - Proof of bound after 3 loops
+ - Proof of correctness for [ge_modulus] and [cond_subtract_modulus]
+ - Proof of correctness for [freeze]
+ * freeze us = encode (decode us)
+ * decode us = x ->
+ canonicalized_BSToWord (freeze us)) = FToWord x
- Lemma minimal_rep_unique : forall us vs x,
- rep us x -> minimal_rep us -> carry_done us ->
- rep vs x -> minimal_rep vs -> carry_done vs ->
- us = vs.
- Proof.
- intros.
- match goal with Hrep1 : rep _ ?x, Hrep2 : rep _ ?x |- _ =>
- pose proof (rep_decode_mod _ _ _ Hrep1 Hrep2) as eqmod end.
- repeat match goal with Hmin : minimal_rep ?us |- _ => unfold minimal_rep in Hmin;
- rewrite <- Hmin in eqmod; clear Hmin end.
- apply Z.compare_eq_iff in eqmod.
- rewrite decode_compare in eqmod; unfold rep in *; auto; intuition; try congruence.
- apply compare_Eq; auto.
- congruence.
- Qed.
+ (where [canonicalized_BSToWord] uses bitwise operations to concatenate digits
+ in BaseSystem in canonical form, splitting along word capacities)
+ *)
- Lemma freeze_canonical : forall us vs x,
- pre_carry_bounds us -> rep us x ->
- pre_carry_bounds vs -> rep vs x ->
- freeze us = freeze vs.
- Proof.
- intros.
- assert (length us = length base) by (unfold rep in *; intuition).
- assert (length vs = length base) by (unfold rep in *; intuition).
- eapply minimal_rep_unique; eauto; rewrite freeze_length; assumption.
- Qed.
-*)
End CanonicalizationProofs.
diff --git a/src/ModularArithmetic/Pow2Base.v b/src/ModularArithmetic/Pow2Base.v
index a2c76016d..9d6cc2410 100644
--- a/src/ModularArithmetic/Pow2Base.v
+++ b/src/ModularArithmetic/Pow2Base.v
@@ -53,14 +53,19 @@ Section Pow2Base.
(Z.pow2_mod di (log_cap i),
Z.shiftr di (log_cap i)).
+ (* [fi] is fed [length us] and [S i] and produces the index of
+ the digit to which value should be added;
+ [fc] modifies the carried value before adding it to that digit *)
Definition carry_gen fc fi i := fun us =>
- let i := fi (length us) i in
+ let i := fi i in
let di := nth_default 0 us i in
let '(di', ci) := carry_single i di in
let us' := set_nth i di' us in
- add_to_nth (fi (length us) (S i)) (fc ci) us'.
+ add_to_nth (fi (S i)) (fc ci) us'.
- Definition carry_simple := carry_gen (fun ci => ci) (fun _ i => i).
+ (* carry_simple does not modify the carried value, and always adds it
+ to the digit with index [S i] *)
+ Definition carry_simple := carry_gen (fun ci => ci) (fun i => i).
Definition carry_simple_sequence is us := fold_right carry_simple us is.
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index 9255f033f..4b616c288 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -34,7 +34,7 @@ Section Pow2BaseProofs.
Lemma two_sum_firstn_limb_widths_nonzero n : 2^sum_firstn limb_widths n <> 0.
Proof. pose proof (two_sum_firstn_limb_widths_pos n); omega. Qed.
- Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length limb_widths)%nat ->
nth_error base i = Some b ->
nth_error limb_widths i = Some w ->
nth_error base (S i) = Some (two_p w * b).
@@ -42,7 +42,7 @@ Section Pow2BaseProofs.
induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
[rewrite (@nil_length0 Z) in *; omega | ].
- simpl in *; rewrite map_length in *.
+ simpl in *.
case_eq i; intros; subst.
+ subst; apply nth_error_first in nth_err_w.
apply nth_error_first in nth_err_b; subst.
@@ -60,15 +60,14 @@ Section Pow2BaseProofs.
Qed.
- Lemma nth_error_base : forall i, (i < length base)%nat ->
+ Lemma nth_error_base : forall i, (i < length limb_widths)%nat ->
nth_error base i = Some (two_p (sum_firstn limb_widths i)).
Proof.
induction i; intros.
+ unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
- + assert (i < length base)%nat as lt_i_length by omega.
+ + assert (i < length limb_widths)%nat as lt_i_length by omega.
specialize (IHi lt_i_length).
- rewrite base_from_limb_widths_length in lt_i_length.
destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
erewrite base_from_limb_widths_step; eauto.
f_equal.
@@ -86,19 +85,16 @@ Section Pow2BaseProofs.
eapply nth_error_value_In; eauto.
Qed.
- Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ Lemma nth_default_base : forall d i, (i < length limb_widths)%nat ->
nth_default d base i = 2 ^ (sum_firstn limb_widths i).
Proof.
intros ? ? i_lt_length.
- destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
- unfold nth_default.
- rewrite nth_err_x.
- rewrite nth_error_base in nth_err_x by assumption.
- rewrite two_p_correct in nth_err_x.
- congruence.
+ apply nth_error_value_eq_nth_default.
+ rewrite nth_error_base, two_p_correct by assumption.
+ reflexivity.
Qed.
- Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ Lemma base_succ : forall i, ((S i) < length limb_widths)%nat ->
nth_default 0 base (S i) mod nth_default 0 base i = 0.
Proof.
intros.
@@ -111,8 +107,7 @@ Section Pow2BaseProofs.
apply limb_widths_nonneg.
rewrite lw_eq.
apply in_eq.
- + assert (i < length base)%nat as i_lt_length by omega.
- rewrite base_from_limb_widths_length in *.
+ + assert (i < length limb_widths)%nat as i_lt_length by omega.
apply nth_error_length_exists_value in i_lt_length.
destruct i_lt_length as [x nth_err_x].
erewrite sum_firstn_succ; eauto.
@@ -126,6 +121,7 @@ Section Pow2BaseProofs.
Proof.
intros i b nth_err_b.
pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite base_from_limb_widths_length in *.
rewrite nth_error_base in nth_err_b by assumption.
rewrite two_p_correct in nth_err_b.
congruence.
@@ -168,19 +164,19 @@ Section Pow2BaseProofs.
Section make_base_vector.
Local Notation k := (sum_firstn limb_widths (length limb_widths)).
Context (limb_widths_match_modulus : forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i + j >= length limb_widths)%nat ->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i + j >= length base)%nat ->
let w_sum := sum_firstn limb_widths in
- k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j)
+ k + w_sum (i + j - length base)%nat <= w_sum i + w_sum j)
(limb_widths_good : forall i j, (i + j < length limb_widths)%nat ->
sum_firstn limb_widths (i + j) <=
sum_firstn limb_widths i + sum_firstn limb_widths j).
Lemma base_matches_modulus: forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i+j >= length limb_widths)%nat->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
let b := nth_default 0 base in
let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
b i * b j = r * (2^k * b (i+j-length base)%nat).
@@ -188,20 +184,20 @@ Section Pow2BaseProofs.
intros.
rewrite (Z.mul_comm r).
subst r.
+ rewrite base_from_limb_widths_length in *;
assert (i + j - length limb_widths < length limb_widths)%nat by omega.
- rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.mul_pos_pos;
- subst b; rewrite ?nth_default_base; zero_bounds; rewrite ?base_from_limb_widths_length;
- auto using sum_firstn_limb_widths_nonneg, limb_widths_nonneg).
+ rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; subst b; rewrite ?nth_default_base; zero_bounds;
+ assumption).
rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
f_equal.
subst b.
- repeat rewrite nth_default_base by (rewrite ?base_from_limb_widths_length; auto).
+ repeat rewrite nth_default_base by auto.
do 2 rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg.
symmetry.
apply Z.mod_same_pow.
split.
+ apply Z.add_nonneg_nonneg; auto using sum_firstn_limb_widths_nonneg.
- + rewrite base_from_limb_widths_length; auto using limb_widths_nonneg, limb_widths_match_modulus.
+ + auto using limb_widths_match_modulus.
Qed.
Lemma base_good : forall i j : nat,
@@ -211,7 +207,9 @@ Section Pow2BaseProofs.
b i * b j = r * b (i + j)%nat.
Proof.
intros; subst b r.
- repeat rewrite nth_default_base by (omega || auto).
+ clear limb_widths_match_modulus.
+ rewrite base_from_limb_widths_length in *.
+ repeat rewrite nth_default_base by omega.
rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; zero_bounds;
auto using sum_firstn_limb_widths_nonneg).
@@ -219,10 +217,11 @@ Section Pow2BaseProofs.
rewrite Z.mod_same_pow; try ring.
split; [ auto using sum_firstn_limb_widths_nonneg | ].
apply limb_widths_good.
- rewrite <-base_from_limb_widths_length; auto using limb_widths_nonneg.
+ assumption.
Qed.
End make_base_vector.
End Pow2BaseProofs.
+Hint Rewrite @base_from_limb_widths_length : distr_length.
Section BitwiseDecodeEncode.
Context {limb_widths} (bv : BaseSystem.BaseVector (base_from_limb_widths limb_widths))
@@ -232,7 +231,7 @@ Section BitwiseDecodeEncode.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation upper_bound := (upper_bound limb_widths).
- Lemma encode'_spec : forall x i, (i <= length base)%nat ->
+ Lemma encode'_spec : forall x i, (i <= length limb_widths)%nat ->
encode' limb_widths x i = BaseSystem.encode' base x upper_bound i.
Proof.
induction i; intros.
@@ -240,13 +239,12 @@ Section BitwiseDecodeEncode.
+ rewrite encode'_succ, <-IHi by omega.
simpl; do 2 f_equal.
rewrite Z.land_ones, Z.shiftr_div_pow2 by auto using sum_firstn_limb_widths_nonneg.
- match goal with H : (S _ <= length base)%nat |- _ =>
+ match goal with H : (S _ <= length limb_widths)%nat |- _ =>
apply le_lt_or_eq in H; destruct H end.
- repeat f_equal; rewrite nth_default_base by (omega || auto); reflexivity.
- repeat f_equal; try solve [rewrite nth_default_base by (omega || auto); reflexivity].
- rewrite nth_default_out_of_bounds by omega.
+ rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- rewrite <-base_from_limb_widths_length by auto.
congruence.
Qed.
@@ -258,20 +256,17 @@ Section BitwiseDecodeEncode.
Lemma base_upper_bound_compatible : @base_max_succ_divide base upper_bound.
Proof.
unfold base_max_succ_divide; intros i lt_Si_length.
+ rewrite base_from_limb_widths_length in lt_Si_length.
rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length;
rewrite !nth_default_base by (omega || auto).
- + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
- + rewrite nth_default_out_of_bounds by omega.
+ + rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- replace (length limb_widths) with (S (pred (length limb_widths))) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- replace i with (pred (length limb_widths)) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ replace (length limb_widths) with (S (pred (length limb_widths))) by omega.
+ replace i with (pred (length limb_widths)) by omega.
+ erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
Qed.
@@ -281,7 +276,7 @@ Section BitwiseDecodeEncode.
BaseSystem.decode base (encodeZ limb_widths x) = x mod upper_bound.
Proof.
intros.
- assert (length base = length limb_widths) by auto using base_from_limb_widths_length.
+ assert (length base = length limb_widths) by distr_length.
unfold encodeZ; rewrite encode'_spec by omega.
rewrite BaseSystemProofs.encode'_spec; unfold Pow2Base.upper_bound; try zero_bounds;
auto using sum_firstn_limb_widths_nonneg.
@@ -521,7 +516,7 @@ Section carrying_helper.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
- Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length base)%nat ->
+ Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (update_nth n f us) =
(let v := nth_default 0 us n in f v - v) * nth_default 0 base n + BaseSystem.decode base us.
Proof.
@@ -540,17 +535,17 @@ Section carrying_helper.
erewrite (nth_error_value_eq_nth_default _ _ us) by eassumption.
rewrite firstn_length in Heqn0.
rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length base) n). {
- rewrite (@nth_default_out_of_bounds _ _ base) by auto.
- rewrite skipn_all by omega.
+ destruct (le_lt_dec (length limb_widths) n). {
+ rewrite (@nth_default_out_of_bounds _ _ base) by (distr_length; auto).
+ rewrite skipn_all by (rewrite base_from_limb_widths_length; omega).
do 2 rewrite decode_base_nil.
ring_simplify; auto.
} {
- rewrite (skipn_nth_default n base 0) by omega.
+ rewrite (skipn_nth_default n base 0) by (distr_length; omega).
do 2 rewrite decode'_cons.
ring_simplify; ring.
} }
- { rewrite (nth_default_out_of_bounds _ base) by omega; ring_simplify.
+ { rewrite (nth_default_out_of_bounds _ base) by (distr_length; omega); ring_simplify.
etransitivity; rewrite BaseSystem.decode'_truncate; [ reflexivity | ].
apply f_equal.
autorewrite with push_firstn simpl_update_nth.
@@ -639,12 +634,12 @@ Section carrying_helper.
Hint Rewrite @length_add_to_nth : distr_length.
- Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (set_nth n x us) =
(x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; unfold set_nth; rewrite update_nth_sum by assumption; reflexivity. Qed.
- Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (add_to_nth n x us) =
x * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; rewrite add_to_nth_set_nth, set_nth_sum; try ring_simplify; auto. Qed.
@@ -696,7 +691,7 @@ Section carrying.
Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed.
Hint Rewrite @length_carry_simple : distr_length.
- Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ Lemma nth_default_base_succ : forall i, (S i < length limb_widths)%nat ->
nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
Proof.
intros.
@@ -705,13 +700,13 @@ Section carrying.
Qed.
Lemma carry_gen_decode_eq : forall fc fi i' us
- (i := fi (length base) i')
- (Si := fi (length base) (S i)),
- (length us = length base) ->
+ (i := fi i')
+ (Si := fi (S i)),
+ (length us = length limb_widths) ->
BaseSystem.decode base (carry_gen limb_widths fc fi i' us)
= (fc (nth_default 0 us i / 2 ^ log_cap i) *
(if eq_nat_dec Si (S i)
- then if lt_dec (S i) (length base)
+ then if lt_dec (S i) (length limb_widths)
then 2 ^ log_cap i * nth_default 0 base i
else 0
else nth_default 0 base Si)
@@ -719,29 +714,29 @@ Section carrying.
+ BaseSystem.decode base us.
Proof.
intros fc fi i' us i Si H; intros.
- destruct (eq_nat_dec 0 (length base));
+ destruct (eq_nat_dec 0 (length limb_widths));
[ destruct limb_widths, us, i; simpl in *; try congruence;
break_match;
unfold carry_gen, carry_single, add_to_nth;
autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length;
reflexivity
| ].
- (*assert (0 <= i < length base)%nat by (subst i; auto with arith).*)
+ (*assert (0 <= i < length limb_widths)%nat by (subst i; auto with arith).*)
assert (0 <= log_cap i) by auto using log_cap_nonneg.
assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia).
unfold carry_gen, carry_single.
- rewrite H; change (i' mod length base)%nat with i.
+ change (i' mod length limb_widths)%nat with i.
rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
rewrite set_nth_sum by omega.
unfold Z.pow2_mod.
rewrite Z.land_ones by auto using log_cap_nonneg.
rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg.
- change (fi (length base) i') with i.
+ change (fi i') with i.
subst Si.
repeat first [ ring
| match goal with H : _ = _ |- _ => rewrite !H in * end
| rewrite nth_default_base_succ by omega
- | rewrite !(nth_default_out_of_bounds _ base) by omega
+ | rewrite !(nth_default_out_of_bounds _ base) by (distr_length; omega)
| rewrite !(nth_default_out_of_bounds _ us) by omega
| rewrite Z.mod_eq by assumption
| progress distr_length
@@ -750,8 +745,8 @@ Section carrying.
Qed.
Lemma carry_simple_decode_eq : forall i us,
- (length us = length base) ->
- (i < (pred (length base)))%nat ->
+ (length us = length limb_widths) ->
+ (i < (pred (length limb_widths)))%nat ->
BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us.
Proof.
unfold carry_simple; intros; rewrite carry_gen_decode_eq by assumption.
@@ -790,11 +785,11 @@ Section carrying.
Lemma nth_default_carry_gen_full fc fi d i n us
: nth_default d (carry_gen limb_widths fc fi i us) n
= if lt_dec n (length us)
- then (if eq_nat_dec n (fi (length us) i)
+ then (if eq_nat_dec n (fi i)
then Z.pow2_mod (nth_default 0 us n) (log_cap n)
else nth_default 0 us n) +
- if eq_nat_dec n (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec n (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0
else d.
Proof.
@@ -826,11 +821,11 @@ Section carrying.
: forall fc fi i us,
(0 <= i < length us)%nat
-> nth_default 0 (carry_gen limb_widths fc fi i us) i
- = (if eq_nat_dec i (fi (length us) i)
+ = (if eq_nat_dec i (fi i)
then Z.pow2_mod (nth_default 0 us i) (log_cap i)
else nth_default 0 us i) +
- if eq_nat_dec i (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec i (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0.
Proof.
intros; autorewrite with push_nth_default natsimplify; break_match; omega.
@@ -848,7 +843,7 @@ Section carrying.
Hint Rewrite @nth_default_carry_simple using (omega || distr_length; omega) : push_nth_default.
End carrying.
-Hint Rewrite @length_carry_gen @base_from_limb_widths_length : distr_length.
+Hint Rewrite @length_carry_gen : distr_length.
Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length.
Hint Rewrite @nth_default_carry_simple_full @nth_default_carry_gen_full : push_nth_default.
Hint Rewrite @nth_default_carry_simple @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.