aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Arithmetic/Saturated.v54
-rw-r--r--src/Util/CPSUtil.v34
2 files changed, 76 insertions, 12 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index bf269ce79..3e33a136a 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -722,23 +722,55 @@ Section API.
Definition zero (n:nat) : T := to_list _ (B.Positional.zeros n).
- Definition divmod (p : T) : T * Z := (List.tl p, List.hd 0 p).
+ Definition divmod_cps (p : T) {R} (f:T * Z->R)
+ := f (List.tl p, List.hd 0 p).
+ Definition divmod p : T * Z := divmod_cps p id.
- Definition drop_high (n : nat) (p : T) : T := firstn n p.
+ Definition drop_high_cps (n : nat) (p : T) {R} (f:T->R)
+ := firstn_cps n p f.
+ Definition drop_high n p : T := drop_high_cps n p id.
- Definition scmul (c : Z) (p : T) : T :=
+ Definition scmul_cps (c : Z) (p : T) {R} (f:T->R) :=
let P := Tuple.from_list (length p) p (eq_refl _) in
Columns.mul_cps (n1:=1) (n3:=length p) (uweight bound) bound c P
- (fun carry_result =>
- to_list _ (left_append (fst carry_result) (snd carry_result))).
+ (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result)
+ (fun result => to_list_cps _ result f)).
+ Definition scmul c p : T := scmul_cps c p id.
- Definition add (p q : T) : T :=
+ Definition add_cps (p q : T) {R} (f:T->R) :=
let P := Tuple.from_list (length p) p (eq_refl _) in
let Q := Tuple.from_list (length q) q (eq_refl _) in
dlet n := max (length p) (length q) in
Columns.add_cps (uweight bound) P Q
- (fun carry_result =>
- to_list (S n) (left_append (fst carry_result) (snd carry_result))).
+ (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result)
+ (fun result => to_list_cps (S n) result f)).
+ Definition add p q : T := add_cps p q id.
+
+ Hint Opaque divmod drop_high scmul add : uncps.
+
+ Section CPSProofs.
+
+ Local Ltac prove_id :=
+ repeat autounfold; autorewrite with uncps; reflexivity.
+
+ Lemma divmod_id p R f :
+ @divmod_cps p R f = f (divmod p).
+ Proof. cbv [divmod_cps divmod]. prove_id. Qed.
+
+ Lemma drop_high_id n p R f :
+ @drop_high_cps n p R f = f (drop_high n p).
+ Proof. cbv [drop_high_cps drop_high]. prove_id. Qed.
+
+ Lemma scmul_id c p R f :
+ @scmul_cps c p R f = f (scmul c p).
+ Proof. cbv [scmul_cps scmul]. prove_id. Qed.
+
+ Lemma add_id p q R f :
+ @add_cps p q R f = f (add p q).
+ Proof. cbv [add_cps add Let_In]. prove_id. Qed.
+
+ End CPSProofs.
+ Hint Rewrite divmod_id drop_high_id scmul_id add_id : uncps.
Section Proofs.
@@ -775,7 +807,7 @@ Section API.
pose proof Z.add_get_carry_full_mod.
pose proof div_correct. pose proof modulo_correct.
repeat match goal with
- | _ => progress (cbv [add eval Let_In]; repeat autounfold)
+ | _ => progress (cbv [add_cps add eval Let_In]; repeat autounfold)
| _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval
| _ => rewrite B.Positional.eval_left_append
@@ -806,7 +838,7 @@ Section API.
pose proof Z.add_get_carry_full_div.
pose proof Z.add_get_carry_full_mod.
pose proof div_correct. pose proof modulo_correct.
- cbv [small add Let_In]. repeat autounfold.
+ cbv [small add_cps add Let_In]. repeat autounfold.
autorewrite with uncps push_id.
destruct (max (length a) (length b)); [simpl; omega |].
rewrite Columns.hd_to_list, hd_left_append.
@@ -841,7 +873,7 @@ Section API.
Lemma eval_div p : small p -> eval (fst (divmod p)) = eval p / bound.
Proof.
repeat match goal with
- | _ => progress (intros; cbv [divmod eval]; repeat autounfold)
+ | _ => progress (intros; cbv [divmod_cps divmod eval]; repeat autounfold)
| _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval
end.
erewrite from_list_tl.
diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v
index 15782fa61..67a909a61 100644
--- a/src/Util/CPSUtil.v
+++ b/src/Util/CPSUtil.v
@@ -51,6 +51,19 @@ Lemma map_cps_correct {A B} g ls: forall {T} f,
Proof. induction ls as [|?? IHls]; simpl; intros; rewrite ?IHls; reflexivity. Qed.
Create HintDb uncps discriminated. Hint Rewrite @map_cps_correct : uncps.
+Fixpoint firstn_cps {A} (n:nat) (l:list A) {T} (f:list A->T) :=
+ match n with
+ | O => f nil
+ | S n' => match l with
+ | nil => f nil
+ | a :: l' => f (a :: firstn n' l')
+ end
+ end.
+Lemma firstn_cps_correct {A} n l T f :
+ @firstn_cps A n l T f = f (firstn n l).
+Proof. induction n; destruct l; reflexivity. Qed.
+Hint Rewrite @firstn_cps_correct : uncps.
+
Fixpoint flat_map_cps {A B} (g:A->forall {T}, (list B->T)->T) (ls : list A) {T} (f:list B->T) :=
match ls with
| nil => f nil
@@ -285,7 +298,26 @@ Module Tuple.
Proof. destruct n; simpl; rewrite ?mapi_with'_cps_correct by assumption; reflexivity. Qed.
Hint Rewrite @mapi_with_cps_correct @mapi_with'_cps_correct
using (intros; autorewrite with uncps; auto): uncps.
+
+ Fixpoint left_append_cps {A n} (x:A) :
+ tuple A n -> forall {R}, (tuple A (S n) -> R) -> R :=
+ match
+ n as n0 return (tuple A n0 -> forall R, (tuple A (S n0) -> R) -> R)
+ with
+ | 0%nat => fun _ _ f => f x
+ | S n' =>
+ fun xs _ f =>
+ left_append_cps x (tl xs) (fun r => f (append (hd xs) r))
+ end.
+ Lemma left_append_cps_correct A n x xs R f :
+ @left_append_cps A n x xs R f = f (left_append x xs).
+ Proof.
+ induction n; [reflexivity|].
+ simpl left_append. simpl left_append_cps.
+ rewrite IHn. reflexivity.
+ Qed.
+
End Tuple.
-Hint Rewrite @Tuple.map_cps_correct : uncps.
+Hint Rewrite @Tuple.map_cps_correct @Tuple.left_append_cps_correct : uncps.
Hint Rewrite @Tuple.mapi_with_cps_correct @Tuple.mapi_with'_cps_correct
using (intros; autorewrite with uncps; auto): uncps.