aboutsummaryrefslogtreecommitdiff
path: root/src/NewBaseSystem.v
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-02-23 10:20:07 -0500
committerGravatar jadep <jade.philipoom@gmail.com>2017-02-23 10:20:49 -0500
commit371b69d283ead05f75c698e31892778397286428 (patch)
treebde84ca733c2325d66905d7f09766a12c5b21304 /src/NewBaseSystem.v
parent2b10dce1a3c60398f3fce7cd695b20b2a881d378 (diff)
added explanation of why CPS is useful
Diffstat (limited to 'src/NewBaseSystem.v')
-rw-r--r--src/NewBaseSystem.v457
1 files changed, 382 insertions, 75 deletions
diff --git a/src/NewBaseSystem.v b/src/NewBaseSystem.v
index 7a7a5214f..5b5dae743 100644
--- a/src/NewBaseSystem.v
+++ b/src/NewBaseSystem.v
@@ -10,6 +10,249 @@ Require Import Crypto.Util.CPSUtil Crypto.Util.Prod.
Require Import Coq.Lists.List. Import ListNotations.
Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.
+(*****
+
+This file provides a generalized version of arithmetic with "mixed
+radix" numerical systems. Later, parameters are entered into the
+general functions, and they are partially evaluated until only runtime
+basic arithmetic operations remain.
+
+CPS
+---
+
+Fuctions are written in continuation passing style (CPS). This means
+that each operation is passed a "continuation" function, which it is
+expected to call on its own output (like a callback). See the end of
+this comment for a motivating example explaining why we do CPS,
+despite a fair amount of resulting boilerplate code for each
+operation. The code block for an operation called A would look like
+this:
+
+```
+Definition A_cps x y {T} f : T := ...
+
+Definition A x y := A_cps x y id.
+Lemma A_cps_id x y : forall {T} f, @A_cps x y T f = f (A x y).
+Hint Opaque A : uncps.
+Hint Rewrite A_cps_id : uncps.
+
+Lemma eval_A x y : eval (A x y) = ...
+Hint Rewrite eval_A : push_basesystem_eval.
+```
+
+`A_cps` is the main, CPS-style definition of the operation (`f` is the
+continuation function). `A` is the non-CPS version of `A_cps`, simply
+defined by passing an identity function to `A_cps`. `A_cps_id` states
+that we can replace the CPS version with the non-cps version. `eval_A`
+is the actual correctness lemma for the operation, stating that it has
+the correct arithmetic properties. In general, the middle block
+containing `A` and `A_cps_id` is boring boilerplate and can be safely
+ignored.
+
+HintDbs
+-------
+
++ `uncps` : Converts CPS operations to their non-CPS versions.
++ `push_basesystem_eval` : Contains all the correctness lemmas for
+ operations in this file, which are in terms of the `eval` function.
+
+Positional/Associational
+------------------------
+
+We represent mixed-radix numbers in a few different ways:
+
++ "Positional" : a tuple of numbers and a weight function (nat->Z),
+which is evaluated by multiplying the `i`th element of the tuple by
+`weight i`, and then summing the products.
++ "Associational" : a list of pairs of numbers--the first is the
+weight, the second is the runtime value. Evaluated by multiplying each
+pair and summing the products.
+
+The associational representation is good for basic operations like
+addition and multiplication; for addition, one can simply just append
+two associational lists. But the end-result code should use the
+positional representation (with each digit representing a machine
+word). Since converting to and fro can be easily compiled away once
+the weight function is known, we use associational to write most of
+the operations and liberally convert back and forth to ensure correct
+output. In particular, it is important to convert before carrying.
+
+Runtime Operations
+------------------
+
+Since some instances of e.g. Z.add or Z.mul operate on (compile-time)
+weights, and some operate on runtime values, we need a way to
+differentiate these cases before partial evaluation. We define a
+runtime_scope to mark certain additions/multiplications as runtime
+values, so they will not be unfolded during partial evaluation. For
+instance, if we have:
+
+```
+Definition f (x y : Z * Z) := (fst x + fst y, (snd x + snd y)%RT).
+```
+
+then when we are partially evaluating `f`, we can easily exclude the
+runtime operations (`cbv - [runtime_add]`) and prevent Coq from trying
+to simplify the second addition.
+
+
+Why CPS?
+--------
+
+Let's suppose we want to add corresponding elements of two `list Z`s
+(so on inputs `[1,2,3]` and `[2,3,1]`, we get `[3,5,4]`). We might
+write our function like this :
+
+```
+Fixpoint add_lists (p q : list Z) :=
+ match p, q with
+ | p0 :: p', q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ | _, _ => nil
+ end.
+```
+
+(Note : `dlet` is a notation for `Let_In`, which is just a dumb
+wrapper for `let`. This allows us to `cbv - [Let_In]` if we want to
+not simplify certain `let`s.)
+
+A CPS equivalent of `add_lists` would look like this:
+
+```
+Fixpoint add_lists_cps (p q : list Z) {T} (f:list Z->T) :=
+ match p, q with
+ | p0 :: p', q0 :: q' =>
+ dlet sum := p0 + q0 in
+ add_lists_cps p' q' (fun r => f (sum :: r))
+ | _, _ => f nil
+ end.
+```
+
+Now let's try some partial evaluation. The expression we'll evaluate is:
+
+```
+Definition x :=
+ (fun a0 a1 a2 b0 b1 b2 =>
+ let r := add_lists [a0;a1;a2] [b0;b1;b2] in
+ let rr := add_lists r r in
+ add_lists rr rr).
+```
+
+Or, using `add_lists_cps`:
+
+```
+Definition y :=
+ (fun a0 a1 a2 b0 b1 b2 =>
+ add_lists_cps [a0;a1;a2] [b0;b1;b2]
+ (fun r => add_lists_cps r r
+ (fun rr => add_lists_cps rr rr id))).
+```
+
+If we run `Eval cbv -[Z.add] in x` and `Eval cbv -[Z.add] in y`, we get
+identical output:
+
+```
+fun a0 a1 a2 b0 b1 b2 : Z =>
+ [a0 + b0 + (a0 + b0) + (a0 + b0 + (a0 + b0));
+ a1 + b1 + (a1 + b1) + (a1 + b1 + (a1 + b1));
+ a2 + b2 + (a2 + b2) + (a2 + b2 + (a2 + b2))]
+```
+
+However, there are a lot of common subexpressions here--this is what
+the `dlet` we put into the functions should help us avoid. Let's try
+`Eval cbv -[Let_In Z.add] in x`:
+
+```
+fun a0 a1 a2 b0 b1 b2 : Z =>
+ (fix add_lists (p q : list Z) {struct p} :
+ list Z :=
+ match p with
+ | [] => []
+ | p0 :: p' =>
+ match q with
+ | [] => []
+ | q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ end
+ end)
+ ((fix add_lists (p q : list Z) {struct p} :
+ list Z :=
+ match p with
+ | [] => []
+ | p0 :: p' =>
+ match q with
+ | [] => []
+ | q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ end
+ end)
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1])))
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1]))))
+ ((fix add_lists (p q : list Z) {struct p} :
+ list Z :=
+ match p with
+ | [] => []
+ | p0 :: p' =>
+ match q with
+ | [] => []
+ | q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ end
+ end)
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1])))
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1]))))
+```
+
+Not so great. Because the `dlet`s are stuck in the inner terms, we
+can't simplify the expression very nicely. Let's try that on the CPS
+version (`Eval cbv -[Let_In Z.add] in y`):
+
+```
+fun a0 a1 a2 b0 b1 b2 : Z =>
+ dlet sum := a0 + b0 in
+ dlet sum0 := a1 + b1 in
+ dlet sum1 := a2 + b2 in
+ dlet sum2 := sum + sum in
+ dlet sum3 := sum0 + sum0 in
+ dlet sum4 := sum1 + sum1 in
+ dlet sum5 := sum2 + sum2 in
+ dlet sum6 := sum3 + sum3 in
+ dlet sum7 := sum4 + sum4 in
+ [sum5; sum6; sum7]
+```
+
+Isn't that lovely? Since we can push continuation functions "under"
+the `dlet`s, we can end up with a nice, concise, simplified
+expression.
+
+One might suggest that we could just inline the `dlet`s and do common
+subexpression elimination. But some of our terms have so many `dlet`s
+that inlining them all would make a term too huge to process in
+reasonable time, so this is not really an option.
+
+*****)
+
+
Local Ltac prove_id :=
repeat match goal with
| _ => progress intros
@@ -63,66 +306,78 @@ Module B.
Definition multerm (t t' : limb) : limb :=
(fst t * fst t', (snd t * snd t')%RT).
+ Lemma eval_map_multerm (a:limb) (q:list limb)
+ : eval (List.map (multerm a) q) = fst a * snd a * eval q.
+ Proof.
+ induction q; cbv [multerm]; simpl List.map;
+ autorewrite with push_basesystem_eval cancel_pair; nsatz.
+ Qed. Hint Rewrite eval_map_multerm : push_basesystem_eval.
+
Definition mul_cps (p q:list limb) {T} (f : list limb->T) :=
flat_map_cps (fun t => @map_cps _ _ (multerm t) q) p f.
+
Definition mul (p q:list limb) := mul_cps p q id.
+ Lemma mul_cps_id p q: forall {T} f, @mul_cps p q T f = f (mul p q).
+ Proof. cbv [mul_cps mul]; prove_id. Qed.
Hint Opaque mul : uncps.
- Lemma eval_map_mul (a:limb) (q:list limb) : eval (List.map (multerm a) q) = fst a * snd a * eval q.
- Proof.
- induction q; cbv [multerm]; simpl List.map;
- autorewrite with push_basesystem_eval cancel_pair; nsatz.
- Qed. Hint Rewrite eval_map_mul : push_basesystem_eval.
- Lemma mul_cps_id p q: forall {T} f,
- @mul_cps p q T f = f (mul p q).
- Proof. cbv [mul_cps mul]; prove_id. Qed. Hint Rewrite mul_cps_id : uncps.
- Lemma eval_mul_noncps p q:
- eval (mul p q) = eval p * eval q.
- Proof.
- cbv [mul mul_cps]; induction p; prove_eval. Qed. Hint Rewrite eval_mul_noncps : push_basesystem_eval.
+ Hint Rewrite mul_cps_id : uncps.
- Fixpoint split (s:Z) (xs:list limb)
+ Lemma eval_mul p q: eval (mul p q) = eval p * eval q.
+ Proof. cbv [mul mul_cps]; induction p; prove_eval. Qed.
+ Hint Rewrite eval_mul : push_basesystem_eval.
+
+ Fixpoint split_cps (s:Z) (xs:list limb)
{T} (f :list limb*list limb->T) :=
match xs with
| nil => f (nil, nil)
| cons x xs' =>
- split s xs'
+ split_cps s xs'
(fun sxs' =>
if dec (fst x mod s = 0)
then f (fst sxs', cons (fst x / s, snd x) (snd sxs'))
else f (cons x (fst sxs'), snd sxs'))
end.
- Definition split_noncps s xs := split s xs id.
- Hint Opaque split_noncps : uncps.
- Lemma split_id s p: forall {T} f,
- @split s p T f = f (split_noncps s p).
+
+ Definition split s xs := split_cps s xs id.
+ Lemma split_cps_id s p: forall {T} f,
+ @split_cps s p T f = f (split s p).
Proof.
induction p;
repeat match goal with
| _ => rewrite IHp
- | _ => progress (cbv [split_noncps]; prove_id)
+ | _ => progress (cbv [split]; prove_id)
end.
- Qed. Hint Rewrite split_id : uncps.
- Lemma eval_split_noncps s p (s_nonzero:s<>0):
- eval (fst (split_noncps s p)) + s*eval (snd (split_noncps s p)) = eval p.
+ Qed.
+ Hint Opaque split : uncps.
+ Hint Rewrite split_cps_id : uncps.
+
+ Lemma eval_split s p (s_nonzero:s<>0):
+ eval (fst (split s p)) + s*eval (snd (split s p)) = eval p.
Proof.
- cbv [split_noncps]; induction p; prove_eval.
- match goal with H:_ |- _ =>
- unique pose proof (Z_div_exact_full_2 _ _ s_nonzero H)
+ cbv [split]; induction p; prove_eval.
+ match goal with
+ H:_ |- _ =>
+ unique pose proof (Z_div_exact_full_2 _ _ s_nonzero H)
end; nsatz.
- Qed. Hint Rewrite @eval_split_noncps using auto : push_basesystem_eval.
+ Qed. Hint Rewrite @eval_split using auto : push_basesystem_eval.
Definition reduce_cps (s:Z) (c:list limb) (p:list limb)
{T} (f : list limb->T) :=
- split s p (fun ab =>mul_cps c (snd ab) (fun rr =>f (fst ab ++ rr))).
+ split_cps s p
+ (fun ab => mul_cps c (snd ab)
+ (fun rr =>f (fst ab ++ rr))).
+
Definition reduce s c p := reduce_cps s c p id.
+ Lemma reduce_cps_id s c p {T} f:
+ @reduce_cps s c p T f = f (reduce s c p).
+ Proof. cbv [reduce_cps reduce]; prove_id. Qed.
Hint Opaque reduce : uncps.
+ Hint Rewrite reduce_cps_id : uncps.
+
Lemma reduction_rule a b s c (modulus_nonzero:s-c<>0) :
(a + s * b) mod (s - c) = (a + c * b) mod (s - c).
Proof. replace (a + s * b) with ((a + c*b) + b*(s-c)) by nsatz.
rewrite Z.add_mod, Z_mod_mult, Z.add_0_r, Z.mod_mod; trivial. Qed.
- Lemma reduce_cps_id s c p {T} f:
- @reduce_cps s c p T f = f (reduce s c p).
- Proof. cbv [reduce_cps reduce]; prove_id. Qed. Hint Rewrite reduce_cps_id : uncps.
Lemma eval_reduce s c p (s_nonzero:s<>0) (modulus_nonzero:s-eval c<>0):
eval (reduce s c p) mod (s - eval c) = eval p mod (s - eval c).
Proof.
@@ -140,16 +395,16 @@ Module B.
then dlet t2 := snd t in
f ((w*fw, div t2 fw) :: (w, modulo t2 fw) :: @nil limb)
else f [t].
- Definition carry_cps(w fw:Z) (p:list limb) {T} (f:list limb->T) :=
- flat_map_cps (carryterm_cps w fw) p f.
+
Definition carryterm w fw t := carryterm_cps w fw t id.
- Hint Opaque carryterm : uncps.
- Definition carry w fw p := carry_cps w fw p id.
- Hint Opaque carry : uncps.
Lemma carryterm_cps_id w fw t {T} f :
@carryterm_cps w fw t T f
= f (@carryterm w fw t).
- Proof. cbv [carryterm_cps carryterm Let_In]; prove_id. Qed. Hint Rewrite carryterm_cps_id : uncps.
+ Proof. cbv [carryterm_cps carryterm Let_In]; prove_id. Qed.
+ Hint Opaque carryterm : uncps.
+ Hint Rewrite carryterm_cps_id : uncps.
+
+
Lemma eval_carryterm w fw (t:limb) (fw_nonzero:fw<>0):
eval (carryterm w fw t) = eval [t].
Proof.
@@ -157,10 +412,17 @@ Module B.
specialize (div_mod (snd t) fw fw_nonzero).
nsatz.
Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval.
+
+ Definition carry_cps(w fw:Z) (p:list limb) {T} (f:list limb->T) :=
+ flat_map_cps (carryterm_cps w fw) p f.
+
+ Definition carry w fw p := carry_cps w fw p id.
Lemma carry_cps_id w fw p {T} f:
@carry_cps w fw p T f = f (carry w fw p).
Proof. cbv [carry_cps carry]; prove_id. Qed.
+ Hint Opaque carry : uncps.
Hint Rewrite carry_cps_id : uncps.
+
Lemma eval_carry w fw p (fw_nonzero:fw<>0):
eval (carry w fw p) = eval p.
Proof. cbv [carry_cps carry]; induction p; prove_eval. Qed.
@@ -179,17 +441,17 @@ Module B.
Definition sat_multerm_cps (t t' : limb) {T} (f:list limb->T) :=
dlet tt' := mul (snd t) (snd t') in
f ((fst t*fst t', runtime_fst tt') :: (fst t*fst t'*word_max, runtime_snd tt') :: nil)%list.
- Definition sat_mul_cps (p q : list limb) {T} (f:list limb->T) :=
- flat_map_cps (fun t => @flat_map_cps _ _ (sat_multerm_cps t) q) p f.
- (* TODO (jgross): kind of an interesting behavior--it infers the type arguments like this but fails to check if I leave them implicit *)
+
Definition sat_multerm t t' := sat_multerm_cps t t' id.
- Definition sat_mul p q := sat_mul_cps p q id.
- Hint Opaque sat_multerm sat_mul : uncps.
Lemma sat_multerm_cps_id t t' : forall {T} (f:list limb->T),
sat_multerm_cps t t' f = f (sat_multerm t t').
- Proof. reflexivity. Qed. Hint Rewrite sat_multerm_cps_id : uncps.
+ Proof. reflexivity. Qed.
+ Hint Opaque sat_multerm : uncps.
+ Hint Rewrite sat_multerm_cps_id : uncps.
+
Lemma eval_map_sat_multerm_cps t q :
- eval (flat_map (fun x => sat_multerm_cps t x id) q) = fst t * snd t * eval q.
+ eval (flat_map (fun x => sat_multerm_cps t x id) q)
+ = fst t * snd t * eval q.
Proof.
cbv [sat_multerm sat_multerm_cps Let_In runtime_fst runtime_snd];
induction q; prove_eval;
@@ -197,8 +459,19 @@ Module B.
specialize (mul_correct a b) end;
nsatz.
Qed. Hint Rewrite eval_map_sat_multerm_cps : push_basesystem_eval.
- Lemma sat_mul_cps_id p q {T} f : @sat_mul_cps p q T f = f (sat_mul p q).
- Proof. cbv [sat_mul_cps sat_mul]; prove_id. Qed. Hint Rewrite sat_mul_cps_id : uncps.
+
+ Definition sat_mul_cps (p q : list limb) {T} (f:list limb->T) :=
+ flat_map_cps (fun t =>
+ @flat_map_cps _ _ (sat_multerm_cps t) q) p f.
+ (* TODO (jgross): kind of an interesting behavior--it infers the type arguments like this but fails to check if I leave them implicit *)
+
+ Definition sat_mul p q := sat_mul_cps p q id.
+ Lemma sat_mul_cps_id p q {T} f :
+ @sat_mul_cps p q T f = f (sat_mul p q).
+ Proof. cbv [sat_mul_cps sat_mul]; prove_id. Qed.
+ Hint Opaque sat_mul : uncps.
+ Hint Rewrite sat_mul_cps_id : uncps.
+
Lemma eval_sat_mul p q : eval (sat_mul p q) = eval p * eval q.
Proof. cbv [sat_mul_cps sat_mul]; induction p; prove_eval. Qed.
Hint Rewrite eval_sat_mul : push_basesystem_eval.
@@ -211,7 +484,7 @@ Module B.
@Associational.carry_cps_id
@Associational.carryterm_cps_id
@Associational.reduce_cps_id
- @Associational.split_id
+ @Associational.split_cps_id
@Associational.mul_cps_id : uncps.
Module Positional.
@@ -222,22 +495,28 @@ Module B.
(weight_nonzero : forall i, weight i <> 0).
(** Converting from positional to associational *)
-
Definition to_associational_cps {n:nat} (xs:tuple Z n)
{T} (f:list limb->T) :=
map_cps weight (seq 0 n)
(fun r =>
to_list_cps n xs (fun rr => combine_cps r rr f)).
- Definition to_associational {n} xs := @to_associational_cps n xs _ id.
- Definition eval {n} x := @to_associational_cps n x _ Associational.eval.
+
+ Definition to_associational {n} xs :=
+ @to_associational_cps n xs _ id.
Lemma to_associational_cps_id {n} x {T} f:
@to_associational_cps n x T f = f (to_associational x).
Proof. cbv [to_associational_cps to_associational]; prove_id. Qed.
+ Hint Opaque to_associational : uncps.
Hint Rewrite @to_associational_cps_id : uncps.
+
+ Definition eval {n} x :=
+ @to_associational_cps n x _ Associational.eval.
+
Lemma eval_to_associational {n} x :
Associational.eval (@to_associational n x) = eval x.
- Proof. cbv [to_associational_cps eval to_associational]; prove_eval. Qed.
- Hint Rewrite @eval_to_associational : push_basesystem_eval.
+ Proof.
+ cbv [to_associational_cps eval to_associational]; prove_eval.
+ Qed. Hint Rewrite @eval_to_associational : push_basesystem_eval.
(** Converting from associational to positional *)
@@ -252,8 +531,8 @@ Module B.
Definition add_to_nth_cps {n} i x t {T} (f:tuple Z n->T) :=
@on_tuple_cps _ _ 0 (update_nth_cps i (runtime_add x)) n n t _ f.
+
Definition add_to_nth {n} i x t := @add_to_nth_cps n i x t _ id.
- Hint Opaque add_to_nth : uncps.
Lemma add_to_nth_cps_id {n} i x xs {T} f:
@add_to_nth_cps n i x xs T f = f (add_to_nth i x xs).
Proof.
@@ -261,7 +540,10 @@ Module B.
by (intros; autorewrite with uncps; reflexivity); prove_id.
Unshelve.
intros; subst. autorewrite with uncps push_id. distr_length.
- Qed. Hint Rewrite @add_to_nth_cps_id : uncps.
+ Qed.
+ Hint Opaque add_to_nth : uncps.
+ Hint Rewrite @add_to_nth_cps_id : uncps.
+
Lemma eval_add_to_nth {n} (i:nat) (x:Z) (H:(i<n)%nat) (xs:tuple Z n):
eval (@add_to_nth n i x xs) = weight i * x + eval xs.
Proof.
@@ -287,52 +569,77 @@ Module B.
if dec (fst t mod weight i = 0)
then f (i, let c := fst t / weight i in (c * snd t)%RT)
else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end.
- Lemma place_cps_in_range (t:limb) (n:nat) : (fst (place_cps t n id) < S n)%nat.
+
+ Definition place t i := place_cps t i id.
+ Lemma place_cps_id t i {T} f :
+ @place_cps t i T f = f (place t i).
+ Proof. cbv [place]; induction i; prove_id. Qed.
+ Hint Opaque place : uncps.
+ Hint Rewrite place_cps_id : uncps.
+
+ Lemma place_cps_in_range (t:limb) (n:nat)
+ : (fst (place_cps t n id) < S n)%nat.
Proof. induction n; simpl; break_match; simpl; omega. Qed.
- Lemma weight_place_cps t i : weight (fst (place_cps t i id)) * snd (place_cps t i id) = fst t * snd t.
+ Lemma weight_place_cps t i
+ : weight (fst (place_cps t i id)) * snd (place_cps t i id)
+ = fst t * snd t.
Proof.
induction i; cbv [id]; simpl place_cps; break_match;
autorewrite with cancel_pair;
try find_apply_lem_hyp Z_div_exact_full_2; nsatz || auto.
Qed.
- Definition place t i := place_cps t i id.
- Hint Opaque place : uncps.
- Lemma place_cps_id t i {T} f :
- @place_cps t i T f = f (place t i).
- Proof. cbv [place]; induction i; prove_id. Qed.
- Hint Rewrite place_cps_id : uncps.
- Definition from_associational_cps n (p:list limb) {T} (f:tuple Z n->T):=
- fold_right_cps (fun t st => place_cps t (pred n) (fun p=> add_to_nth_cps (fst p) (snd p) st id)) (zeros n) p f.
+
+ Definition from_associational_cps n (p:list limb)
+ {T} (f:tuple Z n->T):=
+ fold_right_cps
+ (fun t st =>
+ place_cps t (pred n)
+ (fun p=> add_to_nth_cps (fst p) (snd p) st id))
+ (zeros n) p f.
+
Definition from_associational n p := from_associational_cps n p id.
- Hint Opaque from_associational : uncps.
Lemma from_associational_cps_id {n} p {T} f:
@from_associational_cps n p T f = f (from_associational n p).
- Proof. cbv [from_associational_cps from_associational]; prove_id. Qed.
+ Proof.
+ cbv [from_associational_cps from_associational]; prove_id.
+ Qed.
+ Hint Opaque from_associational : uncps.
Hint Rewrite @from_associational_cps_id : uncps.
+
Lemma eval_from_associational {n} p (n_nonzero:n<>O):
eval (from_associational n p) = Associational.eval p.
Proof.
cbv [from_associational_cps from_associational]; induction p;
[|pose proof (place_cps_in_range a (pred n))]; prove_eval.
cbv [place]; rewrite weight_place_cps. nsatz.
- Qed. Hint Rewrite @eval_from_associational using omega : push_basesystem_eval.
+ Qed.
+ Hint Rewrite @eval_from_associational using omega
+ : push_basesystem_eval.
Section Carries.
Context {modulo div : Z->Z->Z}.
Context {div_mod : forall a b:Z, b <> 0 ->
a = b * (div a b) + modulo a b}.
- Definition carry_cps(index:nat) (p:list limb) {T} (f:list limb->T) :=
- @Associational.carry_cps modulo div (weight index) (weight (S index) / weight index) p T f.
+ Definition carry_cps(index:nat) (p:list limb)
+ {T} (f:list limb->T) :=
+ @Associational.carry_cps modulo div
+ (weight index)
+ (weight (S index) / weight index)
+ p T f.
+
Definition carry i p := carry_cps i p id.
- Hint Opaque carry : uncps.
Lemma carry_cps_id i p {T} f:
@carry_cps i p T f = f (carry i p).
- Proof. cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity. Qed.
- Hint Rewrite carry_cps_id : uncps.
+ Proof.
+ cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity.
+ Qed.
+ Hint Opaque carry : uncps. Hint Rewrite carry_cps_id : uncps.
+
Lemma eval_carry i p: weight (S i) / weight i <> 0 ->
Associational.eval (carry i p) = Associational.eval p.
Proof. cbv [carry_cps carry]; intros; eapply @eval_carry; eauto. Qed.
Hint Rewrite @eval_carry : push_basesystem_eval.
+
End Carries.
End Positional.
End Positional.
@@ -342,7 +649,7 @@ Module B.
@Associational.carry_cps_id
@Associational.carryterm_cps_id
@Associational.reduce_cps_id
- @Associational.split_id
+ @Associational.split_cps_id
@Associational.mul_cps_id
@Positional.carry_cps_id
@Positional.from_associational_cps_id
@@ -352,12 +659,12 @@ Module B.
: uncps.
Hint Rewrite
@Associational.eval_sat_mul
- @Associational.eval_mul_noncps
+ @Associational.eval_mul
@Positional.eval_to_associational
@Associational.eval_carry
@Associational.eval_carryterm
@Associational.eval_reduce
- @Associational.eval_split_noncps
+ @Associational.eval_split
@Positional.eval_carry
@Positional.eval_from_associational
@Positional.eval_add_to_nth
@@ -455,4 +762,4 @@ End Ops.
Eval cbv [projT1 addT lift2_sig proj1_sig] in (projT1 addT).
Eval cbv [projT1 mulT lift2_sig proj1_sig] in
(fun m d div_mod => projT1 (mulT m d div_mod)).
-*) \ No newline at end of file
+*)