aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystem.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystem.v')
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v84
1 files changed, 44 insertions, 40 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 558b9a5a2..ca8c19d18 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -11,14 +11,14 @@ Local Open Scope Z_scope.
Section PseudoMersenneBase.
Context `{prm :PseudoMersenneBaseParams}.
-
+
Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us).
-
- Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x.
+
+ Definition rep (us : digits) (x : F modulus) := (length us = length base)%nat /\ decode us = x.
Local Notation "u '~=' x" := (rep u x) (at level 70).
Local Hint Unfold rep.
- Definition encode (x : F modulus) := encode x.
+ Definition encode (x : F modulus) := encode x ++ BaseSystem.zeros (length base - 1)%nat.
(* Converts from length of extended base to length of base by reduction modulo M.*)
Definition reduce (us : digits) : digits :=
@@ -35,13 +35,13 @@ Section PseudoMersenneBase.
End PseudoMersenneBase.
Section CarryBasePow2.
- Context `{prm :PseudoMersenneBaseParams}.
+ Context `{prm :PseudoMersenneBaseParams}.
Definition log_cap i := nth_default 0 limb_widths i.
Definition add_to_nth n (x:Z) xs :=
set_nth n (x + nth_default 0 xs n) xs.
-
+
Definition pow2_mod n i := Z.land n (Z.ones i).
Definition carry_simple i := fun us =>
@@ -54,64 +54,68 @@ Section CarryBasePow2.
let us' := set_nth i (pow2_mod di (log_cap i)) us in
add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'.
- Definition carry i : digits -> digits :=
+ Definition carry i : digits -> digits :=
if eq_nat_dec i (pred (length base))
then carry_and_reduce i
else carry_simple i.
Definition carry_sequence is us := fold_right carry us is.
-End CarryBasePow2.
-
-Section Canonicalization.
- Context `{prm :PseudoMersenneBaseParams}.
-
Fixpoint make_chain i :=
match i with
| O => nil
| S i' => i' :: make_chain i'
end.
- (* compute at compile time *)
Definition full_carry_chain := make_chain (length limb_widths).
- (* compute at compile time *)
- Definition max_ones := Z.ones
- ((fix loop current_max lw :=
- match lw with
- | nil => current_max
- | w :: lw' => loop (Z.max w current_max) lw'
- end
- ) 0 limb_widths).
-
- (* compute at compile time? *)
Definition carry_full := carry_sequence full_carry_chain.
+ Definition carry_mul us vs := carry_full (mul us vs).
+
+End CarryBasePow2.
+
+Section Canonicalization.
+ Context `{prm :PseudoMersenneBaseParams}.
+
+ (* compute at compile time *)
+ Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths).
+
Definition max_bound i := Z.ones (log_cap i).
- Definition isFull us :=
- (fix loop full i :=
- match i with
- | O => full (* don't test 0; the test for 0 is the initial value of [full]. *)
- | S i' => loop (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i'
- end
- ) (Z.ltb (max_bound 0 - (c + 1)) (nth_default 0 us 0)) (length us - 1)%nat.
+ Fixpoint isFull' us full i :=
+ match i with
+ | O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full
+ | S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i'
+ end.
+
+ Definition isFull us := isFull' us true (length base - 1)%nat.
- Fixpoint range' n m :=
- match m with
- | O => nil
- | S m' => (n - m)%nat :: range' n m'
+ Fixpoint modulus_digits' i :=
+ match i with
+ | O => max_bound i - c + 1 :: nil
+ | S i' => modulus_digits' i' ++ max_bound i :: nil
end.
- Definition range n := range' n n.
+ (* compute at compile time *)
+ Definition modulus_digits := modulus_digits' (length base - 1).
+
+ Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C :=
+ match la with
+ | nil => nil
+ | a :: la' => match lb with
+ | nil => nil
+ | b :: lb' => f a b :: map2 f la' lb'
+ end
+ end.
+
+ Definition and_term us := if isFull us then max_ones else 0.
- Definition land_max_bound and_term i := Z.land and_term (max_bound i).
-
Definition freeze us :=
let us' := carry_full (carry_full (carry_full us)) in
- let and_term := if isFull us' then max_ones else 0 in
+ let and_term := and_term us' in
(* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
Otherwise, it's all zeroes, and the subtractions do nothing. *)
- map (fun x => (snd x) - land_max_bound and_term (fst x)) (combine (range (length us')) us').
-
+ map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits).
+
End Canonicalization.