aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystem.v
blob: 2f264fa6cb62e7b858c41438444d87ab3226fb90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.BaseSystem.
Require Import Crypto.BaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.AdditionChainExponentiation.
Require Import Crypto.Util.Notations.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z_scope.

Section ModularBaseSystem.
  Context `{prm :PseudoMersenneBaseParams}.
  Local Notation base := (base_from_limb_widths limb_widths).
  Local Notation digits := (tuple Z (length limb_widths)).
  Local Arguments to_list {_ _} _.
  Local Arguments from_list {_ _} _ _.
  Local Arguments length_to_list {_ _ _}.
  Local Notation "[[ u ]]" := (to_list u).

  Definition decode (us : digits) : F modulus := decode [[us]].

  Definition encode (x : F modulus) : digits := from_list (encode x) length_encode.

  Definition add (us vs : digits) : digits := from_list (add [[us]] [[vs]])
    (add_same_length _ _ _ length_to_list length_to_list).

  Definition mul (us vs : digits) : digits := from_list (mul [[us]] [[vs]])
    (length_mul length_to_list length_to_list).

  Definition sub (modulus_multiple us vs : digits) : digits :=
   from_list (sub [[modulus_multiple]] [[us]] [[vs]])
   (length_sub length_to_list length_to_list length_to_list).

  Definition zero : digits := encode (F.of_Z _ 0).

  Definition one : digits := encode (F.of_Z _ 1).

  (* Placeholder *)
  Definition opp (x : digits) : digits := encode (F.opp (decode x)).

  Definition pow (x : digits) (chain : list (nat * nat)) : digits :=
    fold_chain one mul chain (x :: nil).

  Definition inv (chain : list (nat * nat))
                 (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus - 2))
                 (x : digits) : digits := pow x chain.

  (* Placeholder *)
  Definition div (x y : digits) : digits := encode (F.div (decode x) (decode y)).

  Definition carry_mul (carry_chain : list nat) (us vs : digits) : digits :=
    from_list (carry_sequence carry_chain [[mul us vs]]) (length_carry_sequence length_to_list).
  
  Definition rep (us : digits) (x : F modulus) := decode us = x.
  Local Notation "u ~= x" := (rep u x).
  Local Hint Unfold rep.

  Definition carry_full (us : digits) : digits := from_list (carry_full [[us]])
    (length_carry_full length_to_list).

  Definition freeze (us : digits) : digits :=
    let us' := carry_full (carry_full (carry_full us)) in
     from_list (conditional_subtract_modulus [[us']] (ge_modulus [[us']]))
     (length_conditional_subtract_modulus length_to_list).

  Definition eq (x y : digits) : Prop := decode x = decode y.

  Import Morphisms.
  Global Instance eq_Equivalence : Equivalence eq.
  Proof.
    split; cbv [eq]; repeat intro; congruence.
  Qed.

  Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x)
          (bits_eq : sum_firstn limb_widths   (length limb_widths) =
                     sum_firstn target_widths (length target_widths)).
  Local Notation target_digits := (tuple Z (length target_widths)).

  Definition pack (x : digits) : target_digits :=
    from_list (pack target_widths_nonneg bits_eq [[x]]) length_pack.
  
  Definition unpack (x : target_digits) : digits :=
    from_list (unpack target_widths_nonneg bits_eq [[x]]) length_unpack.

End ModularBaseSystem.