diff options
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystem.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 126 |
1 files changed, 51 insertions, 75 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 8c850c941..70c8138da 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -1,103 +1,79 @@ Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.BaseSystem. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.Notations. Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.ModularBaseSystemList. +Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Notations. +Require Import Crypto.Tactics.VerdiTactics. Local Open Scope Z_scope. -Section PseudoMersenneBase. - Context `{prm :PseudoMersenneBaseParams} (modulus_multiple : digits). +Section ModularBaseSystem. + Context `{prm :PseudoMersenneBaseParams}. Local Notation base := (base_from_limb_widths limb_widths). + Local Notation digits := (tuple Z (length limb_widths)). + Local Arguments to_list {_ _} _. + Local Arguments from_list {_ _} _ _. + Local Arguments length_to_list {_ _ _}. + Local Notation "[[ u ]]" := (to_list u). - Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us). - - Definition rep (us : digits) (x : F modulus) := (length us = length base)%nat /\ decode us = x. - Local Notation "u ~= x" := (rep u x). - Local Hint Unfold rep. - - (* max must be greater than input; this is used to truncate last digit *) - Definition encode (x : F modulus) := encodeZ limb_widths x. - - (* Converts from length of extended base to length of base by reduction modulo M.*) - Definition reduce (us : digits) : digits := - let high := skipn (length base) us in - let low := firstn (length base) us in - let wrap := map (Z.mul c) high in - BaseSystem.add low wrap. - - Definition mul (us vs : digits) := reduce (BaseSystem.mul (ext_base limb_widths) us vs). + Definition decode (us : digits) : F modulus := decode [[us]]. - (* In order to subtract without underflowing, we add a multiple of the modulus first. *) - Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs. + Definition encode (x : F modulus) : digits := from_list (encode x) length_encode. -End PseudoMersenneBase. - -Section CarryBasePow2. - Context `{prm :PseudoMersenneBaseParams}. - Local Notation base := (base_from_limb_widths limb_widths). - Local Notation log_cap i := (nth_default 0 limb_widths i). + Definition add (us vs : digits) : digits := from_list (add [[us]] [[vs]]) + (add_same_length _ _ _ length_to_list length_to_list). - (* - Definition carry_and_reduce := - carry_gen limb_widths (fun ci => c * ci). - *) - Definition carry_and_reduce i := fun us => - let di := nth_default 0 us i in - let us' := set_nth i (Z.pow2_mod di (log_cap i)) us in - add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'. + Definition mul (us vs : digits) : digits := from_list (mul [[us]] [[vs]]) + (length_mul length_to_list length_to_list). - Definition carry i : digits -> digits := - if eq_nat_dec i (pred (length base)) - then carry_and_reduce i - else carry_simple limb_widths i. + Definition sub (modulus_multiple us vs : digits) : digits := + from_list (sub [[modulus_multiple]] [[us]] [[vs]]) + (length_sub length_to_list length_to_list length_to_list). - Definition carry_sequence is us := fold_right carry us is. + Definition zero : digits := encode (ZToField 0). - Definition carry_full := carry_sequence (full_carry_chain limb_widths). + Definition one : digits := encode (ZToField 1). - Definition carry_mul us vs := carry_full (mul us vs). + (* Placeholder *) + Definition opp (x : digits) : digits := encode (ModularArithmetic.opp (decode x)). -End CarryBasePow2. + (* Placeholder *) + Definition inv (x : digits) : digits := encode (ModularArithmetic.inv (decode x)). -Section Canonicalization. - Context `{prm :PseudoMersenneBaseParams}. - Local Notation base := (base_from_limb_widths limb_widths). - Local Notation log_cap i := (nth_default 0 limb_widths i). + (* Placeholder *) + Definition div (x y : digits) : digits := encode (ModularArithmetic.div (decode x) (decode y)). - (* compute at compile time *) - Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths). + Definition carry i (us : digits) : digits := from_list (carry i [[us]]) + (length_carry length_to_list). - Definition max_bound i := Z.ones (log_cap i). + Definition rep (us : digits) (x : F modulus) := decode us = x. + Local Notation "u ~= x" := (rep u x). + Local Hint Unfold rep. - Fixpoint isFull' us full i := - match i with - | O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full - | S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' - end. + Definition carry_sequence is (us : digits) : digits := fold_right carry us is. - Definition isFull us := isFull' us true (length base - 1)%nat. + Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths). - Fixpoint modulus_digits' i := - match i with - | O => max_bound i - c + 1 :: nil - | S i' => modulus_digits' i' ++ max_bound i :: nil - end. + Definition carry_mul (us vs : digits) : digits := carry_full (mul us vs). - (* compute at compile time *) - Definition modulus_digits := modulus_digits' (length base - 1). + Definition freeze (us : digits) : digits := + let us' := carry_full (carry_full (carry_full us)) in + from_list (conditional_subtract_modulus [[us']] (ge_modulus [[us']])) + (length_conditional_subtract_modulus length_to_list). - Definition and_term us := if isFull us then max_ones else 0. + Definition eq (x y : digits) : Prop := decode x = decode y. - Definition freeze us := - let us' := carry_full (carry_full (carry_full us)) in - let and_term := and_term us' in - (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. - Otherwise, it's all zeroes, and the subtractions do nothing. *) - map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits). + Import Morphisms. + Global Instance eq_Equivalence : Equivalence eq. + Proof. + split; cbv [eq]; repeat intro; congruence. + Qed. -End Canonicalization. +End ModularBaseSystem.
\ No newline at end of file |