aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystem.v
blob: 1769f86c4098c4d4b21a12cbd588e67f3927be6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.BaseSystem.
Require Import Crypto.BaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.AdditionChainExponentiation.
Require Import Crypto.Util.Notations.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z_scope.

Section ModularBaseSystem.
  Context `{prm :PseudoMersenneBaseParams}.
  Local Notation base := (base_from_limb_widths limb_widths).
  Local Notation digits := (tuple Z (length limb_widths)).
  Local Arguments to_list {_ _} _.
  Local Arguments from_list {_ _} _ _.
  Local Arguments length_to_list {_ _ _}.
  Local Notation "[[ u ]]" := (to_list u).

  Definition decode (us : digits) : F modulus := decode [[us]].

  Definition encode (x : F modulus) : digits := from_list (encode x) length_encode.

  Definition add (us vs : digits) : digits := from_list (add [[us]] [[vs]])
    (add_same_length _ _ _ length_to_list length_to_list).

  Definition mul (us vs : digits) : digits := from_list (mul [[us]] [[vs]])
    (length_mul length_to_list length_to_list).

  Definition sub (modulus_multiple: digits)
                 (modulus_multiple_correct : decode modulus_multiple = 0%F)
                 (us vs : digits) : digits :=
   from_list (sub [[modulus_multiple]] [[us]] [[vs]])
   (length_sub length_to_list length_to_list length_to_list).

  Definition zero : digits := encode (F.of_Z _ 0).

  Definition one : digits := encode (F.of_Z _ 1).

  Definition opp (modulus_multiple : digits)
                 (modulus_multiple_correct : decode modulus_multiple = 0%F)
                 (x : digits) :
    digits := sub modulus_multiple modulus_multiple_correct zero x.

  Definition pow (x : digits) (chain : list (nat * nat)) : digits :=
    fold_chain one mul chain (x :: nil).

  Definition inv (chain : list (nat * nat))
                 (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus - 2))
                 (x : digits) : digits := pow x chain.

  (* Placeholder *)
  Definition div (x y : digits) : digits := encode (F.div (decode x) (decode y)).

  Definition carry_mul (carry_chain : list nat) (us vs : digits) : digits :=
    from_list (carry_sequence carry_chain [[mul us vs]]) (length_carry_sequence length_to_list).
  
  Definition rep (us : digits) (x : F modulus) := decode us = x.
  Local Notation "u ~= x" := (rep u x).
  Local Hint Unfold rep.

  Definition eq (x y : digits) : Prop := decode x = decode y.

  Definition freeze (x : digits) : digits :=
    from_list (freeze [[x]]) (length_freeze length_to_list).

  Definition eqb (x y : digits) : bool := fieldwiseb Z.eqb (freeze x) (freeze y).

  (* Note : both of the following square root definitions will produce garbage output if the input is
            not square mod [modulus]. The caller should either provably only call them with square input,
            or test that the output squared is in fact equal to the input and case split. *)
  Definition sqrt_3mod4 (chain : list (nat * nat))
                  (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 4 + 1))
                  (x : digits) : digits := pow x chain.

  (* sqrt_5mod8 is parameterized over implementation of [mul] and [pow] because it relies on bounds-checking
     for these two functions, which is much easier for simplified implementations than the more generalized
     ones defined here. *)
  Definition sqrt_5mod8 mul_ pow_ (chain : list (nat * nat))
                  (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 8 + 1))
                  (sqrt_minus1 x : digits) : digits :=
    let b := pow_ x chain in if eqb (mul_ b b) x then b else mul_ sqrt_minus1 b.

  Import Morphisms.
  Global Instance eq_Equivalence : Equivalence eq.
  Proof.
    split; cbv [eq]; repeat intro; congruence.
  Qed.

  Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x)
          (bits_eq : sum_firstn limb_widths   (length limb_widths) =
                     sum_firstn target_widths (length target_widths)).
  Local Notation target_digits := (tuple Z (length target_widths)).

  Definition pack (x : digits) : target_digits :=
    from_list (pack target_widths_nonneg bits_eq [[x]]) length_pack.
  
  Definition unpack (x : target_digits) : digits :=
    from_list (unpack target_widths_nonneg bits_eq [[x]]) length_unpack.

End ModularBaseSystem.