diff options
author | jadep <jade.philipoom@gmail.com> | 2017-06-15 20:12:22 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2017-06-15 20:12:30 -0400 |
commit | 06d3a5f4cffdf615f209677f6ffccd3e8b23a03b (patch) | |
tree | 06634ee13149ab6a6f54ae3d3d9ab1d81a063498 /src | |
parent | 131f341f368b606fd50b57f135e602e40e132b46 (diff) |
CPSify Saturated API in preparation for CPSifying Montgomery (see #194)
Diffstat (limited to 'src')
-rw-r--r-- | src/Arithmetic/Saturated.v | 54 | ||||
-rw-r--r-- | src/Util/CPSUtil.v | 34 |
2 files changed, 76 insertions, 12 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index bf269ce79..3e33a136a 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -722,23 +722,55 @@ Section API. Definition zero (n:nat) : T := to_list _ (B.Positional.zeros n). - Definition divmod (p : T) : T * Z := (List.tl p, List.hd 0 p). + Definition divmod_cps (p : T) {R} (f:T * Z->R) + := f (List.tl p, List.hd 0 p). + Definition divmod p : T * Z := divmod_cps p id. - Definition drop_high (n : nat) (p : T) : T := firstn n p. + Definition drop_high_cps (n : nat) (p : T) {R} (f:T->R) + := firstn_cps n p f. + Definition drop_high n p : T := drop_high_cps n p id. - Definition scmul (c : Z) (p : T) : T := + Definition scmul_cps (c : Z) (p : T) {R} (f:T->R) := let P := Tuple.from_list (length p) p (eq_refl _) in Columns.mul_cps (n1:=1) (n3:=length p) (uweight bound) bound c P - (fun carry_result => - to_list _ (left_append (fst carry_result) (snd carry_result))). + (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) + (fun result => to_list_cps _ result f)). + Definition scmul c p : T := scmul_cps c p id. - Definition add (p q : T) : T := + Definition add_cps (p q : T) {R} (f:T->R) := let P := Tuple.from_list (length p) p (eq_refl _) in let Q := Tuple.from_list (length q) q (eq_refl _) in dlet n := max (length p) (length q) in Columns.add_cps (uweight bound) P Q - (fun carry_result => - to_list (S n) (left_append (fst carry_result) (snd carry_result))). + (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) + (fun result => to_list_cps (S n) result f)). + Definition add p q : T := add_cps p q id. + + Hint Opaque divmod drop_high scmul add : uncps. + + Section CPSProofs. + + Local Ltac prove_id := + repeat autounfold; autorewrite with uncps; reflexivity. + + Lemma divmod_id p R f : + @divmod_cps p R f = f (divmod p). + Proof. cbv [divmod_cps divmod]. prove_id. Qed. + + Lemma drop_high_id n p R f : + @drop_high_cps n p R f = f (drop_high n p). + Proof. cbv [drop_high_cps drop_high]. prove_id. Qed. + + Lemma scmul_id c p R f : + @scmul_cps c p R f = f (scmul c p). + Proof. cbv [scmul_cps scmul]. prove_id. Qed. + + Lemma add_id p q R f : + @add_cps p q R f = f (add p q). + Proof. cbv [add_cps add Let_In]. prove_id. Qed. + + End CPSProofs. + Hint Rewrite divmod_id drop_high_id scmul_id add_id : uncps. Section Proofs. @@ -775,7 +807,7 @@ Section API. pose proof Z.add_get_carry_full_mod. pose proof div_correct. pose proof modulo_correct. repeat match goal with - | _ => progress (cbv [add eval Let_In]; repeat autounfold) + | _ => progress (cbv [add_cps add eval Let_In]; repeat autounfold) | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval | _ => rewrite B.Positional.eval_left_append @@ -806,7 +838,7 @@ Section API. pose proof Z.add_get_carry_full_div. pose proof Z.add_get_carry_full_mod. pose proof div_correct. pose proof modulo_correct. - cbv [small add Let_In]. repeat autounfold. + cbv [small add_cps add Let_In]. repeat autounfold. autorewrite with uncps push_id. destruct (max (length a) (length b)); [simpl; omega |]. rewrite Columns.hd_to_list, hd_left_append. @@ -841,7 +873,7 @@ Section API. Lemma eval_div p : small p -> eval (fst (divmod p)) = eval p / bound. Proof. repeat match goal with - | _ => progress (intros; cbv [divmod eval]; repeat autounfold) + | _ => progress (intros; cbv [divmod_cps divmod eval]; repeat autounfold) | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval end. erewrite from_list_tl. diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v index 15782fa61..67a909a61 100644 --- a/src/Util/CPSUtil.v +++ b/src/Util/CPSUtil.v @@ -51,6 +51,19 @@ Lemma map_cps_correct {A B} g ls: forall {T} f, Proof. induction ls as [|?? IHls]; simpl; intros; rewrite ?IHls; reflexivity. Qed. Create HintDb uncps discriminated. Hint Rewrite @map_cps_correct : uncps. +Fixpoint firstn_cps {A} (n:nat) (l:list A) {T} (f:list A->T) := + match n with + | O => f nil + | S n' => match l with + | nil => f nil + | a :: l' => f (a :: firstn n' l') + end + end. +Lemma firstn_cps_correct {A} n l T f : + @firstn_cps A n l T f = f (firstn n l). +Proof. induction n; destruct l; reflexivity. Qed. +Hint Rewrite @firstn_cps_correct : uncps. + Fixpoint flat_map_cps {A B} (g:A->forall {T}, (list B->T)->T) (ls : list A) {T} (f:list B->T) := match ls with | nil => f nil @@ -285,7 +298,26 @@ Module Tuple. Proof. destruct n; simpl; rewrite ?mapi_with'_cps_correct by assumption; reflexivity. Qed. Hint Rewrite @mapi_with_cps_correct @mapi_with'_cps_correct using (intros; autorewrite with uncps; auto): uncps. + + Fixpoint left_append_cps {A n} (x:A) : + tuple A n -> forall {R}, (tuple A (S n) -> R) -> R := + match + n as n0 return (tuple A n0 -> forall R, (tuple A (S n0) -> R) -> R) + with + | 0%nat => fun _ _ f => f x + | S n' => + fun xs _ f => + left_append_cps x (tl xs) (fun r => f (append (hd xs) r)) + end. + Lemma left_append_cps_correct A n x xs R f : + @left_append_cps A n x xs R f = f (left_append x xs). + Proof. + induction n; [reflexivity|]. + simpl left_append. simpl left_append_cps. + rewrite IHn. reflexivity. + Qed. + End Tuple. -Hint Rewrite @Tuple.map_cps_correct : uncps. +Hint Rewrite @Tuple.map_cps_correct @Tuple.left_append_cps_correct : uncps. Hint Rewrite @Tuple.mapi_with_cps_correct @Tuple.mapi_with'_cps_correct using (intros; autorewrite with uncps; auto): uncps. |