aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-11-07 14:00:36 -0500
committerGravatar Jason Gross <jgross@mit.edu>2017-11-07 14:00:36 -0500
commitffcde2d3a9ed23a2236ea1d5692d5544207b6da6 (patch)
tree192f52bf18f7dd6b2fefb0cbec0bf5465602d71a /src/Arithmetic
parentfb5cdd711657ad7ce278eda5556b4e2b0b9119f5 (diff)
Move chained_carries' (now chained_carries_reduce)
It now lives in Arithmetic.Core.B.Positional, where it belongs, rather than in Specific/.../HelperTactics. Andres notes that we probably don't need this at all, and could instead make chained_carries reduce after every index (and the spurious reductions should be no-ops). I didn't want to bother verifying this, at the moment, so I left it as-is.
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/Core.v277
-rw-r--r--src/Arithmetic/CoreUnfolder.v21
2 files changed, 209 insertions, 89 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v
index f3692c86d..98600d9a3 100644
--- a/src/Arithmetic/Core.v
+++ b/src/Arithmetic/Core.v
@@ -662,87 +662,6 @@ Module B.
Hint Rewrite @eval_from_associational using omega
: push_basesystem_eval.
- Section Carries.
- Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}.
- Let modulo x y := modulo_cps _ x y id.
- Let div x y := div_cps _ x y id.
- Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)}
- {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}.
- Context {div_mod : forall a b:Z, b <> 0 ->
- a = b * (div a b) + modulo a b}.
- Hint Rewrite modulo_cps_id div_cps_id : uncps.
-
- Definition carry_cps {n m} (index:nat) (p:tuple Z n)
- {T} (f:tuple Z m->T) :=
- to_associational_cps p
- (fun P => @Associational.carry_cps
- modulo_cps div_cps
- (weight index)
- (weight (S index) / weight index)
- P T
- (fun R => from_associational_cps m R f)).
-
- Definition carry {n m} i p := @carry_cps n m i p _ id.
- Lemma carry_cps_id {n m} i p {T} f:
- @carry_cps n m i p T f = f (carry i p).
- Proof.
- cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity.
- Qed.
- Hint Opaque carry : uncps. Hint Rewrite @carry_cps_id : uncps.
-
- Lemma eval_carry {n m} i p: (n <> 0%nat) -> (m <> 0%nat) ->
- weight (S i) / weight i <> 0 ->
- eval (carry (n:=n) (m:=m) i p) = eval p.
- Proof.
- cbv [carry_cps carry]; intros. prove_eval.
- rewrite @eval_carry by eauto.
- apply eval_to_associational.
- Qed.
- Hint Rewrite @eval_carry : push_basesystem_eval.
-
- (* N.B. It is important to reverse [idxs] here. Like
- [fold_right], [fold_right_cps2] is written such that the first
- terms in the list are actually used last in the computation. For
- example, running:
-
- `Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).`
-
- will produce [fun a b c d => (a + (b + (c + d)))].*)
- Definition chained_carries_cps {n} (p:tuple Z n) (idxs : list nat)
- {T} (f:tuple Z n->T) :=
- fold_right_cps2 carry_cps p (rev idxs) f.
-
- Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id.
- Lemma chained_carries_id {n} p idxs : forall {T} f,
- @chained_carries_cps n p idxs T f = f (chained_carries p idxs).
- Proof using modulo_cps_id div_cps_id.
- cbv [chained_carries_cps chained_carries]; prove_id.
- Qed.
- Hint Opaque chained_carries : uncps.
- Hint Rewrite @chained_carries_id : uncps.
-
- Lemma eval_chained_carries {n} (p:tuple Z n) idxs :
- (forall i, In i idxs -> weight (S i) / weight i <> 0) ->
- eval (chained_carries p idxs) = eval p.
- Proof using Type*.
- cbv [chained_carries chained_carries_cps]; intros;
- autorewrite with uncps push_id.
- apply fold_right_invariant; [|intro; rewrite <-in_rev];
- destruct n; prove_eval; auto.
- Qed. Hint Rewrite @eval_chained_carries : push_basesystem_eval.
-
- (* Reverse of [eval]; ranslate from Z to basesystem by putting
- everything in first digit and then carrying. This function, like
- [eval], is not defined using CPS. *)
- Definition encode {n} (x : Z) : tuple Z n :=
- chained_carries (from_associational n [(1,x)]) (seq 0 n).
- Lemma eval_encode {n} x : (n <> 0%nat) ->
- (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) ->
- eval (@encode n x) = x.
- Proof using Type*. cbv [encode]; intros; prove_eval; auto. Qed.
- Hint Rewrite @eval_encode : push_basesystem_eval.
-
- End Carries.
Section Wrappers.
(* Simple wrappers for Associational definitions; convert to
@@ -765,12 +684,6 @@ Module B.
(fun P => Associational.reduce_cps s c P
(fun R => from_associational_cps n R f)).
- Definition carry_reduce_cps {n div_cps modulo_cps}
- (s:Z) (c:list limb) (p : tuple Z n)
- {T} (f: tuple Z n ->T) :=
- carry_cps (div_cps:=div_cps) (modulo_cps:=modulo_cps) (n:=n) (m:=S n) (pred n) p
- (fun r => reduce_cps (m:=S n) (n:=n) s c r f).
-
Definition negate_snd_cps {n} (p : tuple Z n)
{T} (f:tuple Z n->T) :=
to_associational_cps p
@@ -808,13 +721,193 @@ Module B.
Positional.add_cps
Positional.mul_cps
Positional.reduce_cps
- Positional.carry_reduce_cps
Positional.negate_snd_cps
Positional.split_cps
Positional.scmul_cps
Positional.unbalanced_sub_cps
.
+ Section Carries.
+ Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}.
+ Let modulo x y := modulo_cps _ x y id.
+ Let div x y := div_cps _ x y id.
+ Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)}
+ {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}.
+ Context {div_mod : forall a b:Z, b <> 0 ->
+ a = b * (div a b) + modulo a b}.
+ Hint Rewrite modulo_cps_id div_cps_id : uncps.
+
+ Definition carry_cps {n m} (index:nat) (p:tuple Z n)
+ {T} (f:tuple Z m->T) :=
+ to_associational_cps p
+ (fun P => @Associational.carry_cps
+ modulo_cps div_cps
+ (weight index)
+ (weight (S index) / weight index)
+ P T
+ (fun R => from_associational_cps m R f)).
+
+ Definition carry {n m} i p := @carry_cps n m i p _ id.
+ Lemma carry_cps_id {n m} i p {T} f:
+ @carry_cps n m i p T f = f (carry i p).
+ Proof.
+ cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity.
+ Qed.
+ Hint Opaque carry : uncps. Hint Rewrite @carry_cps_id : uncps.
+
+ Lemma eval_carry {n m} i p: (n <> 0%nat) -> (m <> 0%nat) ->
+ weight (S i) / weight i <> 0 ->
+ eval (carry (n:=n) (m:=m) i p) = eval p.
+ Proof.
+ cbv [carry_cps carry]; intros. prove_eval.
+ rewrite @eval_carry by eauto.
+ apply eval_to_associational.
+ Qed.
+ Hint Rewrite @eval_carry : push_basesystem_eval.
+
+ Definition carry_reduce_cps {n}
+ (s:Z) (c:list limb) (p : tuple Z n)
+ {T} (f: tuple Z n ->T) :=
+ carry_cps (n:=n) (m:=S n) (pred n) p
+ (fun r => reduce_cps (m:=S n) (n:=n) s c r f).
+ Hint Unfold carry_reduce_cps.
+
+ (* N.B. It is important to reverse [idxs] here. Like
+ [fold_right], [fold_right_cps2] is written such that the first
+ terms in the list are actually used last in the computation. For
+ example, running:
+
+ `Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).`
+
+ will produce [fun a b c d => (a + (b + (c + d)))].*)
+ Definition chained_carries_cps {n} (p:tuple Z n) (idxs : list nat)
+ {T} (f:tuple Z n->T) :=
+ fold_right_cps2 carry_cps p (rev idxs) f.
+
+ Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id.
+ Lemma chained_carries_id {n} p idxs : forall {T} f,
+ @chained_carries_cps n p idxs T f = f (chained_carries p idxs).
+ Proof using modulo_cps_id div_cps_id.
+ cbv [chained_carries_cps chained_carries]; prove_id.
+ Qed.
+ Hint Opaque chained_carries : uncps.
+ Hint Rewrite @chained_carries_id : uncps.
+
+ Lemma eval_chained_carries {n} (p:tuple Z n) idxs :
+ (forall i, In i idxs -> weight (S i) / weight i <> 0) ->
+ eval (chained_carries p idxs) = eval p.
+ Proof using Type*.
+ cbv [chained_carries chained_carries_cps]; intros;
+ autorewrite with uncps push_id.
+ apply fold_right_invariant; [|intro; rewrite <-in_rev];
+ destruct n; prove_eval; auto.
+ Qed. Hint Rewrite @eval_chained_carries : push_basesystem_eval.
+
+ Definition chained_carries_reduce_cps_step {n} (s:Z) (c:list limb) {T}
+ (chained_carries_reduce_cps : forall (p:tuple Z n) (carry_chains : list (list nat)) (f : tuple Z n -> T), T)
+ (p : tuple Z n) (carry_chains : list (list nat))
+ (f : tuple Z n -> T)
+ : T
+ := match carry_chains with
+ | nil => f p
+ | carry_chain :: nil
+ => chained_carries_cps
+ (n:=n) p carry_chain f
+ | carry_chain :: carry_chains
+ => chained_carries_cps
+ (n:=n) p carry_chain
+ (fun r => carry_reduce_cps (n:=n) s c r
+ (fun r' => chained_carries_reduce_cps r' carry_chains f))
+ end.
+ Section chained_carries_reduce_cps.
+ Context {n:nat} (s:Z) (c:list limb) {T:Type}.
+
+ Fixpoint chained_carries_reduce_cps
+ (p : tuple Z n) (carry_chains : list (list nat))
+ (f : tuple Z n -> T)
+ : T
+ := @chained_carries_reduce_cps_step
+ n s c T
+ chained_carries_reduce_cps p carry_chains f.
+ End chained_carries_reduce_cps.
+
+ Lemma step_chained_carries_reduce_cps {n} (s:Z) (c:list limb) {T} p carry_chain carry_chains (f : tuple Z n -> T)
+ : chained_carries_reduce_cps s c p (carry_chain :: carry_chains) f
+ = match length carry_chains with
+ | O => chained_carries_cps
+ (n:=n) p carry_chain f
+ | S _
+ => chained_carries_cps
+ (n:=n) p carry_chain
+ (fun r => carry_reduce_cps (n:=n) s c r
+ (fun r' => chained_carries_reduce_cps s c r' carry_chains f))
+ end.
+ Proof.
+ destruct carry_chains; reflexivity.
+ Qed.
+
+ Definition chained_carries_reduce {n} (s:Z) (c:list limb) (p:tuple Z n) (carry_chains : list (list nat))
+ : tuple Z n
+ := chained_carries_reduce_cps s c p carry_chains id.
+
+ Lemma chained_carries_reduce_id {n} s c {T} p carry_chains f
+ : @chained_carries_reduce_cps n s c T p carry_chains f
+ = f (@chained_carries_reduce n s c p carry_chains).
+ Proof.
+ destruct carry_chains as [|carry_chain carry_chains]; [ reflexivity | ].
+ cbv [chained_carries_reduce].
+ revert p carry_chain; induction carry_chains as [|? carry_chains IHcarry_chains]; intros.
+ { simpl; repeat autounfold; autorewrite with uncps. reflexivity. }
+ { rewrite !step_chained_carries_reduce_cps.
+ simpl @length; cbv iota beta.
+ repeat autounfold; autorewrite with uncps.
+ rewrite !IHcarry_chains.
+ reflexivity. }
+ Qed.
+ Hint Opaque chained_carries_reduce : uncps.
+ Hint Rewrite @chained_carries_reduce_id : uncps.
+
+ Lemma eval_chained_carries_reduce {n} (s:Z) (c:list limb) (p:tuple Z n) carry_chains
+ (Hn : n <> 0%nat)
+ (s_nonzero:s<>0) m (m_eq : Z.pos m = s - Associational.eval c)
+ (Hwt : weight (S (Init.Nat.pred n)) / weight (Init.Nat.pred n) <> 0)
+ : (List.fold_right
+ and
+ True
+ (List.map
+ (fun idxs
+ => forall i, In i idxs -> weight (S i) / weight i <> 0)
+ carry_chains)) ->
+ mod_eq m (eval (chained_carries_reduce s c p carry_chains)) (eval p).
+ Proof using Type*.
+ destruct carry_chains as [|carry_chain carry_chains]; [ reflexivity | ].
+ cbv [chained_carries_reduce].
+ revert p carry_chain; induction carry_chains as [|? carry_chains IHcarry_chains]; intros.
+ { cbn in *; prove_eval; auto. }
+ { rewrite !step_chained_carries_reduce_cps.
+ simpl @length; cbv iota beta.
+ repeat autounfold; autorewrite with uncps push_id push_basesystem_eval.
+ cbv [chained_carries_reduce].
+ rewrite !IHcarry_chains by (cbn in *; tauto); clear IHcarry_chains.
+ cbn in * |- .
+ prove_eval; auto. }
+ Qed.
+ Hint Rewrite @eval_chained_carries_reduce using (omega || assumption) : push_basesystem_eval.
+
+ (* Reverse of [eval]; translate from Z to basesystem by putting
+ everything in first digit and then carrying. This function, like
+ [eval], is not defined using CPS. *)
+ Definition encode {n} (x : Z) : tuple Z n :=
+ chained_carries (from_associational n [(1,x)]) (seq 0 n).
+ Lemma eval_encode {n} x : (n <> 0%nat) ->
+ (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) ->
+ eval (@encode n x) = x.
+ Proof using Type*. cbv [encode]; intros; prove_eval; auto. Qed.
+ Hint Rewrite @eval_encode : push_basesystem_eval.
+
+ End Carries.
+ Hint Unfold carry_reduce_cps.
+
Section Subtraction.
Context {m n} {coef : tuple Z n}
{coef_mod : mod_eq m (eval coef) 0}.
@@ -982,6 +1075,7 @@ Module B.
@Associational.carryterm_cps_id
@Positional.carry_cps_id
@Positional.chained_carries_id
+ @Positional.chained_carries_reduce_id
using div_mod_cps_t : uncps.
Hint Rewrite
@Associational.eval_mul
@@ -998,6 +1092,7 @@ Module B.
@Positional.eval_from_associational
@Positional.eval_add_to_nth
@Positional.eval_chained_carries
+ @Positional.eval_chained_carries_reduce
@Positional.eval_sub
@Positional.eval_select
using (assumption || (div_mod_cps_t; auto) || vm_decide) : push_basesystem_eval.
@@ -1092,6 +1187,9 @@ Hint Unfold
Positional.carry
Positional.chained_carries_cps
Positional.chained_carries
+ Positional.chained_carries_reduce_cps_step
+ Positional.chained_carries_reduce_cps
+ Positional.chained_carries_reduce
Positional.encode
Positional.add_cps
Positional.mul_cps
@@ -1135,6 +1233,9 @@ Ltac basesystem_partial_evaluation_unfolder t :=
Positional.from_associational_cps Positional.from_associational
Positional.carry_cps Positional.carry
Positional.chained_carries_cps Positional.chained_carries
+ Positional.chained_carries_reduce_cps
+ Positional.chained_carries_reduce
+ Positional.chained_carries_reduce_cps_step
Positional.sub_cps Positional.sub Positional.split_cps
Positional.scmul_cps Positional.unbalanced_sub_cps
Positional.negate_snd_cps Positional.add_cps Positional.opp_cps
diff --git a/src/Arithmetic/CoreUnfolder.v b/src/Arithmetic/CoreUnfolder.v
index cad4f6e7c..991ca3193 100644
--- a/src/Arithmetic/CoreUnfolder.v
+++ b/src/Arithmetic/CoreUnfolder.v
@@ -14,6 +14,7 @@ Hint Unfold Core.div Core.modulo : arithmetic_cps_unfolder.
Ltac make_parameterized_sig t :=
refine (_ : { v : _ | v = t });
eexists; cbv delta [t
+ Core.B.Positional.chained_carries_reduce_cps_step
B.limb ListUtil.sum ListUtil.sum_firstn
CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2
Decidable.dec Decidable.dec_eq_Z
@@ -63,7 +64,7 @@ for i in eval multerm mul_cps mul split_cps split reduce_cps reduce negate_snd_c
done
echo " End Associational."
echo " Module Positional."
-for i in to_associational_cps to_associational eval zeros add_to_nth_cps add_to_nth place_cps place from_associational_cps from_associational carry_cps carry chained_carries_cps chained_carries encode add_cps mul_cps reduce_cps carry_reduce_cps negate_snd_cps split_cps scmul_cps unbalanced_sub_cps sub_cps sub opp_cps Fencode Fdecode eval_from select_cps select; do
+for i in to_associational_cps to_associational eval zeros add_to_nth_cps add_to_nth place_cps place from_associational_cps from_associational carry_cps carry chained_carries_cps chained_carries encode add_cps mul_cps reduce_cps carry_reduce_cps chained_carries_reduce_cps_step chained_carries_reduce_cps chained_carries_reduce negate_snd_cps split_cps scmul_cps unbalanced_sub_cps sub_cps sub opp_cps Fencode Fdecode eval_from select_cps select; do
echo " Definition ${i}_sig := parameterize_sig (@Core.B.Positional.${i}).";
echo " Definition ${i} := parameterize_from_sig ${i}_sig.";
echo " Definition ${i}_eq := parameterize_eq ${i} ${i}_sig.";
@@ -281,6 +282,24 @@ done
Hint Unfold carry_reduce_cps : basesystem_partial_evaluation_unfolder.
Hint Rewrite <- carry_reduce_cps_eq : pattern_runtime.
+ Definition chained_carries_reduce_cps_step_sig := parameterize_sig (@Core.B.Positional.chained_carries_reduce_cps_step).
+ Definition chained_carries_reduce_cps_step := parameterize_from_sig chained_carries_reduce_cps_step_sig.
+ Definition chained_carries_reduce_cps_step_eq := parameterize_eq chained_carries_reduce_cps_step chained_carries_reduce_cps_step_sig.
+ Hint Unfold chained_carries_reduce_cps_step : basesystem_partial_evaluation_unfolder.
+ Hint Rewrite <- chained_carries_reduce_cps_step_eq : pattern_runtime.
+
+ Definition chained_carries_reduce_cps_sig := parameterize_sig (@Core.B.Positional.chained_carries_reduce_cps).
+ Definition chained_carries_reduce_cps := parameterize_from_sig chained_carries_reduce_cps_sig.
+ Definition chained_carries_reduce_cps_eq := parameterize_eq chained_carries_reduce_cps chained_carries_reduce_cps_sig.
+ Hint Unfold chained_carries_reduce_cps : basesystem_partial_evaluation_unfolder.
+ Hint Rewrite <- chained_carries_reduce_cps_eq : pattern_runtime.
+
+ Definition chained_carries_reduce_sig := parameterize_sig (@Core.B.Positional.chained_carries_reduce).
+ Definition chained_carries_reduce := parameterize_from_sig chained_carries_reduce_sig.
+ Definition chained_carries_reduce_eq := parameterize_eq chained_carries_reduce chained_carries_reduce_sig.
+ Hint Unfold chained_carries_reduce : basesystem_partial_evaluation_unfolder.
+ Hint Rewrite <- chained_carries_reduce_eq : pattern_runtime.
+
Definition negate_snd_cps_sig := parameterize_sig (@Core.B.Positional.negate_snd_cps).
Definition negate_snd_cps := parameterize_from_sig negate_snd_cps_sig.
Definition negate_snd_cps_eq := parameterize_eq negate_snd_cps negate_snd_cps_sig.