aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-07-12 11:54:53 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-07-12 11:54:53 -0400
commit58a9cb64c067f931568310b6df0c566c8b603dfd (patch)
tree3098b9d7f30517d4f2cd2e8cadb4981f833c4e9a /src/ModularArithmetic
parentc62b9eaf24020e6fb66cec6c40802c2428c6975d (diff)
pushing through a tweak to the arguments of [sub], and defining a field over ModularBaseSystemInterface using some placeholder operations.
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v6
-rw-r--r--src/ModularArithmetic/ModularBaseSystemBasicInterface.v172
-rw-r--r--src/ModularArithmetic/ModularBaseSystemInterface.v60
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v6
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v47
5 files changed, 261 insertions, 30 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 8ce395289..b6138381e 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -11,7 +11,7 @@ Require Import Crypto.ModularArithmetic.Pow2Base.
Local Open Scope Z_scope.
Section PseudoMersenneBase.
- Context `{prm :PseudoMersenneBaseParams}.
+ Context `{prm :PseudoMersenneBaseParams} (modulus_multiple : digits).
Local Notation base := (base_from_limb_widths limb_widths).
Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us).
@@ -32,8 +32,8 @@ Section PseudoMersenneBase.
Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs).
- Definition sub (xs : digits) (xs_0_mod : (BaseSystem.decode base xs) mod modulus = 0) (us vs : digits) :=
- BaseSystem.sub (add xs us) vs.
+ (* In order to subtract without underflowing, we add a multiple of the modulus first. *)
+ Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs.
End PseudoMersenneBase.
diff --git a/src/ModularArithmetic/ModularBaseSystemBasicInterface.v b/src/ModularArithmetic/ModularBaseSystemBasicInterface.v
new file mode 100644
index 000000000..fc576d372
--- /dev/null
+++ b/src/ModularArithmetic/ModularBaseSystemBasicInterface.v
@@ -0,0 +1,172 @@
+Require Import Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
+Require Import Crypto.BaseSystemProofs.
+Require Import Crypto.Util.Tuple Crypto.Util.CaseUtil.
+Require Import ZArith.
+Require Import Crypto.Algebra.
+Import Group.
+Import PseudoMersenneBaseParams PseudoMersenneBaseParamProofs.
+
+(* Wraps ModularBaseSystem operations in tuples *)
+Generalizable All Variables.
+Section s.
+ Context `{prm:PseudoMersenneBaseParams m}.
+
+ Definition fe := tuple Z (length base).
+
+(* to have them in specific, we need them in interface. To put them in interface, we
+ need to either
+
+ a) define them in interface
+ b) define them in opt
+ c) define them in MBS
+ d) define them in some kind of tuple wrapper for MBS which we then later unfold
+
+abstractions:
+- Fq
+- MBS with list rep +length proofs
+- MBS with tuple rep
+ (do correctness of MBS here)
+- optimized MBS
+*)
+
+ Definition mul (x y:fe) : fe :=
+ carry_mul k_ c (from_list_default 0%Z (length base))
+ (to_list _ x) (to_list _ y).
+
+ Definition add : fe -> fe -> fe.
+ refine (on_tuple2 add_opt _).
+ abstract (intros; rewrite add_opt_correct, add_length_exact; case_max; omega).
+ Defined.
+
+ Definition sub : fe -> fe -> fe.
+ refine (on_tuple2 sub_opt _).
+ abstract (intros; rewrite sub_opt_correct; apply length_sub; rewrite ?coeff_length; auto).
+ Defined.
+
+ Definition carry_simple i := fun us =>
+ let di := nth_default 0 us i in
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'.
+
+ Definition carry_and_reduce i := fun us =>
+ let di := nth_default 0 us i in
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'.
+
+ Definition carry i : digits -> digits :=
+ if eq_nat_dec i (pred (length base))
+ then carry_and_reduce i
+ else carry_simple i.
+
+ Definition carry_sequence is us := fold_right carry us is.
+
+ Fixpoint make_chain i :=
+ match i with
+ | O => nil
+ | S i' => i' :: make_chain i'
+ end.
+
+ Definition full_carry_chain := make_chain (length limb_widths).
+
+ Definition carry_full := carry_sequence full_carry_chain.
+
+ Definition carry_mul us vs := carry_full (mul us vs).
+
+End CarryBasePow2.
+
+Section Canonicalization.
+ Context `{prm :PseudoMersenneBaseParams}.
+
+ (* compute at compile time *)
+ Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths).
+
+ Definition max_bound i := Z.ones (log_cap i).
+
+ Fixpoint isFull' us full i :=
+ match i with
+ | O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full
+ | S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i'
+ end.
+
+ Definition isFull us := isFull' us true (length base - 1)%nat.
+
+ Fixpoint modulus_digits' i :=
+ match i with
+ | O => max_bound i - c + 1 :: nil
+ | S i' => modulus_digits' i' ++ max_bound i :: nil
+ end.
+
+ (* compute at compile time *)
+ Definition modulus_digits := modulus_digits' (length base - 1).
+
+ Definition and_term us := if isFull us then max_ones else 0.
+
+ Definition freeze us :=
+ let us' := carry_full (carry_full (carry_full us)) in
+ let and_term := and_term us' in
+ (* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
+ Otherwise, it's all zeroes, and the subtractions do nothing. *)
+ map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits).
+
+End Canonicalization.
+
+
+ Arguments to_list {_ _} _.
+ Definition phi (a : fe) : ModularArithmetic.F m := decode (to_list a).
+ Definition phi_inv (a : ModularArithmetic.F m) : fe :=
+ from_list_default 0%Z _ (encode a).
+
+ Lemma phi_inv_spec : forall a, phi (phi_inv a) = a.
+ Proof.
+ intros; cbv [phi_inv phi].
+ erewrite from_list_default_eq.
+ rewrite to_list_from_list.
+ apply ModularBaseSystemProofs.encode_rep.
+ Grab Existential Variables.
+ apply ModularBaseSystemProofs.encode_rep.
+ Qed.
+
+ Definition eq (x y : fe) : Prop := phi x = phi y.
+
+ Definition zero : fe := phi_inv (ModularArithmetic.ZToField 0).
+
+ Definition opp (x : fe) : fe := phi_inv (ModularArithmetic.opp (phi x)).
+
+ Lemma add_correct : forall a b,
+ to_list (add a b) = BaseSystem.add (to_list a) (to_list b).
+ Proof.
+ intros; cbv [add on_tuple2].
+ rewrite to_list_from_list.
+ apply add_opt_correct.
+ Qed.
+
+ Lemma add_phi : forall a b : fe,
+ phi (add a b) = ModularArithmetic.add (phi a) (phi b).
+ Proof.
+ intros; cbv [phi].
+ rewrite add_correct.
+ apply ModularBaseSystemProofs.add_rep; auto using decode_rep, length_to_list.
+ Qed.
+
+ Lemma mul_correct : forall a b,
+ to_list (mul a b) = carry_mul (to_list a) (to_list b).
+ Proof.
+ intros; cbv [mul].
+ rewrite carry_mul_opt_cps_correct by assumption.
+ erewrite from_list_default_eq.
+ apply to_list_from_list.
+ Grab Existential Variables.
+ apply carry_mul_length; apply length_to_list.
+ Qed.
+
+ Lemma mul_phi : forall a b : fe,
+ phi (mul a b) = ModularArithmetic.mul (phi a) (phi b).
+ Proof.
+ intros; cbv beta delta [phi].
+ rewrite mul_correct.
+ apply carry_mul_rep; auto using decode_rep, length_to_list.
+ Qed.
+
+End s. \ No newline at end of file
diff --git a/src/ModularArithmetic/ModularBaseSystemInterface.v b/src/ModularArithmetic/ModularBaseSystemInterface.v
index 4a3859077..998e7d959 100644
--- a/src/ModularArithmetic/ModularBaseSystemInterface.v
+++ b/src/ModularArithmetic/ModularBaseSystemInterface.v
@@ -48,10 +48,27 @@ Section s.
Definition eq (x y : fe) : Prop := phi x = phi y.
+ Import Morphisms.
+ Global Instance eq_Equivalence : Equivalence eq.
+ Proof.
+ split; cbv [eq]; repeat intro; congruence.
+ Qed.
+
+ Lemma phi_inv_spec_reverse : forall a, eq (phi_inv (phi a)) a.
+ Proof.
+ intros. unfold eq. rewrite phi_inv_spec; reflexivity.
+ Qed.
+
Definition zero : fe := phi_inv (ModularArithmetic.ZToField 0).
Definition opp (x : fe) : fe := phi_inv (ModularArithmetic.opp (phi x)).
+ Definition one : fe := phi_inv (ModularArithmetic.ZToField 1).
+
+ Definition inv (x : fe) : fe := phi_inv (ModularArithmetic.inv (phi x)).
+
+ Definition div (x y : fe) : fe := phi_inv (ModularArithmetic.div (phi x) (phi y)).
+
Lemma add_correct : forall a b,
to_list (add a b) = BaseSystem.add (to_list a) (to_list b).
Proof.
@@ -68,6 +85,23 @@ Section s.
apply ModularBaseSystemProofs.add_rep; auto using decode_rep, length_to_list.
Qed.
+ Lemma sub_correct : forall a b,
+ to_list (sub a b) = ModularBaseSystem.sub coeff (to_list a) (to_list b).
+ Proof.
+ intros; cbv [sub on_tuple2].
+ rewrite to_list_from_list.
+ apply sub_opt_correct.
+ Qed.
+
+ Lemma sub_phi : forall a b : fe,
+ phi (sub a b) = ModularArithmetic.sub (phi a) (phi b).
+ Proof.
+ intros; cbv [phi].
+ rewrite sub_correct.
+ apply ModularBaseSystemProofs.sub_rep; auto using decode_rep, length_to_list,
+ coeff_length, coeff_mod.
+ Qed.
+
Lemma mul_correct : forall a b,
to_list (mul a b) = carry_mul (to_list a) (to_list b).
Proof.
@@ -87,4 +121,28 @@ Section s.
apply carry_mul_rep; auto using decode_rep, length_to_list.
Qed.
-End s.
+ Lemma modular_base_system_field : @field fe eq zero one opp add sub mul inv div.
+ Proof.
+ eapply (Field.isomorphism_to_subfield_field (phi := phi) (fieldR := PrimeFieldTheorems.field_modulo (prime_q := prime_modulus))).
+ Grab Existential Variables.
+ + intros; apply phi_inv_spec.
+ + intros; apply phi_inv_spec.
+ + intros; apply phi_inv_spec.
+ + intros; apply phi_inv_spec.
+ + intros; apply mul_phi.
+ + intros; apply sub_phi.
+ + intros; apply add_phi.
+ + intros; apply phi_inv_spec.
+ + cbv [eq zero one]. rewrite !phi_inv_spec. intro A.
+ eapply (PrimeFieldTheorems.Fq_1_neq_0 (prime_q := prime_modulus)). congruence.
+ + trivial.
+ + repeat intro. cbv [div]. congruence.
+ + repeat intro. cbv [inv]. congruence.
+ + repeat intro. cbv [eq]. rewrite !mul_phi. congruence.
+ + repeat intro. cbv [eq]. rewrite !sub_phi. congruence.
+ + repeat intro. cbv [eq]. rewrite !add_phi. congruence.
+ + repeat intro. cbv [opp]. congruence.
+ + cbv [eq]. auto using ModularArithmeticTheorems.F_eq_dec.
+ Qed.
+
+End s. \ No newline at end of file
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index f7d33a97b..80e2f58ce 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -15,7 +15,7 @@ Local Open Scope Z.
Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := {
coeff : BaseSystem.digits;
coeff_length : (length coeff = length (Pow2Base.base_from_limb_widths limb_widths))%nat;
- coeff_mod: (BaseSystem.decode (Pow2Base.base_from_limb_widths limb_widths) coeff) mod m = 0
+ coeff_mod: decode coeff = 0%F
}.
(* Computed versions of some functions. *)
@@ -337,7 +337,7 @@ End Addition.
Section Subtraction.
Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.
- Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff coeff_mod us vs }.
+ Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff us vs }.
Proof.
eexists.
cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub].
@@ -348,7 +348,7 @@ Section Subtraction.
:= Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs).
Definition sub_opt_correct us vs
- : sub_opt us vs = sub coeff coeff_mod us vs
+ : sub_opt us vs = sub coeff us vs
:= proj2_sig (sub_opt_sig us vs).
End Subtraction.
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index ba06d4e6c..c59e04ca6 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -124,29 +124,6 @@ Section PseudoMersenneProofs.
subst; auto.
Qed.
- Lemma length_sub : forall c x u v,
- length c = length base
- -> length u = length base
- -> length v = length base
- -> length (ModularBaseSystem.sub c x u v) = length base.
- Proof.
- autounfold; unfold ModularBaseSystem.sub; intuition idtac.
- rewrite sub_length, add_length_exact.
- case_max; try rewrite Max.max_r; omega.
- Qed.
-
- Lemma sub_rep : forall c c_0modq, (length c = length base)%nat ->
- forall u v x y, u ~= x -> v ~= y ->
- ModularBaseSystem.sub c c_0modq u v ~= (x-y)%F.
- Proof.
- split; autounfold in *.
- { apply length_sub; intuition (auto; omega). }
- { unfold decode, ModularBaseSystem.sub, BaseSystem.decode in *; intuition idtac.
- rewrite BaseSystemProofs.sub_rep, BaseSystemProofs.add_rep.
- rewrite ZToField_sub, ZToField_add, ZToField_mod.
- rewrite c_0modq, F_add_0_l. congruence. }
- Qed.
-
Lemma decode_short : forall (us : BaseSystem.digits),
(length us <= length base)%nat ->
BaseSystem.decode base us = BaseSystem.decode ext_base us.
@@ -366,6 +343,30 @@ Section PseudoMersenneProofs.
apply Z.log2_lt_pow2; auto.
Qed.
+ Context mm (mm_length : length mm = length base) (mm_spec : decode mm = 0%F).
+
+ Lemma length_sub : forall u v,
+ length u = length base
+ -> length v = length base
+ -> length (ModularBaseSystem.sub mm u v) = length base.
+ Proof.
+ autounfold; unfold ModularBaseSystem.sub; intuition idtac.
+ rewrite sub_length, add_length_exact.
+ case_max; try rewrite Max.max_r; omega.
+ Qed.
+
+ Lemma sub_rep : forall u v x y, u ~= x -> v ~= y ->
+ ModularBaseSystem.sub mm u v ~= (x-y)%F.
+ Proof.
+ split; autounfold in *.
+ { apply length_sub; intuition (auto; omega). }
+ { unfold decode, ModularBaseSystem.sub, BaseSystem.decode in *; intuition idtac.
+ rewrite BaseSystemProofs.sub_rep, BaseSystemProofs.add_rep.
+ rewrite ZToField_sub, ZToField_add.
+ match goal with H : _ = 0%F |- _ => rewrite H end.
+ rewrite F_add_0_l. congruence. }
+ Qed.
+
End PseudoMersenneProofs.
Section CarryProofs.