diff options
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystem.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 84 |
1 files changed, 44 insertions, 40 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 558b9a5a2..ca8c19d18 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -11,14 +11,14 @@ Local Open Scope Z_scope. Section PseudoMersenneBase. Context `{prm :PseudoMersenneBaseParams}. - + Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us). - - Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x. + + Definition rep (us : digits) (x : F modulus) := (length us = length base)%nat /\ decode us = x. Local Notation "u '~=' x" := (rep u x) (at level 70). Local Hint Unfold rep. - Definition encode (x : F modulus) := encode x. + Definition encode (x : F modulus) := encode x ++ BaseSystem.zeros (length base - 1)%nat. (* Converts from length of extended base to length of base by reduction modulo M.*) Definition reduce (us : digits) : digits := @@ -35,13 +35,13 @@ Section PseudoMersenneBase. End PseudoMersenneBase. Section CarryBasePow2. - Context `{prm :PseudoMersenneBaseParams}. + Context `{prm :PseudoMersenneBaseParams}. Definition log_cap i := nth_default 0 limb_widths i. Definition add_to_nth n (x:Z) xs := set_nth n (x + nth_default 0 xs n) xs. - + Definition pow2_mod n i := Z.land n (Z.ones i). Definition carry_simple i := fun us => @@ -54,64 +54,68 @@ Section CarryBasePow2. let us' := set_nth i (pow2_mod di (log_cap i)) us in add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'. - Definition carry i : digits -> digits := + Definition carry i : digits -> digits := if eq_nat_dec i (pred (length base)) then carry_and_reduce i else carry_simple i. Definition carry_sequence is us := fold_right carry us is. -End CarryBasePow2. - -Section Canonicalization. - Context `{prm :PseudoMersenneBaseParams}. - Fixpoint make_chain i := match i with | O => nil | S i' => i' :: make_chain i' end. - (* compute at compile time *) Definition full_carry_chain := make_chain (length limb_widths). - (* compute at compile time *) - Definition max_ones := Z.ones - ((fix loop current_max lw := - match lw with - | nil => current_max - | w :: lw' => loop (Z.max w current_max) lw' - end - ) 0 limb_widths). - - (* compute at compile time? *) Definition carry_full := carry_sequence full_carry_chain. + Definition carry_mul us vs := carry_full (mul us vs). + +End CarryBasePow2. + +Section Canonicalization. + Context `{prm :PseudoMersenneBaseParams}. + + (* compute at compile time *) + Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths). + Definition max_bound i := Z.ones (log_cap i). - Definition isFull us := - (fix loop full i := - match i with - | O => full (* don't test 0; the test for 0 is the initial value of [full]. *) - | S i' => loop (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' - end - ) (Z.ltb (max_bound 0 - (c + 1)) (nth_default 0 us 0)) (length us - 1)%nat. + Fixpoint isFull' us full i := + match i with + | O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full + | S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' + end. + + Definition isFull us := isFull' us true (length base - 1)%nat. - Fixpoint range' n m := - match m with - | O => nil - | S m' => (n - m)%nat :: range' n m' + Fixpoint modulus_digits' i := + match i with + | O => max_bound i - c + 1 :: nil + | S i' => modulus_digits' i' ++ max_bound i :: nil end. - Definition range n := range' n n. + (* compute at compile time *) + Definition modulus_digits := modulus_digits' (length base - 1). + + Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C := + match la with + | nil => nil + | a :: la' => match lb with + | nil => nil + | b :: lb' => f a b :: map2 f la' lb' + end + end. + + Definition and_term us := if isFull us then max_ones else 0. - Definition land_max_bound and_term i := Z.land and_term (max_bound i). - Definition freeze us := let us' := carry_full (carry_full (carry_full us)) in - let and_term := if isFull us' then max_ones else 0 in + let and_term := and_term us' in (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. Otherwise, it's all zeroes, and the subtractions do nothing. *) - map (fun x => (snd x) - land_max_bound and_term (fst x)) (combine (range (length us')) us'). - + map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits). + End Canonicalization. |