aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Saturated/Freeze.v
blob: d8e7f4b5ea948e13ec836738528e94b365f5dbd3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
Require Import Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
Local Open Scope Z_scope.

Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Arithmetic.Saturated.Core.
Require Import Crypto.Arithmetic.Saturated.Wrappers.
Require Import Crypto.Util.ZUtil.AddGetCarry.
Require Import Crypto.Util.ZUtil.Definitions.
Require Import Crypto.Util.ZUtil.Modulo.PullPush.
Require Import Crypto.Util.ZUtil.Le.
Require Import Crypto.Util.ZUtil.CPS.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
Local Notation "A ^ n" := (tuple A n) : type_scope.

(* Canonicalize bignums by fully reducing them modulo p.
   This works on unsaturated digits, but uses saturated add/subtract
   loops.*)
Section Freeze.
    Context (weight : nat->Z)
            {weight_0 : weight 0%nat = 1}
            {weight_nonzero : forall i, weight i <> 0}
            {weight_positive : forall i, weight i > 0}
            {weight_multiples : forall i, weight (S i) mod weight i = 0}
            {weight_divides : forall i : nat, weight (S i) / weight i > 0}
    .


  (*
    The input to [freeze] should be less than 2*m (this can probably
    be accomplished by a single carry_reduce step, for most moduli).

    [freeze] has the following steps:
    (1) subtract modulus in a carrying loop (in our framework, this
    consists of two steps; [Columns.unbalanced_sub_cps] combines the
    input p and the modulus m such that the ith limb in the output is
    the list [p[i];-m[i]]. We can then call [Columns.compact].)
    (2) look at the final carry, which should be either 0 or -1. If
    it's -1, then we add the modulus back in. Otherwise we add 0 for
    constant-timeness.
    (3) discard the carry after this last addition; it should be 1 if
    the carry in step 3 was -1, so they cancel out.
   *)
  Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
    Columns.unbalanced_sub_cps (n3:=n) weight p m
      (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m
      (fun carry_r => f (snd carry_r)))
  .

  Definition freeze {n} mask m p :=
    @freeze_cps n mask m p _ id.
  Lemma freeze_id {n} mask m p T f:
    @freeze_cps n mask m p T f = f (freeze mask m p).
  Proof.
    cbv [freeze_cps freeze]; repeat progress autounfold;
      autorewrite with uncps push_id; reflexivity.
  Qed.
  Hint Opaque freeze : uncps.
  Hint Rewrite @freeze_id : uncps.

  Lemma freezeZ m s c y y0 z z0 c0 a :
    m = s - c ->
    0 < c < s ->
    s <> 0 ->
    0 <= y < 2*m ->
    y0 = y - m ->
    z = y0 mod s ->
    c0 = y0 / s ->
    z0 = z + (if (dec (c0 = 0)) then 0 else m) ->
    a = z0 mod s ->
    a mod m = y0 mod m.
  Proof.
    clear. intros. subst. break_match.
    { rewrite Z.add_0_r, Z.mod_mod by omega.
      assert (-(s-c) <= y - (s-c) < s-c) by omega.
      match goal with H : s <> 0 |- _ =>
                      rewrite (proj2 (Z.mod_small_iff _ s H))
                              by (apply Z.div_small_iff; assumption)
      end.
      reflexivity. }
    { rewrite <-Z.add_mod_l, Z.sub_mod_full.
      rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega.
      rewrite Z.mod_small with (b := s)
      by (pose proof (Z.div_small (y - (s-c)) s); omega).
      f_equal. ring. }
  Qed.

  Lemma eval_freeze {n} c mask m p
        (n_nonzero:n<>0%nat)
        (Hc : 0 < B.Associational.eval c < weight n)
        (Hmask : Tuple.map (Z.land mask) m = m)
        modulus (Hm : B.Positional.eval weight m = Z.pos modulus)
        (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus))
        (Hsc : Z.pos modulus = weight n - B.Associational.eval c)
    :
      mod_eq modulus
             (B.Positional.eval weight (@freeze n mask m p))
             (B.Positional.eval weight p).
  Proof.
    cbv [freeze_cps freeze].
    repeat progress autounfold.
    pose proof Z.add_get_carry_full_mod.
    pose proof Z.add_get_carry_full_div.
    pose proof div_correct. pose proof modulo_correct.
    pose proof @div_id. pose proof @modulo_id.
    pose proof @Z.add_get_carry_full_cps_correct.
    autorewrite with uncps push_id push_basesystem_eval.

    pose proof (weight_nonzero n).

    remember (B.Positional.eval weight p) as y.
    remember (y + -B.Positional.eval weight m) as y0.
    rewrite Hm in *.

    transitivity y0; cbv [mod_eq].
    { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0);
        try assumption; reflexivity. }
    { subst y0.
      assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0.
      rewrite Z.add_mod by assumption.
      rewrite Z.mod_opp_l_z by auto using Z.mod_same.
      rewrite Z.add_0_r, Z.mod_mod by assumption.
      reflexivity. }
  Qed.
End Freeze.
Hint Opaque freeze_cps : uncps.
Hint Rewrite @freeze_id : uncps.
Hint Rewrite @eval_freeze
     using (assumption || reflexivity || auto || eassumption || omega) : push_basesystem_eval.

Hint Unfold
     freeze freeze_cps
  : basesystem_partial_evaluation_unfolder.

Ltac basesystem_partial_evaluation_unfolder t :=
  let t := (eval cbv delta [freeze freeze_cps] in t) in
  let t := Saturated.Wrappers.basesystem_partial_evaluation_unfolder t in
  let t := Saturated.Core.basesystem_partial_evaluation_unfolder t in
  let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in
  t.

Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::=
  basesystem_partial_evaluation_unfolder t.