aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-06-29 21:32:04 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2017-06-29 21:32:04 -0400
commit2876f7c688590a64189f47b439f7edf26c91c5de (patch)
treefffe55bc24e83105fca356a81a352e1fa4309999 /src/Arithmetic
parentb291707642db5986240b3e9eb9a80839d81ffe42 (diff)
Reorganization of saturated arithmetic
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v6
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v11
-rw-r--r--src/Arithmetic/Saturated/AddSub.v109
-rw-r--r--src/Arithmetic/Saturated/Core.v993
-rw-r--r--src/Arithmetic/Saturated/Freeze.v122
-rw-r--r--src/Arithmetic/Saturated/MontgomeryAPI.v599
-rw-r--r--src/Arithmetic/Saturated/MulSplit.v73
-rw-r--r--src/Arithmetic/Saturated/UniformWeight.v71
-rw-r--r--src/Arithmetic/Saturated/Wrappers.v53
9 files changed, 1036 insertions, 1001 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
index f344cb7de..9affa82fa 100644
--- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
@@ -5,7 +5,7 @@
of the algorithm; note that it may be that none of the algorithms
there exactly match what we're doing here. *)
Require Import Coq.ZArith.ZArith.
-Require Import Crypto.Arithmetic.Saturated.
+Require Import Crypto.Arithmetic.Saturated.MontgomeryAPI.
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.LetIn.
@@ -22,8 +22,8 @@ Section WordByWordMontgomery.
(N : T R_numlimbs).
Local Notation scmul := (@scmul (Z.pos r)).
- Local Notation addT' := (@Saturated.add_S1 (Z.pos r)).
- Local Notation addT := (@Saturated.add (Z.pos r)).
+ Local Notation addT' := (@MontgomeryAPI.add_S1 (Z.pos r)).
+ Local Notation addT := (@MontgomeryAPI.add (Z.pos r)).
Local Notation conditional_sub_cps := (fun V => @conditional_sub_cps (Z.pos r) _ V N _).
Local Notation conditional_sub := (fun V => @conditional_sub (Z.pos r) _ V N).
Local Notation sub_then_maybe_add_cps :=
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
index 747280fe6..83791ec5f 100644
--- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
@@ -1,7 +1,8 @@
(*** Word-By-Word Montgomery Multiplication Proofs *)
Require Import Coq.ZArith.BinInt.
Require Import Coq.micromega.Lia.
-Require Import Crypto.Arithmetic.Saturated.
+Require Import Crypto.Arithmetic.Saturated.UniformWeight.
+Require Import Crypto.Arithmetic.Saturated.MontgomeryAPI.
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition.
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Proofs.
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition.
@@ -16,8 +17,8 @@ Section WordByWordMontgomery.
(R_numlimbs : nat).
Local Notation small := (@small (Z.pos r)).
Local Notation eval := (@eval (Z.pos r)).
- Local Notation addT' := (@Saturated.add_S1 (Z.pos r)).
- Local Notation addT := (@Saturated.add (Z.pos r)).
+ Local Notation addT' := (@MontgomeryAPI.add_S1 (Z.pos r)).
+ Local Notation addT := (@MontgomeryAPI.add (Z.pos r)).
Local Notation scmul := (@scmul (Z.pos r)).
Local Notation eval_zero := (@eval_zero (Z.pos r)).
Local Notation small_zero := (@small_zero r (Zorder.Zgt_pos_0 _)).
@@ -61,11 +62,11 @@ Section WordByWordMontgomery.
Qed.
Local Lemma small_addT : forall n a b, small a -> small b -> small (@addT n a b).
Proof.
- intros; apply Saturated.small_add; auto; lia.
+ intros; apply MontgomeryAPI.small_add; auto; lia.
Qed.
Local Lemma small_addT' : forall n a b, small a -> small b -> small (@addT' n a b).
Proof.
- intros; apply Saturated.small_add_S1; auto; lia.
+ intros; apply MontgomeryAPI.small_add_S1; auto; lia.
Qed.
Local Notation conditional_sub_cps := (fun V : T (S R_numlimbs) => @conditional_sub_cps (Z.pos r) _ V N _).
diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v
new file mode 100644
index 000000000..c6758b865
--- /dev/null
+++ b/src/Arithmetic/Saturated/AddSub.v
@@ -0,0 +1,109 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Arithmetic.Saturated.Core.
+Require Import Crypto.Arithmetic.Saturated.UniformWeight.
+Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
+Local Notation "A ^ n" := (tuple A n) : type_scope.
+
+Module B.
+ Module Positional.
+ Section Positional.
+ Context {s:Z}. (* s is bitwidth *)
+ Let small {n} := @small s n.
+ Section GenericOp.
+ Context {op : Z -> Z -> Z}
+ {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *)
+ {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *)
+
+ Fixpoint chain_op'_cps {n}:
+ option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T :=
+ match n with
+ | O => fun c p _ _ f =>
+ let carry := match c with | None => 0 | Some x => x end in
+ f (carry,p)
+ | S n' =>
+ fun c p q _ f =>
+ (* for the first call, use op_get_carry, then op_with_carry *)
+ let op' := match c with
+ | None => op_get_carry
+ | Some x => op_with_carry x end in
+ dlet carry_result := op' (hd p) (hd q) in
+ chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) _
+ (fun carry_pq =>
+ f (fst carry_pq,
+ append (fst carry_result) (snd carry_pq)))
+ end.
+ Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id.
+ Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f.
+ Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id.
+
+ Lemma chain_op'_id {n} : forall c p q T f,
+ @chain_op'_cps n c p q T f = f (chain_op' c p q).
+ Proof.
+ cbv [chain_op']; induction n; intros; destruct c;
+ simpl chain_op'_cps; cbv [Let_In]; try reflexivity.
+ { etransitivity; rewrite IHn; reflexivity. }
+ { etransitivity; rewrite IHn; reflexivity. }
+ Qed.
+
+ Lemma chain_op_id {n} p q T f :
+ @chain_op_cps n p q T f = f (chain_op p q).
+ Proof. apply chain_op'_id. Qed.
+ End GenericOp.
+
+ Section AddSub.
+ Let eval {n} := B.Positional.eval (n:=n) (uweight s).
+
+ Definition sat_add_cps {n} p q T (f:Z*Z^n->T) :=
+ chain_op_cps (op_get_carry := Z.add_get_carry_full s)
+ (op_with_carry := Z.add_with_get_carry_full s)
+ p q f.
+ Definition sat_add {n} p q := @sat_add_cps n p q _ id.
+
+ Lemma sat_add_id n p q T f :
+ @sat_add_cps n p q T f = f (sat_add p q).
+ Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed.
+
+ Lemma sat_add_mod n p q :
+ eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n).
+ Admitted.
+
+ Lemma sat_add_div n p q :
+ fst (@sat_add n p q) = (eval p + eval q) / (uweight s n).
+ Admitted.
+
+ Lemma small_sat_add n p q : small (snd (@sat_add n p q)).
+ Admitted.
+
+ Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) :=
+ chain_op_cps (op_get_carry := Z.sub_get_borrow_full s)
+ (op_with_carry := Z.sub_with_get_borrow_full s)
+ p q f.
+ Definition sat_sub {n} p q := @sat_sub_cps n p q _ id.
+
+ Lemma sat_sub_id n p q T f :
+ @sat_sub_cps n p q T f = f (sat_sub p q).
+ Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed.
+
+ Lemma sat_sub_mod n p q :
+ eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n).
+ Admitted.
+
+ Lemma sat_sub_div n p q :
+ fst (@sat_sub n p q) = - ((eval p - eval q) / uweight s n).
+ Admitted.
+
+ Lemma small_sat_sub n p q : small (snd (@sat_sub n p q)).
+ Admitted.
+
+ End AddSub.
+ End Positional.
+ End Positional.
+End B.
+Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps.
+Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps.
+Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. \ No newline at end of file
diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v
index 0c059b93d..27171c741 100644
--- a/src/Arithmetic/Saturated/Core.v
+++ b/src/Arithmetic/Saturated/Core.v
@@ -11,10 +11,6 @@ Require Import Crypto.Util.Tuple Crypto.Util.ListUtil.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Decidable Crypto.Util.ZUtil.
Require Import Crypto.Util.NatUtil.
-Require Import Crypto.Util.ZUtil.Definitions.
-Require Import Crypto.Util.ZUtil.AddGetCarry.
-Require Import Crypto.Util.ZUtil.Zselect.
-Require Import Crypto.Util.ZUtil.MulSplit.
Require Import Crypto.Util.Tactics.SpecializeBy.
Local Notation "A ^ n" := (tuple A n) : type_scope.
@@ -107,71 +103,6 @@ check confirms our result.
***)
-Module Associational.
- Section Associational.
- Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *)
- {mul_split_mod : forall s x y,
- fst (mul_split s x y) = (x * y) mod s}
- {mul_split_div : forall s x y,
- snd (mul_split s x y) = (x * y) / s}
- .
-
- Definition multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) :=
- dlet xy := mul_split s (snd t) (snd t') in
- f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil).
-
- Definition multerm s t t' := multerm_cps s t t' id.
- Lemma multerm_id s t t' T f :
- @multerm_cps s t t' T f = f (multerm s t t').
- Proof. reflexivity. Qed.
- Hint Opaque multerm : uncps.
- Hint Rewrite multerm_id : uncps.
-
- Definition mul_cps s (p q : list B.limb) {T} (f : list B.limb -> T) :=
- flat_map_cps (fun t => @flat_map_cps _ _ (multerm_cps s t) q) p f.
-
- Definition mul s p q := mul_cps s p q id.
- Lemma mul_id s p q T f : @mul_cps s p q T f = f (mul s p q).
- Proof. cbv [mul mul_cps]. autorewrite with uncps. reflexivity. Qed.
- Hint Opaque mul : uncps.
- Hint Rewrite mul_id : uncps.
-
- Lemma eval_map_multerm s a q (s_nonzero:s<>0):
- B.Associational.eval (flat_map (multerm s a) q) = fst a * snd a * B.Associational.eval q.
- Proof.
- cbv [multerm multerm_cps Let_In]; induction q;
- repeat match goal with
- | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * )
- | _ => progress simpl flat_map
- | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div
- | _ => rewrite Z.mod_eq by assumption
- | _ => ring_simplify; omega
- end.
- Qed.
- Hint Rewrite eval_map_multerm using (omega || assumption)
- : push_basesystem_eval.
-
- Lemma eval_mul s p q (s_nonzero:s<>0):
- B.Associational.eval (mul s p q) = B.Associational.eval p * B.Associational.eval q.
- Proof.
- cbv [mul mul_cps]; induction p; [reflexivity|].
- repeat match goal with
- | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * )
- | _ => progress simpl flat_map
- | _ => rewrite IHp
- | _ => progress change (fun x => multerm_cps s a x id) with (multerm s a)
- | _ => ring_simplify; omega
- end.
- Qed.
- Hint Rewrite eval_mul : push_basesystem_eval.
-
- End Associational.
-End Associational.
-Hint Opaque Associational.mul Associational.multerm : uncps.
-Hint Rewrite @Associational.mul_id @Associational.multerm_id : uncps.
-Hint Rewrite @Associational.eval_mul @Associational.eval_map_multerm using (omega || assumption) : push_basesystem_eval.
-
-
Module Columns.
Section Columns.
Context (weight : nat->Z)
@@ -480,56 +411,7 @@ Module Columns.
rewrite eval_cons_to_nth by omega. nsatz.
Qed.
End Columns.
- Hint Rewrite
- @Columns.compact_id
- @Columns.from_associational_id
- : uncps.
- Hint Rewrite
- @Columns.compact_mod
- @Columns.compact_div
- @Columns.eval_from_associational
- using (assumption || omega): push_basesystem_eval.
-
- Section Wrappers.
- Context (weight : nat->Z).
-
- Definition add_cps {n1 n2 n3} (p : Z^n1) (q : Z^n2)
- {T} (f : (Z*Z^n3)->T) :=
- B.Positional.to_associational_cps weight p
- (fun P => B.Positional.to_associational_cps weight q
- (fun Q => from_associational_cps weight n3 (P++Q)
- (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))).
-
- Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2)
- {T} (f : (Z*Z^n3)->T) :=
- B.Positional.to_associational_cps weight p
- (fun P => B.Positional.negate_snd_cps weight q
- (fun nq => B.Positional.to_associational_cps weight nq
- (fun Q => from_associational_cps weight n3 (P++Q)
- (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
-
- Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2)
- {T} (f : (Z*Z^n3)->T) :=
- B.Positional.to_associational_cps weight p
- (fun P => B.Positional.to_associational_cps weight q
- (fun Q => Associational.mul_cps (mul_split := Z.mul_split) s P Q
- (fun PQ => from_associational_cps weight n3 PQ
- (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
-
- Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2)
- {T} (f:_->T) :=
- B.Positional.select_cps mask cond q
- (fun qq => add_cps (n3:=n3) p qq f).
-
- End Wrappers.
- Hint Unfold add_cps unbalanced_sub_cps mul_cps conditional_add_cps.
-
End Columns.
-Hint Unfold
- Columns.conditional_add_cps
- Columns.add_cps
- Columns.unbalanced_sub_cps
- Columns.mul_cps.
Hint Rewrite
@Columns.compact_digit_id
@Columns.compact_step_id
@@ -544,878 +426,3 @@ Hint Rewrite
@Columns.eval_from_associational
@Columns.eval_nils
using (assumption || omega): push_basesystem_eval.
-
-Section Freeze.
- Context (weight : nat->Z)
- {weight_0 : weight 0%nat = 1}
- {weight_nonzero : forall i, weight i <> 0}
- {weight_positive : forall i, weight i > 0}
- {weight_multiples : forall i, weight (S i) mod weight i = 0}
- {weight_divides : forall i : nat, weight (S i) / weight i > 0}
- .
-
-
- (*
- The input to [freeze] should be less than 2*m (this can probably
- be accomplished by a single carry_reduce step, for most moduli).
-
- [freeze] has the following steps:
- (1) subtract modulus in a carrying loop (in our framework, this
- consists of two steps; [Columns.unbalanced_sub_cps] combines the
- input p and the modulus m such that the ith limb in the output is
- the list [p[i];-m[i]]. We can then call [Columns.compact].)
- (2) look at the final carry, which should be either 0 or -1. If
- it's -1, then we add the modulus back in. Otherwise we add 0 for
- constant-timeness.
- (3) discard the carry after this last addition; it should be 1 if
- the carry in step 3 was -1, so they cancel out.
- *)
- Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
- Columns.unbalanced_sub_cps (n3:=n) weight p m
- (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m
- (fun carry_r => f (snd carry_r)))
- .
-
- Definition freeze {n} mask m p :=
- @freeze_cps n mask m p _ id.
- Lemma freeze_id {n} mask m p T f:
- @freeze_cps n mask m p T f = f (freeze mask m p).
- Proof.
- cbv [freeze_cps freeze]; repeat progress autounfold;
- autorewrite with uncps push_id; reflexivity.
- Qed.
- Hint Opaque freeze : uncps.
- Hint Rewrite @freeze_id : uncps.
-
- Lemma freezeZ m s c y y0 z z0 c0 a :
- m = s - c ->
- 0 < c < s ->
- s <> 0 ->
- 0 <= y < 2*m ->
- y0 = y - m ->
- z = y0 mod s ->
- c0 = y0 / s ->
- z0 = z + (if (dec (c0 = 0)) then 0 else m) ->
- a = z0 mod s ->
- a mod m = y0 mod m.
- Proof.
- clear. intros. subst. break_match.
- { rewrite Z.add_0_r, Z.mod_mod by omega.
- assert (-(s-c) <= y - (s-c) < s-c) by omega.
- match goal with H : s <> 0 |- _ =>
- rewrite (proj2 (Z.mod_small_iff _ s H))
- by (apply Z.div_small_iff; assumption)
- end.
- reflexivity. }
- { rewrite <-Z.add_mod_l, Z.sub_mod_full.
- rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega.
- rewrite Z.mod_small with (b := s)
- by (pose proof (Z.div_small (y - (s-c)) s); omega).
- f_equal. ring. }
- Qed.
-
- Lemma eval_freeze {n} c mask m p
- (n_nonzero:n<>0%nat)
- (Hc : 0 < B.Associational.eval c < weight n)
- (Hmask : Tuple.map (Z.land mask) m = m)
- modulus (Hm : B.Positional.eval weight m = Z.pos modulus)
- (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus))
- (Hsc : Z.pos modulus = weight n - B.Associational.eval c)
- :
- mod_eq modulus
- (B.Positional.eval weight (@freeze n mask m p))
- (B.Positional.eval weight p).
- Proof.
- cbv [freeze_cps freeze].
- repeat progress autounfold.
- pose proof Z.add_get_carry_full_mod.
- pose proof Z.add_get_carry_full_div.
- pose proof div_correct. pose proof modulo_correct.
- autorewrite with uncps push_id push_basesystem_eval.
-
- pose proof (weight_nonzero n).
-
- remember (B.Positional.eval weight p) as y.
- remember (y + -B.Positional.eval weight m) as y0.
- rewrite Hm in *.
-
- transitivity y0; cbv [mod_eq].
- { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0);
- try assumption; reflexivity. }
- { subst y0.
- assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0.
- rewrite Z.add_mod by assumption.
- rewrite Z.mod_opp_l_z by auto using Z.mod_same.
- rewrite Z.add_0_r, Z.mod_mod by assumption.
- reflexivity. }
- Qed.
-End Freeze.
-
-Section UniformWeight.
- Context (bound : Z) {bound_pos : bound > 0}.
-
- Definition uweight : nat -> Z := fun i => bound ^ Z.of_nat i.
- Lemma uweight_0 : uweight 0%nat = 1. Proof. reflexivity. Qed.
- Lemma uweight_positive i : uweight i > 0.
- Proof. apply Z.lt_gt, Z.pow_pos_nonneg; omega. Qed.
- Lemma uweight_nonzero i : uweight i <> 0.
- Proof. auto using Z.positive_is_nonzero, uweight_positive. Qed.
- Lemma uweight_multiples i : uweight (S i) mod uweight i = 0.
- Proof. apply Z.mod_same_pow; rewrite Nat2Z.inj_succ; omega. Qed.
- Lemma uweight_divides i : uweight (S i) / uweight i > 0.
- Proof.
- cbv [uweight]. rewrite <-Z.pow_sub_r by (rewrite ?Nat2Z.inj_succ; omega).
- apply Z.lt_gt, Z.pow_pos_nonneg; rewrite ?Nat2Z.inj_succ; omega.
- Qed.
-
- (* TODO : move to Positional *)
- Lemma eval_from_eq {n} (p:Z^n) wt offset :
- (forall i, wt i = uweight (i + offset)) ->
- B.Positional.eval wt p = B.Positional.eval_from uweight offset p.
- Proof. cbv [B.Positional.eval_from]. auto using B.Positional.eval_wt_equiv. Qed.
-
- Lemma uweight_eval_from {n} (p:Z^n): forall offset,
- B.Positional.eval_from uweight offset p = uweight offset * B.Positional.eval uweight p.
- Proof.
- induction n; intros; cbv [B.Positional.eval_from];
- [|rewrite (subst_append p)];
- repeat match goal with
- | _ => destruct p
- | _ => rewrite B.Positional.eval_unit; [ ]
- | _ => rewrite B.Positional.eval_step; [ ]
- | _ => rewrite IHn; [ ]
- | _ => rewrite eval_from_eq with (offset0:=S offset)
- by (intros; f_equal; omega)
- | _ => rewrite eval_from_eq with
- (wt:=fun i => uweight (S i)) (offset0:=1%nat)
- by (intros; f_equal; omega)
- | _ => ring
- end.
- repeat match goal with
- | _ => cbv [uweight]; progress autorewrite with natsimplify
- | _ => progress (rewrite ?Nat2Z.inj_succ, ?Nat2Z.inj_0, ?Z.pow_0_r)
- | _ => rewrite !Z.pow_succ_r by (try apply Nat2Z.is_nonneg; omega)
- | _ => ring
- end.
- Qed.
-
- Lemma uweight_eval_step {n} (p:Z^S n):
- B.Positional.eval uweight p = hd p + bound * B.Positional.eval uweight (tl p).
- Proof.
- rewrite (subst_append p) at 1; rewrite B.Positional.eval_step.
- rewrite eval_from_eq with (offset := 1%nat) by (intros; f_equal; omega).
- rewrite uweight_eval_from. cbv [uweight]; rewrite Z.pow_0_r, Z.pow_1_r.
- ring.
- Qed.
-
- Definition small {n} (p : Z^n) : Prop :=
- forall x, In x (to_list _ p) -> 0 <= x < bound.
-
-End UniformWeight.
-
-Module Positional.
- Section Positional.
- Context {s:Z}. (* s is bitwidth *)
- Let small {n} := @small s n.
- Section GenericOp.
- Context {op : Z -> Z -> Z}
- {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *)
- {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *)
-
- Fixpoint chain_op'_cps {n}:
- option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T :=
- match n with
- | O => fun c p _ _ f =>
- let carry := match c with | None => 0 | Some x => x end in
- f (carry,p)
- | S n' =>
- fun c p q _ f =>
- (* for the first call, use op_get_carry, then op_with_carry *)
- let op' := match c with
- | None => op_get_carry
- | Some x => op_with_carry x end in
- dlet carry_result := op' (hd p) (hd q) in
- chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) _
- (fun carry_pq =>
- f (fst carry_pq,
- append (fst carry_result) (snd carry_pq)))
- end.
- Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id.
- Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f.
- Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id.
-
- Lemma chain_op'_id {n} : forall c p q T f,
- @chain_op'_cps n c p q T f = f (chain_op' c p q).
- Proof.
- cbv [chain_op']; induction n; intros; destruct c;
- simpl chain_op'_cps; cbv [Let_In]; try reflexivity.
- { etransitivity; rewrite IHn; reflexivity. }
- { etransitivity; rewrite IHn; reflexivity. }
- Qed.
-
- Lemma chain_op_id {n} p q T f :
- @chain_op_cps n p q T f = f (chain_op p q).
- Proof. apply chain_op'_id. Qed.
- End GenericOp.
-
- Section AddSub.
- Let eval {n} := B.Positional.eval (n:=n) (uweight s).
-
- Definition sat_add_cps {n} p q T (f:Z*Z^n->T) :=
- chain_op_cps (op_get_carry := Z.add_get_carry_full s)
- (op_with_carry := Z.add_with_get_carry_full s)
- p q f.
- Definition sat_add {n} p q := @sat_add_cps n p q _ id.
-
- Lemma sat_add_id n p q T f :
- @sat_add_cps n p q T f = f (sat_add p q).
- Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed.
-
- Lemma sat_add_mod n p q :
- eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n).
- Admitted.
-
- Lemma sat_add_div n p q :
- fst (@sat_add n p q) = (eval p + eval q) / (uweight s n).
- Admitted.
-
- Lemma small_sat_add n p q : small (snd (@sat_add n p q)).
- Admitted.
-
- Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) :=
- chain_op_cps (op_get_carry := Z.sub_get_borrow_full s)
- (op_with_carry := Z.sub_with_get_borrow_full s)
- p q f.
- Definition sat_sub {n} p q := @sat_sub_cps n p q _ id.
-
- Lemma sat_sub_id n p q T f :
- @sat_sub_cps n p q T f = f (sat_sub p q).
- Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed.
-
- Lemma sat_sub_mod n p q :
- eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n).
- Admitted.
-
- Lemma sat_sub_div n p q :
- fst (@sat_sub n p q) = - ((eval p - eval q) / uweight s n).
- Admitted.
-
- Lemma small_sat_sub n p q : small (snd (@sat_sub n p q)).
- Admitted.
-
- End AddSub.
- End Positional.
-End Positional.
-Hint Opaque Positional.sat_sub Positional.sat_add Positional.chain_op Positional.chain_op' : uncps.
-Hint Rewrite @Positional.sat_sub_id @Positional.sat_add_id @Positional.chain_op_id @Positional.chain_op' : uncps.
-Hint Rewrite @Positional.sat_sub_mod @Positional.sat_sub_div @Positional.sat_add_mod @Positional.sat_add_div using (omega || assumption) : push_basesystem_eval.
-
-Section API.
- Context (bound : Z) {bound_pos : bound > 0}.
- Definition T : nat -> Type := tuple Z.
-
- (* lowest limb is less than its bound; this is required for [divmod]
- to simply separate the lowest limb from the rest and be equivalent
- to normal div/mod with [bound]. *)
- Local Notation small := (@small bound).
-
- Definition zero {n:nat} : T n := B.Positional.zeros n.
-
- (** Returns 0 iff all limbs are 0 *)
- Definition nonzero_cps {n} (p : T n) {cpsT} (f : Z -> cpsT) : cpsT
- := CPSUtil.to_list_cps _ p (fun p => CPSUtil.fold_right_cps runtime_lor 0%Z p f).
- Definition nonzero {n} (p : T n) : Z
- := nonzero_cps p id.
-
- Definition join0_cps {n:nat} (p : T n) {R} (f:T (S n) -> R)
- := Tuple.left_append_cps 0 p f.
- Definition join0 {n} p : T (S n) := @join0_cps n p _ id.
-
- Definition divmod_cps {n} (p : T (S n)) {R} (f:T n * Z->R) : R
- := Tuple.tl_cps p (fun d => Tuple.hd_cps p (fun m => f (d, m))).
- Definition divmod {n} p : T n * Z := @divmod_cps n p _ id.
-
- Definition drop_high_cps {n : nat} (p : T (S n)) {R} (f:T n->R)
- := Tuple.left_tl_cps p f.
- Definition drop_high {n} p : T n := @drop_high_cps n p _ id.
-
- Definition scmul_cps {n} (c : Z) (p : T n) {R} (f:T (S n)->R) :=
- Columns.mul_cps (n1:=1) (n3:=S n) (uweight bound) bound c p
- (* The carry that comes out of Columns.mul_cps will be 0, since
- (S n) limbs is enough to hold the result of the
- multiplication, so we can safely discard it. *)
- (fun carry_result =>f (snd carry_result)).
- Definition scmul {n} c p : T (S n) := @scmul_cps n c p _ id.
-
- Definition add_cps {n} (p q: T n) {R} (f:T (S n)->R) :=
- Positional.sat_add_cps (s:=bound) p q _
- (* join the last carry *)
- (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) f).
- Definition add {n} p q : T (S n) := @add_cps n p q _ id.
-
- (* Wrappers for additions with slightly uneven limb counts *)
- Definition add_S1_cps {n} (p: T (S n)) (q: T n) {R} (f:T (S (S n))->R) :=
- join0_cps q (fun Q => add_cps p Q f).
- Definition add_S1 {n} p q := @add_S1_cps n p q _ id.
- Definition add_S2_cps {n} (p: T n) (q: T (S n)) {R} (f:T (S (S n))->R) :=
- join0_cps p (fun P => add_cps P q f).
- Definition add_S2 {n} p q := @add_S2_cps n p q _ id.
->>>>>>> addsubchains
-
- Definition sub_then_maybe_add_cps {n} mask (p q r : T n)
- {R} (f:T n -> R) :=
- Positional.sat_sub_cps (s:=bound) p q _
- (* the carry will be 0 unless we underflow--we do the addition only
- in the underflow case *)
- (fun carry_result =>
- B.Positional.select_cps mask (fst carry_result) r
- (fun selected => join0_cps selected
- (fun selected' =>
- Positional.sat_sub_cps (s:=bound) (left_append (fst carry_result) (snd carry_result)) selected' _
- (* We can now safely discard the carry and the highest digit.
- This relies on the precondition that p - q + r < bound^n. *)
- (fun carry_result' => drop_high_cps (snd carry_result') f)))).
- Definition sub_then_maybe_add {n} mask (p q r : T n) :=
- sub_then_maybe_add_cps mask p q r id.
-
- (* Subtract q if and only if p >= q. We rely on the preconditions
- that 0 <= p < 2*q and q < bound^n (this ensures the output is less
- than bound^n). *)
- Definition conditional_sub_cps {n} (p:Z^S n) (q:Z^n) R (f:Z^n->R) :=
- join0_cps q
- (fun qq => Positional.sat_sub_cps (s:=bound) p qq _
- (* if carry is zero, we select the result of the subtraction,
- otherwise the first input *)
- (fun carry_result =>
- Tuple.map2_cps (Z.zselect (fst carry_result)) (snd carry_result) p
- (* in either case, since our result must be < q and therefore <
- bound^n, we can drop the high digit *)
- (fun r => drop_high_cps r f))).
- Definition conditional_sub {n} p q := @conditional_sub_cps n p q _ id.
-
- Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps.
-
- Section CPSProofs.
-
- Local Ltac prove_id :=
- repeat autounfold; autorewrite with uncps; reflexivity.
-
- Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p).
- Proof. cbv [nonzero nonzero_cps]. prove_id. Qed.
-
- Lemma join0_id n p R f :
- @join0_cps n p R f = f (join0 p).
- Proof. cbv [join0_cps join0]. prove_id. Qed.
-
- Lemma divmod_id n p R f :
- @divmod_cps n p R f = f (divmod p).
- Proof. cbv [divmod_cps divmod]; prove_id. Qed.
-
- Lemma drop_high_id n p R f :
- @drop_high_cps n p R f = f (drop_high p).
- Proof. cbv [drop_high_cps drop_high]; prove_id. Qed.
- Hint Rewrite drop_high_id : uncps.
-
- Lemma scmul_id n c p R f :
- @scmul_cps n c p R f = f (scmul c p).
- Proof. cbv [scmul_cps scmul]. prove_id. Qed.
-
- Lemma add_id n p q R f :
- @add_cps n p q R f = f (add p q).
- Proof. cbv [add_cps add Let_In]. prove_id. Qed.
- Hint Rewrite add_id : uncps.
-
- Lemma add_S1_id n p q R f :
- @add_S1_cps n p q R f = f (add_S1 p q).
- Proof. cbv [add_S1_cps add_S1 join0_cps]. prove_id. Qed.
-
- Lemma add_S2_id n p q R f :
- @add_S2_cps n p q R f = f (add_S2 p q).
- Proof. cbv [add_S2_cps add_S2 join0_cps]. prove_id. Qed.
-
- Lemma sub_then_maybe_add_id n mask p q r R f :
- @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r).
- Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add join0_cps Let_In]. prove_id. Qed.
-
- Lemma conditional_sub_id n p q R f :
- @conditional_sub_cps n p q R f = f (conditional_sub p q).
- Proof. cbv [conditional_sub_cps conditional_sub join0_cps Let_In]. prove_id. Qed.
-
- End CPSProofs.
- Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps.
-
- Section Proofs.
-
- Definition eval {n} (p : T n) : Z :=
- B.Positional.eval (uweight bound) p.
-
- Lemma eval_small n (p : T n) (Hsmall : small p) :
- 0 <= eval p < uweight bound n.
- Proof.
- cbv [small eval] in *; intros.
- induction n; cbv [T uweight] in *; [destruct p|rewrite (subst_left_append p)];
- repeat match goal with
- | _ => progress autorewrite with push_basesystem_eval
- | _ => rewrite Z.pow_0_r
- | _ => specialize (IHn (left_tl p))
- | _ =>
- let H := fresh "H" in
- match type of IHn with
- ?P -> _ => assert P as H by auto using Tuple.In_to_list_left_tl;
- specialize (IHn H)
- end
- | |- context [?b ^ Z.of_nat (S ?n)] =>
- replace (b ^ Z.of_nat (S n)) with (b ^ Z.of_nat n * b) by
- (rewrite Nat2Z.inj_succ, <-Z.add_1_r, Z.pow_add_r,
- Z.pow_1_r by (omega || auto using Nat2Z.is_nonneg);
- reflexivity)
- | _ => omega
- end.
-
- specialize (Hsmall _ (Tuple.In_left_hd _ p)).
- split; [Z.zero_bounds; omega |].
- apply Z.lt_le_trans with (m:=bound^Z.of_nat n * (left_hd p+1)).
- { rewrite Z.mul_add_distr_l.
- apply Z.add_le_lt_mono; omega. }
- { apply Z.mul_le_mono_nonneg; omega. }
- Qed.
-
- Lemma eval_zero n : eval (@zero n) = 0.
- Proof.
- cbv [eval zero].
- autorewrite with push_basesystem_eval.
- reflexivity.
- Qed.
-
- Lemma small_zero n : small (@zero n).
- Proof.
- cbv [zero small B.Positional.zeros]. destruct n; [simpl;tauto|].
- rewrite to_list_repeat.
- intros x H; apply repeat_spec in H; subst x; omega.
- Qed.
-
- Lemma eval_pair n (p : T (S (S n))) : small p -> (snd p = 0 /\ eval (n:=S n) (fst p) = 0) <-> eval p = 0.
- Admitted.
-
- Lemma eval_nonzero n p : small p -> @nonzero n p = 0 <-> eval p = 0.
- Proof.
- destruct n as [|n].
- { compute; split; trivial. }
- induction n as [|n IHn].
- { simpl; rewrite Z.lor_0_r; unfold eval, id.
- cbv -[Z.add iff].
- rewrite Z.add_0_r.
- destruct p; omega. }
- { destruct p as [ps p]; specialize (IHn ps).
- unfold nonzero, nonzero_cps in *.
- autorewrite with uncps in *.
- unfold id in *.
- setoid_rewrite to_list_S.
- set (k := S n) in *; simpl in *.
- intro Hsmall.
- rewrite Z.lor_eq_0_iff, IHn
- by (hnf in Hsmall |- *; simpl in *; eauto);
- clear IHn.
- exact (eval_pair n (ps, p) Hsmall). }
- Qed.
-
- Lemma eval_join0 n p
- : eval (@join0 n p) = eval p.
- Proof.
- Admitted.
-
- Local Ltac pose_uweight bound :=
- match goal with H : bound > 0 |- _ =>
- pose proof (uweight_0 bound);
- pose proof (@uweight_positive bound H);
- pose proof (@uweight_nonzero bound H);
- pose proof (@uweight_multiples bound);
- pose proof (@uweight_divides bound H)
- end.
-
- Local Ltac pose_all :=
- pose_uweight bound;
- pose proof Z.add_get_carry_full_div;
- pose proof Z.add_get_carry_full_mod;
- pose proof Z.mul_split_div; pose proof Z.mul_split_mod;
- pose proof div_correct; pose proof modulo_correct.
-
- Lemma eval_add_nz n p q :
- n <> 0%nat ->
- eval (@add n p q) = eval p + eval q.
- Proof.
- intros. pose_all.
- repeat match goal with
- | _ => progress (cbv [add_cps add eval Let_In] in *; repeat autounfold)
- | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval
- | _ => rewrite B.Positional.eval_left_append
-
- | _ => progress
- (rewrite <-!from_list_default_eq with (d:=0);
- erewrite !length_to_list, !from_list_default_eq,
- from_list_to_list)
- | _ => apply Z.mod_small; omega
- end.
- Admitted.
-
- Lemma eval_add_z n p q :
- n = 0%nat ->
- eval (@add n p q) = eval p + eval q.
- Proof. intros; subst; reflexivity. Qed.
-
- Lemma eval_add n p q
- : eval (@add n p q) = eval p + eval q.
- Proof.
- destruct (Nat.eq_dec n 0%nat); intuition auto using eval_add_z, eval_add_nz.
- Qed.
- Lemma eval_add_same n p q
- : eval (@add n p q) = eval p + eval q.
- Proof. apply eval_add; omega. Qed.
- Lemma eval_add_S1 n p q
- : eval (@add_S1 n p q) = eval p + eval q.
- Proof.
- cbv [add_S1 add_S1_cps]. autorewrite with uncps push_id.
- (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*)
- Admitted.
- Lemma eval_add_S2 n p q
- : eval (@add_S2 n p q) = eval p + eval q.
- Proof.
- cbv [add_S2 add_S2_cps]. autorewrite with uncps push_id.
- (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*)
- Admitted.
->>>>>>> addsubchains
- Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval.
-
- Lemma uweight_le_mono n m : (n <= m)%nat ->
- uweight bound n <= uweight bound m.
- Proof.
- unfold uweight; intro; Z.peel_le; omega.
- Qed.
-
- Lemma uweight_lt_mono (bound_gt_1 : bound > 1) n m : (n < m)%nat ->
- uweight bound n < uweight bound m.
- Proof.
- clear bound_pos.
- unfold uweight; intro; apply Z.pow_lt_mono_r; omega.
- Qed.
-
- Lemma uweight_succ n : uweight bound (S n) = bound * uweight bound n.
- Proof.
- unfold uweight.
- rewrite Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg; reflexivity.
- Qed.
-
- Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
- Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
- Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)).
- Proof.
- pose_all.
- match goal with
- |- ?G => assert (G /\ fst (compact p) = fst (compact p)); [|tauto]
- end. (* assert a dummy second statement so that fst (compact x) is in context *)
- cbv [compact Columns.compact Columns.compact_cps small
- Columns.compact_step Columns.compact_step_cps];
- autorewrite with uncps push_id.
- change (fun i s a => Columns.compact_digit_cps (uweight bound) i (s :: a) id)
- with (fun i s a => compact_digit i (s :: a)).
- remember (fun i s a => compact_digit i (s :: a)) as f.
-
- apply @mapi_with'_linvariant with (n:=n) (f:=f) (inp:=p);
- intros; [|simpl; tauto]. split; [|reflexivity].
- let P := fresh "H" in
- match goal with H : _ /\ _ |- _ => destruct H end.
- destruct n0; subst f.
- { cbv [compact_digit uweight to_list to_list' In].
- rewrite Columns.compact_digit_mod by assumption.
- rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?.
- match goal with
- H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end.
- subst x. apply Z.mod_pos_bound, Z.gt_lt, bound_pos. }
- { rewrite Tuple.to_list_left_append.
- let H := fresh "H" in
- intros x H; apply in_app_or in H; destruct H;
- [solve[auto]| cbv [In] in H; destruct H;
- [|exfalso; assumption] ].
- subst x. cbv [compact_digit].
- rewrite Columns.compact_digit_mod by assumption.
- rewrite !uweight_succ, Z.div_mul by
- (apply Z.neq_mul_0; split; auto; omega).
- apply Z.mod_pos_bound, Z.gt_lt, bound_pos. }
- Qed.
-
- Lemma small_add n a b :
- (2 <= bound) ->
- small a -> small b -> small (@add n a b).
- Proof.
- intros. pose_all.
- cbv [add_cps add Let_In].
- autorewrite with uncps push_id.
- apply Positional.small_sat_add.
- (*apply Positional.small_sat_add.*)
- Admitted.
-
- Lemma small_add_S1 n a b :
- (2 <= bound) ->
- small a -> small b -> small (@add_S1 n a b).
- Proof.
- intros. pose_all.
- cbv [add_cps add add_S1 Let_In].
- autorewrite with uncps push_id.
- (*apply Positional.small_sat_add.*)
- Admitted.
-
- Lemma small_add_S2 n a b :
- (2 <= bound) ->
- small a -> small b -> small (@add_S2 n a b).
- Proof.
- intros. pose_all.
- cbv [add_cps add add_S2 Let_In].
- autorewrite with uncps push_id.
- (*apply Positional.small_sat_add.*)
->>>>>>> addsubchains
- Admitted.
-
- Lemma small_left_tl n (v:T (S n)) : small v -> small (left_tl v).
- Proof. cbv [small]. auto using Tuple.In_to_list_left_tl. Qed.
-
- Lemma small_divmod n (p: T (S n)) (Hsmall : small p) :
- left_hd p = eval p / uweight bound n /\ eval (left_tl p) = eval p mod (uweight bound n).
- Admitted.
-
- Lemma eval_drop_high n v :
- small v -> eval (@drop_high n v) = eval v mod (uweight bound n).
- Proof.
- cbv [drop_high drop_high_cps eval].
- rewrite Tuple.left_tl_cps_correct, push_id. (* TODO : for some reason autorewrite with uncps doesn't work here *)
- intro H. apply small_left_tl in H.
- rewrite (subst_left_append v) at 2.
- autorewrite with push_basesystem_eval.
- apply eval_small in H.
- rewrite Z.mod_add_l' by (pose_uweight bound; auto).
- rewrite Z.mod_small; auto.
- Qed.
-
- Lemma small_drop_high n v : small v -> small (@drop_high n v).
- Proof.
- cbv [drop_high drop_high_cps].
- rewrite Tuple.left_tl_cps_correct, push_id.
- apply small_left_tl.
- Qed.
-
- Lemma div_nonzero_neg_iff x y : x < y -> 0 < y -> x / y <> 0 <-> x < 0.
- Proof.
- repeat match goal with
- | _ => progress intros
- | _ => rewrite Z.div_small_iff by omega
- | _ => split
- | _ => omega
- end.
- Qed.
-
- Lemma eval_sub_then_maybe_add_nz n mask p q r:
- small p -> small q -> small r -> (n<>0)%nat ->
- (map (Z.land mask) r = r) ->
- (0 <= eval p < eval r) -> (0 <= eval q < eval r) ->
- eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0).
- Proof.
- pose_all.
- repeat match goal with
- | _ => progress (cbv [sub_then_maybe_add sub_then_maybe_add_cps eval] in *; intros)
- | _ => progress autounfold
- | _ => progress autorewrite with uncps push_id push_basesystem_eval
- | _ => rewrite eval_drop_high
- | _ => rewrite eval_join0
- | H : small _ |- _ => apply eval_small in H
- | _ => progress break_match
- | _ => (rewrite Z.add_opp_r in * )
- | H : _ |- _ => rewrite Z.ltb_lt in H;
- rewrite <-div_nonzero_neg_iff with
- (y:=uweight bound n) in H by (auto; omega)
- | H : _ |- _ => rewrite Z.ltb_ge in H
- | _ => rewrite Z.mod_small by omega
- | _ => omega
- | _ => progress autorewrite with zsimplify; [ ]
- end.
- Admitted.
-
- Lemma eval_sub_then_maybe_add n mask p q r :
- small p -> small q -> small r ->
- (map (Z.land mask) r = r) ->
- (0 <= eval p < eval r) -> (0 <= eval q < eval r) ->
- eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0).
- Proof.
- destruct n; [|solve[auto using eval_sub_then_maybe_add_nz]].
- destruct p, q, r; reflexivity.
- Qed.
-
- Lemma small_sub_then_maybe_add n mask (p q r : T n) :
- small (sub_then_maybe_add mask p q r).
- Proof.
- cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros.
- repeat progress autounfold. autorewrite with uncps push_id.
- apply small_drop_high, Positional.small_sat_sub.
- Qed.
-
- (* TODO : remove if unneeded when all admits are proven
- Lemma small_highest_zero_iff {n} (p: T (S n)) (Hsmall : small p) :
- (left_hd p = 0 <-> eval p < uweight bound n).
- Proof.
- destruct (small_divmod _ p Hsmall) as [Hdiv Hmod].
- pose proof Hsmall as Hsmalltl. apply eval_small in Hsmall.
- apply small_left_tl, eval_small in Hsmalltl. rewrite Hdiv.
- rewrite (Z.div_small_iff (eval p) (uweight bound n))
- by auto using uweight_nonzero.
- split; [|intros; left; omega].
- let H := fresh "H" in intro H; destruct H; [|omega].
- omega.
- Qed.
- *)
-
- Lemma map2_zselect n cond x y :
- Tuple.map2 (n:=n) (Z.zselect cond) x y = if dec (cond = 0) then x else y.
- Proof.
- unfold Z.zselect.
- break_innermost_match; Z.ltb_to_lt; subst; try omega;
- [ rewrite Tuple.map2_fst, Tuple.map_id
- | rewrite Tuple.map2_snd, Tuple.map_id ];
- reflexivity.
- Qed.
-
- Lemma eval_conditional_sub_nz n (p:T (S n)) (q:T n)
- (n_nonzero: (n <> 0)%nat) (psmall : small p) (qsmall : small q):
- 0 <= eval p < eval q + uweight bound n ->
- eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0).
- Proof.
- cbv [conditional_sub conditional_sub_cps]. intros. pose_all.
- repeat autounfold. apply eval_small in qsmall.
- pose proof psmall; apply eval_small in psmall.
- cbv [eval] in *. autorewrite with uncps push_id push_basesystem_eval.
- rewrite map2_zselect.
- let H := fresh "H" in let X := fresh "P" in
- match goal with |- context [?x / ?y] =>
- pose proof (div_nonzero_neg_iff x y) end;
- repeat match type of H with ?P -> _ =>
- assert P as X by omega; specialize (H X);
- clear X end.
-
- break_match;
- repeat match goal with
- | _ => progress cbv [eval]
- | H : (_ <=? _) = true |- _ => apply Z.leb_le in H
- | H : (_ <=? _) = false |- _ => apply Z.leb_gt in H
- | _ => rewrite eval_drop_high by auto using Positional.small_sat_sub
- | _ => (rewrite eval_join0 in * )
- | _ => progress autorewrite with uncps push_id push_basesystem_eval
- | _ => repeat rewrite Z.mod_small; omega
- | _ => omega
- end.
- Admitted.
-
- Lemma eval_conditional_sub n (p:T (S n)) (q:T n)
- (psmall : small p) (qsmall : small q) :
- 0 <= eval p < eval q + uweight bound n ->
- eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0).
- Proof.
- destruct n; [|solve[auto using eval_conditional_sub_nz]].
- repeat match goal with
- | _ => progress (intros; cbv [T tuple tuple'] in p, q)
- | q : unit |- _ => destruct q
- | _ => progress (cbv [conditional_sub conditional_sub_cps eval] in * )
- | _ => progress autounfold
- | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * )
- | _ => (rewrite uweight_0 in * )
- | _ => assert (p = 0) by omega; subst p; break_match; ring
- end.
- Qed.
-
- Lemma small_conditional_sub n (p:T (S n)) (q:T n)
- (psmall : small p) (qsmall : small q) :
- 0 <= eval p < eval q + uweight bound n ->
- small (conditional_sub p q).
- Admitted.
-
- Lemma eval_scmul n a v : small v -> 0 <= a < bound ->
- eval (@scmul n a v) = a * eval v.
- Proof.
- intro Hsmall. pose_all. apply eval_small in Hsmall.
- intros. cbv [scmul scmul_cps eval] in *. repeat autounfold.
- autorewrite with uncps push_id push_basesystem_eval.
- rewrite uweight_0, Z.mul_1_l. apply Z.mod_small.
- split; [solve[Z.zero_bounds]|]. cbv [uweight] in *.
- rewrite !Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg.
- apply Z.mul_lt_mono_nonneg; omega.
- Qed.
-
- Lemma small_scmul n a v : small (@scmul n a v).
- Proof.
- cbv [scmul scmul_cps eval] in *. repeat autounfold.
- autorewrite with uncps push_id push_basesystem_eval.
- apply small_compact.
- Qed.
-
- (* TODO : move to tuple *)
- Lemma from_list_tl {A n} (ls : list A) H H':
- from_list n (List.tl ls) H = tl (from_list (S n) ls H').
- Proof.
- induction ls; distr_length. simpl List.tl.
- rewrite from_list_cons, tl_append, <-!(from_list_default_eq a ls).
- reflexivity.
- Qed.
-
- Lemma small_hd n p : @small (S n) p -> 0 <= hd p < bound.
- Proof.
- cbv [small]. let H := fresh "H" in intro H; apply H.
- rewrite (subst_append p). rewrite to_list_append, hd_append.
- apply in_eq.
- Qed.
-
-
- Lemma eval_div n p : small p -> eval (fst (@divmod n p)) = eval p / bound.
- Proof.
- cbv [divmod divmod_cps eval]. intros.
- autorewrite with uncps push_id cancel_pair.
- rewrite (subst_append p) at 2.
- rewrite uweight_eval_step. rewrite hd_append, tl_append.
- rewrite Z.div_add' by omega. rewrite Z.div_small by auto using small_hd.
- ring.
- Qed.
-
- Lemma eval_mod n p : small p -> snd (@divmod n p) = eval p mod bound.
- Proof.
- cbv [divmod divmod_cps eval]. intros.
- autorewrite with uncps push_id cancel_pair.
- rewrite (subst_append p) at 2.
- rewrite uweight_eval_step, Z.mod_add'_full, hd_append.
- rewrite Z.mod_small by auto using small_hd. reflexivity.
- Qed.
-
- Lemma small_div n v : small v -> small (fst (@divmod n v)).
- Admitted.
-
- End Proofs.
-End API.
-Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id add_S1_id add_S2_id sub_then_maybe_add_id conditional_sub_id : uncps.
-
-(*
-(* Just some pretty-printing *)
-Local Notation "fst~ a" := (let (x,_) := a in x) (at level 40, only printing).
-Local Notation "snd~ a" := (let (_,y) := a in y) (at level 40, only printing).
-
-(* Simple example : base 10, multiply two bignums and compact them *)
-Definition base10 i := Eval compute in 10^(Z.of_nat i).
-Eval cbv -[runtime_add runtime_mul Let_In] in
- (fun adc a0 a1 a2 b0 b1 b2 =>
- Columns.mul_cps (weight := base10) (n:=3) (a2,a1,a0) (b2,b1,b0) (fun ab => Columns.compact (n:=5) (add_get_carry:=adc) (weight:=base10) ab)).
-
-(* More complex example : base 2^56, 8 limbs *)
-Definition base2pow56 i := Eval compute in 2^(56*Z.of_nat i).
-Time Eval cbv -[runtime_add runtime_mul Let_In] in
- (fun adc a0 a1 a2 a3 a4 a5 a6 a7 b0 b1 b2 b3 b4 b5 b6 b7 =>
- Columns.mul_cps (weight := base2pow56) (n:=8) (a7,a6,a5,a4,a3,a2,a1,a0) (b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=15) (add_get_carry:=adc) (weight:=base2pow56) ab)). (* Finished transaction in 151.392 secs *)
-
-(* Mixed-radix example : base 2^25.5, 10 limbs *)
-Definition base2pow25p5 i := Eval compute in 2^(25*Z.of_nat i + ((Z.of_nat i + 1) / 2)).
-Time Eval cbv -[runtime_add runtime_mul Let_In] in
- (fun adc a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 =>
- Columns.mul_cps (weight := base2pow25p5) (n:=10) (a9,a8,a7,a6,a5,a4,a3,a2,a1,a0) (b9,b8,b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=19) (add_get_carry:=adc) (weight:=base2pow25p5) ab)). (* Finished transaction in 97.341 secs *)
-*) \ No newline at end of file
diff --git a/src/Arithmetic/Saturated/Freeze.v b/src/Arithmetic/Saturated/Freeze.v
new file mode 100644
index 000000000..735663636
--- /dev/null
+++ b/src/Arithmetic/Saturated/Freeze.v
@@ -0,0 +1,122 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Arithmetic.Saturated.Core.
+Require Import Crypto.Arithmetic.Saturated.Wrappers.
+Require Import Crypto.Util.ZUtil.AddGetCarry.
+Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Decidable Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
+Local Notation "A ^ n" := (tuple A n) : type_scope.
+
+(* Canonicalize bignums by fully reducing them modulo p.
+ This works on unsaturated digits, but uses saturated add/subtract
+ loops.*)
+Section Freeze.
+ Context (weight : nat->Z)
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_positive : forall i, weight i > 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ {weight_divides : forall i : nat, weight (S i) / weight i > 0}
+ .
+
+
+ (*
+ The input to [freeze] should be less than 2*m (this can probably
+ be accomplished by a single carry_reduce step, for most moduli).
+
+ [freeze] has the following steps:
+ (1) subtract modulus in a carrying loop (in our framework, this
+ consists of two steps; [Columns.unbalanced_sub_cps] combines the
+ input p and the modulus m such that the ith limb in the output is
+ the list [p[i];-m[i]]. We can then call [Columns.compact].)
+ (2) look at the final carry, which should be either 0 or -1. If
+ it's -1, then we add the modulus back in. Otherwise we add 0 for
+ constant-timeness.
+ (3) discard the carry after this last addition; it should be 1 if
+ the carry in step 3 was -1, so they cancel out.
+ *)
+ Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
+ Columns.unbalanced_sub_cps (n3:=n) weight p m
+ (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m
+ (fun carry_r => f (snd carry_r)))
+ .
+
+ Definition freeze {n} mask m p :=
+ @freeze_cps n mask m p _ id.
+ Lemma freeze_id {n} mask m p T f:
+ @freeze_cps n mask m p T f = f (freeze mask m p).
+ Proof.
+ cbv [freeze_cps freeze]; repeat progress autounfold;
+ autorewrite with uncps push_id; reflexivity.
+ Qed.
+ Hint Opaque freeze : uncps.
+ Hint Rewrite @freeze_id : uncps.
+
+ Lemma freezeZ m s c y y0 z z0 c0 a :
+ m = s - c ->
+ 0 < c < s ->
+ s <> 0 ->
+ 0 <= y < 2*m ->
+ y0 = y - m ->
+ z = y0 mod s ->
+ c0 = y0 / s ->
+ z0 = z + (if (dec (c0 = 0)) then 0 else m) ->
+ a = z0 mod s ->
+ a mod m = y0 mod m.
+ Proof.
+ clear. intros. subst. break_match.
+ { rewrite Z.add_0_r, Z.mod_mod by omega.
+ assert (-(s-c) <= y - (s-c) < s-c) by omega.
+ match goal with H : s <> 0 |- _ =>
+ rewrite (proj2 (Z.mod_small_iff _ s H))
+ by (apply Z.div_small_iff; assumption)
+ end.
+ reflexivity. }
+ { rewrite <-Z.add_mod_l, Z.sub_mod_full.
+ rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega.
+ rewrite Z.mod_small with (b := s)
+ by (pose proof (Z.div_small (y - (s-c)) s); omega).
+ f_equal. ring. }
+ Qed.
+
+ Lemma eval_freeze {n} c mask m p
+ (n_nonzero:n<>0%nat)
+ (Hc : 0 < B.Associational.eval c < weight n)
+ (Hmask : Tuple.map (Z.land mask) m = m)
+ modulus (Hm : B.Positional.eval weight m = Z.pos modulus)
+ (Hp : 0 <= B.Positional.eval weight p < 2*(Z.pos modulus))
+ (Hsc : Z.pos modulus = weight n - B.Associational.eval c)
+ :
+ mod_eq modulus
+ (B.Positional.eval weight (@freeze n mask m p))
+ (B.Positional.eval weight p).
+ Proof.
+ cbv [freeze_cps freeze].
+ repeat progress autounfold.
+ pose proof Z.add_get_carry_full_mod.
+ pose proof Z.add_get_carry_full_div.
+ pose proof div_correct. pose proof modulo_correct.
+ autorewrite with uncps push_id push_basesystem_eval.
+
+ pose proof (weight_nonzero n).
+
+ remember (B.Positional.eval weight p) as y.
+ remember (y + -B.Positional.eval weight m) as y0.
+ rewrite Hm in *.
+
+ transitivity y0; cbv [mod_eq].
+ { eapply (freezeZ (Z.pos modulus) (weight n) (B.Associational.eval c) y y0);
+ try assumption; reflexivity. }
+ { subst y0.
+ assert (Z.pos modulus <> 0) by auto using Z.positive_is_nonzero, Zgt_pos_0.
+ rewrite Z.add_mod by assumption.
+ rewrite Z.mod_opp_l_z by auto using Z.mod_same.
+ rewrite Z.add_0_r, Z.mod_mod by assumption.
+ reflexivity. }
+ Qed.
+End Freeze. \ No newline at end of file
diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v
new file mode 100644
index 000000000..0ce1ac265
--- /dev/null
+++ b/src/Arithmetic/Saturated/MontgomeryAPI.v
@@ -0,0 +1,599 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Arithmetic.Saturated.Core.
+Require Import Crypto.Arithmetic.Saturated.UniformWeight.
+Require Import Crypto.Arithmetic.Saturated.Wrappers.
+Require Import Crypto.Arithmetic.Saturated.AddSub.
+Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
+Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
+Require Import Crypto.Util.Tactics Crypto.Util.Decidable.
+Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil.
+Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.Zselect.
+Require Import Crypto.Util.ZUtil.AddGetCarry.
+Require Import Crypto.Util.ZUtil.MulSplit.
+Local Notation "A ^ n" := (tuple A n) : type_scope.
+
+Section API.
+ Context (bound : Z) {bound_pos : bound > 0}.
+ Definition T : nat -> Type := tuple Z.
+
+ (* lowest limb is less than its bound; this is required for [divmod]
+ to simply separate the lowest limb from the rest and be equivalent
+ to normal div/mod with [bound]. *)
+ Local Notation small := (@small bound).
+
+ Definition zero {n:nat} : T n := B.Positional.zeros n.
+
+ (** Returns 0 iff all limbs are 0 *)
+ Definition nonzero_cps {n} (p : T n) {cpsT} (f : Z -> cpsT) : cpsT
+ := CPSUtil.to_list_cps _ p (fun p => CPSUtil.fold_right_cps runtime_lor 0%Z p f).
+ Definition nonzero {n} (p : T n) : Z
+ := nonzero_cps p id.
+
+ Definition join0_cps {n:nat} (p : T n) {R} (f:T (S n) -> R)
+ := Tuple.left_append_cps 0 p f.
+ Definition join0 {n} p : T (S n) := @join0_cps n p _ id.
+
+ Definition divmod_cps {n} (p : T (S n)) {R} (f:T n * Z->R) : R
+ := Tuple.tl_cps p (fun d => Tuple.hd_cps p (fun m => f (d, m))).
+ Definition divmod {n} p : T n * Z := @divmod_cps n p _ id.
+
+ Definition drop_high_cps {n : nat} (p : T (S n)) {R} (f:T n->R)
+ := Tuple.left_tl_cps p f.
+ Definition drop_high {n} p : T n := @drop_high_cps n p _ id.
+
+ Definition scmul_cps {n} (c : Z) (p : T n) {R} (f:T (S n)->R) :=
+ Columns.mul_cps (n1:=1) (n3:=S n) (uweight bound) bound c p
+ (* The carry that comes out of Columns.mul_cps will be 0, since
+ (S n) limbs is enough to hold the result of the
+ multiplication, so we can safely discard it. *)
+ (fun carry_result =>f (snd carry_result)).
+ Definition scmul {n} c p : T (S n) := @scmul_cps n c p _ id.
+
+ Definition add_cps {n} (p q: T n) {R} (f:T (S n)->R) :=
+ B.Positional.sat_add_cps (s:=bound) p q _
+ (* join the last carry *)
+ (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) f).
+ Definition add {n} p q : T (S n) := @add_cps n p q _ id.
+
+ (* Wrappers for additions with slightly uneven limb counts *)
+ Definition add_S1_cps {n} (p: T (S n)) (q: T n) {R} (f:T (S (S n))->R) :=
+ join0_cps q (fun Q => add_cps p Q f).
+ Definition add_S1 {n} p q := @add_S1_cps n p q _ id.
+ Definition add_S2_cps {n} (p: T n) (q: T (S n)) {R} (f:T (S (S n))->R) :=
+ join0_cps p (fun P => add_cps P q f).
+ Definition add_S2 {n} p q := @add_S2_cps n p q _ id.
+
+ Definition sub_then_maybe_add_cps {n} mask (p q r : T n)
+ {R} (f:T n -> R) :=
+ B.Positional.sat_sub_cps (s:=bound) p q _
+ (* the carry will be 0 unless we underflow--we do the addition only
+ in the underflow case *)
+ (fun carry_result =>
+ B.Positional.select_cps mask (fst carry_result) r
+ (fun selected => join0_cps selected
+ (fun selected' =>
+ B.Positional.sat_sub_cps (s:=bound) (left_append (fst carry_result) (snd carry_result)) selected' _
+ (* We can now safely discard the carry and the highest digit.
+ This relies on the precondition that p - q + r < bound^n. *)
+ (fun carry_result' => drop_high_cps (snd carry_result') f)))).
+ Definition sub_then_maybe_add {n} mask (p q r : T n) :=
+ sub_then_maybe_add_cps mask p q r id.
+
+ (* Subtract q if and only if p >= q. We rely on the preconditions
+ that 0 <= p < 2*q and q < bound^n (this ensures the output is less
+ than bound^n). *)
+ Definition conditional_sub_cps {n} (p:Z^S n) (q:Z^n) R (f:Z^n->R) :=
+ join0_cps q
+ (fun qq => B.Positional.sat_sub_cps (s:=bound) p qq _
+ (* if carry is zero, we select the result of the subtraction,
+ otherwise the first input *)
+ (fun carry_result =>
+ Tuple.map2_cps (Z.zselect (fst carry_result)) (snd carry_result) p
+ (* in either case, since our result must be < q and therefore <
+ bound^n, we can drop the high digit *)
+ (fun r => drop_high_cps r f))).
+ Definition conditional_sub {n} p q := @conditional_sub_cps n p q _ id.
+
+ Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps.
+
+ Section CPSProofs.
+
+ Local Ltac prove_id :=
+ repeat autounfold; autorewrite with uncps; reflexivity.
+
+ Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p).
+ Proof. cbv [nonzero nonzero_cps]. prove_id. Qed.
+
+ Lemma join0_id n p R f :
+ @join0_cps n p R f = f (join0 p).
+ Proof. cbv [join0_cps join0]. prove_id. Qed.
+
+ Lemma divmod_id n p R f :
+ @divmod_cps n p R f = f (divmod p).
+ Proof. cbv [divmod_cps divmod]; prove_id. Qed.
+
+ Lemma drop_high_id n p R f :
+ @drop_high_cps n p R f = f (drop_high p).
+ Proof. cbv [drop_high_cps drop_high]; prove_id. Qed.
+ Hint Rewrite drop_high_id : uncps.
+
+ Lemma scmul_id n c p R f :
+ @scmul_cps n c p R f = f (scmul c p).
+ Proof. cbv [scmul_cps scmul]. prove_id. Qed.
+
+ Lemma add_id n p q R f :
+ @add_cps n p q R f = f (add p q).
+ Proof. cbv [add_cps add Let_In]. prove_id. Qed.
+ Hint Rewrite add_id : uncps.
+
+ Lemma add_S1_id n p q R f :
+ @add_S1_cps n p q R f = f (add_S1 p q).
+ Proof. cbv [add_S1_cps add_S1 join0_cps]. prove_id. Qed.
+
+ Lemma add_S2_id n p q R f :
+ @add_S2_cps n p q R f = f (add_S2 p q).
+ Proof. cbv [add_S2_cps add_S2 join0_cps]. prove_id. Qed.
+
+ Lemma sub_then_maybe_add_id n mask p q r R f :
+ @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r).
+ Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add join0_cps Let_In]. prove_id. Qed.
+
+ Lemma conditional_sub_id n p q R f :
+ @conditional_sub_cps n p q R f = f (conditional_sub p q).
+ Proof. cbv [conditional_sub_cps conditional_sub join0_cps Let_In]. prove_id. Qed.
+
+ End CPSProofs.
+ Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps.
+
+ Section Proofs.
+
+ Definition eval {n} (p : T n) : Z :=
+ B.Positional.eval (uweight bound) p.
+
+ Lemma eval_small n (p : T n) (Hsmall : small p) :
+ 0 <= eval p < uweight bound n.
+ Proof.
+ cbv [small eval] in *; intros.
+ induction n; cbv [T uweight] in *; [destruct p|rewrite (subst_left_append p)];
+ repeat match goal with
+ | _ => progress autorewrite with push_basesystem_eval
+ | _ => rewrite Z.pow_0_r
+ | _ => specialize (IHn (left_tl p))
+ | _ =>
+ let H := fresh "H" in
+ match type of IHn with
+ ?P -> _ => assert P as H by auto using Tuple.In_to_list_left_tl;
+ specialize (IHn H)
+ end
+ | |- context [?b ^ Z.of_nat (S ?n)] =>
+ replace (b ^ Z.of_nat (S n)) with (b ^ Z.of_nat n * b) by
+ (rewrite Nat2Z.inj_succ, <-Z.add_1_r, Z.pow_add_r,
+ Z.pow_1_r by (omega || auto using Nat2Z.is_nonneg);
+ reflexivity)
+ | _ => omega
+ end.
+
+ specialize (Hsmall _ (Tuple.In_left_hd _ p)).
+ split; [Z.zero_bounds; omega |].
+ apply Z.lt_le_trans with (m:=bound^Z.of_nat n * (left_hd p+1)).
+ { rewrite Z.mul_add_distr_l.
+ apply Z.add_le_lt_mono; omega. }
+ { apply Z.mul_le_mono_nonneg; omega. }
+ Qed.
+
+ Lemma eval_zero n : eval (@zero n) = 0.
+ Proof.
+ cbv [eval zero].
+ autorewrite with push_basesystem_eval.
+ reflexivity.
+ Qed.
+
+ Lemma small_zero n : small (@zero n).
+ Proof.
+ cbv [zero small B.Positional.zeros]. destruct n; [simpl;tauto|].
+ rewrite to_list_repeat.
+ intros x H; apply repeat_spec in H; subst x; omega.
+ Qed.
+
+ Lemma eval_pair n (p : T (S (S n))) : small p -> (snd p = 0 /\ eval (n:=S n) (fst p) = 0) <-> eval p = 0.
+ Admitted.
+
+ Lemma eval_nonzero n p : small p -> @nonzero n p = 0 <-> eval p = 0.
+ Proof.
+ destruct n as [|n].
+ { compute; split; trivial. }
+ induction n as [|n IHn].
+ { simpl; rewrite Z.lor_0_r; unfold eval, id.
+ cbv -[Z.add iff].
+ rewrite Z.add_0_r.
+ destruct p; omega. }
+ { destruct p as [ps p]; specialize (IHn ps).
+ unfold nonzero, nonzero_cps in *.
+ autorewrite with uncps in *.
+ unfold id in *.
+ setoid_rewrite to_list_S.
+ set (k := S n) in *; simpl in *.
+ intro Hsmall.
+ rewrite Z.lor_eq_0_iff, IHn
+ by (hnf in Hsmall |- *; simpl in *; eauto);
+ clear IHn.
+ exact (eval_pair n (ps, p) Hsmall). }
+ Qed.
+
+ Lemma eval_join0 n p
+ : eval (@join0 n p) = eval p.
+ Proof.
+ Admitted.
+
+ Local Ltac pose_uweight bound :=
+ match goal with H : bound > 0 |- _ =>
+ pose proof (uweight_0 bound);
+ pose proof (@uweight_positive bound H);
+ pose proof (@uweight_nonzero bound H);
+ pose proof (@uweight_multiples bound);
+ pose proof (@uweight_divides bound H)
+ end.
+
+ Local Ltac pose_all :=
+ pose_uweight bound;
+ pose proof Z.add_get_carry_full_div;
+ pose proof Z.add_get_carry_full_mod;
+ pose proof Z.mul_split_div; pose proof Z.mul_split_mod;
+ pose proof div_correct; pose proof modulo_correct.
+
+ Lemma eval_add_nz n p q :
+ n <> 0%nat ->
+ eval (@add n p q) = eval p + eval q.
+ Proof.
+ intros. pose_all.
+ repeat match goal with
+ | _ => progress (cbv [add_cps add eval Let_In] in *; repeat autounfold)
+ | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval
+ | _ => rewrite B.Positional.eval_left_append
+
+ | _ => progress
+ (rewrite <-!from_list_default_eq with (d:=0);
+ erewrite !length_to_list, !from_list_default_eq,
+ from_list_to_list)
+ | _ => apply Z.mod_small; omega
+ end.
+ Admitted.
+
+ Lemma eval_add_z n p q :
+ n = 0%nat ->
+ eval (@add n p q) = eval p + eval q.
+ Proof. intros; subst; reflexivity. Qed.
+
+ Lemma eval_add n p q
+ : eval (@add n p q) = eval p + eval q.
+ Proof.
+ destruct (Nat.eq_dec n 0%nat); intuition auto using eval_add_z, eval_add_nz.
+ Qed.
+ Lemma eval_add_same n p q
+ : eval (@add n p q) = eval p + eval q.
+ Proof. apply eval_add; omega. Qed.
+ Lemma eval_add_S1 n p q
+ : eval (@add_S1 n p q) = eval p + eval q.
+ Proof.
+ cbv [add_S1 add_S1_cps]. autorewrite with uncps push_id.
+ (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*)
+ Admitted.
+ Lemma eval_add_S2 n p q
+ : eval (@add_S2 n p q) = eval p + eval q.
+ Proof.
+ cbv [add_S2 add_S2_cps]. autorewrite with uncps push_id.
+ (*rewrite eval_add; rewrite eval_join0; [reflexivity|assumption].*)
+ Admitted.
+ Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval.
+
+ Lemma uweight_le_mono n m : (n <= m)%nat ->
+ uweight bound n <= uweight bound m.
+ Proof.
+ unfold uweight; intro; Z.peel_le; omega.
+ Qed.
+
+ Lemma uweight_lt_mono (bound_gt_1 : bound > 1) n m : (n < m)%nat ->
+ uweight bound n < uweight bound m.
+ Proof.
+ clear bound_pos.
+ unfold uweight; intro; apply Z.pow_lt_mono_r; omega.
+ Qed.
+
+ Lemma uweight_succ n : uweight bound (S n) = bound * uweight bound n.
+ Proof.
+ unfold uweight.
+ rewrite Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg; reflexivity.
+ Qed.
+
+ Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
+ Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
+ Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)).
+ Proof.
+ pose_all.
+ match goal with
+ |- ?G => assert (G /\ fst (compact p) = fst (compact p)); [|tauto]
+ end. (* assert a dummy second statement so that fst (compact x) is in context *)
+ cbv [compact Columns.compact Columns.compact_cps small
+ Columns.compact_step Columns.compact_step_cps];
+ autorewrite with uncps push_id.
+ change (fun i s a => Columns.compact_digit_cps (uweight bound) i (s :: a) id)
+ with (fun i s a => compact_digit i (s :: a)).
+ remember (fun i s a => compact_digit i (s :: a)) as f.
+
+ apply @mapi_with'_linvariant with (n:=n) (f:=f) (inp:=p);
+ intros; [|simpl; tauto]. split; [|reflexivity].
+ let P := fresh "H" in
+ match goal with H : _ /\ _ |- _ => destruct H end.
+ destruct n0; subst f.
+ { cbv [compact_digit uweight to_list to_list' In].
+ rewrite Columns.compact_digit_mod by assumption.
+ rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?.
+ match goal with
+ H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end.
+ subst x. apply Z.mod_pos_bound, Z.gt_lt, bound_pos. }
+ { rewrite Tuple.to_list_left_append.
+ let H := fresh "H" in
+ intros x H; apply in_app_or in H; destruct H;
+ [solve[auto]| cbv [In] in H; destruct H;
+ [|exfalso; assumption] ].
+ subst x. cbv [compact_digit].
+ rewrite Columns.compact_digit_mod by assumption.
+ rewrite !uweight_succ, Z.div_mul by
+ (apply Z.neq_mul_0; split; auto; omega).
+ apply Z.mod_pos_bound, Z.gt_lt, bound_pos. }
+ Qed.
+
+ Lemma small_add n a b :
+ (2 <= bound) ->
+ small a -> small b -> small (@add n a b).
+ Proof.
+ intros. pose_all.
+ cbv [add_cps add Let_In].
+ autorewrite with uncps push_id.
+ (*apply Positional.small_sat_add.*)
+ Admitted.
+
+ Lemma small_add_S1 n a b :
+ (2 <= bound) ->
+ small a -> small b -> small (@add_S1 n a b).
+ Proof.
+ intros. pose_all.
+ cbv [add_cps add add_S1 Let_In].
+ (*apply Positional.small_sat_add.*)
+ Admitted.
+
+ Lemma small_add_S2 n a b :
+ (2 <= bound) ->
+ small a -> small b -> small (@add_S2 n a b).
+ Proof.
+ intros. pose_all.
+ cbv [add_cps add add_S2 Let_In].
+ autorewrite with uncps push_id.
+ (*apply Positional.small_sat_add.*)
+ Admitted.
+
+ Lemma small_left_tl n (v:T (S n)) : small v -> small (left_tl v).
+ Proof. cbv [small]. auto using Tuple.In_to_list_left_tl. Qed.
+
+ Lemma small_divmod n (p: T (S n)) (Hsmall : small p) :
+ left_hd p = eval p / uweight bound n /\ eval (left_tl p) = eval p mod (uweight bound n).
+ Admitted.
+
+ Lemma eval_drop_high n v :
+ small v -> eval (@drop_high n v) = eval v mod (uweight bound n).
+ Proof.
+ cbv [drop_high drop_high_cps eval].
+ rewrite Tuple.left_tl_cps_correct, push_id. (* TODO : for some reason autorewrite with uncps doesn't work here *)
+ intro H. apply small_left_tl in H.
+ rewrite (subst_left_append v) at 2.
+ autorewrite with push_basesystem_eval.
+ apply eval_small in H.
+ rewrite Z.mod_add_l' by (pose_uweight bound; auto).
+ rewrite Z.mod_small; auto.
+ Qed.
+
+ Lemma small_drop_high n v : small v -> small (@drop_high n v).
+ Proof.
+ cbv [drop_high drop_high_cps].
+ rewrite Tuple.left_tl_cps_correct, push_id.
+ apply small_left_tl.
+ Qed.
+
+ Lemma div_nonzero_neg_iff x y : x < y -> 0 < y -> x / y <> 0 <-> x < 0.
+ Proof.
+ repeat match goal with
+ | _ => progress intros
+ | _ => rewrite Z.div_small_iff by omega
+ | _ => split
+ | _ => omega
+ end.
+ Qed.
+
+ Lemma eval_sub_then_maybe_add_nz n mask p q r:
+ small p -> small q -> small r -> (n<>0)%nat ->
+ (map (Z.land mask) r = r) ->
+ (0 <= eval p < eval r) -> (0 <= eval q < eval r) ->
+ eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0).
+ Proof.
+ pose_all.
+ repeat match goal with
+ | _ => progress (cbv [sub_then_maybe_add sub_then_maybe_add_cps eval] in *; intros)
+ | _ => progress autounfold
+ | _ => progress autorewrite with uncps push_id push_basesystem_eval
+ | _ => rewrite eval_drop_high
+ | _ => rewrite eval_join0
+ | H : small _ |- _ => apply eval_small in H
+ | _ => progress break_match
+ | _ => (rewrite Z.add_opp_r in * )
+ | H : _ |- _ => rewrite Z.ltb_lt in H;
+ rewrite <-div_nonzero_neg_iff with
+ (y:=uweight bound n) in H by (auto; omega)
+ | H : _ |- _ => rewrite Z.ltb_ge in H
+ | _ => rewrite Z.mod_small by omega
+ | _ => omega
+ | _ => progress autorewrite with zsimplify; [ ]
+ end.
+ Admitted.
+
+ Lemma eval_sub_then_maybe_add n mask p q r :
+ small p -> small q -> small r ->
+ (map (Z.land mask) r = r) ->
+ (0 <= eval p < eval r) -> (0 <= eval q < eval r) ->
+ eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0).
+ Proof.
+ destruct n; [|solve[auto using eval_sub_then_maybe_add_nz]].
+ destruct p, q, r; reflexivity.
+ Qed.
+
+ Lemma small_sub_then_maybe_add n mask (p q r : T n) :
+ small (sub_then_maybe_add mask p q r).
+ Proof.
+ cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros.
+ repeat progress autounfold. autorewrite with uncps push_id.
+ apply small_drop_high, B.Positional.small_sat_sub.
+ Qed.
+
+ (* TODO : remove if unneeded when all admits are proven
+ Lemma small_highest_zero_iff {n} (p: T (S n)) (Hsmall : small p) :
+ (left_hd p = 0 <-> eval p < uweight bound n).
+ Proof.
+ destruct (small_divmod _ p Hsmall) as [Hdiv Hmod].
+ pose proof Hsmall as Hsmalltl. apply eval_small in Hsmall.
+ apply small_left_tl, eval_small in Hsmalltl. rewrite Hdiv.
+ rewrite (Z.div_small_iff (eval p) (uweight bound n))
+ by auto using uweight_nonzero.
+ split; [|intros; left; omega].
+ let H := fresh "H" in intro H; destruct H; [|omega].
+ omega.
+ Qed.
+ *)
+
+ Lemma map2_zselect n cond x y :
+ Tuple.map2 (n:=n) (Z.zselect cond) x y = if dec (cond = 0) then x else y.
+ Proof.
+ unfold Z.zselect.
+ break_innermost_match; Z.ltb_to_lt; subst; try omega;
+ [ rewrite Tuple.map2_fst, Tuple.map_id
+ | rewrite Tuple.map2_snd, Tuple.map_id ];
+ reflexivity.
+ Qed.
+
+ Lemma eval_conditional_sub_nz n (p:T (S n)) (q:T n)
+ (n_nonzero: (n <> 0)%nat) (psmall : small p) (qsmall : small q):
+ 0 <= eval p < eval q + uweight bound n ->
+ eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0).
+ Proof.
+ cbv [conditional_sub conditional_sub_cps]. intros. pose_all.
+ repeat autounfold. apply eval_small in qsmall.
+ pose proof psmall; apply eval_small in psmall.
+ cbv [eval] in *. autorewrite with uncps push_id push_basesystem_eval.
+ rewrite map2_zselect.
+ let H := fresh "H" in let X := fresh "P" in
+ match goal with |- context [?x / ?y] =>
+ pose proof (div_nonzero_neg_iff x y) end;
+ repeat match type of H with ?P -> _ =>
+ assert P as X by omega; specialize (H X);
+ clear X end.
+
+ break_match;
+ repeat match goal with
+ | _ => progress cbv [eval]
+ | H : (_ <=? _) = true |- _ => apply Z.leb_le in H
+ | H : (_ <=? _) = false |- _ => apply Z.leb_gt in H
+ | _ => rewrite eval_drop_high by auto using B.Positional.small_sat_sub
+ | _ => (rewrite eval_join0 in * )
+ | _ => progress autorewrite with uncps push_id push_basesystem_eval
+ | _ => repeat rewrite Z.mod_small; omega
+ | _ => omega
+ end.
+ Admitted.
+
+ Lemma eval_conditional_sub n (p:T (S n)) (q:T n)
+ (psmall : small p) (qsmall : small q) :
+ 0 <= eval p < eval q + uweight bound n ->
+ eval (conditional_sub p q) = eval p + (if eval q <=? eval p then - eval q else 0).
+ Proof.
+ destruct n; [|solve[auto using eval_conditional_sub_nz]].
+ repeat match goal with
+ | _ => progress (intros; cbv [T tuple tuple'] in p, q)
+ | q : unit |- _ => destruct q
+ | _ => progress (cbv [conditional_sub conditional_sub_cps eval] in * )
+ | _ => progress autounfold
+ | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * )
+ | _ => (rewrite uweight_0 in * )
+ | _ => assert (p = 0) by omega; subst p; break_match; ring
+ end.
+ Qed.
+
+ Lemma small_conditional_sub n (p:T (S n)) (q:T n)
+ (psmall : small p) (qsmall : small q) :
+ 0 <= eval p < eval q + uweight bound n ->
+ small (conditional_sub p q).
+ Admitted.
+
+ Lemma eval_scmul n a v : small v -> 0 <= a < bound ->
+ eval (@scmul n a v) = a * eval v.
+ Proof.
+ intro Hsmall. pose_all. apply eval_small in Hsmall.
+ intros. cbv [scmul scmul_cps eval] in *. repeat autounfold.
+ autorewrite with uncps push_id push_basesystem_eval.
+ rewrite uweight_0, Z.mul_1_l. apply Z.mod_small.
+ split; [solve[Z.zero_bounds]|]. cbv [uweight] in *.
+ rewrite !Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg.
+ apply Z.mul_lt_mono_nonneg; omega.
+ Qed.
+
+ Lemma small_scmul n a v : small (@scmul n a v).
+ Proof.
+ cbv [scmul scmul_cps eval] in *. repeat autounfold.
+ autorewrite with uncps push_id push_basesystem_eval.
+ apply small_compact.
+ Qed.
+
+ (* TODO : move to tuple *)
+ Lemma from_list_tl {A n} (ls : list A) H H':
+ from_list n (List.tl ls) H = tl (from_list (S n) ls H').
+ Proof.
+ induction ls; distr_length. simpl List.tl.
+ rewrite from_list_cons, tl_append, <-!(from_list_default_eq a ls).
+ reflexivity.
+ Qed.
+
+ Lemma small_hd n p : @small (S n) p -> 0 <= hd p < bound.
+ Proof.
+ cbv [small]. let H := fresh "H" in intro H; apply H.
+ rewrite (subst_append p). rewrite to_list_append, hd_append.
+ apply in_eq.
+ Qed.
+
+
+ Lemma eval_div n p : small p -> eval (fst (@divmod n p)) = eval p / bound.
+ Proof.
+ cbv [divmod divmod_cps eval]. intros.
+ autorewrite with uncps push_id cancel_pair.
+ rewrite (subst_append p) at 2.
+ rewrite uweight_eval_step. rewrite hd_append, tl_append.
+ rewrite Z.div_add' by omega. rewrite Z.div_small by auto using small_hd.
+ ring.
+ Qed.
+
+ Lemma eval_mod n p : small p -> snd (@divmod n p) = eval p mod bound.
+ Proof.
+ cbv [divmod divmod_cps eval]. intros.
+ autorewrite with uncps push_id cancel_pair.
+ rewrite (subst_append p) at 2.
+ rewrite uweight_eval_step, Z.mod_add'_full, hd_append.
+ rewrite Z.mod_small by auto using small_hd. reflexivity.
+ Qed.
+
+ Lemma small_div n v : small v -> small (fst (@divmod n v)).
+ Admitted.
+
+ End Proofs.
+End API.
+Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id add_S1_id add_S2_id sub_then_maybe_add_id conditional_sub_id : uncps. \ No newline at end of file
diff --git a/src/Arithmetic/Saturated/MulSplit.v b/src/Arithmetic/Saturated/MulSplit.v
new file mode 100644
index 000000000..45f37ef56
--- /dev/null
+++ b/src/Arithmetic/Saturated/MulSplit.v
@@ -0,0 +1,73 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
+
+(* Defines bignum multiplication using a two-output multiply operation. *)
+Module B.
+ Module Associational.
+ Section Associational.
+ Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *)
+ {mul_split_mod : forall s x y,
+ fst (mul_split s x y) = (x * y) mod s}
+ {mul_split_div : forall s x y,
+ snd (mul_split s x y) = (x * y) / s}
+ .
+
+ Definition sat_multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) :=
+ dlet xy := mul_split s (snd t) (snd t') in
+ f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil).
+
+ Definition sat_multerm s t t' := sat_multerm_cps s t t' id.
+ Lemma sat_multerm_id s t t' T f :
+ @sat_multerm_cps s t t' T f = f (sat_multerm s t t').
+ Proof. reflexivity. Qed.
+ Hint Opaque sat_multerm : uncps.
+ Hint Rewrite sat_multerm_id : uncps.
+
+ Definition sat_mul_cps s (p q : list B.limb) {T} (f : list B.limb -> T) :=
+ flat_map_cps (fun t => @flat_map_cps _ _ (sat_multerm_cps s t) q) p f.
+
+ Definition sat_mul s p q := sat_mul_cps s p q id.
+ Lemma sat_mul_id s p q T f : @sat_mul_cps s p q T f = f (sat_mul s p q).
+ Proof. cbv [sat_mul sat_mul_cps]. autorewrite with uncps. reflexivity. Qed.
+ Hint Opaque sat_mul : uncps.
+ Hint Rewrite sat_mul_id : uncps.
+
+ Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0):
+ B.Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * B.Associational.eval q.
+ Proof.
+ cbv [sat_multerm sat_multerm_cps Let_In]; induction q;
+ repeat match goal with
+ | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * )
+ | _ => progress simpl flat_map
+ | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div
+ | _ => rewrite Z.mod_eq by assumption
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_map_sat_multerm using (omega || assumption)
+ : push_basesystem_eval.
+
+ Lemma eval_sat_mul s p q (s_nonzero:s<>0):
+ B.Associational.eval (sat_mul s p q) = B.Associational.eval p * B.Associational.eval q.
+ Proof.
+ cbv [sat_mul sat_mul_cps]; induction p; [reflexivity|].
+ repeat match goal with
+ | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * )
+ | _ => progress simpl flat_map
+ | _ => rewrite IHp
+ | _ => progress change (fun x => sat_multerm_cps s a x id) with (sat_multerm s a)
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_sat_mul : push_basesystem_eval.
+ End Associational.
+ End Associational.
+End B.
+Hint Opaque B.Associational.sat_mul B.Associational.sat_multerm : uncps.
+Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id : uncps.
+Hint Rewrite @B.Associational.eval_sat_mul @B.Associational.eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval.
+
diff --git a/src/Arithmetic/Saturated/UniformWeight.v b/src/Arithmetic/Saturated/UniformWeight.v
new file mode 100644
index 000000000..51eb71b0b
--- /dev/null
+++ b/src/Arithmetic/Saturated/UniformWeight.v
@@ -0,0 +1,71 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Arithmetic.Saturated.Core.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.LetIn Crypto.Util.Tuple.
+Local Notation "A ^ n" := (tuple A n) : type_scope.
+
+Section UniformWeight.
+ Context (bound : Z) {bound_pos : bound > 0}.
+
+ Definition uweight : nat -> Z := fun i => bound ^ Z.of_nat i.
+ Lemma uweight_0 : uweight 0%nat = 1. Proof. reflexivity. Qed.
+ Lemma uweight_positive i : uweight i > 0.
+ Proof. apply Z.lt_gt, Z.pow_pos_nonneg; omega. Qed.
+ Lemma uweight_nonzero i : uweight i <> 0.
+ Proof. auto using Z.positive_is_nonzero, uweight_positive. Qed.
+ Lemma uweight_multiples i : uweight (S i) mod uweight i = 0.
+ Proof. apply Z.mod_same_pow; rewrite Nat2Z.inj_succ; omega. Qed.
+ Lemma uweight_divides i : uweight (S i) / uweight i > 0.
+ Proof.
+ cbv [uweight]. rewrite <-Z.pow_sub_r by (rewrite ?Nat2Z.inj_succ; omega).
+ apply Z.lt_gt, Z.pow_pos_nonneg; rewrite ?Nat2Z.inj_succ; omega.
+ Qed.
+
+ (* TODO : move to Positional *)
+ Lemma eval_from_eq {n} (p:Z^n) wt offset :
+ (forall i, wt i = uweight (i + offset)) ->
+ B.Positional.eval wt p = B.Positional.eval_from uweight offset p.
+ Proof. cbv [B.Positional.eval_from]. auto using B.Positional.eval_wt_equiv. Qed.
+
+ Lemma uweight_eval_from {n} (p:Z^n): forall offset,
+ B.Positional.eval_from uweight offset p = uweight offset * B.Positional.eval uweight p.
+ Proof.
+ induction n; intros; cbv [B.Positional.eval_from];
+ [|rewrite (subst_append p)];
+ repeat match goal with
+ | _ => destruct p
+ | _ => rewrite B.Positional.eval_unit; [ ]
+ | _ => rewrite B.Positional.eval_step; [ ]
+ | _ => rewrite IHn; [ ]
+ | _ => rewrite eval_from_eq with (offset0:=S offset)
+ by (intros; f_equal; omega)
+ | _ => rewrite eval_from_eq with
+ (wt:=fun i => uweight (S i)) (offset0:=1%nat)
+ by (intros; f_equal; omega)
+ | _ => ring
+ end.
+ repeat match goal with
+ | _ => cbv [uweight]; progress autorewrite with natsimplify
+ | _ => progress (rewrite ?Nat2Z.inj_succ, ?Nat2Z.inj_0, ?Z.pow_0_r)
+ | _ => rewrite !Z.pow_succ_r by (try apply Nat2Z.is_nonneg; omega)
+ | _ => ring
+ end.
+ Qed.
+
+ Lemma uweight_eval_step {n} (p:Z^S n):
+ B.Positional.eval uweight p = hd p + bound * B.Positional.eval uweight (tl p).
+ Proof.
+ rewrite (subst_append p) at 1; rewrite B.Positional.eval_step.
+ rewrite eval_from_eq with (offset := 1%nat) by (intros; f_equal; omega).
+ rewrite uweight_eval_from. cbv [uweight]; rewrite Z.pow_0_r, Z.pow_1_r.
+ ring.
+ Qed.
+
+ Definition small {n} (p : Z^n) : Prop :=
+ forall x, In x (to_list _ p) -> 0 <= x < bound.
+
+End UniformWeight. \ No newline at end of file
diff --git a/src/Arithmetic/Saturated/Wrappers.v b/src/Arithmetic/Saturated/Wrappers.v
new file mode 100644
index 000000000..e1da74e60
--- /dev/null
+++ b/src/Arithmetic/Saturated/Wrappers.v
@@ -0,0 +1,53 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Arithmetic.Saturated.Core.
+Require Import Crypto.Arithmetic.Saturated.MulSplit.
+Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.MulSplit.
+Require Import Crypto.Util.Tuple.
+Local Notation "A ^ n" := (tuple A n) : type_scope.
+
+(* Define wrapper definitions that use Columns representation
+internally but with input and output in Positonal representation.*)
+Module Columns.
+ Section Wrappers.
+ Context (weight : nat->Z).
+
+ Definition add_cps {n1 n2 n3} (p : Z^n1) (q : Z^n2)
+ {T} (f : (Z*Z^n3)->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.to_associational_cps weight q
+ (fun Q => Columns.from_associational_cps weight n3 (P++Q)
+ (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))).
+
+ Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2)
+ {T} (f : (Z*Z^n3)->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.negate_snd_cps weight q
+ (fun nq => B.Positional.to_associational_cps weight nq
+ (fun Q => Columns.from_associational_cps weight n3 (P++Q)
+ (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
+
+ Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2)
+ {T} (f : (Z*Z^n3)->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.to_associational_cps weight q
+ (fun Q => B.Associational.sat_mul_cps (mul_split := Z.mul_split) s P Q
+ (fun PQ => Columns.from_associational_cps weight n3 PQ
+ (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
+
+ Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2)
+ {T} (f:_->T) :=
+ B.Positional.select_cps mask cond q
+ (fun qq => add_cps (n3:=n3) p qq f).
+
+ End Wrappers.
+End Columns.
+Hint Unfold
+ Columns.conditional_add_cps
+ Columns.add_cps
+ Columns.unbalanced_sub_cps
+ Columns.mul_cps. \ No newline at end of file