From b94f708b8ee634da2c6920f9417a7e72ca47814c Mon Sep 17 00:00:00 2001 From: jadep Date: Mon, 27 Feb 2017 11:55:10 -0500 Subject: added Positional wrappers for Associational operations, added correctness proof of --- src/NewBaseSystem.v | 88 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 25 deletions(-) (limited to 'src/NewBaseSystem.v') diff --git a/src/NewBaseSystem.v b/src/NewBaseSystem.v index 40f7e3821..7f73c7cdf 100644 --- a/src/NewBaseSystem.v +++ b/src/NewBaseSystem.v @@ -418,7 +418,7 @@ Module B. nsatz. Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval. - Definition carry_cps(w fw:Z) (p:list limb) {T} (f:list limb->T) := + Definition carry_cps (w fw:Z) (p:list limb) {T} (f:list limb->T) := flat_map_cps (carryterm_cps w fw) p f. Definition carry w fw p := carry_cps w fw p id. @@ -595,13 +595,56 @@ Module B. Proof. cbv [carry_cps carry]; intros; eapply @eval_carry; eauto. Qed. Hint Rewrite @eval_carry : push_basesystem_eval. - (* TODO make a correctness proof for this *) - Definition chained_carries (p:list limb) (idxs : list nat) - {T} (f:list limb->T) := - fold_right_cps2 carry_cps p idxs f. + Definition chained_carries_cps {n} (p:tuple Z n) (idxs : list nat) + {T} (f:tuple Z n->T) := + to_associational_cps p + (fun P => fold_right_cps2 carry_cps P idxs + (fun R => from_associational_cps n R f)). + + Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id. + Lemma chained_carries_id {n} p idxs : forall {T} f, + @chained_carries_cps n p idxs T f = f (chained_carries p idxs). + Proof. cbv [chained_carries_cps chained_carries]; prove_id. Qed. + Hint Opaque chained_carries : uncps. + Hint Rewrite @chained_carries_id : uncps. + + Lemma eval_chained_carries {n} (p:tuple Z n) idxs : + (forall i, In i idxs -> weight (S i) / weight i <> 0) -> + eval (chained_carries p idxs) = eval p. + Proof. + cbv [chained_carries chained_carries_cps]; intros; + autorewrite with uncps push_id. + apply fold_right_invariant; destruct n; prove_eval; auto. + Qed. Hint Rewrite @eval_chained_carries : push_basesystem_eval. End Carries. + + Section Wrappers. + (* Simple wrappers for Associational definitions; convert to + associational, do the operation, convert back. *) + + Definition add_cps {n} (p q : tuple Z n) {T} (f:tuple Z n->T) := + to_associational_cps p + (fun P => to_associational_cps q + (fun Q => from_associational_cps n (P++Q) f)). + + Definition mul_cps {n m} (p q : tuple Z n) {T} (f:tuple Z m->T) := + to_associational_cps p + (fun P => to_associational_cps q + (fun Q => Associational.mul_cps P Q + (fun PQ => from_associational_cps m PQ f))). + + Definition reduce_cps {m n} (s:Z) (c:list B.limb) (p : tuple Z m) + {T} (f:tuple Z n->T) := + to_associational_cps p + (fun P => Associational.reduce_cps s c P + (fun R => from_associational_cps n R f)). + End Wrappers. End Positional. End Positional. + Hint Unfold + Positional.add_cps + Positional.mul_cps + Positional.reduce_cps. Hint Rewrite @Associational.carry_cps_id @Associational.carryterm_cps_id @@ -613,6 +656,7 @@ Module B. @Positional.place_cps_id @Positional.add_to_nth_cps_id @Positional.to_associational_cps_id + @Positional.chained_carries_id : uncps. Hint Rewrite @Associational.eval_mul @@ -624,6 +668,7 @@ Module B. @Positional.eval_carry @Positional.eval_from_associational @Positional.eval_add_to_nth + @Positional.eval_chained_carries using (omega || assumption) : push_basesystem_eval. End B. @@ -635,7 +680,7 @@ Ltac basesystem_partial_evaluation_RHS := let t0 := match goal with |- _ _ ?t => t end in let t := (eval cbv delta [ (* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], or [runtime_mul] *) -Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries +Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries_cps Positional.chained_carries Associational.eval Associational.multerm Associational.mul_cps Associational.mul Associational.split_cps Associational.split Associational.reduce_cps Associational.reduce Associational.carryterm_cps Associational.carryterm Associational.carry_cps Associational.carry ] in t0) in let t := (eval pattern @runtime_mul in t) in @@ -662,6 +707,10 @@ Ltac assert_preconditions := unique assert (n <> 0%nat) by (cbv; congruence) | |- context [Positional.carry_cps?wt ?i] => unique assert (wt (S i) / wt i <> 0) by (cbv; congruence) + | |- context [Positional.chained_carries_cps ?wt _ ?idxs] => + unique assert (forall i, In i idxs -> wt (S i) / wt i <> 0) + by (clear; simpl; intuition; subst_let; subst; cbv in *; + congruence) | |- context [Associational.reduce_cps ?s _] => unique assert (s <> 0) by (cbv; congruence) | |- context [Associational.reduce_cps ?s ?c] => @@ -716,13 +765,10 @@ Section Ops. eval (add a b) = eval a + eval b }. Proof. let x := constr:(fun wt a b => - Positional.to_associational_cps (n := sz) wt a - (fun r => Positional.to_associational_cps (n := sz) wt b - (fun r0 => Positional.from_associational_cps wt sz (r ++ r0) id - ))) in + Positional.add_cps (n := sz) wt a b id) in lift2_sig; eexists; - transitivity (Positional.eval wt (x wt a b)); - [|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity]. + transitivity (Positional.eval wt (x wt a b)); autounfold; + [|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity]. apply f_equal. @@ -738,20 +784,12 @@ Section Ops. mod_eq m (eval (mul a b)) (eval a * eval b)}. Proof. let x := constr:(fun w a b => - Positional.to_associational_cps (n := sz) w a - (fun r => Positional.to_associational_cps (n := sz) w b - (fun r0 => Associational.mul_cps r r0 - (fun r1 => Positional.from_associational_cps w sz2 r1 - (fun r2 => Positional.to_associational_cps w r2 - (fun r3 => Associational.reduce_cps s c r3 - (fun r4 => Positional.from_associational_cps w sz r4 - (fun r5 => Positional.to_associational_cps w r5 - (fun r6 => Positional.chained_carries(div:=div)(modulo:=modulo) w r6 (seq 0 sz) - (fun r13 => Positional.from_associational_cps w sz r13 id - )))))))))) in + Positional.mul_cps (n:=sz) (m:=sz2) w a b + (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) w s c ab + (fun r => Positional.chained_carries_cps (n:=sz) (div:=div)(modulo:=modulo) w r (seq 0 sz) id))) in lift2_sig; eexists; - transitivity (Positional.eval wt (x wt a b)); - [|cbv [Positional.chained_carries fold_right_cps2 seq fold_right sz2 sz]; assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity]. + transitivity (Positional.eval wt (x wt a b)); autounfold; + [|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity]. cbv [mod_eq]. apply f_equal2; [|reflexivity]. -- cgit v1.2.3