aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-07-19 15:35:35 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-07-19 15:35:35 -0400
commitb4875d9ca86b5626512178c0bf48e324a6391b7b (patch)
treed0e2081c5ca29724c7f80a0bc9a0b035cf01d702 /src/ModularArithmetic
parent6bc05eaded36d4c2e31e8d9979ee8660ad179080 (diff)
parent51602bd1ccf7493e53f78afa958238cad14571f2 (diff)
merge
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r--src/ModularArithmetic/ExtendedBaseVector.v36
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v11
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v126
-rw-r--r--src/ModularArithmetic/Pow2Base.v19
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v188
5 files changed, 180 insertions, 200 deletions
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v
index d4df6040f..0afd6b484 100644
--- a/src/ModularArithmetic/ExtendedBaseVector.v
+++ b/src/ModularArithmetic/ExtendedBaseVector.v
@@ -7,12 +7,13 @@ Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Import Crypto.BaseSystemProofs.
Require Crypto.BaseSystem.
Local Open Scope Z_scope.
Section ExtendedBaseVector.
Context `{prm : PseudoMersenneBaseParams}.
- Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
+ Local Notation base := (base_from_limb_widths limb_widths).
(* This section defines a new BaseVector that has double the length of the BaseVector
* used to construct [params]. The coefficients of the new vector are as follows:
@@ -37,11 +38,19 @@ Section ExtendedBaseVector.
*
* This sum may be short enough to express using base; if not, we can reduce again.
*)
- Definition ext_base := base ++ (map (Z.mul (2^k)) base).
+ Definition ext_limb_widths := limb_widths ++ limb_widths.
+ Definition ext_base := base_from_limb_widths ext_limb_widths.
+ Lemma ext_base_alt : ext_base = base ++ (map (Z.mul (2^k)) base).
+ Proof.
+ unfold ext_base, ext_limb_widths.
+ rewrite base_from_limb_widths_app by auto using limb_widths_pos, Z.lt_le_incl.
+ rewrite two_p_equiv.
+ reflexivity.
+ Qed.
Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
Proof.
- unfold ext_base. intros b In_b_base.
+ rewrite ext_base_alt. intros b In_b_base.
rewrite in_app_iff in In_b_base.
destruct In_b_base as [In_b_base | In_b_extbase].
+ eapply BaseSystem.base_positive.
@@ -68,7 +77,7 @@ Section ExtendedBaseVector.
Lemma b0_1 : forall x, nth_default x ext_base 0 = 1.
Proof.
- intros. unfold ext_base.
+ intros. rewrite ext_base_alt.
rewrite nth_default_app.
assert (0 < length base)%nat by (apply base_length_nonzero).
destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega.
@@ -120,7 +129,7 @@ Section ExtendedBaseVector.
Proof.
intros.
subst b. subst r.
- unfold ext_base in *.
+ rewrite ext_base_alt in *.
rewrite app_length in H; rewrite map_length in H.
repeat rewrite nth_default_app.
repeat break_if; try omega.
@@ -157,6 +166,21 @@ Section ExtendedBaseVector.
Lemma extended_base_length:
length ext_base = (length base + length base)%nat.
Proof.
- unfold ext_base; rewrite app_length; rewrite map_length; auto.
+ rewrite ext_base_alt, app_length, map_length; auto.
Qed.
+
+ Lemma firstn_us_base_ext_base : forall (us : BaseSystem.digits),
+ (length us <= length base)%nat
+ -> firstn (length us) base = firstn (length us) ext_base.
+ Proof.
+ rewrite ext_base_alt; intros.
+ rewrite firstn_app_inleft; auto; omega.
+ Qed.
+
+ Lemma decode_short : forall (us : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ BaseSystem.decode base us = BaseSystem.decode ext_base us.
+ Proof. auto using decode_short_initial, firstn_us_base_ext_base. Qed.
End ExtendedBaseVector.
+
+Hint Rewrite @extended_base_length : distr_length.
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 7c7004dce..696f10438 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -3,6 +3,7 @@ Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
+Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystem Crypto.ModularArithmetic.ModularBaseSystem.
Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
@@ -118,9 +119,10 @@ Section Carries.
cbv [carry].
rewrite <- pull_app_if_sumbool.
cbv beta delta
- [carry carry_and_reduce Pow2Base.carry_simple Pow2Base.add_to_nth
+ [carry carry_and_reduce Pow2Base.carry_gen Pow2Base.carry_and_reduce_single Pow2Base.carry_simple
Z.pow2_mod Z.ones Z.pred
PseudoMersenneBaseParams.limb_widths].
+ rewrite !add_to_nth_set_nth.
change @Pow2Base.base_from_limb_widths with @base_from_limb_widths_opt.
change @nth_default with @nth_default_opt in *.
change @set_nth with @set_nth_opt in *.
@@ -374,7 +376,7 @@ Section Multiplication.
cbv [mul_bi'_step].
opt_step.
{ reflexivity. }
- { cbv [crosscoef ext_base].
+ { cbv [crosscoef].
change Z.div with Z_div_opt.
change Z.mul with Z_mul_opt at 2.
change @nth_default with @nth_default_opt.
@@ -403,7 +405,7 @@ Section Multiplication.
rewrite <- IHvsr; clear IHvsr.
unfold mul_bi'_opt, mul_bi'_opt_step.
apply f_equal2; [ | reflexivity ].
- cbv [crosscoef ext_base].
+ cbv [crosscoef].
change Z.div with Z_div_opt.
change Z.mul with Z_mul_opt at 2.
change @nth_default with @nth_default_opt.
@@ -475,7 +477,8 @@ Section Multiplication.
Definition mul_opt_sig (us vs : digits) : { b : digits | b = mul us vs }.
Proof.
eexists.
- cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce].
+ cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros reduce].
+ rewrite ext_base_alt.
rewrite <- mul'_opt_correct.
change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
rewrite Z.map_shiftl by apply k_nonneg.
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index 58f44e8e9..7d4f0107c 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -1,7 +1,7 @@
-Require Import Zpower ZArith.
+Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Numbers.Natural.Peano.NPeano.
-Require Import List.
-Require Import VerdiTactics.
+Require Import Coq.Lists.List.
+Require Import Crypto.Tactics.VerdiTactics.
Require Import Crypto.BaseSystem.
Require Import Crypto.BaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
@@ -19,6 +19,7 @@ Require Import Crypto.Util.Tactics.
Require Import Crypto.Util.Notations.
Local Open Scope Z_scope.
+Local Opaque add_to_nth carry_simple.
Section PseudoMersenneProofs.
Context `{prm :PseudoMersenneBaseParams}.
@@ -30,6 +31,7 @@ Section PseudoMersenneProofs.
Local Notation "u ~= x" := (rep u x).
Local Notation digits := (tuple Z (length limb_widths)).
Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg.
+
Local Hint Resolve log_cap_nonneg.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
@@ -85,27 +87,29 @@ Section PseudoMersenneProofs.
f_equal; assumption.
Qed.
- Lemma decode_short : forall (us : BaseSystem.digits),
- (length us <= length base)%nat ->
- BaseSystem.decode base us = BaseSystem.decode ext_base us.
+ Lemma firstn_us_base_ext_base : forall (us : BaseSystem.digits),
+ (length us <= length base)%nat
+ -> firstn (length us) base = firstn (length us) ext_base.
Proof.
- intros.
- unfold BaseSystem.decode, BaseSystem.decode'.
- rewrite combine_truncate_r.
- rewrite (combine_truncate_r us ext_base).
- f_equal; f_equal.
- unfold ext_base.
+ rewrite ext_base_alt; intros.
rewrite firstn_app_inleft; auto; omega.
Qed.
+ Local Hint Immediate firstn_us_base_ext_base.
+
+ Lemma decode_short : forall (us : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ BaseSystem.decode base us = BaseSystem.decode ext_base us.
+ Proof. auto using decode_short_initial. Qed.
+
+ Local Hint Immediate ExtBaseVector.
Lemma mul_rep_extended : forall (us vs : BaseSystem.digits),
(length us <= length base)%nat ->
(length vs <= length base)%nat ->
(BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs).
Proof.
- intros.
- rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega).
- f_equal; rewrite decode_short; auto.
+ intros; apply mul_rep_two_base; auto;
+ autorewrite with distr_length; try omega.
Qed.
Lemma modulus_nonzero : modulus <> 0.
@@ -137,7 +141,7 @@ Section PseudoMersenneProofs.
Proof.
intros.
unfold BaseSystem.decode; rewrite <- mul_each_rep.
- unfold ext_base.
+ rewrite ext_base_alt.
replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
rewrite base_mul_app.
rewrite <- mul_each_rep; auto.
@@ -463,8 +467,8 @@ Section CanonicalizationProofs.
Opaque Z.pow2_mod max_value.
(* automation *)
- Ltac carry_length_conditions' := unfold carry_full, add_to_nth;
- rewrite ?length_set_nth, ?length_carry, ?carry_sequence_length;
+ Ltac carry_length_conditions' := unfold carry_full;
+ rewrite ?length_set_nth, ?length_add_to_nth, ?length_carry, ?carry_sequence_length;
try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ].
Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'.
@@ -499,11 +503,8 @@ Section CanonicalizationProofs.
+ unfold carry_and_reduce.
add_set_nth.
apply pow2_mod_log_cap_bounds_upper.
- + unfold carry_simple.
- destruct (lt_dec i (length us)).
- - add_set_nth.
- apply pow2_mod_log_cap_bounds_upper.
- - rewrite nth_default_out_of_bounds by carry_length_conditions; auto.
+ + autorewrite with push_nth_default natsimplify.
+ destruct (lt_dec i (length us)); auto using pow2_mod_log_cap_bounds_upper.
Qed.
Local Hint Resolve nth_default_carry_bound_upper.
@@ -515,11 +516,8 @@ Section CanonicalizationProofs.
+ unfold carry_and_reduce.
add_set_nth.
apply pow2_mod_log_cap_bounds_lower.
- + unfold carry_simple.
- destruct (lt_dec i (length us)).
- - add_set_nth.
- apply pow2_mod_log_cap_bounds_lower.
- - rewrite nth_default_out_of_bounds by carry_length_conditions; omega.
+ + autorewrite with push_nth_default natsimplify.
+ break_if; auto using pow2_mod_log_cap_bounds_lower, Z.le_refl.
Qed.
Local Hint Resolve nth_default_carry_bound_lower.
@@ -533,10 +531,8 @@ Section CanonicalizationProofs.
rewrite nth_default_out_of_bounds; carry_length_conditions.
unfold carry_and_reduce.
carry_length_conditions.
- + unfold carry_simple.
- destruct (lt_dec (S i) (length us)).
- - add_set_nth; zero_bounds.
- - rewrite nth_default_out_of_bounds by carry_length_conditions; omega.
+ + autorewrite with push_nth_default natsimplify.
+ break_if; zero_bounds.
Qed.
Lemma carry_unaffected_low : forall i j us, ((0 < i < j)%nat \/ (i = 0 /\ j <> 0 /\ j <> pred (length base))%nat)->
@@ -548,12 +544,8 @@ Section CanonicalizationProofs.
break_if.
+ unfold carry_and_reduce.
add_set_nth.
- + unfold carry_simple.
- destruct (lt_dec i (length us)).
- - add_set_nth.
- - rewrite !nth_default_out_of_bounds by
- (omega || rewrite length_add_to_nth; rewrite length_set_nth; pose proof base_length_nonzero; omega).
- reflexivity.
+ + autorewrite with push_nth_default simpl_nth_default natsimplify.
+ repeat break_if; autorewrite with simpl_nth_default natsimplify; omega.
Qed.
Lemma carry_unaffected_high : forall i j us, (S j < i)%nat -> (length us = length base) ->
@@ -562,21 +554,27 @@ Section CanonicalizationProofs.
intros.
destruct (lt_dec i (length us));
[ | rewrite !nth_default_out_of_bounds by carry_length_conditions; reflexivity].
- unfold carry, carry_simple.
- break_if; [omega | add_set_nth].
+ unfold carry.
+ break_if; [omega | autorewrite with push_nth_default natsimplify; repeat break_if; omega ].
Qed.
+ Hint Rewrite max_bound_shiftr_eq_0 using omega : core.
+ Hint Rewrite pow2_mod_log_cap_small using assumption : core.
+
Lemma carry_nothing : forall i j us, (i < length base)%nat ->
(length us = length base)%nat ->
0 <= nth_default 0 us j <= max_value j ->
nth_default 0 (carry j us) i = nth_default 0 us i.
Proof.
- unfold carry, carry_simple, carry_and_reduce; intros.
- break_if; (add_set_nth;
- [ rewrite max_value_shiftr_eq_0 by omega; ring
- | subst; apply pow2_mod_log_cap_small; assumption ]).
+ unfold carry, carry_and_reduce; intros.
+ repeat (break_if
+ || subst
+ || (autorewrite with push_nth_default natsimplify core)
+ || omega).
Qed.
+ Hint Rewrite pow2_mod_log_cap_small using (intuition; auto using shiftr_eq_0_max_bound) : core.
+
Lemma carry_carry_done_done : forall i us,
(length us = length base)%nat ->
(i < length base)%nat ->
@@ -590,15 +588,17 @@ Section CanonicalizationProofs.
split; [ apply Hcarry_done; auto | ].
apply shiftr_eq_0_max_value.
apply Hcarry_done; auto.
- + unfold carry, carry_simple, carry_and_reduce; break_if; subst.
+ + unfold carry, carry_and_reduce; break_if; subst.
- add_set_nth; subst.
* rewrite shiftr_0_i, Z.mul_0_r, Z.add_0_l.
assumption.
* rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_value).
assumption.
- - rewrite shiftr_0_i by omega.
- rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_value).
- add_set_nth; subst; rewrite ?Z.add_0_l; auto.
+ - repeat (carry_length_conditions
+ || (autorewrite with push_nth_default natsimplify core zsimplify)
+ || break_if
+ || subst
+ || rewrite shiftr_0_i by omega).
Qed.
Lemma carry_sequence_chain_step : forall i us,
@@ -712,8 +712,10 @@ Section CanonicalizationProofs.
0 <= nth_default 0 (carry i us) (S i) < 2 ^ B.
Proof.
intros.
- unfold carry, carry_simple; break_if; try omega.
- add_set_nth.
+ unfold carry; break_if; try omega.
+ autorewrite with push_nth_default natsimplify.
+ break_if; try omega.
+ rewrite Z.add_comm.
replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega.
split; [ zero_bounds | ].
apply Z.add_lt_mono; try omega.
@@ -790,8 +792,9 @@ Section CanonicalizationProofs.
Proof.
induction i; intros; try omega.
simpl.
- unfold carry, carry_simple; break_if; try omega.
- add_set_nth.
+ unfold carry; break_if; try omega.
+ autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ].
+ rewrite Z.add_comm.
split.
+ zero_bounds; [destruct (eq_nat_dec i 0); subst | ].
- simpl; apply carry_full_bounds_0; auto.
@@ -849,8 +852,9 @@ Section CanonicalizationProofs.
0 <= nth_default 0 (carry_simple limb_widths i
(carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i).
Proof.
- unfold carry_simple; intros ? ? PCB length_eq ? IH.
- add_set_nth.
+ intros ? ? PCB length_eq ? IH.
+ autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ].
+ rewrite Z.add_comm.
split.
+ zero_bounds. destruct i;
[ simpl; pose proof (carry_full_2_bounds_0 us PCB length_eq); omega | ].
@@ -875,7 +879,9 @@ Section CanonicalizationProofs.
simpl; unfold carry.
break_if; try omega.
split; (destruct (eq_nat_dec i 0); subst;
- [ cbv [make_chain carry_sequence fold_right carry_simple]; add_set_nth
+ [ cbv [make_chain carry_sequence fold_right];
+ autorewrite with push_nth_default natsimplify distr_length; break_if; [ | omega ];
+ rewrite Z.add_comm
| eapply carry_full_2_bounds_succ; eauto; omega]).
+ zero_bounds.
- eapply carry_full_2_bounds_0; eauto.
@@ -923,8 +929,7 @@ Section CanonicalizationProofs.
break_if;
[ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
simpl.
- unfold carry_simple.
- add_set_nth.
+ autorewrite with push_nth_default natsimplify.
apply pow2_mod_log_cap_bounds_lower.
+ rewrite carry_unaffected_low by carry_length_conditions.
assert (0 < S i < length base)%nat by omega.
@@ -962,10 +967,11 @@ Section CanonicalizationProofs.
split; [ auto using carry_full_2_bounds_lower | ].
destruct i; rewrite <-max_value_log_cap, Z.lt_succ_r; auto.
apply carry_full_bounds; auto using carry_full_bounds_lower.
- - left; unfold carry, carry_simple.
+ - left; unfold carry.
break_if;
[ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
- add_set_nth. simpl.
+ autorewrite with push_nth_default natsimplify.
+ simpl.
remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x.
apply Z.le_trans with (m := (max_value 0 + c) - (1 + max_value 0)); try omega.
replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring.
@@ -1882,7 +1888,7 @@ Section CanonicalizationProofs.
Lemma freeze_canonical : forall us vs x,
pre_carry_bounds us -> rep us x ->
pre_carry_bounds vs -> rep vs x ->
- freeze us = freeze vs.
+ freeze us = freeze vs.
Proof.
intros.
assert (length us = length base) by (unfold rep in *; intuition).
diff --git a/src/ModularArithmetic/Pow2Base.v b/src/ModularArithmetic/Pow2Base.v
index f434a0c9f..acc96ad73 100644
--- a/src/ModularArithmetic/Pow2Base.v
+++ b/src/ModularArithmetic/Pow2Base.v
@@ -48,31 +48,20 @@ Section Pow2Base.
carrying. *)
Notation log_cap i := (nth_default 0 limb_widths i).
-
- Definition add_to_nth n (x:Z) xs :=
- set_nth n (x + nth_default 0 xs n) xs.
- (* TODO: Maybe we should use this instead? *)
- (*
Definition add_to_nth n (x:Z) xs :=
update_nth n (fun y => x + y) xs.
-
Definition carry_and_reduce_single i := fun di =>
(Z.pow2_mod di (log_cap i),
Z.shiftr di (log_cap i)).
- Definition carry_gen f i := fun us =>
- let i := (i mod length us)%nat in
+ Definition carry_gen fc fi i := fun us =>
+ let i := fi (length us) i in
let di := nth_default 0 us i in
let '(di', ci) := carry_and_reduce_single i di in
let us' := set_nth i di' us in
- add_to_nth ((S i) mod (length us)) (f ci) us'.
+ add_to_nth (fi (length us) (S i)) (fc ci) us'.
- Definition carry_simple := carry_gen (fun ci => ci).
- *)
- Definition carry_simple i := fun us =>
- let di := nth_default 0 us i in
- let us' := set_nth i (Z.pow2_mod di (log_cap i)) us in
- add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'.
+ Definition carry_simple := carry_gen (fun ci => ci) (fun _ i => i).
Definition carry_simple_sequence is us := fold_right carry_simple us is.
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index a7d7da800..db910ba93 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -150,6 +150,18 @@ Section Pow2BaseProofs.
reflexivity.
Qed.
+ Lemma base_from_limb_widths_app : forall l0 l
+ (l0_nonneg : forall x, In x l0 -> 0 <= x)
+ (l_nonneg : forall x, In x l -> 0 <= x),
+ base_from_limb_widths (l0 ++ l)
+ = base_from_limb_widths l0 ++ map (Z.mul (two_p (sum_firstn l0 (length l0)))) (base_from_limb_widths l).
+ Proof.
+ induction l0 as [|?? IHl0].
+ { simpl; intros; rewrite <- map_id at 1; apply map_ext; intros; omega. }
+ { simpl; intros; rewrite !IHl0, !map_app, map_map, sum_firstn_succ_cons, two_p_is_exp by auto with znonzero.
+ do 2 f_equal; apply map_ext; intros; lia. }
+ Qed.
+
End Pow2BaseProofs.
Section BitwiseDecodeEncode.
@@ -575,7 +587,7 @@ Section carrying_helper.
Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
BaseSystem.decode base (add_to_nth n x us) =
x * nth_default 0 base n + BaseSystem.decode base us.
- Proof. unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. Qed.
+ Proof. intros; rewrite add_to_nth_set_nth, set_nth_sum; try ring_simplify; auto. Qed.
Lemma add_to_nth_nth_default_full : forall n x l i d,
nth_default d (add_to_nth n x l) i =
@@ -615,12 +627,10 @@ Section carrying.
Local Notation log_cap i := (nth_default 0 limb_widths i).
Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg.
- (*
- Lemma length_carry_gen : forall f i us, length (carry_gen limb_widths f i us) = length us.
+ Lemma length_carry_gen : forall fc fi i us, length (carry_gen limb_widths fc fi i us) = length us.
Proof. intros; unfold carry_gen, carry_and_reduce_single; distr_length; reflexivity. Qed.
Hint Rewrite @length_carry_gen : distr_length.
- *)
Lemma length_carry_simple : forall i us, length (carry_simple limb_widths i us) = length us.
Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed.
@@ -634,26 +644,29 @@ Section carrying.
autorewrite with simpl_sum_firstn; reflexivity.
Qed.
- (*
- Lemma carry_gen_decode_eq : forall f i' us (i := (i' mod length base)%nat),
+ Lemma carry_gen_decode_eq : forall fc fi i' us
+ (i := fi (length base) i')
+ (Si := fi (length base) (S i)),
(length us = length base) ->
- BaseSystem.decode base (carry_gen limb_widths f i' us)
- = ((f (nth_default 0 us i / 2 ^ log_cap i))
- * (if eq_nat_dec (S i mod length base) 0
- then nth_default 0 base 0
- else (2 ^ log_cap i) * (nth_default 0 base i))
- - (nth_default 0 us i / 2 ^ log_cap i) * 2 ^ log_cap i * nth_default 0 base i
- )
+ BaseSystem.decode base (carry_gen limb_widths fc fi i' us)
+ = (fc (nth_default 0 us i / 2 ^ log_cap i) *
+ (if eq_nat_dec Si (S i)
+ then if lt_dec (S i) (length base)
+ then 2 ^ log_cap i * nth_default 0 base i
+ else 0
+ else nth_default 0 base Si)
+ - 2 ^ log_cap i * (nth_default 0 us i / 2 ^ log_cap i) * nth_default 0 base i)
+ BaseSystem.decode base us.
Proof.
- intros f i' us i H; intros.
+ intros fc fi i' us i Si H; intros.
destruct (eq_nat_dec 0 (length base));
[ destruct limb_widths, us, i; simpl in *; try congruence;
+ break_match;
unfold carry_gen, carry_and_reduce_single, add_to_nth;
autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length;
reflexivity
| ].
- assert (0 <= i < length base)%nat by (subst i; auto with arith).
+ (*assert (0 <= i < length base)%nat by (subst i; auto with arith).*)
assert (0 <= log_cap i) by auto using log_cap_nonneg.
assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia).
unfold carry_gen, carry_and_reduce_single.
@@ -663,17 +676,17 @@ Section carrying.
unfold Z.pow2_mod.
rewrite Z.land_ones by auto using log_cap_nonneg.
rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg.
- destruct (eq_nat_dec (S i mod length base) 0);
- repeat first [ ring
- | congruence
- | match goal with H : _ = _ |- _ => rewrite !H in * end
- | rewrite nth_default_base_succ by omega
- | rewrite !(nth_default_out_of_bounds _ base) by omega
- | rewrite !(nth_default_out_of_bounds _ us) by omega
- | rewrite Z.mod_eq by assumption
- | progress distr_length
- | progress autorewrite with natsimplify zsimplify in *
- | progress break_match ].
+ change (fi (length base) i') with i.
+ subst Si.
+ repeat first [ ring
+ | match goal with H : _ = _ |- _ => rewrite !H in * end
+ | rewrite nth_default_base_succ by omega
+ | rewrite !(nth_default_out_of_bounds _ base) by omega
+ | rewrite !(nth_default_out_of_bounds _ us) by omega
+ | rewrite Z.mod_eq by assumption
+ | progress distr_length
+ | progress autorewrite with natsimplify zsimplify in *
+ | progress break_match ].
Qed.
Lemma carry_simple_decode_eq : forall i us,
@@ -685,26 +698,7 @@ Section carrying.
autorewrite with natsimplify.
break_match; lia.
Qed.
-*)
- Lemma carry_simple_decode_eq : forall i us,
- (length us = length base) ->
- (i < (pred (length base)))%nat ->
- BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us.
- Proof.
- unfold carry_simple. intros.
- rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
- rewrite set_nth_sum by omega.
- unfold Z.pow2_mod.
- rewrite Z.land_ones by eauto using log_cap_nonneg.
- rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
- rewrite nth_default_base_succ by omega.
- rewrite Z.mul_assoc.
- rewrite (Z.mul_comm _ (2 ^ log_cap i)).
- rewrite Z.mul_div_eq; try ring.
- apply Z.gt_lt_iff.
- apply Z.pow_pos_nonneg; omega || eauto using log_cap_nonneg.
- Qed.
Lemma length_carry_simple_sequence : forall is us, length (carry_simple_sequence limb_widths is us) = length us.
Proof.
@@ -732,32 +726,23 @@ Section carrying.
Proof.
induction x; simpl; intuition.
Qed.
-(*
- Lemma nth_default_carry_gen_full : forall f d i n us,
- nth_default d (carry_gen limb_widths f i us) n
+
+ Lemma nth_default_carry_gen_full fc fi d i n us
+ : nth_default d (carry_gen limb_widths fc fi i us) n
= if lt_dec n (length us)
- then if eq_nat_dec n (i mod length us)
- then (if eq_nat_dec (S n) (length us)
- then (if eq_nat_dec n 0
- then f ((nth_default 0 us n) >> log_cap n)
- else 0)
- else 0)
- + Z.pow2_mod (nth_default 0 us n) (log_cap n)
- else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us))
- then f (nth_default 0 us (i mod length us) >> log_cap (i mod length us))
- else 0)
- + nth_default d us n
+ then (if eq_nat_dec n (fi (length us) i)
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n) +
+ if eq_nat_dec n (fi (length us) (S (fi (length us) i)))
+ then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ else 0
else d.
Proof.
unfold carry_gen, carry_and_reduce_single.
intros; autorewrite with push_nth_default natsimplify distr_length.
- edestruct lt_dec; [ | reflexivity ].
- change (S ?x) with (1 + x)%nat.
- rewrite (Nat.add_mod_idemp_r 1 i (length us)) by omega.
- autorewrite with natsimplify.
- change (1 + ?x)%nat with (S x).
- destruct (eq_nat_dec n (i mod length us));
- subst; repeat break_match; omega.
+ edestruct (lt_dec n (length us)) as [H|H]; [ | reflexivity ].
+ rewrite !(@nth_default_in_bounds Z 0 d) by assumption.
+ repeat break_match; subst; try omega; try rewrite_hyp *; omega.
Qed.
Hint Rewrite @nth_default_carry_gen_full : push_nth_default.
@@ -765,72 +750,45 @@ Section carrying.
Lemma nth_default_carry_simple_full : forall d i n us,
nth_default d (carry_simple limb_widths i us) n
= if lt_dec n (length us)
- then if eq_nat_dec n (i mod length us)
- then (if eq_nat_dec (S n) (length us)
- then (if eq_nat_dec n 0
- then (nth_default 0 us n >> log_cap n + Z.pow2_mod (nth_default 0 us n) (log_cap n))
- (* FIXME: The above is just [nth_default 0 us n], but do we really care about the case of [n = 0], [length us = 1]? *)
- else Z.pow2_mod (nth_default 0 us n) (log_cap n))
- else Z.pow2_mod (nth_default 0 us n) (log_cap n))
- else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us))
- then nth_default 0 us (i mod length us) >> log_cap (i mod length us)
- else 0)
- + nth_default d us n
+ then if eq_nat_dec n i
+ then Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else nth_default 0 us n +
+ if eq_nat_dec n (S i) then nth_default 0 us i >> log_cap i else 0
else d.
Proof.
- intros; unfold carry_simple; autorewrite with push_nth_default;
- repeat break_match; reflexivity.
+ intros; unfold carry_simple; autorewrite with push_nth_default.
+ repeat break_match; try omega; try reflexivity.
Qed.
Hint Rewrite @nth_default_carry_simple_full : push_nth_default.
Lemma nth_default_carry_gen
- : forall f i us,
+ : forall fc fi i us,
(0 <= i < length us)%nat
- -> nth_default 0 (carry_gen limb_widths f i us) i
- = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i)
- then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i)
- else Z.pow2_mod (nth_default 0 us i) (log_cap i)).
+ -> nth_default 0 (carry_gen limb_widths fc fi i us) i
+ = (if eq_nat_dec i (fi (length us) i)
+ then Z.pow2_mod (nth_default 0 us i) (log_cap i)
+ else nth_default 0 us i) +
+ if eq_nat_dec i (fi (length us) (S (fi (length us) i)))
+ then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ else 0.
Proof.
- unfold carry_gen, carry_and_reduce_single.
- intros; autorewrite with push_nth_default natsimplify; reflexivity.
+ intros; autorewrite with push_nth_default natsimplify; break_match; omega.
Qed.
Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
Lemma nth_default_carry_simple
- : forall f i us,
- (0 <= i < length us)%nat
- -> nth_default 0 (carry_gen limb_widths f i us) i
- = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i)
- then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i)
- else Z.pow2_mod (nth_default 0 us i) (log_cap i)).
- Proof.
- unfold carry_gen, carry_and_reduce_single.
- intros; autorewrite with push_nth_default natsimplify; reflexivity.
- Qed.
- Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
-
-
- Lemma nth_default_carry_gen
- : forall f i us,
+ : forall i us,
(0 <= i < length us)%nat
- -> nth_default 0 (carry_gen limb_widths f i us) i
- = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i)
- then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i)
- else Z.pow2_mod (nth_default 0 us i) (log_cap i)).
+ -> nth_default 0 (carry_simple limb_widths i us) i
+ = Z.pow2_mod (nth_default 0 us i) (log_cap i).
Proof.
- unfold carry_gen, carry_and_reduce_single.
- intros; autorewrite with push_nth_default natsimplify; reflexivity.
+ intros; autorewrite with push_nth_default natsimplify; break_match; omega.
Qed.
- Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
-*)
+ Hint Rewrite @nth_default_carry_simple using (omega || distr_length; omega) : push_nth_default.
End carrying.
-(*
Hint Rewrite @length_carry_gen : distr_length.
-*)
Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length.
-(*
-Hint Rewrite @nth_default_carry_gen_full : push_nth_default.
-Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
-*)
+Hint Rewrite @nth_default_carry_simple_full @nth_default_carry_gen_full : push_nth_default.
+Hint Rewrite @nth_default_carry_simple @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.