blob: ca8c19d18cd008d5b124573270de3041053920f8 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z_scope.
Section PseudoMersenneBase.
Context `{prm :PseudoMersenneBaseParams}.
Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us).
Definition rep (us : digits) (x : F modulus) := (length us = length base)%nat /\ decode us = x.
Local Notation "u '~=' x" := (rep u x) (at level 70).
Local Hint Unfold rep.
Definition encode (x : F modulus) := encode x ++ BaseSystem.zeros (length base - 1)%nat.
(* Converts from length of extended base to length of base by reduction modulo M.*)
Definition reduce (us : digits) : digits :=
let high := skipn (length base) us in
let low := firstn (length base) us in
let wrap := map (Z.mul c) high in
BaseSystem.add low wrap.
Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs).
Definition sub (xs : digits) (xs_0_mod : (BaseSystem.decode base xs) mod modulus = 0) (us vs : digits) :=
BaseSystem.sub (add xs us) vs.
End PseudoMersenneBase.
Section CarryBasePow2.
Context `{prm :PseudoMersenneBaseParams}.
Definition log_cap i := nth_default 0 limb_widths i.
Definition add_to_nth n (x:Z) xs :=
set_nth n (x + nth_default 0 xs n) xs.
Definition pow2_mod n i := Z.land n (Z.ones i).
Definition carry_simple i := fun us =>
let di := nth_default 0 us i in
let us' := set_nth i (pow2_mod di (log_cap i)) us in
add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'.
Definition carry_and_reduce i := fun us =>
let di := nth_default 0 us i in
let us' := set_nth i (pow2_mod di (log_cap i)) us in
add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'.
Definition carry i : digits -> digits :=
if eq_nat_dec i (pred (length base))
then carry_and_reduce i
else carry_simple i.
Definition carry_sequence is us := fold_right carry us is.
Fixpoint make_chain i :=
match i with
| O => nil
| S i' => i' :: make_chain i'
end.
Definition full_carry_chain := make_chain (length limb_widths).
Definition carry_full := carry_sequence full_carry_chain.
Definition carry_mul us vs := carry_full (mul us vs).
End CarryBasePow2.
Section Canonicalization.
Context `{prm :PseudoMersenneBaseParams}.
(* compute at compile time *)
Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths).
Definition max_bound i := Z.ones (log_cap i).
Fixpoint isFull' us full i :=
match i with
| O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full
| S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i'
end.
Definition isFull us := isFull' us true (length base - 1)%nat.
Fixpoint modulus_digits' i :=
match i with
| O => max_bound i - c + 1 :: nil
| S i' => modulus_digits' i' ++ max_bound i :: nil
end.
(* compute at compile time *)
Definition modulus_digits := modulus_digits' (length base - 1).
Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C :=
match la with
| nil => nil
| a :: la' => match lb with
| nil => nil
| b :: lb' => f a b :: map2 f la' lb'
end
end.
Definition and_term us := if isFull us then max_ones else 0.
Definition freeze us :=
let us' := carry_full (carry_full (carry_full us)) in
let and_term := and_term us' in
(* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
Otherwise, it's all zeroes, and the subtractions do nothing. *)
map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits).
End Canonicalization.
|