aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Saturated
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-06-30 23:11:55 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2017-06-30 23:11:55 -0400
commit3b0113a9c52855d5362eeaebabe2556efcafcb87 (patch)
tree9734fd5ed429fc2e0e670541e2e38d867fbc3afa /src/Arithmetic/Saturated
parent518f79958112a93eae30942e62096173c4fb0b28 (diff)
Prove saturated carrying-addition-chain correct
Diffstat (limited to 'src/Arithmetic/Saturated')
-rw-r--r--src/Arithmetic/Saturated/AddSub.v61
-rw-r--r--src/Arithmetic/Saturated/Core.v3
-rw-r--r--src/Arithmetic/Saturated/MontgomeryAPI.v21
-rw-r--r--src/Arithmetic/Saturated/UniformWeight.v20
4 files changed, 81 insertions, 24 deletions
diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v
index c6758b865..76d369c50 100644
--- a/src/Arithmetic/Saturated/AddSub.v
+++ b/src/Arithmetic/Saturated/AddSub.v
@@ -6,13 +6,14 @@ Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Arithmetic.Saturated.Core.
Require Import Crypto.Arithmetic.Saturated.UniformWeight.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.AddGetCarry.
Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
Local Notation "A ^ n" := (tuple A n) : type_scope.
Module B.
Module Positional.
Section Positional.
- Context {s:Z}. (* s is bitwidth *)
+ Context {s:Z} {s_pos : 0 < s}. (* s is bitwidth *)
Let small {n} := @small s n.
Section GenericOp.
Context {op : Z -> Z -> Z}
@@ -54,8 +55,16 @@ Module B.
@chain_op_cps n p q T f = f (chain_op p q).
Proof. apply chain_op'_id. Qed.
End GenericOp.
+ Hint Opaque chain_op chain_op' : uncps.
+ Hint Rewrite @chain_op_id @chain_op'_id : uncps.
Section AddSub.
+ Create HintDb divmod discriminated.
+ Hint Rewrite Z.add_get_carry_full_mod
+ Z.add_get_carry_full_div
+ Z.add_with_get_carry_full_mod
+ Z.add_with_get_carry_full_div
+ : divmod.
Let eval {n} := B.Positional.eval (n:=n) (uweight s).
Definition sat_add_cps {n} p q T (f:Z*Z^n->T) :=
@@ -68,13 +77,59 @@ Module B.
@sat_add_cps n p q T f = f (sat_add p q).
Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed.
+ Lemma sat_add_mod_step n c d :
+ c mod s + s * ((d + c / s) mod (uweight s n))
+ = (s * d + c) mod (s * uweight s n).
+ Proof.
+ assert (0 < uweight s n) as wt_pos
+ by auto using Z.lt_gt, Z.gt_lt, uweight_positive.
+ rewrite <-(Columns.compact_mod_step s (uweight s n) c d s_pos wt_pos).
+ repeat (ring_simplify; f_equal; ring_simplify; try omega).
+ Qed.
+
+ Lemma sat_add_div_step n c d :
+ (d + c / s) / uweight s n = (s * d + c) / (s * uweight s n).
+ Proof.
+ assert (0 < uweight s n) as wt_pos
+ by auto using Z.lt_gt, Z.gt_lt, uweight_positive.
+ rewrite <-(Columns.compact_div_step s (uweight s n) c d s_pos wt_pos).
+ repeat (ring_simplify; f_equal; ring_simplify; try omega).
+ Qed.
+
+ Lemma sat_add_divmod n p q :
+ eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n)
+ /\ fst (@sat_add n p q) = (eval p + eval q) / (uweight s n).
+ Proof.
+ cbv [sat_add sat_add_cps chain_op_cps].
+ remember None as c.
+ replace (eval p + eval q) with
+ (eval p + eval q + match c with | None => 0 | Some x => x end)
+ by (subst; ring).
+ destruct Heqc. revert c.
+ induction n; [|destruct c]; intros; simpl chain_op'_cps;
+ repeat match goal with
+ | _ => progress cbv [eval Let_In] in *
+ | _ => progress autorewrite with uncps divmod push_id cancel_pair push_basesystem_eval
+ | _ => rewrite uweight_0, ?Z.mod_1_r, ?Z.div_1_r
+ | _ => rewrite uweight_succ
+ | p : Z ^ 0 |- _ => destruct p
+ | _ => rewrite uweight_eval_step, ?hd_append, ?tl_append
+ | |- context[B.Positional.eval _ (snd (chain_op' ?c ?p ?q))]
+ => specialize (IHn p q c); autorewrite with push_id uncps in IHn;
+ rewrite (proj1 IHn); rewrite (proj2 IHn)
+ | _ => tauto
+ end;
+ (split; [rewrite sat_add_mod_step | rewrite sat_add_div_step];
+ f_equal; ring_simplify; omega).
+ Qed.
+
Lemma sat_add_mod n p q :
eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n).
- Admitted.
+ Proof. exact (proj1 (sat_add_divmod n p q)). Qed.
Lemma sat_add_div n p q :
fst (@sat_add n p q) = (eval p + eval q) / (uweight s n).
- Admitted.
+ Proof. exact (proj2 (sat_add_divmod n p q)). Qed.
Lemma small_sat_add n p q : small (snd (@sat_add n p q)).
Admitted.
diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v
index 27171c741..355d6b429 100644
--- a/src/Arithmetic/Saturated/Core.v
+++ b/src/Arithmetic/Saturated/Core.v
@@ -246,6 +246,7 @@ Module Columns.
Lemma compact_mod_step a b c d: 0 < a -> 0 < b ->
a * ((c / a + d) mod b) + c mod a = (a * d + c) mod (a * b).
Proof.
+ clear.
intros Ha Hb. assert (a <= a * b) by (apply Z.le_mul_diag_r; omega).
pose proof (Z.mod_pos_bound c a Ha).
pose proof (Z.mod_pos_bound (c/a+d) b Hb).
@@ -262,7 +263,7 @@ Module Columns.
Lemma compact_div_step a b c d : 0 < a -> 0 < b ->
(c / a + d) / b = (a * d + c) / (a * b).
Proof.
- intros Ha Hb.
+ clear. intros Ha Hb.
rewrite <-Z.div_div by omega.
rewrite Z.div_add_l' by omega.
f_equal; ring.
diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v
index 0ce1ac265..d2ad92e4f 100644
--- a/src/Arithmetic/Saturated/MontgomeryAPI.v
+++ b/src/Arithmetic/Saturated/MontgomeryAPI.v
@@ -291,25 +291,6 @@ Section API.
Admitted.
Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval.
- Lemma uweight_le_mono n m : (n <= m)%nat ->
- uweight bound n <= uweight bound m.
- Proof.
- unfold uweight; intro; Z.peel_le; omega.
- Qed.
-
- Lemma uweight_lt_mono (bound_gt_1 : bound > 1) n m : (n < m)%nat ->
- uweight bound n < uweight bound m.
- Proof.
- clear bound_pos.
- unfold uweight; intro; apply Z.pow_lt_mono_r; omega.
- Qed.
-
- Lemma uweight_succ n : uweight bound (S n) = bound * uweight bound n.
- Proof.
- unfold uweight.
- rewrite Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg; reflexivity.
- Qed.
-
Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)).
@@ -455,7 +436,7 @@ Section API.
Proof.
cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros.
repeat progress autounfold. autorewrite with uncps push_id.
- apply small_drop_high, B.Positional.small_sat_sub.
+ apply small_drop_high, @B.Positional.small_sat_sub; omega.
Qed.
(* TODO : remove if unneeded when all admits are proven
diff --git a/src/Arithmetic/Saturated/UniformWeight.v b/src/Arithmetic/Saturated/UniformWeight.v
index 51eb71b0b..bd351b6cd 100644
--- a/src/Arithmetic/Saturated/UniformWeight.v
+++ b/src/Arithmetic/Saturated/UniformWeight.v
@@ -65,6 +65,26 @@ Section UniformWeight.
ring.
Qed.
+ Lemma uweight_le_mono n m : (n <= m)%nat ->
+ uweight n <= uweight m.
+ Proof.
+ unfold uweight; intro; Z.peel_le; omega.
+ Qed.
+
+ Lemma uweight_lt_mono (bound_gt_1 : bound > 1) n m : (n < m)%nat ->
+ uweight n < uweight m.
+ Proof.
+ clear bound_pos.
+ unfold uweight; intro; apply Z.pow_lt_mono_r; omega.
+ Qed.
+
+ Lemma uweight_succ n : uweight (S n) = bound * uweight n.
+ Proof.
+ unfold uweight.
+ rewrite Nat2Z.inj_succ, Z.pow_succ_r by auto using Nat2Z.is_nonneg; reflexivity.
+ Qed.
+
+
Definition small {n} (p : Z^n) : Prop :=
forall x, In x (to_list _ p) -> 0 <= x < bound.