aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemList.v
blob: 8cce5481ce9739f675a27ff031101b72514d928f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations.
Require Import Crypto.Tactics.VerdiTactics.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Conversion.
Local Open Scope Z_scope.

Section Defs.
  Context `{prm :PseudoMersenneBaseParams} (modulus_multiple : digits).
  Local Notation base := (base_from_limb_widths limb_widths).
  Local Notation "u [ i ]" := (nth_default 0 u i).

  Definition decode (us : digits) := F.of_Z modulus (BaseSystem.decode base us).

  Definition encode (x : F modulus) := encodeZ limb_widths (F.to_Z x).

  (* Converts from length of extended base to length of base by reduction modulo M.*)
  Definition reduce (us : digits) : digits :=
    let high := skipn (length limb_widths) us in
    let low := firstn (length limb_widths) us in
    let wrap := map (Z.mul c) high in
    BaseSystem.add low wrap.

  Definition mul (us vs : digits) := reduce (BaseSystem.mul (ext_base limb_widths) us vs).

  (* In order to subtract without underflowing, we add a multiple of the modulus first. *)
  Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs.

  (* [carry_and_reduce] multiplies the carried value by c, and, if carrying
     from index [i] in a list [us], adds the value to the digit with index
     [(S i) mod (length us)] *)
  Definition carry_and_reduce :=
    carry_gen limb_widths (fun ci => c * ci) (fun Si => (Si mod (length limb_widths))%nat).

  Definition carry i : digits -> digits :=
    if eq_nat_dec i (pred (length limb_widths))
    then carry_and_reduce i
    else carry_simple limb_widths i.

  Definition carry_sequence is (us : digits) : digits := fold_right carry us is.

  Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths).

  Definition modulus_digits := encodeZ limb_widths modulus.

  (* Constant-time comparison with modulus; only works if all digits of [us]
    are less than 2 ^ their respective limb width. *)
  Fixpoint ge_modulus' {A} (f : Z -> A) us (result : Z) i :=
    dlet r := result in
    match i return A with
    | O =>
      dlet x := (cmovl (modulus_digits [0]) (us [0]) r 0) in f x
    | S i' =>
      ge_modulus' f us (cmovne (modulus_digits [i]) (us [i]) r 0) i'
    end.

  Definition ge_modulus us := ge_modulus' id us 1 (length limb_widths - 1)%nat.

  Definition conditional_subtract_modulus int_width (us : digits) (cond : Z) :=
    (* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
       Otherwise, it's all zeroes, and the subtractions do nothing. *)
     map2 (fun x y => x - y) us (map (Z.land (neg int_width cond)) modulus_digits).

  Definition freeze int_width (us : digits) : digits :=
    let us' := carry_full (carry_full (carry_full us)) in
     conditional_subtract_modulus int_width us' (ge_modulus us').

  Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x)
          (bits_eq : sum_firstn limb_widths   (length limb_widths) =
                     sum_firstn target_widths (length target_widths)).

  Definition pack := @convert limb_widths limb_widths_nonneg
                              target_widths target_widths_nonneg
                              (Z.eq_le_incl _ _ bits_eq).

  Definition unpack := @convert target_widths target_widths_nonneg
                                limb_widths limb_widths_nonneg
                                (Z.eq_le_incl _ _ (Z.eq_sym bits_eq)).

End Defs.