aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/ModularArithmetic/PseudoMersenneBaseParamProofs.v')
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParamProofs.v246
1 files changed, 246 insertions, 0 deletions
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
new file mode 100644
index 000000000..1a7b3316e
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
@@ -0,0 +1,246 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import VerdiTactics.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Section PseudoMersenneBaseParamProofs.
+ Context `{prm : PseudoMersenneBaseParams}.
+
+ Fixpoint base_from_limb_widths limb_widths :=
+ match limb_widths with
+ | nil => nil
+ | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw)
+ end.
+
+ Definition base := base_from_limb_widths limb_widths.
+
+ Lemma base_length : length base = length limb_widths.
+ Proof.
+ unfold base.
+ induction limb_widths; try reflexivity.
+ simpl; rewrite map_length; auto.
+ Qed.
+
+ Lemma nth_error_first : forall {T} (a b : T) l, nth_error (a :: l) 0 = Some b ->
+ a = b.
+ Proof.
+ intros; simpl in *.
+ unfold value in *.
+ congruence.
+ Qed.
+
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ nth_error base i = Some b ->
+ nth_error limb_widths i = Some w ->
+ nth_error base (S i) = Some (two_p w * b).
+ Proof.
+ unfold base; induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
+ unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
+ [rewrite (@nil_length0 Z) in *; omega | ].
+ simpl in *; rewrite map_length in *.
+ case_eq i; intros; subst.
+ + subst; apply nth_error_first in nth_err_w.
+ apply nth_error_first in nth_err_b; subst.
+ apply map_nth_error.
+ case_eq l; intros; subst; [simpl in *; omega | ].
+ unfold base_from_limb_widths; fold base_from_limb_widths.
+ reflexivity.
+ + simpl in nth_err_w.
+ apply nth_error_map in nth_err_w.
+ destruct nth_err_w as [x [A B]].
+ subst.
+ replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring.
+ apply map_nth_error.
+ apply IHl; auto; omega.
+ Qed.
+
+ Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x),
+ exists l', l = x :: l'.
+ Proof.
+ induction l; try discriminate; eexists.
+ apply nth_error_first in H.
+ subst; eauto.
+ Qed.
+
+ Lemma sum_firstn_succ : forall l i x,
+ nth_error l i = Some x ->
+ sum_firstn l (S i) = x + sum_firstn l i.
+ Proof.
+ unfold sum_firstn; induction l;
+ [intros; rewrite (@nth_error_nil_error Z) in *; congruence | ].
+ intros ? x nth_err_x; destruct (NPeano.Nat.eq_dec i 0).
+ + subst; simpl in *; unfold value in *.
+ congruence.
+ + rewrite <- (NPeano.Nat.succ_pred i) at 2 by auto.
+ rewrite <- (NPeano.Nat.succ_pred i) in nth_err_x by auto.
+ simpl. simpl in nth_err_x.
+ specialize (IHl (pred i) x).
+ rewrite NPeano.Nat.succ_pred in IHl by auto.
+ destruct (NPeano.Nat.eq_dec (pred i) 0).
+ - replace i with 1%nat in * by omega.
+ simpl. replace (pred 1) with 0%nat in * by auto.
+ apply nth_error_exists_first in nth_err_x.
+ destruct nth_err_x as [l' ?].
+ subst; simpl; ring.
+ - rewrite IHl by auto; ring.
+ Qed.
+
+ Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n.
+ Proof.
+ unfold sum_firstn; intros.
+ apply fold_right_invariant; try omega.
+ intros y In_y_lw ? ?.
+ apply Z.add_nonneg_nonneg; try assumption.
+ apply limb_widths_nonneg.
+ eapply In_firstn; eauto.
+ Qed.
+
+ Lemma k_nonneg : 0 <= k.
+ Proof.
+ apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma nth_error_base : forall i, (i < length base)%nat ->
+ nth_error base i = Some (two_p (sum_firstn limb_widths i)).
+ Proof.
+ induction i; intros.
+ + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
+ intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
+ +
+ assert (i < length base)%nat as lt_i_length by omega.
+ specialize (IHi lt_i_length).
+ rewrite base_length in lt_i_length.
+ destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
+ erewrite base_from_limb_widths_step; eauto.
+ f_equal.
+ simpl.
+ destruct (NPeano.Nat.eq_dec i 0).
+ - subst; unfold sum_firstn; simpl.
+ apply nth_error_exists_first in nth_err_w.
+ destruct nth_err_w as [l' lw_destruct]; subst.
+ rewrite lw_destruct.
+ ring_simplify.
+ f_equal; simpl; ring.
+ - erewrite sum_firstn_succ; eauto.
+ symmetry.
+ apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed.
+
+ Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ nth_default d base i = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros ? ? i_lt_length.
+ destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
+ unfold nth_default.
+ rewrite nth_err_x.
+ rewrite nth_error_base in nth_err_x by assumption.
+ rewrite two_p_correct in nth_err_x.
+ congruence.
+ Qed.
+
+ Lemma base_matches_modulus: forall i j,
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
+ let b := nth_default 0 base in
+ let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
+ b i * b j = r * (2^k * b (i+j-length base)%nat).
+ Proof.
+ intros.
+ rewrite (Z.mul_comm r).
+ subst r.
+ assert (i + j - length base < length base)%nat by omega.
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos;
+ [ | subst b; rewrite nth_default_base; try assumption ];
+ apply Z.pow_pos_nonneg; omega || apply k_nonneg || apply sum_firstn_limb_widths_nonneg).
+ rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
+ f_equal.
+ subst b.
+ repeat rewrite nth_default_base by assumption.
+ do 2 rewrite <- Z.pow_add_r by (apply sum_firstn_limb_widths_nonneg || apply k_nonneg).
+ symmetry.
+ apply mod_same_pow.
+ split.
+ + apply Z.add_nonneg_nonneg; apply sum_firstn_limb_widths_nonneg || apply k_nonneg.
+ + rewrite base_length in *; apply limb_widths_match_modulus; assumption.
+ Qed.
+
+ Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) mod nth_default 0 base i = 0.
+ Proof.
+ intros.
+ repeat rewrite nth_default_base by omega.
+ apply mod_same_pow.
+ split; [apply sum_firstn_limb_widths_nonneg | ].
+ destruct (NPeano.Nat.eq_dec i 0); subst.
+ + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq.
+ apply Z.add_nonneg_nonneg; try omega.
+ apply limb_widths_nonneg.
+ rewrite lw_eq.
+ apply in_eq.
+ + assert (i < length base)%nat as i_lt_length by omega.
+ rewrite base_length in *.
+ apply nth_error_length_exists_value in i_lt_length.
+ destruct i_lt_length as [x nth_err_x].
+ erewrite sum_firstn_succ; eauto.
+ apply nth_error_value_In in nth_err_x.
+ apply limb_widths_nonneg in nth_err_x.
+ omega.
+ Qed.
+
+ Lemma nth_error_subst : forall i b, nth_error base i = Some b ->
+ b = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros i b nth_err_b.
+ pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite nth_error_base in nth_err_b by assumption.
+ rewrite two_p_correct in nth_err_b.
+ congruence.
+ Qed.
+
+ Lemma base_positive : forall b : Z, In b base -> b > 0.
+ Proof.
+ intros b In_b_base.
+ apply In_nth_error_value in In_b_base.
+ destruct In_b_base as [i nth_err_b].
+ apply nth_error_subst in nth_err_b.
+ rewrite nth_err_b.
+ apply gt_lt_symmetry.
+ apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma b0_1 : forall x : Z, nth_default x base 0 = 1.
+ Proof.
+ unfold base; case_eq limb_widths; intros; [pose proof limb_widths_nonnil; congruence | reflexivity].
+ Qed.
+
+ Lemma base_good : forall i j : nat,
+ (i + j < length base)%nat ->
+ let b := nth_default 0 base in
+ let r := b i * b j / b (i + j)%nat in
+ b i * b j = r * b (i + j)%nat.
+ Proof.
+ intros; subst b r.
+ repeat rewrite nth_default_base by omega.
+ rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg).
+ rewrite <- Z.pow_add_r by apply sum_firstn_limb_widths_nonneg.
+ rewrite mod_same_pow; try ring.
+ split; [ apply sum_firstn_limb_widths_nonneg | ].
+ apply limb_widths_good.
+ rewrite <- base_length; assumption.
+ Qed.
+
+ Global Instance bv : BaseSystem.BaseVector base := {
+ base_positive := base_positive;
+ b0_1 := b0_1;
+ base_good := base_good
+ }.
+
+End PseudoMersenneBaseParamProofs.