blob: e64ed5d0f278665660ea3097bf436efeac90a0b0 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.Tactics.VerdiTactics.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Conversion.
Local Open Scope Z_scope.
Section Defs.
Context `{prm :PseudoMersenneBaseParams} (modulus_multiple : digits).
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation "u [ i ]" := (nth_default 0 u i).
Definition decode (us : digits) := F.of_Z modulus (BaseSystem.decode base us).
Definition encode (x : F modulus) := encodeZ limb_widths (F.to_Z x).
(* Converts from length of extended base to length of base by reduction modulo M.*)
Definition reduce (us : digits) : digits :=
let high := skipn (length limb_widths) us in
let low := firstn (length limb_widths) us in
let wrap := map (Z.mul c) high in
BaseSystem.add low wrap.
Definition mul (us vs : digits) := reduce (BaseSystem.mul (ext_base limb_widths) us vs).
(* In order to subtract without underflowing, we add a multiple of the modulus first. *)
Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs.
(* [carry_and_reduce] multiplies the carried value by c, and, if carrying
from index [i] in a list [us], adds the value to the digit with index
[(S i) mod (length us)] *)
Definition carry_and_reduce :=
carry_gen limb_widths (fun ci => c * ci) (fun Si => (Si mod (length limb_widths))%nat).
Definition carry i : digits -> digits :=
if eq_nat_dec i (pred (length limb_widths))
then carry_and_reduce i
else carry_simple limb_widths i.
Definition carry_sequence is (us : digits) : digits := fold_right carry us is.
Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths).
Definition modulus_digits := encodeZ limb_widths modulus.
(* Constant-time comparison with modulus; only works if all digits of [us]
are less than 2 ^ their respective limb width. *)
Fixpoint ge_modulus' {A} (f : Z -> A) us (result : Z) i :=
dlet r := result in
match i return A with
| O => dlet x := if Z.leb (modulus_digits [0]) (us [0])
then r
else 0 in f x
| S i' => ge_modulus' f us
(if Z.eqb (modulus_digits [i]) (us [i])
then r
else 0) i'
end.
Definition ge_modulus us := ge_modulus' id us 1 (length limb_widths - 1)%nat.
(* analagous to NEG assembly instruction on an integer that is 0 or 1:
neg 1 = 2^64 - 1 (on 64-bit; 2^32-1 on 32-bit, etc.)
neg 0 = 0 *)
Definition neg (int_width : Z) (b : Z) := if b =? 1 then Z.ones int_width else 0.
Definition conditional_subtract_modulus int_width (us : digits) (cond : Z) :=
(* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
Otherwise, it's all zeroes, and the subtractions do nothing. *)
map2 (fun x y => x - y) us (map (Z.land (neg int_width cond)) modulus_digits).
Definition freeze int_width (us : digits) : digits :=
let us' := carry_full (carry_full (carry_full us)) in
conditional_subtract_modulus int_width us' (ge_modulus us').
Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x)
(bits_eq : sum_firstn limb_widths (length limb_widths) =
sum_firstn target_widths (length target_widths)).
Definition pack := @convert limb_widths limb_widths_nonneg
target_widths target_widths_nonneg
(Z.eq_le_incl _ _ bits_eq).
Definition unpack := @convert target_widths target_widths_nonneg
limb_widths limb_widths_nonneg
(Z.eq_le_incl _ _ (Z.eq_sym bits_eq)).
End Defs.
|