aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystem.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystem.v')
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v126
1 files changed, 51 insertions, 75 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 8c850c941..70c8138da 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -1,103 +1,79 @@
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
-Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.BaseSystem.
-Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Import Crypto.BaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
-Require Import Crypto.Tactics.VerdiTactics.
-Require Import Crypto.Util.Notations.
Require Import Crypto.ModularArithmetic.Pow2Base.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.Notations.
+Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z_scope.
-Section PseudoMersenneBase.
- Context `{prm :PseudoMersenneBaseParams} (modulus_multiple : digits).
+Section ModularBaseSystem.
+ Context `{prm :PseudoMersenneBaseParams}.
Local Notation base := (base_from_limb_widths limb_widths).
+ Local Notation digits := (tuple Z (length limb_widths)).
+ Local Arguments to_list {_ _} _.
+ Local Arguments from_list {_ _} _ _.
+ Local Arguments length_to_list {_ _ _}.
+ Local Notation "[[ u ]]" := (to_list u).
- Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us).
-
- Definition rep (us : digits) (x : F modulus) := (length us = length base)%nat /\ decode us = x.
- Local Notation "u ~= x" := (rep u x).
- Local Hint Unfold rep.
-
- (* max must be greater than input; this is used to truncate last digit *)
- Definition encode (x : F modulus) := encodeZ limb_widths x.
-
- (* Converts from length of extended base to length of base by reduction modulo M.*)
- Definition reduce (us : digits) : digits :=
- let high := skipn (length base) us in
- let low := firstn (length base) us in
- let wrap := map (Z.mul c) high in
- BaseSystem.add low wrap.
-
- Definition mul (us vs : digits) := reduce (BaseSystem.mul (ext_base limb_widths) us vs).
+ Definition decode (us : digits) : F modulus := decode [[us]].
- (* In order to subtract without underflowing, we add a multiple of the modulus first. *)
- Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs.
+ Definition encode (x : F modulus) : digits := from_list (encode x) length_encode.
-End PseudoMersenneBase.
-
-Section CarryBasePow2.
- Context `{prm :PseudoMersenneBaseParams}.
- Local Notation base := (base_from_limb_widths limb_widths).
- Local Notation log_cap i := (nth_default 0 limb_widths i).
+ Definition add (us vs : digits) : digits := from_list (add [[us]] [[vs]])
+ (add_same_length _ _ _ length_to_list length_to_list).
- (*
- Definition carry_and_reduce :=
- carry_gen limb_widths (fun ci => c * ci).
- *)
- Definition carry_and_reduce i := fun us =>
- let di := nth_default 0 us i in
- let us' := set_nth i (Z.pow2_mod di (log_cap i)) us in
- add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'.
+ Definition mul (us vs : digits) : digits := from_list (mul [[us]] [[vs]])
+ (length_mul length_to_list length_to_list).
- Definition carry i : digits -> digits :=
- if eq_nat_dec i (pred (length base))
- then carry_and_reduce i
- else carry_simple limb_widths i.
+ Definition sub (modulus_multiple us vs : digits) : digits :=
+ from_list (sub [[modulus_multiple]] [[us]] [[vs]])
+ (length_sub length_to_list length_to_list length_to_list).
- Definition carry_sequence is us := fold_right carry us is.
+ Definition zero : digits := encode (ZToField 0).
- Definition carry_full := carry_sequence (full_carry_chain limb_widths).
+ Definition one : digits := encode (ZToField 1).
- Definition carry_mul us vs := carry_full (mul us vs).
+ (* Placeholder *)
+ Definition opp (x : digits) : digits := encode (ModularArithmetic.opp (decode x)).
-End CarryBasePow2.
+ (* Placeholder *)
+ Definition inv (x : digits) : digits := encode (ModularArithmetic.inv (decode x)).
-Section Canonicalization.
- Context `{prm :PseudoMersenneBaseParams}.
- Local Notation base := (base_from_limb_widths limb_widths).
- Local Notation log_cap i := (nth_default 0 limb_widths i).
+ (* Placeholder *)
+ Definition div (x y : digits) : digits := encode (ModularArithmetic.div (decode x) (decode y)).
- (* compute at compile time *)
- Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths).
+ Definition carry i (us : digits) : digits := from_list (carry i [[us]])
+ (length_carry length_to_list).
- Definition max_bound i := Z.ones (log_cap i).
+ Definition rep (us : digits) (x : F modulus) := decode us = x.
+ Local Notation "u ~= x" := (rep u x).
+ Local Hint Unfold rep.
- Fixpoint isFull' us full i :=
- match i with
- | O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full
- | S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i'
- end.
+ Definition carry_sequence is (us : digits) : digits := fold_right carry us is.
- Definition isFull us := isFull' us true (length base - 1)%nat.
+ Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths).
- Fixpoint modulus_digits' i :=
- match i with
- | O => max_bound i - c + 1 :: nil
- | S i' => modulus_digits' i' ++ max_bound i :: nil
- end.
+ Definition carry_mul (us vs : digits) : digits := carry_full (mul us vs).
- (* compute at compile time *)
- Definition modulus_digits := modulus_digits' (length base - 1).
+ Definition freeze (us : digits) : digits :=
+ let us' := carry_full (carry_full (carry_full us)) in
+ from_list (conditional_subtract_modulus [[us']] (ge_modulus [[us']]))
+ (length_conditional_subtract_modulus length_to_list).
- Definition and_term us := if isFull us then max_ones else 0.
+ Definition eq (x y : digits) : Prop := decode x = decode y.
- Definition freeze us :=
- let us' := carry_full (carry_full (carry_full us)) in
- let and_term := and_term us' in
- (* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
- Otherwise, it's all zeroes, and the subtractions do nothing. *)
- map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits).
+ Import Morphisms.
+ Global Instance eq_Equivalence : Equivalence eq.
+ Proof.
+ split; cbv [eq]; repeat intro; congruence.
+ Qed.
-End Canonicalization.
+End ModularBaseSystem. \ No newline at end of file