aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-10-20 00:17:37 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-10-20 00:17:37 -0400
commit19f2d44ac898a345b7f35acba5103cc476005766 (patch)
tree2e18a676ca9e011347cd82c2b4d3981a788f1f06 /src/Arithmetic
parenta2e8a332f4ea13d90d56e00fdc8419d8ba245e5c (diff)
Fix bug in previous commit
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/Saturated/Core.v38
-rw-r--r--src/Arithmetic/Saturated/Freeze.v1
-rw-r--r--src/Arithmetic/Saturated/MontgomeryAPI.v5
-rw-r--r--src/Arithmetic/Saturated/Wrappers.v6
4 files changed, 29 insertions, 21 deletions
diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v
index fa7bbe919..631c842e2 100644
--- a/src/Arithmetic/Saturated/Core.v
+++ b/src/Arithmetic/Saturated/Core.v
@@ -113,19 +113,25 @@ Module Columns.
{weight_divides : forall i : nat, weight (S i) / weight i > 0}
(* add_get_carry takes in a number at which to split output *)
{add_get_carry_cps: forall {T}, Z ->Z -> Z -> (Z * Z -> T) -> T}
- {add_get_carry_cps_id : forall {T} s x y f,
- @add_get_carry_cps T s x y f = f (@add_get_carry_cps _ s x y id)}
+ {div_cps modulo_cps : forall {T}, Z -> Z -> (Z -> T) -> T}.
+ Let add_get_carry s x y := add_get_carry_cps _ s x y id.
+ Let div x y := div_cps _ x y id.
+ Let modulo x y := modulo_cps _ x y id.
+ Context {add_get_carry_cps_id : forall {T} s x y f,
+ @add_get_carry_cps T s x y f = f (add_get_carry s x y)}
{add_get_carry_mod : forall s x y,
- fst (add_get_carry_cps s x y id) = (x + y) mod s}
+ fst (add_get_carry s x y) = (x + y) mod s}
{add_get_carry_div : forall s x y,
- snd (add_get_carry_cps s x y id) = (x + y) / s}
- {div modulo : Z -> Z -> Z}
+ snd (add_get_carry s x y) = (x + y) / s}
+ {div_cps_id : forall {T} x y f,
+ @div_cps T x y f = f (div x y)}
+ {modulo_cps_id : forall {T} x y f,
+ @modulo_cps T x y f = f (modulo x y)}
{div_correct : forall a b, div a b = a / b}
{modulo_correct : forall a b, modulo a b = a mod b}
.
Hint Rewrite div_correct modulo_correct add_get_carry_mod add_get_carry_div : div_mod.
- Let add_get_carry s x y := add_get_carry_cps _ s x y id.
- Hint Rewrite (add_get_carry_cps_id : forall T s x y f, _ = f (@add_get_carry s x y)) : uncps.
+ Hint Rewrite add_get_carry_cps_id div_cps_id modulo_cps_id : uncps.
Definition eval {n} (x : (list Z)^n) : Z :=
B.Positional.eval weight (Tuple.map sum x).
@@ -166,8 +172,8 @@ Module Columns.
Fixpoint compact_digit_cps (digit : list Z) (f:Z * Z->T) :=
match digit with
| nil => f (0, 0)
- | x :: nil => div_cps x (weight (S n) / weight n) (fun d =>
- modulo_cps x (weight (S n) / weight n) (fun m =>
+ | x :: nil => div_cps _ x (weight (S n) / weight n) (fun d =>
+ modulo_cps _ x (weight (S n) / weight n) (fun m =>
f (d, m)))
| x :: y :: nil =>
add_get_carry_cps _ (weight (S n) / weight n) x y (fun sum_carry =>
@@ -187,7 +193,7 @@ Module Columns.
Definition compact_digit n digit := compact_digit_cps n digit id.
Lemma compact_digit_id n digit: forall {T} f,
@compact_digit_cps n T digit f = f (compact_digit n digit).
- Proof using add_get_carry_cps_id.
+ Proof using add_get_carry_cps_id div_cps_id modulo_cps_id.
induction digit; intros; cbv [compact_digit]; [reflexivity|].
simpl compact_digit_cps; break_match; rewrite ?IHdigit; clear IHdigit;
cbv [Let_In]; autorewrite with uncps; reflexivity.
@@ -202,7 +208,7 @@ Module Columns.
Definition compact_step i c d := compact_step_cps i c d id.
Lemma compact_step_id i c d T f :
@compact_step_cps i c d T f = f (compact_step i c d).
- Proof using add_get_carry_cps_id. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed.
+ Proof using add_get_carry_cps_id div_cps_id modulo_cps_id. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed.
Hint Opaque compact_step : uncps.
Hint Rewrite compact_step_id : uncps.
@@ -211,15 +217,15 @@ Module Columns.
Definition compact {n} xs := @compact_cps n xs _ id.
Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs).
- Proof using add_get_carry_cps_id. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed.
+ Proof using add_get_carry_cps_id div_cps_id modulo_cps_id. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed.
Lemma compact_digit_mod i (xs : list Z) :
snd (compact_digit i xs) = sum xs mod (weight (S i) / weight i).
- Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct add_get_carry_cps_id.
+ Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct add_get_carry_cps_id div_cps_id modulo_cps_id.
induction xs; cbv [compact_digit]; simpl compact_digit_cps;
cbv [Let_In];
repeat match goal with
- | _ => cbv [add_get_carry]; progress autorewrite with div_mod
+ | _ => progress autorewrite with div_mod
| _ => rewrite IHxs, <-Z.add_mod_r
| _ => progress (rewrite ?sum_cons, ?sum_nil in * )
| _ => progress (autorewrite with uncps push_id cancel_pair in * )
@@ -231,11 +237,11 @@ Module Columns.
Lemma compact_digit_div i (xs : list Z) :
fst (compact_digit i xs) = sum xs / (weight (S i) / weight i).
- Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides add_get_carry_cps_id.
+ Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides add_get_carry_cps_id div_cps_id modulo_cps_id.
induction xs; cbv [compact_digit]; simpl compact_digit_cps;
cbv [Let_In];
repeat match goal with
- | _ => cbv [add_get_carry]; progress autorewrite with div_mod
+ | _ => progress autorewrite with div_mod
| _ => rewrite IHxs
| _ => progress (rewrite ?sum_cons, ?sum_nil in * )
| _ => progress (autorewrite with uncps push_id cancel_pair in * )
diff --git a/src/Arithmetic/Saturated/Freeze.v b/src/Arithmetic/Saturated/Freeze.v
index b56a69a3d..658eb867d 100644
--- a/src/Arithmetic/Saturated/Freeze.v
+++ b/src/Arithmetic/Saturated/Freeze.v
@@ -102,6 +102,7 @@ Section Freeze.
pose proof Z.add_get_carry_full_mod.
pose proof Z.add_get_carry_full_div.
pose proof div_correct. pose proof modulo_correct.
+ pose proof @div_id. pose proof @modulo_id.
pose proof @Z.add_get_carry_full_cps_correct.
autorewrite with uncps push_id push_basesystem_eval.
diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v
index fb896749b..49e45663d 100644
--- a/src/Arithmetic/Saturated/MontgomeryAPI.v
+++ b/src/Arithmetic/Saturated/MontgomeryAPI.v
@@ -285,6 +285,7 @@ Section API.
pose proof Z.add_get_carry_full_mod;
pose proof Z.mul_split_div; pose proof Z.mul_split_mod;
pose proof div_correct; pose proof modulo_correct;
+ pose proof @div_id; pose proof @modulo_id;
pose proof @Z.add_get_carry_full_cps_correct;
pose proof @Z.mul_split_cps_correct;
pose proof @Z.mul_split_cps'_correct.
@@ -314,8 +315,8 @@ Section API.
Qed.
Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval.
- Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound).
- Local Definition compact_digit := Columns.compact_digit (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound).
+ Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (uweight bound).
+ Local Definition compact_digit := Columns.compact_digit (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (uweight bound).
Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)).
Proof.
pose_all.
diff --git a/src/Arithmetic/Saturated/Wrappers.v b/src/Arithmetic/Saturated/Wrappers.v
index 6bb3893d5..e750cfd36 100644
--- a/src/Arithmetic/Saturated/Wrappers.v
+++ b/src/Arithmetic/Saturated/Wrappers.v
@@ -22,7 +22,7 @@ Module Columns.
B.Positional.to_associational_cps weight p
(fun P => B.Positional.to_associational_cps weight q
(fun Q => Columns.from_associational_cps weight n3 (P++Q)
- (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f))).
+ (fun R => Columns.compact_cps (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f))).
Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2)
{T} (f : (Z*Z^n3)->T) :=
@@ -30,7 +30,7 @@ Module Columns.
(fun P => B.Positional.negate_snd_cps weight q
(fun nq => B.Positional.to_associational_cps weight nq
(fun Q => Columns.from_associational_cps weight n3 (P++Q)
- (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))).
+ (fun R => Columns.compact_cps (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))).
Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2)
{T} (f : (Z*Z^n3)->T) :=
@@ -38,7 +38,7 @@ Module Columns.
(fun P => B.Positional.to_associational_cps weight q
(fun Q => B.Associational.sat_mul_cps (mul_split_cps := @Z.mul_split_cps') s P Q
(fun PQ => Columns.from_associational_cps weight n3 PQ
- (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))).
+ (fun R => Columns.compact_cps (div_cps:=@div_cps) (modulo_cps:=@modulo_cps) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))).
Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2)
{T} (f:_->T) :=