diff options
-rw-r--r-- | _CoqProject | 6 | ||||
-rw-r--r-- | src/BaseSystem.v | 470 | ||||
-rw-r--r-- | src/BaseSystemProofs.v | 490 | ||||
-rw-r--r-- | src/ModularArithmetic/ExtendedBaseVector.v | 163 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 631 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 449 | ||||
-rw-r--r-- | src/ModularArithmetic/PseudoMersenneBaseParamProofs.v | 277 | ||||
-rw-r--r-- | src/ModularArithmetic/PseudoMersenneBaseParams.v | 26 | ||||
-rw-r--r-- | src/ModularArithmetic/PseudoMersenneBaseRep.v | 43 |
9 files changed, 1480 insertions, 1075 deletions
diff --git a/_CoqProject b/_CoqProject index d7c71f7cb..cae81fdae 100644 --- a/_CoqProject +++ b/_CoqProject @@ -1,5 +1,6 @@ -R src Crypto src/BaseSystem.v +src/BaseSystemProofs.v src/BoundedIterOp.v src/EdDSAProofs.v src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v @@ -7,12 +8,17 @@ src/CompleteEdwardsCurve/DoubleAndAdd.v src/CompleteEdwardsCurve/ExtendedCoordinates.v src/CompleteEdwardsCurve/Pre.v src/Encoding/EncodingTheorems.v +src/ModularArithmetic/ExtendedBaseVector.v src/ModularArithmetic/FField.v src/ModularArithmetic/FNsatz.v src/ModularArithmetic/ModularArithmeticTheorems.v src/ModularArithmetic/ModularBaseSystem.v +src/ModularArithmetic/ModularBaseSystemProofs.v src/ModularArithmetic/Pre.v src/ModularArithmetic/PrimeFieldTheorems.v +src/ModularArithmetic/PseudoMersenneBaseParams.v +src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +src/ModularArithmetic/PseudoMersenneBaseRep.v src/ModularArithmetic/Tutorial.v src/Spec/CompleteEdwardsCurve.v src/Spec/Ed25519.v diff --git a/src/BaseSystem.v b/src/BaseSystem.v index e9e4bc9d4..f3a1a0fdb 100644 --- a/src/BaseSystem.v +++ b/src/BaseSystem.v @@ -16,7 +16,7 @@ Class BaseVector (base : list Z):= { }. Section BaseSystem. - Context (base : list Z) (base_vector : BaseVector base). + Context (base : list Z). (** [BaseSystem] implements an constrained positional number system. A wide variety of bases are supported: the base coefficients are not required to be powers of 2, and it is NOT necessarily the case that @@ -32,7 +32,6 @@ Section BaseSystem. Definition accumulate p acc := fst p * snd p + acc. Definition decode' bs u := fold_right accumulate 0 (combine u bs). Definition decode := decode' base. - Hint Unfold accumulate. (* Does not carry; z becomes the lowest and only digit. *) Definition encode (z : Z) := z :: nil. @@ -51,50 +50,7 @@ Section BaseSystem. Hint Extern 1 (@eq Z _ _) => ring. - Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs. - Proof. - unfold decode, decode'; induction bs; destruct us; destruct vs; boring. - Qed. - - Lemma decode_nil : forall bs, decode' bs nil = 0. - auto. - Qed. - Hint Rewrite decode_nil. - - Lemma decode_base_nil : forall us, decode' nil us = 0. - Proof. - intros; rewrite decode'_truncate; auto. - Qed. - Hint Rewrite decode_base_nil. - Definition mul_each u := map (Z.mul u). - Lemma mul_each_rep : forall bs u vs, - decode' bs (mul_each u vs) = u * decode' bs vs. - Proof. - unfold decode'; induction bs; destruct vs; boring. - Qed. - - Lemma base_eq_1cons: base = 1 :: skipn 1 base. - Proof. - pose proof (b0_1 0) as H. - destruct base; compute in H; try discriminate; boring. - Qed. - - Lemma decode'_cons : forall x1 x2 xs1 xs2, - decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2. - Proof. - unfold decode'; boring. - Qed. - Hint Rewrite decode'_cons. - - Lemma decode_cons : forall x us, - decode (x :: us) = x + decode (0 :: us). - Proof. - unfold decode; intros. - rewrite base_eq_1cons. - autorewrite with core; ring_simplify; auto. - Qed. - Fixpoint sub (us vs:digits) : digits := match us,vs with | u::us', v::vs' => u-v :: sub us' vs' @@ -102,76 +58,13 @@ Section BaseSystem. | nil, v::vs' => (0-v)::sub nil vs' end. - Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs. - Proof. - induction bs; destruct us; destruct vs; boring. - Qed. - - Lemma encode_rep : forall z, decode (encode z) = z. - Proof. - pose proof base_eq_1cons. - unfold decode, encode; destruct z; boring. - Qed. - - Lemma mul_each_base : forall us bs c, - decode' bs (mul_each c us) = decode' (mul_each c bs) us. - Proof. - induction us; destruct bs; boring. - Qed. - - Hint Rewrite (@nth_default_nil Z). - Hint Rewrite (@firstn_nil Z). - Hint Rewrite (@skipn_nil Z). - - Lemma base_app : forall us low high, - decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us). - Proof. - induction us; destruct low; boring. - Qed. - - Lemma base_mul_app : forall low c us, - decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) + - c * decode' low (skipn (length low) us). - Proof. - intros. - rewrite base_app; f_equal. - rewrite <- mul_each_rep. - rewrite mul_each_base. - reflexivity. - Qed. - Definition crosscoef i j : Z := let b := nth_default 0 base in (b(i) * b(j)) / b(i+j)%nat. Hint Unfold crosscoef. Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end. - Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0. - induction bs; destruct n; boring. - Qed. - Lemma length_zeros : forall n, length (zeros n) = n. - induction n; boring. - Qed. - Hint Rewrite length_zeros. - - Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m). - Proof. - induction n; boring. - Qed. - Hint Rewrite app_zeros_zeros. - - Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m). - Proof. - induction m; boring. - Qed. - Hint Rewrite zeros_app0. - - Lemma rev_zeros : forall n, rev (zeros n) = zeros n. - Proof. - induction n; boring. - Qed. - Hint Rewrite rev_zeros. - + (* mul' is multiplication with the SECOND ARGUMENT REVERSED and OUTPUT REVERSED *) Fixpoint mul_bi' (i:nat) (vsr:digits) := match vsr with @@ -180,317 +73,7 @@ Section BaseSystem. end. Definition mul_bi (i:nat) (vs:digits) : digits := zeros i ++ rev (mul_bi' i (rev vs)). - - Hint Unfold nth_default. - - Lemma decode_single : forall n bs x, - decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x. - Proof. - induction n; destruct bs; boring. - Qed. - Hint Rewrite decode_single. - - Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys. - Proof. - boring. - Qed. - Hint Rewrite zeros_rep peel_decode. - - Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs. - Proof. - induction xs; destruct bs; boring. - Qed. - - Lemma mul_bi'_zeros : forall n m, mul_bi' n (zeros m) = zeros m. - induction m; boring. - Qed. - Hint Rewrite mul_bi'_zeros. - - Lemma nth_error_base_nonzero : forall n x, - nth_error base n = Some x -> x <> 0. - Proof. - eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive. - Qed. - - Hint Rewrite plus_0_r. - - Lemma mul_bi_single : forall m n x, - (n + m < length base)%nat -> - decode (mul_bi n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n. - Proof. - unfold mul_bi, decode. - destruct m; simpl; simpl_list; simpl; intros. { - pose proof nth_error_base_nonzero as nth_nonzero. - case_eq base; [intros; boring | intros z l base_eq]. - specialize (b0_1 0); intro b0_1'. - rewrite base_eq in *. - rewrite nth_default_cons in b0_1'. - rewrite b0_1' in *. - boring; nth_tac. - rewrite Z_div_mul'; eauto. - destruct x; ring. - } { - ssimpl_list. - autorewrite with core. - rewrite app_assoc. - autorewrite with core. - unfold crosscoef; simpl; ring_simplify. - rewrite Nat.add_1_r. - rewrite base_good by auto. - rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto). - rewrite <- Z.mul_assoc. - rewrite <- Z.mul_comm. - rewrite <- Z.mul_assoc. - rewrite <- Z.mul_assoc. - destruct (Z.eq_dec x 0); subst; try ring. - rewrite Z.mul_cancel_l by auto. - rewrite <- base_good by auto. - ring. - } - Qed. - - Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil). - induction vs; boring; f_equal; ring. - Qed. - - Lemma set_higher : forall bs vs x, - decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x. - Proof. - intros. - rewrite set_higher'. - rewrite add_rep. - f_equal. - apply decode_single. - Qed. - - Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n. - induction n; auto. - simpl; f_equal; auto. - Qed. - - Lemma mul_bi'_n_nil : forall n, mul_bi' n nil = nil. - Proof. - unfold mul_bi; auto. - Qed. - Hint Rewrite mul_bi'_n_nil. - - Lemma add_nil_l : forall us, nil .+ us = us. - induction us; auto. - Qed. - Hint Rewrite add_nil_l. - - Lemma add_nil_r : forall us, us .+ nil = us. - induction us; auto. - Qed. - Hint Rewrite add_nil_r. - - Lemma add_first_terms : forall us vs a b, - (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs). - auto. - Qed. - Hint Rewrite add_first_terms. - - Lemma mul_bi'_cons : forall n x us, - mul_bi' n (x :: us) = x * crosscoef n (length us) :: mul_bi' n us. - Proof. - unfold mul_bi'; auto. - Qed. - - Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) -> - length (us .+ vs) = l. - Proof. - induction us, vs; boring. - erewrite (IHus vs (pred l)); boring. - Qed. - - Hint Rewrite app_nil_l. - Hint Rewrite app_nil_r. - - Lemma add_snoc_same_length : forall l us vs a b, - (length us = l) -> (length vs = l) -> - (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil. - Proof. - induction l, us, vs; boring; discriminate. - Qed. - - Lemma mul_bi'_add : forall us n vs l - (Hlus: length us = l) - (Hlvs: length vs = l), - mul_bi' n (rev (us .+ vs)) = - mul_bi' n (rev us) .+ mul_bi' n (rev vs). - Proof. - (* TODO(adamc): please help prettify this *) - induction us using rev_ind; - try solve [destruct vs; boring; congruence]. - destruct vs using rev_ind; boring; clear IHvs; simpl_list. - erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list. - repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list. - rewrite (IHus n vs (pred l)). - replace (length us) with (pred l). - replace (length vs) with (pred l). - rewrite (add_same_length us vs (pred l)). - f_equal; ring. - - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - Qed. - - Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n). - auto. - Qed. - - Lemma add_leading_zeros : forall n us vs, - (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs). - Proof. - induction n; boring. - Qed. - Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) -> - (rev us) .+ (rev vs) = rev (us .+ vs). - Proof. - induction us, vs; boring; try solve [subst; discriminate]. - rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega). - rewrite (IHus vs (pred l)) by omega; auto. - Qed. - Hint Rewrite rev_add_rev. - - Lemma mul_bi'_length : forall us n, length (mul_bi' n us) = length us. - Proof. - induction us, n; boring. - Qed. - Hint Rewrite mul_bi'_length. - - Lemma add_comm : forall us vs, us .+ vs = vs .+ us. - Proof. - induction us, vs; boring; f_equal; auto. - Qed. - - Hint Rewrite rev_length. - - Lemma mul_bi_add_same_length : forall n us vs l, - (length us = l) -> (length vs = l) -> - mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs. - Proof. - unfold mul_bi; boring. - rewrite add_leading_zeros. - erewrite mul_bi'_add; boring. - erewrite rev_add_rev; boring. - Qed. - - Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us. - Proof. - induction us; boring; f_equal; omega. - Qed. - - Hint Rewrite add_zeros_same_length. - Hint Rewrite minus_diag. - - Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat -> - us .+ vs = us .+ (vs ++ (zeros (length us - length vs))). - Proof. - induction us, vs; boring; f_equal; boring. - Qed. - - Lemma length_add_ge : forall us vs, (length us >= length vs)%nat -> - (length (us .+ vs) <= length us)%nat. - Proof. - intros. - rewrite add_trailing_zeros by trivial. - erewrite add_same_length by (pose proof app_length; boring); omega. - Qed. - - Lemma add_length_le_max : forall us vs, - (length (us .+ vs) <= max (length us) (length vs))%nat. - Proof. - intros; case_max; (rewrite add_comm; apply length_add_ge; omega) || - (apply length_add_ge; omega) . - Qed. - - Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us. - Proof. - induction us; boring. - Qed. - - Lemma sub_length_le_max : forall us vs, - (length (sub us vs) <= max (length us) (length vs))%nat. - Proof. - induction us, vs; boring. - rewrite sub_nil_length; auto. - Qed. - - Lemma mul_bi_length : forall us n, length (mul_bi n us) = (length us + n)%nat. - Proof. - pose proof mul_bi'_length; unfold mul_bi. - destruct us; repeat progress (simpl_list; boring). - Qed. - Hint Rewrite mul_bi_length. - - Lemma mul_bi_trailing_zeros : forall m n us, - mul_bi n us ++ zeros m = mul_bi n (us ++ zeros m). - Proof. - unfold mul_bi. - induction m; intros; try solve [boring]. - rewrite <- zeros_app0. - rewrite app_assoc. - repeat progress (boring; rewrite rev_app_distr). - Qed. - - Lemma mul_bi_add_longer : forall n us vs, - (length us >= length vs)%nat -> - mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs. - Proof. - boring. - rewrite add_trailing_zeros by auto. - rewrite (add_trailing_zeros (mul_bi n us) (mul_bi n vs)) - by (repeat (rewrite mul_bi_length); omega). - erewrite mul_bi_add_same_length by - (eauto; simpl_list; rewrite length_zeros; omega). - rewrite mul_bi_trailing_zeros. - repeat (f_equal; boring). - Qed. - - Lemma mul_bi_add : forall n us vs, - mul_bi n (us .+ vs) = (mul_bi n us) .+ (mul_bi n vs). - Proof. - intros; pose proof mul_bi_add_longer. - destruct (le_ge_dec (length us) (length vs)). { - replace (mul_bi n us .+ mul_bi n vs) - with (mul_bi n vs .+ mul_bi n us) - by (apply add_comm). - replace (us .+ vs) - with (vs .+ us) - by (apply add_comm). - boring. - } { - boring. - } - Qed. - - Lemma mul_bi_rep : forall i vs, - (i + length vs < length base)%nat -> - decode (mul_bi i vs) = decode vs * nth_default 0 base i. - Proof. - unfold decode. - induction vs using rev_ind; intros; try solve [unfold mul_bi; boring]. - assert (i + length vs < length base)%nat by - (rewrite app_length in *; boring). - - rewrite set_higher. - ring_simplify. - rewrite <- IHvs by auto; clear IHvs. - rewrite <- mul_bi_single by auto. - rewrite <- add_rep. - rewrite <- mul_bi_add. - rewrite set_higher'. - auto. - Qed. - (* mul' is multiplication with the FIRST ARGUMENT REVERSED *) Fixpoint mul' (usr vs:digits) : digits := match usr with @@ -500,51 +83,6 @@ Section BaseSystem. end. Definition mul us := mul' (rev us). - Lemma mul'_rep : forall us vs, - (length us + length vs <= length base)%nat -> - decode (mul' (rev us) vs) = decode us * decode vs. - Proof. - unfold decode. - induction us using rev_ind; boring. - - assert (length us + length vs < length base)%nat by - (rewrite app_length in *; boring). - - ssimpl_list. - rewrite add_rep. - boring. - rewrite set_higher. - rewrite mul_each_rep. - rewrite mul_bi_rep by auto. - unfold decode; ring. - Qed. - - Lemma mul_rep : forall us vs, - (length us + length vs <= length base)%nat -> - decode (mul us vs) = decode us * decode vs. - Proof. - exact mul'_rep. - Qed. - - Lemma mul'_length: forall us vs, - (length (mul' us vs) <= length us + length vs)%nat. - Proof. - pose proof add_length_le_max. - induction us; boring. - unfold mul_each. - simpl_list; case_max; boring; omega. - Qed. - - Lemma mul_length: forall us vs, - (length (mul us vs) <= length us + length vs)%nat. - Proof. - intros; unfold mul. - rewrite mul'_length. - rewrite rev_length; omega. - Qed. - -(* Print Assumptions mul_rep. *) - End BaseSystem. Section PolynomialBaseCoefs. @@ -615,7 +153,6 @@ Section PolynomialBaseCoefs. split; apply Z.pow_nonzero; try apply Zle_0_nat; try solve [intro H; inversion H]. Qed. - Print BaseVector. Instance PolyBaseVector : BaseVector poly_base := { base_positive := poly_base_positive; b0_1 := poly_b0_1; @@ -633,8 +170,7 @@ Section BaseSystemExample. Qed. Definition base2 := poly_base 2 baseLength. - About mul. - Example three_times_two : @mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0]. + Example three_times_two : mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0]. Proof. reflexivity. Qed. diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v new file mode 100644 index 000000000..84374fe8f --- /dev/null +++ b/src/BaseSystemProofs.v @@ -0,0 +1,490 @@ +Require Import List. +Require Import Util.ListUtil Util.CaseUtil Util.ZUtil. +Require Import ZArith.ZArith ZArith.Zdiv. +Require Import Omega NPeano Arith. +Require Import Crypto.BaseSystem. +Local Open Scope Z. + +Local Infix ".+" := add (at level 50). + +Local Hint Extern 1 (@eq Z _ _) => ring. + +Section BaseSystemProofs. + Context `(base_vector : BaseVector). + + Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us). + Proof. + unfold decode'; intros; f_equal; apply combine_truncate_l. + Qed. + + Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs. + Proof. + unfold decode', accumulate; induction bs; destruct us, vs; boring; ring. + Qed. + + Lemma decode_nil : forall bs, decode' bs nil = 0. + auto. + Qed. + Hint Rewrite decode_nil. + + Lemma decode_base_nil : forall us, decode' nil us = 0. + Proof. + intros; rewrite decode'_truncate; auto. + Qed. + Hint Rewrite decode_base_nil. + + Lemma mul_each_rep : forall bs u vs, + decode' bs (mul_each u vs) = u * decode' bs vs. + Proof. + unfold decode', accumulate; induction bs; destruct vs; boring; ring. + Qed. + + Lemma base_eq_1cons: base = 1 :: skipn 1 base. + Proof. + pose proof (b0_1 0) as H. + destruct base; compute in H; try discriminate; boring. + Qed. + + Lemma decode'_cons : forall x1 x2 xs1 xs2, + decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2. + Proof. + unfold decode', accumulate; boring; ring. + Qed. + Hint Rewrite decode'_cons. + + Lemma decode_cons : forall x us, + decode base (x :: us) = x + decode base (0 :: us). + Proof. + unfold decode; intros. + rewrite base_eq_1cons. + autorewrite with core; ring_simplify; auto. + Qed. + + Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs. + Proof. + induction bs; destruct us; destruct vs; boring; ring. + Qed. + + Lemma encode_rep : forall z, decode base (encode z) = z. + Proof. + pose proof base_eq_1cons. + unfold decode, encode; destruct z; boring. + Qed. + + Lemma mul_each_base : forall us bs c, + decode' bs (mul_each c us) = decode' (mul_each c bs) us. + Proof. + induction us; destruct bs; boring; ring. + Qed. + + Hint Rewrite (@nth_default_nil Z). + Hint Rewrite (@firstn_nil Z). + Hint Rewrite (@skipn_nil Z). + + Lemma base_app : forall us low high, + decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us). + Proof. + induction us; destruct low; boring. + Qed. + + Lemma base_mul_app : forall low c us, + decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) + + c * decode' low (skipn (length low) us). + Proof. + intros. + rewrite base_app; f_equal. + rewrite <- mul_each_rep. + rewrite mul_each_base. + reflexivity. + Qed. + + Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0. + induction bs; destruct n; boring. + Qed. + Lemma length_zeros : forall n, length (zeros n) = n. + induction n; boring. + Qed. + Hint Rewrite length_zeros. + + Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m). + Proof. + induction n; boring. + Qed. + Hint Rewrite app_zeros_zeros. + + Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m). + Proof. + induction m; boring. + Qed. + Hint Rewrite zeros_app0. + + Lemma rev_zeros : forall n, rev (zeros n) = zeros n. + Proof. + induction n; boring. + Qed. + Hint Rewrite rev_zeros. + + Hint Unfold nth_default. + + Lemma decode_single : forall n bs x, + decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x. + Proof. + induction n; destruct bs; boring. + Qed. + Hint Rewrite decode_single. + + Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys. + Proof. + boring. + Qed. + Hint Rewrite zeros_rep peel_decode. + + Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs. + Proof. + induction xs; destruct bs; boring. + Qed. + + Lemma mul_bi'_zeros : forall n m, mul_bi' base n (zeros m) = zeros m. + induction m; boring. + Qed. + Hint Rewrite mul_bi'_zeros. + + Lemma nth_error_base_nonzero : forall n x, + nth_error base n = Some x -> x <> 0. + Proof. + eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive. + Qed. + + Hint Rewrite plus_0_r. + + Lemma mul_bi_single : forall m n x, + (n + m < length base)%nat -> + decode base (mul_bi base n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n. + Proof. + unfold mul_bi, decode. + destruct m; simpl; simpl_list; simpl; intros. { + pose proof nth_error_base_nonzero as nth_nonzero. + case_eq base; [intros; boring | intros z l base_eq]. + specialize (b0_1 0); intro b0_1'. + rewrite base_eq in *. + rewrite nth_default_cons in b0_1'. + rewrite b0_1' in *. + unfold crosscoef. + autounfold; autorewrite with core. + unfold nth_default. + nth_tac. + rewrite Z.mul_1_r. + rewrite Z_div_same_full. + destruct x; ring. + eapply nth_nonzero; eauto. + } { + ssimpl_list. + autorewrite with core. + rewrite app_assoc. + autorewrite with core. + unfold crosscoef; simpl; ring_simplify. + rewrite Nat.add_1_r. + rewrite base_good by auto. + rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto). + rewrite <- Z.mul_assoc. + rewrite <- Z.mul_comm. + rewrite <- Z.mul_assoc. + rewrite <- Z.mul_assoc. + destruct (Z.eq_dec x 0); subst; try ring. + rewrite Z.mul_cancel_l by auto. + rewrite <- base_good by auto. + ring. + } + Qed. + + Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil). + induction vs; boring; f_equal; ring. + Qed. + + Lemma set_higher : forall bs vs x, + decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x. + Proof. + intros. + rewrite set_higher'. + rewrite add_rep. + f_equal. + apply decode_single. + Qed. + + Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n. + induction n; auto. + simpl; f_equal; auto. + Qed. + + Lemma mul_bi'_n_nil : forall n, mul_bi' base n nil = nil. + Proof. + unfold mul_bi; auto. + Qed. + Hint Rewrite mul_bi'_n_nil. + + Lemma add_nil_l : forall us, nil .+ us = us. + induction us; auto. + Qed. + Hint Rewrite add_nil_l. + + Lemma add_nil_r : forall us, us .+ nil = us. + induction us; auto. + Qed. + Hint Rewrite add_nil_r. + + Lemma add_first_terms : forall us vs a b, + (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs). + auto. + Qed. + Hint Rewrite add_first_terms. + + Lemma mul_bi'_cons : forall n x us, + mul_bi' base n (x :: us) = x * crosscoef base n (length us) :: mul_bi' base n us. + Proof. + unfold mul_bi'; auto. + Qed. + + Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) -> + length (us .+ vs) = l. + Proof. + induction us, vs; boring. + erewrite (IHus vs (pred l)); boring. + Qed. + + Hint Rewrite app_nil_l. + Hint Rewrite app_nil_r. + + Lemma add_snoc_same_length : forall l us vs a b, + (length us = l) -> (length vs = l) -> + (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil. + Proof. + induction l, us, vs; boring; discriminate. + Qed. + + Lemma mul_bi'_add : forall us n vs l + (Hlus: length us = l) + (Hlvs: length vs = l), + mul_bi' base n (rev (us .+ vs)) = + mul_bi' base n (rev us) .+ mul_bi' base n (rev vs). + Proof. + (* TODO(adamc): please help prettify this *) + induction us using rev_ind; + try solve [destruct vs; boring; congruence]. + destruct vs using rev_ind; boring; clear IHvs; simpl_list. + erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list. + repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list. + rewrite (IHus n vs (pred l)). + replace (length us) with (pred l). + replace (length vs) with (pred l). + rewrite (add_same_length us vs (pred l)). + f_equal; ring. + + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + Qed. + + Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n). + auto. + Qed. + + Lemma add_leading_zeros : forall n us vs, + (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs). + Proof. + induction n; boring. + Qed. + + Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) -> + (rev us) .+ (rev vs) = rev (us .+ vs). + Proof. + induction us, vs; boring; try solve [subst; discriminate]. + rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega). + rewrite (IHus vs (pred l)) by omega; auto. + Qed. + Hint Rewrite rev_add_rev. + + Lemma mul_bi'_length : forall us n, length (mul_bi' base n us) = length us. + Proof. + induction us, n; boring. + Qed. + Hint Rewrite mul_bi'_length. + + Lemma add_comm : forall us vs, us .+ vs = vs .+ us. + Proof. + induction us, vs; boring; f_equal; auto. + Qed. + + Hint Rewrite rev_length. + + Lemma mul_bi_add_same_length : forall n us vs l, + (length us = l) -> (length vs = l) -> + mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs. + Proof. + unfold mul_bi; boring. + rewrite add_leading_zeros. + erewrite mul_bi'_add; boring. + erewrite rev_add_rev; boring. + Qed. + + Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us. + Proof. + induction us; boring; f_equal; omega. + Qed. + + Hint Rewrite add_zeros_same_length. + Hint Rewrite minus_diag. + + Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat -> + us .+ vs = us .+ (vs ++ (zeros (length us - length vs))). + Proof. + induction us, vs; boring; f_equal; boring. + Qed. + + Lemma length_add_ge : forall us vs, (length us >= length vs)%nat -> + (length (us .+ vs) <= length us)%nat. + Proof. + intros. + rewrite add_trailing_zeros by trivial. + erewrite add_same_length by (pose proof app_length; boring); omega. + Qed. + + Lemma add_length_le_max : forall us vs, + (length (us .+ vs) <= max (length us) (length vs))%nat. + Proof. + intros; case_max; (rewrite add_comm; apply length_add_ge; omega) || + (apply length_add_ge; omega) . + Qed. + + Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us. + Proof. + induction us; boring. + Qed. + + Lemma sub_length_le_max : forall us vs, + (length (sub us vs) <= max (length us) (length vs))%nat. + Proof. + induction us, vs; boring. + rewrite sub_nil_length; auto. + Qed. + + Lemma mul_bi_length : forall us n, length (mul_bi base n us) = (length us + n)%nat. + Proof. + pose proof mul_bi'_length; unfold mul_bi. + destruct us; repeat progress (simpl_list; boring). + Qed. + Hint Rewrite mul_bi_length. + + Lemma mul_bi_trailing_zeros : forall m n us, + mul_bi base n us ++ zeros m = mul_bi base n (us ++ zeros m). + Proof. + unfold mul_bi. + induction m; intros; try solve [boring]. + rewrite <- zeros_app0. + rewrite app_assoc. + repeat progress (boring; rewrite rev_app_distr). + Qed. + + Lemma mul_bi_add_longer : forall n us vs, + (length us >= length vs)%nat -> + mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs. + Proof. + boring. + rewrite add_trailing_zeros by auto. + rewrite (add_trailing_zeros (mul_bi base n us) (mul_bi base n vs)) + by (repeat (rewrite mul_bi_length); omega). + erewrite mul_bi_add_same_length by + (eauto; simpl_list; rewrite length_zeros; omega). + rewrite mul_bi_trailing_zeros. + repeat (f_equal; boring). + Qed. + + Lemma mul_bi_add : forall n us vs, + mul_bi base n (us .+ vs) = (mul_bi base n us) .+ (mul_bi base n vs). + Proof. + intros; pose proof mul_bi_add_longer. + destruct (le_ge_dec (length us) (length vs)). { + rewrite add_comm. + rewrite (add_comm (mul_bi base n us)). + boring. + } { + boring. + } + Qed. + + Lemma mul_bi_rep : forall i vs, + (i + length vs < length base)%nat -> + decode base (mul_bi base i vs) = decode base vs * nth_default 0 base i. + Proof. + unfold decode. + induction vs using rev_ind; intros; try solve [unfold mul_bi; boring]. + assert (i + length vs < length base)%nat by + (rewrite app_length in *; boring). + + rewrite set_higher. + ring_simplify. + rewrite <- IHvs by auto; clear IHvs. + rewrite <- mul_bi_single by auto. + rewrite <- add_rep. + rewrite <- mul_bi_add. + rewrite set_higher'. + auto. + Qed. + + (* mul' is multiplication with the FIRST ARGUMENT REVERSED *) + Fixpoint mul' (usr vs:digits) : digits := + match usr with + | u::usr' => + mul_each u (mul_bi base (length usr') vs) .+ mul' usr' vs + | _ => nil + end. + Definition mul us := mul' (rev us). + + Lemma mul'_rep : forall us vs, + (length us + length vs <= length base)%nat -> + decode base (mul' (rev us) vs) = decode base us * decode base vs. + Proof. + unfold decode. + induction us using rev_ind; boring. + + assert (length us + length vs < length base)%nat by + (rewrite app_length in *; boring). + + ssimpl_list. + rewrite add_rep. + boring. + rewrite set_higher. + rewrite mul_each_rep. + rewrite mul_bi_rep by auto. + unfold decode; ring. + Qed. + + Lemma mul_rep : forall us vs, + (length us + length vs <= length base)%nat -> + decode base (mul us vs) = decode base us * decode base vs. + Proof. + exact mul'_rep. + Qed. + + Lemma mul'_length: forall us vs, + (length (mul' us vs) <= length us + length vs)%nat. + Proof. + pose proof add_length_le_max. + induction us; boring. + unfold mul_each. + simpl_list; case_max; boring; omega. + Qed. + + Lemma mul_length: forall us vs, + (length (mul us vs) <= length us + length vs)%nat. + Proof. + intros; unfold mul. + rewrite mul'_length. + rewrite rev_length; omega. + Qed. + +End BaseSystemProofs. diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v new file mode 100644 index 000000000..746791d8d --- /dev/null +++ b/src/ModularArithmetic/ExtendedBaseVector.v @@ -0,0 +1,163 @@ +Require Import Zpower ZArith. +Require Import List. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import VerdiTactics. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Section ExtendedBaseVector. + Context `{prm : PseudoMersenneBaseParams}. + Existing Instance bv. + + (* This section defines a new BaseVector that has double the length of the BaseVector + * used to construct [params]. The coefficients of the new vector are as follows: + * + * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i] + * + * The purpose of this construction is that it allows us to multiply numbers expressed + * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as + * vectors of digits; the value of a digit vector is obtained by doing a dot product with + * the base vector.) So if x, y are digit vectors: + * + * (x \dot base) * (y \dot base) = (z \dot ext_base) + * + * Then we can separate z into its first and second halves: + * + * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base) + * + * Now, if we want to reduce the product modulo 2 ^ k - c: + * + * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c) + * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c) + * + * This sum may be short enough to express using base; if not, we can reduce again. + *) + Definition ext_base := base ++ (map (Z.mul (2^k)) base). + + Lemma ext_base_positive : forall b, In b ext_base -> b > 0. + Proof. + unfold ext_base. intros b In_b_base. + rewrite in_app_iff in In_b_base. + destruct In_b_base as [In_b_base | In_b_extbase]. + + eapply BaseSystem.base_positive. + eapply In_b_base. + + eapply in_map_iff in In_b_extbase. + destruct In_b_extbase as [b' [b'_2k_b In_b'_base]]. + subst. + specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos. + replace 0 with (2 ^ k * 0) by ring. + apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition]. + rewrite Z.gt_lt_iff. + apply Z.pow_pos_nonneg; intuition. + pose proof k_nonneg; omega. + Qed. + + Lemma base_length_nonzero : (0 < length base)%nat. + Proof. + assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1). + unfold nth_default in H. + case_eq (nth_error base 0); intros; + try (rewrite H0 in H; omega). + apply (nth_error_value_length _ 0 base z); auto. + Qed. + + Lemma b0_1 : forall x, nth_default x ext_base 0 = 1. + Proof. + intros. unfold ext_base. + rewrite nth_default_app. + assert (0 < length base)%nat by (apply base_length_nonzero). + destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega. + Qed. + + Lemma two_k_nonzero : 2^k <> 0. + Proof. + pose proof (Z.pow_eq_0 2 k k_nonneg). + intuition. + Qed. + + Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> + nth_default 0 (map (Z.mul (2 ^ k)) base) n = + (2 ^ k) * (nth_default 0 base n). + Proof. + intros. + erewrite map_nth_default; auto. + Qed. + + Lemma base_good_over_boundary : forall + (i : nat) + (l : (i < length base)%nat) + (j' : nat) + (Hj': (i + j' < length base)%nat) + , + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') / + (2 ^ k * nth_default 0 base (i + j')) * + (2 ^ k * nth_default 0 base (i + j')) + . + Proof. + intros. + remember (nth_default 0 base) as b. + rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). + replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) + with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. + rewrite Z.mul_cancel_l by (exact two_k_nonzero). + replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) + with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. + subst b. + apply (BaseSystem.base_good i j'); omega. + Qed. + + Lemma ext_base_good : + forall i j, (i+j < length ext_base)%nat -> + let b := nth_default 0 ext_base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat. + Proof. + intros. + subst b. subst r. + unfold ext_base in *. + rewrite app_length in H; rewrite map_length in H. + repeat rewrite nth_default_app. + destruct (lt_dec i (length base)); + destruct (lt_dec j (length base)); + destruct (lt_dec (i + j) (length base)); + try omega. + { (* i < length base, j < length base, i + j < length base *) + apply BaseSystem.base_good; auto. + } { (* i < length base, j < length base, i + j >= length base *) + rewrite (map_nth_default _ _ _ _ 0) by omega. + apply base_matches_modulus; omega. + } { (* i < length base, j >= length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (j - length base)%nat as j'. + replace (i + j - length base)%nat with (i + j')%nat by omega. + replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) + with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } { (* i >= length base, j < length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (i - length base)%nat as i'. + replace (i + j - length base)%nat with (j + i')%nat by omega. + replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) + with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } + Qed. + + Instance ExtBaseVector : BaseSystem.BaseVector ext_base := { + base_positive := ext_base_positive; + b0_1 := b0_1; + base_good := ext_base_good + }. + + Lemma extended_base_length: + length ext_base = (length base + length base)%nat. + Proof. + unfold ext_base; rewrite app_length; rewrite map_length; auto. + Qed. +End ExtendedBaseVector. + diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index b0e493871..d121e9d5c 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -1,643 +1,58 @@ Require Import Zpower ZArith. Require Import List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import VerdiTactics. -Require Crypto.BaseSystem. +Require Import Crypto.BaseSystem Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector. Local Open Scope Z_scope. -Class PseudoMersenneBaseParams (modulus : Z) (base : list Z) (bv : BaseSystem.BaseVector base) := { - k : Z; - c : Z; - modulus_pseudomersenne : modulus = 2^k - c; - prime_modulus : Znumtheory.prime modulus; - base_matches_modulus : - forall i j, - (i < length base)%nat -> - (j < length base)%nat -> - (i+j >= length base)%nat-> - let b := nth_default 0 base in - let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in - b i * b j = r * (2^k * b (i+j-length base)%nat); - base_succ : forall i, ((S i) < length base)%nat -> - let b := nth_default 0 base in - b (S i) mod b i = 0; - base_tail_matches_modulus: - 2^k mod nth_default 0 base (pred (length base)) = 0; - k_nonneg : 0 <= k (* Probably implied by modulus_pseudomersenne. *) -}. -(* -Class RepZMod (modulus : Z) := { - T : Type; - encode : F modulus -> T; - decode : T -> F modulus; - - rep : T -> F modulus -> Prop; - encode_rep : forall x, rep (encode x) x; - rep_decode : forall u x, rep u x -> decode u = x; - - add : T -> T -> T; - add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F; - - sub : T -> T -> T; - sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F; - - mul : T -> T -> T; - mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F -}. -*) -Print PseudoMersenneBaseParams. -Section ExtendedBaseVector. - Context (base : list Z) {modulus : Z} `(params : PseudoMersenneBaseParams modulus base). - (* This section defines a new BaseVector that has double the length of the BaseVector - * used to construct [params]. The coefficients of the new vector are as follows: - * - * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i] - * - * The purpose of this construction is that it allows us to multiply numbers expressed - * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as - * vectors of digits; the value of a digit vector is obtained by doing a dot product with - * the base vector.) So if x, y are digit vectors: - * - * (x \dot base) * (y \dot base) = (z \dot ext_base) - * - * Then we can separate z into its first and second halves: - * - * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base) - * - * Now, if we want to reduce the product modulo 2 ^ k - c: - * - * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c) - * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c) - * - * This sum may be short enough to express using base; if not, we can reduce again. - *) - Definition ext_base := base ++ (map (Z.mul (2^k)) base). - - Lemma ext_base_positive : forall b, In b ext_base -> b > 0. - Proof. - unfold ext_base. intros b In_b_base. - rewrite in_app_iff in In_b_base. - destruct In_b_base as [? | In_b_extbase]; auto using BaseSystem.base_positive. - apply in_map_iff in In_b_extbase. - destruct In_b_extbase as [b' [b'_2k_b In_b'_base]]. - subst. - specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos. - replace 0 with (2 ^ k * 0) by ring. - apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition]. - rewrite Z.gt_lt_iff. - apply Z.pow_pos_nonneg; intuition. - pose proof k_nonneg; omega. - Qed. - - Lemma base_length_nonzero : (0 < length base)%nat. - Proof. - assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1). - unfold nth_default in H. - case_eq (nth_error base 0); intros; - try (rewrite H0 in H; omega). - apply (nth_error_value_length _ 0 base z); auto. - Qed. - - Lemma b0_1 : forall x, nth_default x ext_base 0 = 1. - Proof. - intros. unfold ext_base. - rewrite nth_default_app. - assert (0 < length base)%nat by (apply base_length_nonzero). - destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega. - Qed. - - Lemma two_k_nonzero : 2^k <> 0. - Proof. - pose proof (Z.pow_eq_0 2 k k_nonneg). - intuition. - Qed. - - Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> - nth_default 0 (map (Z.mul (2 ^ k)) base) n = - (2 ^ k) * (nth_default 0 base n). - Proof. - intros. - erewrite map_nth_default; auto. - Qed. - - Lemma ext_base_succ : forall i, ((S i) < length ext_base)%nat -> - let b := nth_default 0 ext_base in - b (S i) mod b i = 0. - Proof. - intros; subst b; unfold ext_base. - repeat rewrite nth_default_app. - do 2 break_if; [apply base_succ; auto | omega | | ]. { - destruct (lt_eq_lt_dec (S i) (length base)); try omega. - destruct s; intuition. - rewrite map_nth_default_base_high by omega. - replace i with (pred(length base)) by omega. - rewrite <- Zmult_mod_idemp_l. - rewrite base_tail_matches_modulus. - rewrite Zmod_0_l; auto. - } { - unfold ext_base in *; rewrite app_length, map_length in *. - repeat rewrite map_nth_default_base_high by omega. - rewrite Zmult_mod_distr_l. - rewrite <- minus_Sn_m by omega. - rewrite base_succ by omega; ring. - } - Qed. - - Lemma base_good_over_boundary : forall - (i : nat) - (l : (i < length base)%nat) - (j' : nat) - (Hj': (i + j' < length base)%nat) - , - 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = - 2 ^ k * (nth_default 0 base i * nth_default 0 base j') / - (2 ^ k * nth_default 0 base (i + j')) * - (2 ^ k * nth_default 0 base (i + j')) - . - Proof. - intros. - remember (nth_default 0 base) as b. - rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). - replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) - with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. - rewrite Z.mul_cancel_l by (exact two_k_nonzero). - replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) - with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. - subst b. - apply (BaseSystem.base_good i j'); omega. - Qed. - - Lemma ext_base_good : - forall i j, (i+j < length ext_base)%nat -> - let b := nth_default 0 ext_base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. - Proof. - intros. - subst b. subst r. - unfold ext_base in *. - rewrite app_length in H; rewrite map_length in H. - repeat rewrite nth_default_app. - destruct (lt_dec i (length base)); - destruct (lt_dec j (length base)); - destruct (lt_dec (i + j) (length base)); - try omega. - { (* i < length base, j < length base, i + j < length base *) - apply BaseSystem.base_good; auto. - } { (* i < length base, j < length base, i + j >= length base *) - rewrite (map_nth_default _ _ _ _ 0) by omega. - apply base_matches_modulus; omega. - } { (* i < length base, j >= length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (j - length base)%nat as j'. - replace (i + j - length base)%nat with (i + j')%nat by omega. - replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) - with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } { (* i >= length base, j < length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (i - length base)%nat as i'. - replace (i + j - length base)%nat with (j + i')%nat by omega. - replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) - with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } - Qed. - Instance ExtBaseVector : BaseSystem.BaseVector ext_base := { - base_positive := ext_base_positive; - b0_1 := b0_1; - base_good := ext_base_good - }. -End ExtendedBaseVector. - -Print ExtBaseVector. Section PseudoMersenneBase. - Context `(prm :PseudoMersenneBaseParams). - - Definition T := BaseSystem.digits. - Definition decode (us : T) : F modulus := ZToField (BaseSystem.decode base us). - Local Hint Unfold decode. - Definition rep (us : T) (x : F modulus) := (length us <= length base)%nat /\ decode us = x. + Context `{prm :PseudoMersenneBaseParams}. + Existing Instance bv. + + Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us). + + Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x. Local Notation "u '~=' x" := (rep u x) (at level 70). Local Hint Unfold rep. - Lemma rep_decode : forall us x, us ~= x -> decode us = x. - Proof. - autounfold; intuition. - Qed. - - Definition encode (x : F modulus) := BaseSystem.encode x. - - Lemma encode_rep : forall x : F modulus, encode x ~= x. - Proof. - intros. unfold encode, rep. - split. { - unfold encode; simpl. - apply base_length_nonzero. - assumption. - } { - unfold decode. - rewrite BaseSystem.encode_rep. - apply ZToField_FieldToZ. - assumption. - } - Qed. - - Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F. - Proof. - autounfold; intuition. { - unfold add. - rewrite BaseSystem.add_length_le_max. - case_max; try rewrite Max.max_r; omega. - } - unfold decode in *; unfold BaseSystem.decode in *. - rewrite BaseSystem.add_rep. - rewrite ZToField_add. - subst; auto. - Qed. - - Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F. - Proof. - autounfold; intuition. { - rewrite BaseSystem.sub_length_le_max. - case_max; try rewrite Max.max_r; omega. - } - unfold decode in *; unfold BaseSystem.decode in *. - rewrite BaseSystem.sub_rep. - rewrite ZToField_sub. - subst; auto. - Qed. - - Lemma decode_short : forall (us : T), - (length us <= length base)%nat -> - BaseSystem.decode base us = BaseSystem.decode (ext_base base prm) us. - Proof. - intros. - unfold BaseSystem.decode, BaseSystem.decode'. - rewrite combine_truncate_r. - rewrite (combine_truncate_r us (ext_base base prm)). - f_equal; f_equal. - unfold ext_base. - rewrite firstn_app_inleft; auto; omega. - Qed. - - Lemma extended_base_length: - length (ext_base base prm) = (length base + length base)%nat. - Proof. - unfold ext_base; rewrite app_length; rewrite map_length; auto. - Qed. - - Lemma mul_rep_extended : forall (us vs : T), - (length us <= length base)%nat -> - (length vs <= length base)%nat -> - (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode (ext_base base prm) (BaseSystem.mul (ext_base base prm) us vs). - Proof. - intros. - rewrite BaseSystem.mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega). - f_equal; rewrite decode_short; auto. - Qed. + Definition encode (x : F modulus) := encode x. (* Converts from length of extended base to length of base by reduction modulo M.*) - Definition reduce (us : T) : T := + Definition reduce (us : digits) : digits := let high := skipn (length base) us in let low := firstn (length base) us in let wrap := map (Z.mul c) high in BaseSystem.add low wrap. - Lemma modulus_nonzero : modulus <> 0. - pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. - Qed. - - (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) - Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. - Proof. - intros. - replace (2^k) with ((2^k - c) + c) by ring. - rewrite Z.mul_add_distr_r. - rewrite Zplus_mod. - rewrite <- modulus_pseudomersenne. - rewrite Z.mul_comm. - rewrite mod_mult_plus; auto using modulus_nonzero. - rewrite <- Zplus_mod; auto. - Qed. + Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs). - Lemma extended_shiftadd: forall (us : T), - BaseSystem.decode (ext_base base prm) us = - BaseSystem.decode base (firstn (length base) us) - + (2^k * BaseSystem.decode base (skipn (length base) us)). - Proof. - intros. - unfold BaseSystem.decode; rewrite <- BaseSystem.mul_each_rep. - unfold ext_base. - replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto. - rewrite BaseSystem.base_mul_app. - rewrite <- BaseSystem.mul_each_rep; auto. - Qed. - - Lemma reduce_rep : forall us, - BaseSystem.decode base (reduce us) mod modulus = - BaseSystem.decode (ext_base base prm) us mod modulus. - Proof. - intros. - rewrite extended_shiftadd. - rewrite pseudomersenne_add. - unfold reduce. - remember (firstn (length base) us) as low. - remember (skipn (length base) us) as high. - unfold BaseSystem.decode. - rewrite BaseSystem.add_rep. - replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto. - rewrite BaseSystem.mul_each_rep; auto. - Qed. +End PseudoMersenneBase. - Lemma reduce_length : forall us, - (length us <= length (ext_base base prm))%nat -> - (length (reduce us) <= length (base))%nat. - Proof. - intros. - unfold reduce. - remember (map (Z.mul c) (skipn (length base) us)) as high. - remember (firstn (length base) us) as low. - assert (length low >= length high)%nat. { - subst. rewrite firstn_length. - rewrite map_length. - rewrite skipn_length. - destruct (le_dec (length base) (length us)). { - rewrite Min.min_l by omega. - rewrite extended_base_length in H. omega. - } { - rewrite Min.min_r; omega. - } - } - assert ((length low <= length base)%nat) - by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l). - assert (length high <= length base)%nat - by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length; - rewrite extended_base_length in H; omega). - rewrite BaseSystem.add_trailing_zeros; auto. - rewrite (BaseSystem.add_same_length _ _ (length low)); auto. - rewrite app_length. - rewrite BaseSystem.length_zeros; intuition. - Qed. +Section CarryBasePow2. + Context `{prm :PseudoMersenneBaseParams}. - Definition mul (us vs : T) := reduce (BaseSystem.mul (ext_base base prm) us vs). - Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F. - Proof. - autounfold; unfold mul; intuition. - { - apply reduce_length. - rewrite BaseSystem.mul_length, extended_base_length. - omega. - } { - rewrite ZToField_mod, reduce_rep, <-ZToField_mod. - rewrite BaseSystem.mul_rep by - (apply ExtBaseVector || rewrite extended_base_length; omega). - subst. - do 2 rewrite decode_short by auto. - apply ZToField_mul. - } - Qed. + Definition log_cap i := nth_default 0 limb_widths i. Definition add_to_nth n (x:Z) xs := set_nth n (x + nth_default 0 xs n) xs. - Hint Unfold add_to_nth. - - (* i must be in the domain of base *) - Definition cap i := - if eq_nat_dec i (pred (length base)) - then (2^k) / nth_default 0 base i - else nth_default 0 base (S i) / nth_default 0 base i. + + Definition pow2_mod n i := Z.land n (Z.ones i). Definition carry_simple i := fun us => let di := nth_default 0 us i in - let us' := set_nth i (di mod cap i) us in - add_to_nth (S i) ( (di / cap i)) us'. + let us' := set_nth i (pow2_mod di (log_cap i)) us in + add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'. Definition carry_and_reduce i := fun us => let di := nth_default 0 us i in - let us' := set_nth i (di mod cap i) us in - add_to_nth 0 (c * (di / cap i)) us'. + let us' := set_nth i (pow2_mod di (log_cap i)) us in + add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'. - Definition carry i : T -> T := + Definition carry i : digits -> digits := if eq_nat_dec i (pred (length base)) then carry_and_reduce i else carry_simple i. - (* TODO: move to BaseSystemProofs *) - Lemma decode'_splice : forall xs ys bs, - BaseSystem.decode' bs (xs ++ ys) = - BaseSystem.decode' (firstn (length xs) bs) xs + - BaseSystem.decode' (skipn (length xs) bs) ys. - Proof. - unfold BaseSystem.decode'. - induction xs; destruct ys, bs; boring. - + rewrite combine_truncate_r. - do 2 rewrite Z.add_0_r; auto. - + unfold BaseSystem.accumulate. - apply Z.add_assoc. - Qed. - - Lemma set_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (set_nth n x us) = - (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. - Proof. - intros. - unfold BaseSystem.decode. - nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) - unfold splice_nth. - rewrite <- (firstn_skipn n us) at 4. - do 2 rewrite decode'_splice. - remember (length (firstn n us)) as n0. - ring_simplify. - remember (BaseSystem.decode' (firstn n0 base) (firstn n us)). - rewrite (skipn_nth_default n us 0) by omega. - rewrite firstn_length in Heqn0. - rewrite Min.min_l in Heqn0 by omega; subst n0. - destruct (le_lt_dec (length base) n). { - rewrite nth_default_out_of_bounds by auto. - rewrite skipn_all by omega. - do 2 rewrite BaseSystem.decode_base_nil. - ring_simplify; auto. - } { - rewrite (skipn_nth_default n base 0) by omega. - do 2 rewrite BaseSystem.decode'_cons. - ring_simplify; ring. - } - Qed. - - Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (add_to_nth n x us) = - x * nth_default 0 base n + BaseSystem.decode base us. - Proof. - unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. - Qed. - - Lemma nth_default_base_positive : forall i, (i < length base)%nat -> - nth_default 0 base i > 0. - Proof. - intros. - pose proof (nth_error_length_exists_value _ _ H). - destruct H0. - pose proof (nth_error_value_In _ _ _ H0). - pose proof (BaseSystem.base_positive _ H1). - unfold nth_default. - rewrite H0; auto. - Qed. - - Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat -> - nth_default 0 base (S i) = nth_default 0 base i * - (nth_default 0 base (S i) / nth_default 0 base i). - Proof. - intros. - apply Z_div_exact_2; try (apply nth_default_base_positive; omega). - apply base_succ; auto. - Qed. - - Lemma base_length_lt_pred : (pred (length base) < length base)%nat. - Proof. - pose proof (base_length_nonzero base); omega. - Qed. - Hint Resolve base_length_lt_pred. - - Lemma cap_positive: forall i, (i < length base)%nat -> cap i > 0. - Proof. - unfold cap; intros; break_if. { - apply div_positive_gt_0; try (subst; apply base_tail_matches_modulus). { - rewrite <- two_p_equiv. - apply two_p_gt_ZERO. - apply k_nonneg. - } { - apply nth_default_base_positive; subst; auto. - } - } { - apply div_positive_gt_0; try (apply base_succ; omega); - try (apply nth_default_base_positive; omega). - } - Qed. - - Lemma cap_div_mod : forall us i, (i < (pred (length base)))%nat -> - let di := nth_default 0 us i in - (di - (di mod cap i)) * nth_default 0 base i = - (di / cap i) * nth_default 0 base (S i). - Proof. - intros. - rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; omega); - ring_simplify. - unfold cap; break_if; intuition. - rewrite base_succ_div_mult at 4 by omega; ring. - Qed. - - Lemma carry_simple_decode_eq : forall i us, - (length us = length base) -> - (i < (pred (length base)))%nat -> - BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us. - Proof. - unfold carry_simple. intros. - rewrite add_to_nth_sum by (rewrite length_set_nth; omega). - rewrite set_nth_sum by omega. - rewrite <- cap_div_mod by auto; ring_simplify; auto. - Qed. - - Lemma two_k_div_mul_last : - 2 ^ k = nth_default 0 base (pred (length base)) * - (2 ^ k / nth_default 0 base (pred (length base))). - Proof. - intros. - pose proof base_tail_matches_modulus. - rewrite (Z_div_mod_eq (2 ^ k) (nth_default 0 base (pred (length base)))) at 1 by - (apply nth_default_base_positive; auto); omega. - Qed. - - Lemma cap_div_mod_reduce : forall us, - let i := pred (length base) in - let di := nth_default 0 us i in - (di - (di mod cap i)) * nth_default 0 base i = - (di / cap i) * 2 ^ k. - Proof. - intros. - rewrite (Z_div_mod_eq di (cap i)) at 1 by - (apply cap_positive; auto); ring_simplify. - unfold cap; break_if; intuition. - rewrite Z.mul_comm, Z.mul_assoc. - subst i; rewrite <- two_k_div_mul_last; ring. - Qed. - - Lemma carry_decode_eq_reduce : forall us, - (length us = length base) -> - BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus - = BaseSystem.decode base us mod modulus. - Proof. - unfold carry_and_reduce; intros. - pose proof (base_length_nonzero base). - rewrite add_to_nth_sum by (rewrite length_set_nth; omega). - rewrite set_nth_sum by omega. - rewrite Zplus_comm, <- Z.mul_assoc. - rewrite <- pseudomersenne_add. - rewrite BaseSystem.b0_1. - rewrite (Z.mul_comm (2 ^ k)). - rewrite <- Zred_factor0. - rewrite <- cap_div_mod_reduce by auto. - do 2 rewrite Zmult_minus_distr_r. - f_equal. - ring. - Qed. - - Lemma carry_length : forall i us, - (length us <= length base)%nat -> - (length (carry i us) <= length base)%nat. - Proof. - unfold carry, carry_simple, carry_and_reduce, add_to_nth. - intros; break_if; subst; repeat (rewrite length_set_nth); auto. - Qed. - Hint Resolve carry_length. - - Lemma carry_rep : forall i us x, - (length us = length base) -> - (i < length base)%nat -> - us ~= x -> carry i us ~= x. - Proof. - pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. - unfold rep, decode, carry in *; intros. - intuition; break_if; subst; eauto; - apply F_eq; simpl; intuition. - Qed. - Hint Resolve carry_rep. - Definition carry_sequence is us := fold_right carry us is. - Lemma carry_sequence_length: forall is us, - (length us <= length base)%nat -> - (length (carry_sequence is us) <= length base)%nat. - Proof. - induction is; boring. - Qed. - Hint Resolve carry_sequence_length. - - Lemma carry_length_exact : forall i us, - (length us = length base)%nat -> - (length (carry i us) = length base)%nat. - Proof. - unfold carry, carry_simple, carry_and_reduce, add_to_nth. - intros; break_if; subst; repeat (rewrite length_set_nth); auto. - Qed. - - Lemma carry_sequence_length_exact: forall is us, - (length us = length base)%nat -> - (length (carry_sequence is us) = length base)%nat. - Proof. - induction is; boring. - apply carry_length_exact; auto. - Qed. - Hint Resolve carry_sequence_length_exact. - - Lemma carry_sequence_rep : forall is us x, - (forall i, In i is -> (i < length base)%nat) -> - (length us = length base) -> - us ~= x -> carry_sequence is us ~= x. - Proof. - induction is; boring. - Qed. -End PseudoMersenneBase. +End CarryBasePow2. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v new file mode 100644 index 000000000..524c2da27 --- /dev/null +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -0,0 +1,449 @@ +Require Import Zpower ZArith. +Require Import List. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import VerdiTactics. +Require Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.BaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector. +Local Open Scope Z_scope. + +Section PseudoMersenneProofs. + Context `{prm :PseudoMersenneBaseParams}. + Existing Instance bv. + + Local Hint Unfold decode. + Local Notation "u '~=' x" := (rep u x) (at level 70). + Local Notation "u '.+' x" := (add u x) (at level 70). + Local Notation "u '.*' x" := (ModularBaseSystem.mul u x) (at level 70). + Local Hint Unfold rep. + + Lemma rep_decode : forall us x, us ~= x -> decode us = x. + Proof. + autounfold; intuition. + Qed. + + Lemma encode_rep : forall x : F modulus, encode x ~= x. + Proof. + intros. unfold encode, rep. + split. { + unfold encode; simpl. + apply base_length_nonzero. + } { + unfold decode. + rewrite encode_rep. + apply ZToField_FieldToZ. + apply bv. + } + Qed. + + Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F. + Proof. + autounfold; intuition. { + unfold add. + rewrite add_length_le_max. + case_max; try rewrite Max.max_r; omega. + } + unfold decode in *; unfold decode in *. + rewrite add_rep. + rewrite ZToField_add. + subst; auto. + Qed. + + Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F. + Proof. + autounfold; intuition. { + rewrite sub_length_le_max. + case_max; try rewrite Max.max_r; omega. + } + unfold decode in *; unfold BaseSystem.decode in *. + rewrite sub_rep. + rewrite ZToField_sub. + subst; auto. + Qed. + + Lemma decode_short : forall (us : BaseSystem.digits), + (length us <= length base)%nat -> + BaseSystem.decode base us = BaseSystem.decode ext_base us. + Proof. + intros. + unfold BaseSystem.decode, BaseSystem.decode'. + rewrite combine_truncate_r. + rewrite (combine_truncate_r us ext_base). + f_equal; f_equal. + unfold ext_base. + rewrite firstn_app_inleft; auto; omega. + Qed. + + Lemma mul_rep_extended : forall (us vs : BaseSystem.digits), + (length us <= length base)%nat -> + (length vs <= length base)%nat -> + (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs). + Proof. + intros. + rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega). + f_equal; rewrite decode_short; auto. + Qed. + + Lemma modulus_nonzero : modulus <> 0. + pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. + Qed. + + (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) + Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. + Proof. + intros. + replace (2^k) with ((2^k - c) + c) by ring. + rewrite Z.mul_add_distr_r. + rewrite Zplus_mod. + rewrite <- modulus_pseudomersenne. + rewrite Z.mul_comm. + rewrite mod_mult_plus; auto using modulus_nonzero. + rewrite <- Zplus_mod; auto. + Qed. + + Lemma extended_shiftadd: forall (us : BaseSystem.digits), + BaseSystem.decode ext_base us = + BaseSystem.decode base (firstn (length base) us) + + (2^k * BaseSystem.decode base (skipn (length base) us)). + Proof. + intros. + unfold BaseSystem.decode; rewrite <- mul_each_rep. + unfold ext_base. + replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto. + rewrite base_mul_app. + rewrite <- mul_each_rep; auto. + Qed. + + Lemma reduce_rep : forall us, + BaseSystem.decode base (reduce us) mod modulus = + BaseSystem.decode ext_base us mod modulus. + Proof. + intros. + rewrite extended_shiftadd. + rewrite pseudomersenne_add. + unfold reduce. + remember (firstn (length base) us) as low. + remember (skipn (length base) us) as high. + unfold BaseSystem.decode. + rewrite BaseSystemProofs.add_rep. + replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto. + rewrite mul_each_rep; auto. + Qed. + + Lemma reduce_length : forall us, + (length us <= length ext_base)%nat -> + (length (reduce us) <= length base)%nat. + Proof. + intros. + unfold reduce. + remember (map (Z.mul c) (skipn (length base) us)) as high. + remember (firstn (length base) us) as low. + assert (length low >= length high)%nat. { + subst. rewrite firstn_length. + rewrite map_length. + rewrite skipn_length. + destruct (le_dec (length base) (length us)). { + rewrite Min.min_l by omega. + rewrite extended_base_length in H. omega. + } { + rewrite Min.min_r; omega. + } + } + assert ((length low <= length base)%nat) + by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l). + assert (length high <= length base)%nat + by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length; + rewrite extended_base_length in H; omega). + rewrite add_trailing_zeros; auto. + rewrite (add_same_length _ _ (length low)); auto. + rewrite app_length. + rewrite length_zeros; intuition. + Qed. + + Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F. + Proof. + autounfold; unfold ModularBaseSystem.mul; intuition. + { + apply reduce_length. + rewrite mul_length, extended_base_length. + omega. + } { + rewrite ZToField_mod, reduce_rep, <-ZToField_mod. + rewrite mul_rep by + (apply ExtBaseVector || rewrite extended_base_length; omega). + subst. + do 2 rewrite decode_short by auto. + apply ZToField_mul. + } + Qed. + + (* TODO: move to BaseSystemProofs *) + Lemma decode'_splice : forall xs ys bs, + BaseSystem.decode' bs (xs ++ ys) = + BaseSystem.decode' (firstn (length xs) bs) xs + + BaseSystem.decode' (skipn (length xs) bs) ys. + Proof. + unfold BaseSystem.decode'. + induction xs; destruct ys, bs; boring. + + rewrite combine_truncate_r. + do 2 rewrite Z.add_0_r; auto. + + unfold BaseSystem.accumulate. + apply Z.add_assoc. + Qed. + + Lemma set_nth_sum : forall n x us, (n < length us)%nat -> + BaseSystem.decode base (set_nth n x us) = + (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. + Proof. + intros. + unfold BaseSystem.decode. + nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) + unfold splice_nth. + rewrite <- (firstn_skipn n us) at 4. + do 2 rewrite decode'_splice. + remember (length (firstn n us)) as n0. + ring_simplify. + remember (BaseSystem.decode' (firstn n0 base) (firstn n us)). + rewrite (skipn_nth_default n us 0) by omega. + rewrite firstn_length in Heqn0. + rewrite Min.min_l in Heqn0 by omega; subst n0. + destruct (le_lt_dec (length base) n). { + rewrite nth_default_out_of_bounds by auto. + rewrite skipn_all by omega. + do 2 rewrite decode_base_nil. + ring_simplify; auto. + } { + rewrite (skipn_nth_default n base 0) by omega. + do 2 rewrite decode'_cons. + ring_simplify; ring. + } + Qed. + + Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> + BaseSystem.decode base (add_to_nth n x us) = + x * nth_default 0 base n + BaseSystem.decode base us. + Proof. + unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. + Qed. + + Lemma nth_default_base_positive : forall i, (i < length base)%nat -> + nth_default 0 base i > 0. + Proof. + intros. + pose proof (nth_error_length_exists_value _ _ H). + destruct H0. + pose proof (nth_error_value_In _ _ _ H0). + pose proof (BaseSystem.base_positive _ H1). + unfold nth_default. + rewrite H0; auto. + Qed. + + Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat -> + nth_default 0 base (S i) = nth_default 0 base i * + (nth_default 0 base (S i) / nth_default 0 base i). + Proof. + intros. + apply Z_div_exact_2; try (apply nth_default_base_positive; omega). + apply base_succ; auto. + Qed. + +End PseudoMersenneProofs. + +Section CarryProofs. + Context `{prm : PseudoMersenneBaseParams}. + Existing Instance bv. + Local Notation "u '~=' x" := (rep u x) (at level 70). + Hint Unfold log_cap. + + Lemma base_length_lt_pred : (pred (length base) < length base)%nat. + Proof. + pose proof base_length_nonzero; omega. + Qed. + Hint Resolve base_length_lt_pred. + + Lemma log_cap_nonneg : forall i, 0 <= log_cap i. + Proof. + unfold log_cap, nth_default; intros. + case_eq (nth_error limb_widths i); intros; try omega. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. + + (* TODO : move to ZUtil *) + Lemma div_pow2succ : forall n x, (0 <= x) -> + n / 2 ^ Z.succ x = Z.div2 (n / 2 ^ x). + Proof. + intros. + rewrite Z.pow_succ_r, Z.mul_comm by auto. + rewrite <- Z.div_div by (try apply Z.pow_nonzero; omega). + rewrite Zdiv2_div. + reflexivity. + Qed. + + (* TODO: move to ZUtil *) + Lemma shiftr_succ : forall n x, + Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1. + Proof. + intros. + rewrite Z.shiftr_shiftr by omega. + reflexivity. + Qed. + + (* TODO : move to ZUtil *) + Lemma shiftr_div : forall n i, (0 <= i) -> Z.shiftr n i = n / (2 ^ i). + Proof. + intro. + apply natlike_ind; intros; [boring|]. + rewrite div_pow2succ by auto. + rewrite shiftr_succ. + unfold Z.shiftr. + simpl; f_equal. + auto. + Qed. + + (* TODO : move to ListUtil *) + Lemma nth_error_Some_nth_default : forall {T} i x (l : list T), (i < length l)%nat -> + nth_error l i = Some (nth_default x l i). + Proof. + intros ? ? ? ? i_lt_length. + destruct (nth_error_length_exists_value _ _ i_lt_length) as [k nth_err_k]. + unfold nth_default. + rewrite nth_err_k. + reflexivity. + Qed. + + Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> + nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. + Proof. + intros. + repeat rewrite nth_default_base by omega. + rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg). + destruct (NPeano.Nat.eq_dec i 0). + + subst; f_equal. + unfold sum_firstn, log_cap. + destruct limb_widths; auto. + + erewrite sum_firstn_succ; eauto. + unfold log_cap. + apply nth_error_Some_nth_default. + rewrite <- base_length; omega. + Qed. + + Lemma carry_simple_decode_eq : forall i us, + (length us = length base) -> + (i < (pred (length base)))%nat -> + BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us. + Proof. + unfold carry_simple. intros. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + unfold pow2_mod. + rewrite Z.land_ones by apply log_cap_nonneg. + rewrite shiftr_div by apply log_cap_nonneg. + rewrite nth_default_base_succ by omega. + rewrite Z.mul_assoc. + rewrite (Z.mul_comm _ (2 ^ log_cap i)). + rewrite mul_div_eq; try ring. + apply gt_lt_symmetry. + apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg. + Qed. + + Lemma carry_decode_eq_reduce : forall us, + (length us = length base) -> + BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus + = BaseSystem.decode base us mod modulus. + Proof. + unfold carry_and_reduce; intros ? length_eq. + pose proof base_length_nonzero. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + rewrite Zplus_comm, <- Z.mul_assoc, <- pseudomersenne_add, BaseSystem.b0_1. + rewrite (Z.mul_comm (2 ^ k)), <- Zred_factor0. + f_equal. + rewrite <- (Z.add_comm (BaseSystem.decode base us)), <- Z.add_assoc, <- Z.add_0_r. + f_equal. + destruct (NPeano.Nat.eq_dec (length base) 0) as [length_zero | length_nonzero]. + + apply length0_nil in length_zero. + pose proof (base_length) as limbs_length. + rewrite length_zero in length_eq, limbs_length. + apply length0_nil in length_eq. + symmetry in limbs_length. + apply length0_nil in limbs_length. + unfold log_cap. + subst; rewrite length_zero, limbs_length, nth_default_nil. + reflexivity. + + rewrite nth_default_base by omega. + rewrite <- Z.add_opp_l, <- Z.opp_sub_distr. + unfold pow2_mod. + rewrite Z.land_ones by apply log_cap_nonneg. + rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg). + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + rewrite Zopp_mult_distr_r. + rewrite Z.mul_comm. + rewrite Z.mul_assoc. + rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg). + rewrite <- k_matches_limb_widths. + replace (length limb_widths) with (S (pred (length base))) by + (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega). + rewrite sum_firstn_succ with (x:= log_cap (pred (length base))) by + (unfold log_cap; apply nth_error_Some_nth_default; rewrite <- base_length; omega). + rewrite <- Zopp_mult_distr_r. + rewrite Z.mul_comm. + rewrite (Z.add_comm (log_cap (pred (length base)))). + ring. + Qed. + + Lemma carry_length : forall i us, + (length us <= length base)%nat -> + (length (carry i us) <= length base)%nat. + Proof. + unfold carry, carry_simple, carry_and_reduce, add_to_nth. + intros; break_if; subst; repeat (rewrite length_set_nth); auto. + Qed. + Hint Resolve carry_length. + + Lemma carry_rep : forall i us x, + (length us = length base) -> + (i < length base)%nat -> + us ~= x -> carry i us ~= x. + Proof. + pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. + unfold rep, decode, carry in *; intros. + intuition; break_if; subst; eauto; + apply F_eq; simpl; intuition. + Qed. + Hint Resolve carry_rep. + + Lemma carry_sequence_length: forall is us, + (length us <= length base)%nat -> + (length (carry_sequence is us) <= length base)%nat. + Proof. + induction is; boring. + Qed. + Hint Resolve carry_sequence_length. + + Lemma carry_length_exact : forall i us, + (length us = length base)%nat -> + (length (carry i us) = length base)%nat. + Proof. + unfold carry, carry_simple, carry_and_reduce, add_to_nth. + intros; break_if; subst; repeat (rewrite length_set_nth); auto. + Qed. + + Lemma carry_sequence_length_exact: forall is us, + (length us = length base)%nat -> + (length (carry_sequence is us) = length base)%nat. + Proof. + induction is; boring. + apply carry_length_exact; auto. + Qed. + Hint Resolve carry_sequence_length_exact. + + Lemma carry_sequence_rep : forall is us x, + (forall i, In i is -> (i < length base)%nat) -> + (length us = length base) -> + us ~= x -> carry_sequence is us ~= x. + Proof. + induction is; boring. + Qed. + +End CarryProofs. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v new file mode 100644 index 000000000..847b8e85f --- /dev/null +++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v @@ -0,0 +1,277 @@ +Require Import Zpower ZArith. +Require Import List. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import VerdiTactics. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Section PseudoMersenneBaseParamProofs. + Context `{prm : PseudoMersenneBaseParams}. + + Fixpoint base_from_limb_widths limb_widths := + match limb_widths with + | nil => nil + | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw) + end. + + Definition base := base_from_limb_widths limb_widths. + + Lemma base_length : length base = length limb_widths. + Proof. + unfold base. + induction limb_widths; try reflexivity. + simpl; rewrite map_length; auto. + Qed. + + Lemma nth_error_first : forall {T} (a b : T) l, nth_error (a :: l) 0 = Some b -> + a = b. + Proof. + intros; simpl in *. + unfold value in *. + congruence. + Qed. + + Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat -> + nth_error base i = Some b -> + nth_error limb_widths i = Some w -> + nth_error base (S i) = Some (two_p w * b). + Proof. + unfold base; induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b; + unfold base_from_limb_widths in *; fold base_from_limb_widths in *; + [rewrite (@nil_length0 Z) in *; omega | ]. + simpl in *; rewrite map_length in *. + case_eq i; intros; subst. + + subst; apply nth_error_first in nth_err_w. + apply nth_error_first in nth_err_b; subst. + apply map_nth_error. + case_eq l; intros; subst; [simpl in *; omega | ]. + unfold base_from_limb_widths; fold base_from_limb_widths. + reflexivity. + + simpl in nth_err_w. + apply nth_error_map in nth_err_w. + destruct nth_err_w as [x [A B]]. + subst. + replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring. + apply map_nth_error. + apply IHl; auto; omega. + Qed. + + Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x), + exists l', l = x :: l'. + Proof. + induction l; try discriminate; eexists. + apply nth_error_first in H. + subst; eauto. + Qed. + + Lemma sum_firstn_succ : forall l i x, + nth_error l i = Some x -> + sum_firstn l (S i) = x + sum_firstn l i. + Proof. + unfold sum_firstn; induction l; + [intros; rewrite (@nth_error_nil_error Z) in *; congruence | ]. + intros ? x nth_err_x; destruct (NPeano.Nat.eq_dec i 0). + + subst; simpl in *; unfold value in *. + congruence. + + rewrite <- (NPeano.Nat.succ_pred i) at 2 by auto. + rewrite <- (NPeano.Nat.succ_pred i) in nth_err_x by auto. + simpl. simpl in nth_err_x. + specialize (IHl (pred i) x). + rewrite NPeano.Nat.succ_pred in IHl by auto. + destruct (NPeano.Nat.eq_dec (pred i) 0). + - replace i with 1%nat in * by omega. + simpl. replace (pred 1) with 0%nat in * by auto. + apply nth_error_exists_first in nth_err_x. + destruct nth_err_x as [l' ?]. + subst; simpl; ring. + - rewrite IHl by auto; ring. + Qed. + + (* TODO : move to LsitUtil *) + Lemma fold_right_invariant : forall {A} P (f: A -> A -> A) l x, + P x -> (forall y, In y l -> forall z, P z -> P (f y z)) -> + P (fold_right f x l). + Proof. + induction l; intros ? ? step; auto. + simpl. + apply step; try apply in_eq. + apply IHl; auto. + intros y in_y_l. + apply (in_cons a) in in_y_l. + auto. + Qed. + + (* TODO : move to ListUtil *) + Lemma In_firstn : forall {T} n l (x : T), In x (firstn n l) -> In x l. + Proof. + induction n; destruct l; boring. + Qed. + + Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n. + Proof. + unfold sum_firstn; intros. + apply fold_right_invariant; try omega. + intros y In_y_lw ? ?. + apply Z.add_nonneg_nonneg; try assumption. + apply limb_widths_nonneg. + eapply In_firstn; eauto. + Qed. + + Lemma k_nonneg : 0 <= k. + Proof. + rewrite <- k_matches_limb_widths. + apply sum_firstn_limb_widths_nonneg. + Qed. + + Lemma nth_error_base : forall i, (i < length base)%nat -> + nth_error base i = Some (two_p (sum_firstn limb_widths i)). + Proof. + induction i; intros. + + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity. + intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega. + + + assert (i < length base)%nat as lt_i_length by omega. + specialize (IHi lt_i_length). + rewrite base_length in lt_i_length. + destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w]. + erewrite base_from_limb_widths_step; eauto. + f_equal. + simpl. + destruct (NPeano.Nat.eq_dec i 0). + - subst; unfold sum_firstn; simpl. + apply nth_error_exists_first in nth_err_w. + destruct nth_err_w as [l' lw_destruct]; subst. + rewrite lw_destruct. + ring_simplify. + f_equal; simpl; ring. + - erewrite sum_firstn_succ; eauto. + symmetry. + apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. + + Lemma nth_default_base : forall d i, (i < length base)%nat -> + nth_default d base i = 2 ^ (sum_firstn limb_widths i). + Proof. + intros ? ? i_lt_length. + destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x]. + unfold nth_default. + rewrite nth_err_x. + rewrite nth_error_base in nth_err_x by assumption. + rewrite two_p_correct in nth_err_x. + congruence. + Qed. + + + (* TODO : move to ZUtil *) + Lemma mod_same_pow : forall a b c, 0 <= c <= b -> a ^ b mod a ^ c = 0. + Proof. + intros. + replace b with (b - c + c) by ring. + rewrite Z.pow_add_r by omega. + apply Z_mod_mult. + Qed. + + Lemma base_matches_modulus: forall i j, + (i < length base)%nat -> + (j < length base)%nat -> + (i+j >= length base)%nat-> + let b := nth_default 0 base in + let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in + b i * b j = r * (2^k * b (i+j-length base)%nat). + Proof. + intros. + rewrite (Z.mul_comm r). + subst r. + assert (i + j - length base < length base)%nat by omega. + rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos; + [ | subst b; rewrite nth_default_base; try assumption ]; + apply Z.pow_pos_nonneg; omega || apply k_nonneg || apply sum_firstn_limb_widths_nonneg). + rewrite (Zminus_0_l_reverse (b i * b j)) at 1. + f_equal. + subst b. + repeat rewrite nth_default_base by assumption. + do 2 rewrite <- Z.pow_add_r by (apply sum_firstn_limb_widths_nonneg || apply k_nonneg). + symmetry. + apply mod_same_pow. + split. + + apply Z.add_nonneg_nonneg; apply sum_firstn_limb_widths_nonneg || apply k_nonneg. + + rewrite base_length in *; apply limb_widths_match_modulus; assumption. + Qed. + + Lemma base_succ : forall i, ((S i) < length base)%nat -> + nth_default 0 base (S i) mod nth_default 0 base i = 0. + Proof. + intros. + repeat rewrite nth_default_base by omega. + apply mod_same_pow. + split; [apply sum_firstn_limb_widths_nonneg | ]. + destruct (NPeano.Nat.eq_dec i 0); subst. + + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq. + apply Z.add_nonneg_nonneg; try omega. + apply limb_widths_nonneg. + rewrite lw_eq. + apply in_eq. + + assert (i < length base)%nat as i_lt_length by omega. + rewrite base_length in *. + apply nth_error_length_exists_value in i_lt_length. + destruct i_lt_length as [x nth_err_x]. + erewrite sum_firstn_succ; eauto. + apply nth_error_value_In in nth_err_x. + apply limb_widths_nonneg in nth_err_x. + omega. + Qed. + + Lemma nth_error_subst : forall i b, nth_error base i = Some b -> + b = 2 ^ (sum_firstn limb_widths i). + Proof. + intros i b nth_err_b. + pose proof (nth_error_value_length _ _ _ _ nth_err_b). + rewrite nth_error_base in nth_err_b by assumption. + rewrite two_p_correct in nth_err_b. + congruence. + Qed. + + Lemma base_positive : forall b : Z, In b base -> b > 0. + Proof. + intros b In_b_base. + apply In_nth_error_value in In_b_base. + destruct In_b_base as [i nth_err_b]. + apply nth_error_subst in nth_err_b. + rewrite nth_err_b. + apply gt_lt_symmetry. + apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg. + Qed. + + Lemma b0_1 : forall x : Z, nth_default x base 0 = 1. + Proof. + unfold base; case_eq limb_widths; intros; [pose proof limb_widths_nonnil; congruence | reflexivity]. + Qed. + + Lemma base_good : forall i j : nat, + (i + j < length base)%nat -> + let b := nth_default 0 base in + let r := b i * b j / b (i + j)%nat in + b i * b j = r * b (i + j)%nat. + Proof. + intros; subst b r. + repeat rewrite nth_default_base by omega. + rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))). + rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg). + rewrite <- Z.pow_add_r by apply sum_firstn_limb_widths_nonneg. + rewrite mod_same_pow; try ring. + split; [ apply sum_firstn_limb_widths_nonneg | ]. + apply limb_widths_good. + rewrite <- base_length; assumption. + Qed. + + Instance bv : BaseSystem.BaseVector base := { + base_positive := base_positive; + b0_1 := b0_1; + base_good := base_good + }. + +End PseudoMersenneBaseParamProofs. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v new file mode 100644 index 000000000..122cac0ab --- /dev/null +++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v @@ -0,0 +1,26 @@ +Require Import ZArith. +Require Import List. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Definition sum_firstn l n := fold_right Z.add 0 (firstn n l). + +Class PseudoMersenneBaseParams (modulus : Z) := { + limb_widths : list Z; + limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w; + limb_widths_nonnil : limb_widths <> nil; + limb_widths_good : forall i j, (i + j < length limb_widths)%nat -> + sum_firstn limb_widths (i + j) <= + sum_firstn limb_widths i + sum_firstn limb_widths j; + k : Z; + c : Z; + modulus_pseudomersenne : modulus = 2^k - c; + prime_modulus : Znumtheory.prime modulus; + limb_widths_match_modulus : forall i j, + (i < length limb_widths)%nat -> + (j < length limb_widths)%nat -> + (i + j >= length limb_widths)%nat -> + let w_sum := sum_firstn limb_widths in + k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j; + k_matches_limb_widths : sum_firstn limb_widths (length limb_widths) = k +}. diff --git a/src/ModularArithmetic/PseudoMersenneBaseRep.v b/src/ModularArithmetic/PseudoMersenneBaseRep.v new file mode 100644 index 000000000..2cc12b933 --- /dev/null +++ b/src/ModularArithmetic/PseudoMersenneBaseRep.v @@ -0,0 +1,43 @@ +Require Import ZArith. +Require Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Local Open Scope Z_scope. + +Class RepZMod (modulus : Z) := { + T : Type; + encode : F modulus -> T; + decode : T -> F modulus; + + rep : T -> F modulus -> Prop; + encode_rep : forall x, rep (encode x) x; + rep_decode : forall u x, rep u x -> decode u = x; + + add : T -> T -> T; + add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F; + + sub : T -> T -> T; + sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F; + + mul : T -> T -> T; + mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F +}. + +Instance PseudoMersenneBase m (prm : PseudoMersenneBaseParams m) : RepZMod m := { + T := list Z; + encode := ModularBaseSystem.encode; + decode := ModularBaseSystem.decode; + + rep := ModularBaseSystem.rep; + encode_rep := ModularBaseSystemProofs.encode_rep; + rep_decode := ModularBaseSystemProofs.rep_decode; + + add := BaseSystem.add; + add_rep := ModularBaseSystemProofs.add_rep; + + sub := BaseSystem.sub; + sub_rep := ModularBaseSystemProofs.sub_rep; + + mul := ModularBaseSystem.mul; + mul_rep := ModularBaseSystemProofs.mul_rep +}. |