aboutsummaryrefslogtreecommitdiff
path: root/src/NewBaseSystem.v
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-02-27 11:55:10 -0500
committerGravatar jadep <jade.philipoom@gmail.com>2017-02-27 11:55:34 -0500
commitb94f708b8ee634da2c6920f9417a7e72ca47814c (patch)
tree19f2828e16992a663729c1b1d11007444188b0cd /src/NewBaseSystem.v
parent1fb1958d505974a3864322d3f10f5dfa042f363a (diff)
added Positional wrappers for Associational operations, added correctness proof of
Diffstat (limited to 'src/NewBaseSystem.v')
-rw-r--r--src/NewBaseSystem.v88
1 files changed, 63 insertions, 25 deletions
diff --git a/src/NewBaseSystem.v b/src/NewBaseSystem.v
index 40f7e3821..7f73c7cdf 100644
--- a/src/NewBaseSystem.v
+++ b/src/NewBaseSystem.v
@@ -418,7 +418,7 @@ Module B.
nsatz.
Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval.
- Definition carry_cps(w fw:Z) (p:list limb) {T} (f:list limb->T) :=
+ Definition carry_cps (w fw:Z) (p:list limb) {T} (f:list limb->T) :=
flat_map_cps (carryterm_cps w fw) p f.
Definition carry w fw p := carry_cps w fw p id.
@@ -595,13 +595,56 @@ Module B.
Proof. cbv [carry_cps carry]; intros; eapply @eval_carry; eauto. Qed.
Hint Rewrite @eval_carry : push_basesystem_eval.
- (* TODO make a correctness proof for this *)
- Definition chained_carries (p:list limb) (idxs : list nat)
- {T} (f:list limb->T) :=
- fold_right_cps2 carry_cps p idxs f.
+ Definition chained_carries_cps {n} (p:tuple Z n) (idxs : list nat)
+ {T} (f:tuple Z n->T) :=
+ to_associational_cps p
+ (fun P => fold_right_cps2 carry_cps P idxs
+ (fun R => from_associational_cps n R f)).
+
+ Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id.
+ Lemma chained_carries_id {n} p idxs : forall {T} f,
+ @chained_carries_cps n p idxs T f = f (chained_carries p idxs).
+ Proof. cbv [chained_carries_cps chained_carries]; prove_id. Qed.
+ Hint Opaque chained_carries : uncps.
+ Hint Rewrite @chained_carries_id : uncps.
+
+ Lemma eval_chained_carries {n} (p:tuple Z n) idxs :
+ (forall i, In i idxs -> weight (S i) / weight i <> 0) ->
+ eval (chained_carries p idxs) = eval p.
+ Proof.
+ cbv [chained_carries chained_carries_cps]; intros;
+ autorewrite with uncps push_id.
+ apply fold_right_invariant; destruct n; prove_eval; auto.
+ Qed. Hint Rewrite @eval_chained_carries : push_basesystem_eval.
End Carries.
+
+ Section Wrappers.
+ (* Simple wrappers for Associational definitions; convert to
+ associational, do the operation, convert back. *)
+
+ Definition add_cps {n} (p q : tuple Z n) {T} (f:tuple Z n->T) :=
+ to_associational_cps p
+ (fun P => to_associational_cps q
+ (fun Q => from_associational_cps n (P++Q) f)).
+
+ Definition mul_cps {n m} (p q : tuple Z n) {T} (f:tuple Z m->T) :=
+ to_associational_cps p
+ (fun P => to_associational_cps q
+ (fun Q => Associational.mul_cps P Q
+ (fun PQ => from_associational_cps m PQ f))).
+
+ Definition reduce_cps {m n} (s:Z) (c:list B.limb) (p : tuple Z m)
+ {T} (f:tuple Z n->T) :=
+ to_associational_cps p
+ (fun P => Associational.reduce_cps s c P
+ (fun R => from_associational_cps n R f)).
+ End Wrappers.
End Positional.
End Positional.
+ Hint Unfold
+ Positional.add_cps
+ Positional.mul_cps
+ Positional.reduce_cps.
Hint Rewrite
@Associational.carry_cps_id
@Associational.carryterm_cps_id
@@ -613,6 +656,7 @@ Module B.
@Positional.place_cps_id
@Positional.add_to_nth_cps_id
@Positional.to_associational_cps_id
+ @Positional.chained_carries_id
: uncps.
Hint Rewrite
@Associational.eval_mul
@@ -624,6 +668,7 @@ Module B.
@Positional.eval_carry
@Positional.eval_from_associational
@Positional.eval_add_to_nth
+ @Positional.eval_chained_carries
using (omega || assumption) : push_basesystem_eval.
End B.
@@ -635,7 +680,7 @@ Ltac basesystem_partial_evaluation_RHS :=
let t0 := match goal with |- _ _ ?t => t end in
let t := (eval cbv delta [
(* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], or [runtime_mul] *)
-Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries
+Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries_cps Positional.chained_carries
Associational.eval Associational.multerm Associational.mul_cps Associational.mul Associational.split_cps Associational.split Associational.reduce_cps Associational.reduce Associational.carryterm_cps Associational.carryterm Associational.carry_cps Associational.carry
] in t0) in
let t := (eval pattern @runtime_mul in t) in
@@ -662,6 +707,10 @@ Ltac assert_preconditions :=
unique assert (n <> 0%nat) by (cbv; congruence)
| |- context [Positional.carry_cps?wt ?i] =>
unique assert (wt (S i) / wt i <> 0) by (cbv; congruence)
+ | |- context [Positional.chained_carries_cps ?wt _ ?idxs] =>
+ unique assert (forall i, In i idxs -> wt (S i) / wt i <> 0)
+ by (clear; simpl; intuition; subst_let; subst; cbv in *;
+ congruence)
| |- context [Associational.reduce_cps ?s _] =>
unique assert (s <> 0) by (cbv; congruence)
| |- context [Associational.reduce_cps ?s ?c] =>
@@ -716,13 +765,10 @@ Section Ops.
eval (add a b) = eval a + eval b }.
Proof.
let x := constr:(fun wt a b =>
- Positional.to_associational_cps (n := sz) wt a
- (fun r => Positional.to_associational_cps (n := sz) wt b
- (fun r0 => Positional.from_associational_cps wt sz (r ++ r0) id
- ))) in
+ Positional.add_cps (n := sz) wt a b id) in
lift2_sig; eexists;
- transitivity (Positional.eval wt (x wt a b));
- [|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].
+ transitivity (Positional.eval wt (x wt a b)); autounfold;
+ [|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].
apply f_equal.
@@ -738,20 +784,12 @@ Section Ops.
mod_eq m (eval (mul a b)) (eval a * eval b)}.
Proof.
let x := constr:(fun w a b =>
- Positional.to_associational_cps (n := sz) w a
- (fun r => Positional.to_associational_cps (n := sz) w b
- (fun r0 => Associational.mul_cps r r0
- (fun r1 => Positional.from_associational_cps w sz2 r1
- (fun r2 => Positional.to_associational_cps w r2
- (fun r3 => Associational.reduce_cps s c r3
- (fun r4 => Positional.from_associational_cps w sz r4
- (fun r5 => Positional.to_associational_cps w r5
- (fun r6 => Positional.chained_carries(div:=div)(modulo:=modulo) w r6 (seq 0 sz)
- (fun r13 => Positional.from_associational_cps w sz r13 id
- )))))))))) in
+ Positional.mul_cps (n:=sz) (m:=sz2) w a b
+ (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) w s c ab
+ (fun r => Positional.chained_carries_cps (n:=sz) (div:=div)(modulo:=modulo) w r (seq 0 sz) id))) in
lift2_sig; eexists;
- transitivity (Positional.eval wt (x wt a b));
- [|cbv [Positional.chained_carries fold_right_cps2 seq fold_right sz2 sz]; assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].
+ transitivity (Positional.eval wt (x wt a b)); autounfold;
+ [|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].
cbv [mod_eq].
apply f_equal2; [|reflexivity].