aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-10-17 22:37:26 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-10-18 11:02:14 -0400
commita426187067726ecb0362aabe12c3877166e427a0 (patch)
treeaf46a607b5f63c5d6ecf034bef7d4d20a10007d2 /src/Arithmetic
parentb5c975c9b5a7cf522d9bd94a7843b96d91f64a9b (diff)
Allow instantiating type arguments without reducing matches
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/Core.v42
-rw-r--r--src/Arithmetic/Saturated/AddSub.v72
-rw-r--r--src/Arithmetic/Saturated/Core.v36
3 files changed, 83 insertions, 67 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v
index 542b7d72b..8e7f3cb23 100644
--- a/src/Arithmetic/Core.v
+++ b/src/Arithmetic/Core.v
@@ -348,21 +348,25 @@ Module B.
Proof. cbv [mul mul_cps]; induction p; prove_eval. Qed.
Hint Rewrite eval_mul : push_basesystem_eval.
- Fixpoint split_cps (s:Z) (xs:list limb)
- {T} (f :list limb*list limb->T) :=
- match xs with
- | nil => f (nil, nil)
- | cons x xs' =>
- split_cps s xs'
- (fun sxs' =>
- if dec (fst x mod s = 0)
- then f (fst sxs', cons (fst x / s, snd x) (snd sxs'))
- else f (cons x (fst sxs'), snd sxs'))
- end.
+ Section split_cps.
+ Context (s:Z) {T : Type}.
+
+ Fixpoint split_cps (xs:list limb)
+ (f :list limb*list limb->T) :=
+ match xs with
+ | nil => f (nil, nil)
+ | cons x xs' =>
+ split_cps xs'
+ (fun sxs' =>
+ if dec (fst x mod s = 0)
+ then f (fst sxs', cons (fst x / s, snd x) (snd sxs'))
+ else f (cons x (fst sxs'), snd sxs'))
+ end.
+ End split_cps.
Definition split s xs := split_cps s xs id.
Lemma split_cps_id s p: forall {T} f,
- @split_cps s p T f = f (split s p).
+ @split_cps s T p f = f (split s p).
Proof.
induction p as [|?? IHp];
repeat match goal with
@@ -576,14 +580,18 @@ Module B.
intros; subst. autorewrite with uncps push_id. distr_length.
Qed. Hint Rewrite @eval_add_to_nth using omega : push_basesystem_eval.
- Fixpoint place_cps (t:limb) (i:nat) {T} (f:nat * Z->T) :=
- if dec (fst t mod weight i = 0)
- then f (i, let c := fst t / weight i in (c * snd t)%RT)
- else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end.
+ Section place_cps.
+ Context {T : Type}.
+
+ Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) :=
+ if dec (fst t mod weight i = 0)
+ then f (i, let c := fst t / weight i in (c * snd t)%RT)
+ else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end.
+ End place_cps.
Definition place t i := place_cps t i id.
Lemma place_cps_id t i {T} f :
- @place_cps t i T f = f (place t i).
+ @place_cps T t i f = f (place t i).
Proof using Type. cbv [place]; induction i; prove_id. Qed.
Hint Opaque place : uncps.
Hint Rewrite place_cps_id : uncps.
diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v
index b8b6ee31a..c1c701870 100644
--- a/src/Arithmetic/Saturated/AddSub.v
+++ b/src/Arithmetic/Saturated/AddSub.v
@@ -20,30 +20,34 @@ Module B.
{op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *)
{op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *)
- Fixpoint chain_op'_cps {n}:
- option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T :=
- match n with
- | O => fun c p _ _ f =>
- let carry := match c with | None => 0 | Some x => x end in
- f (carry,p)
- | S n' =>
- fun c p q _ f =>
- (* for the first call, use op_get_carry, then op_with_carry *)
- let op' := match c with
- | None => op_get_carry
- | Some x => op_with_carry x end in
- dlet carry_result := op' (hd p) (hd q) in
- chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) _
- (fun carry_pq =>
- f (fst carry_pq,
- append (fst carry_result) (snd carry_pq)))
- end.
- Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id.
- Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f.
+ Section chain_op'_cps.
+ Context (T : Type).
+
+ Fixpoint chain_op'_cps {n} (c:option Z) (p q:Z^n)
+ : (Z*Z^n->T)->T :=
+ match n return option Z -> Z^n -> Z^n -> (Z*Z^n -> T) -> T with
+ | O => fun c p _ f =>
+ let carry := match c with | None => 0 | Some x => x end in
+ f (carry,p)
+ | S n' =>
+ fun c p q f =>
+ (* for the first call, use op_get_carry, then op_with_carry *)
+ let op' := match c with
+ | None => op_get_carry
+ | Some x => op_with_carry x end in
+ dlet carry_result := op' (hd p) (hd q) in
+ chain_op'_cps (Some (snd carry_result)) (tl p) (tl q)
+ (fun carry_pq =>
+ f (fst carry_pq,
+ append (fst carry_result) (snd carry_pq)))
+ end c p q.
+ End chain_op'_cps.
+ Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id.
+ Definition chain_op_cps {n} p q {T} f := @chain_op'_cps T n None p q f.
Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id.
Lemma chain_op'_id {n} : forall c p q T f,
- @chain_op'_cps n c p q T f = f (chain_op' c p q).
+ @chain_op'_cps T n c p q f = f (chain_op' c p q).
Proof.
cbv [chain_op']; induction n; intros; destruct c;
simpl chain_op'_cps; cbv [Let_In]; try reflexivity.
@@ -53,7 +57,7 @@ Module B.
Lemma chain_op_id {n} p q T f :
@chain_op_cps n p q T f = f (chain_op p q).
- Proof. apply chain_op'_id. Qed.
+ Proof. apply (@chain_op'_id n None). Qed.
End GenericOp.
Hint Opaque chain_op chain_op' : uncps.
Hint Rewrite @chain_op_id @chain_op'_id : uncps.
@@ -117,9 +121,9 @@ Module B.
| _ => progress autorewrite with uncps divmod push_id cancel_pair push_basesystem_eval
| _ => rewrite uweight_0, ?Z.mod_1_r, ?Z.div_1_r
| _ => rewrite uweight_succ
- | _ => rewrite Z.sub_opp_r
- | _ => rewrite sat_add_mod_step
- | _ => rewrite sat_add_div_step
+ | _ => rewrite Z.sub_opp_r
+ | _ => rewrite sat_add_mod_step
+ | _ => rewrite sat_add_div_step
| p : Z ^ 0 |- _ => destruct p
| _ => rewrite uweight_eval_step, ?hd_append, ?tl_append
| |- context[B.Positional.eval _ (snd (chain_op' ?c ?p ?q))]
@@ -149,8 +153,8 @@ Module B.
| _ => progress (simpl chain_op'_cps in * )
| _ => progress autorewrite with uncps push_id cancel_pair in H
| H : _ |- _ => rewrite to_list_append in H;
- simpl In in H
- | H : _ \/ _ |- _ => destruct H
+ simpl In in H
+ | H : _ \/ _ |- _ => destruct H
| _ => contradiction
end.
{ subst x.
@@ -185,9 +189,9 @@ Module B.
| _ => progress autorewrite with uncps divmod push_id cancel_pair push_basesystem_eval
| _ => rewrite uweight_0, ?Z.mod_1_r, ?Z.div_1_r
| _ => rewrite uweight_succ
- | _ => rewrite Z.sub_opp_r
- | _ => rewrite sat_add_mod_step
- | _ => rewrite sat_add_div_step
+ | _ => rewrite Z.sub_opp_r
+ | _ => rewrite sat_add_mod_step
+ | _ => rewrite sat_add_div_step
| p : Z ^ 0 |- _ => destruct p
| _ => rewrite uweight_eval_step, ?hd_append, ?tl_append
| |- context[B.Positional.eval _ (snd (chain_op' ?c ?p ?q))]
@@ -197,7 +201,7 @@ Module B.
| _ => solve [split; repeat (f_equal; ring_simplify; try omega)]
end.
Qed.
-
+
Lemma sat_sub_mod n p q :
eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n).
Proof. exact (proj1 (sat_sub_divmod n p q)). Qed.
@@ -217,8 +221,8 @@ Module B.
| _ => progress (simpl chain_op'_cps in * )
| _ => progress autorewrite with uncps push_id cancel_pair in H
| H : _ |- _ => rewrite to_list_append in H;
- simpl In in H
- | H : _ \/ _ |- _ => destruct H
+ simpl In in H
+ | H : _ \/ _ |- _ => destruct H
| _ => contradiction
end.
{ subst x.
@@ -233,4 +237,4 @@ Module B.
End B.
Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps.
Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps.
-Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. \ No newline at end of file
+Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval.
diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v
index face608ab..ce9888b74 100644
--- a/src/Arithmetic/Saturated/Core.v
+++ b/src/Arithmetic/Saturated/Core.v
@@ -156,25 +156,29 @@ Module Columns.
(* Sums a list of integers using carry bits.
Output : carry, sum
*)
- Fixpoint compact_digit_cps n (digit : list Z) {T} (f:Z * Z->T) :=
- match digit with
- | nil => f (0, 0)
- | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n))
- | x :: y :: nil =>
- dlet sum_carry := add_get_carry (weight (S n) / weight n) x y in
- dlet carry := snd sum_carry in
- f (carry, fst sum_carry)
- | x :: tl =>
- compact_digit_cps n tl
- (fun rec =>
- dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in
- dlet carry' := (fst rec + snd sum_carry)%RT in
- f (carry', fst sum_carry))
- end.
+ Section compact_digit_cps.
+ Context (n : nat) {T : Type}.
+
+ Fixpoint compact_digit_cps (digit : list Z) (f:Z * Z->T) :=
+ match digit with
+ | nil => f (0, 0)
+ | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n))
+ | x :: y :: nil =>
+ dlet sum_carry := add_get_carry (weight (S n) / weight n) x y in
+ dlet carry := snd sum_carry in
+ f (carry, fst sum_carry)
+ | x :: tl =>
+ compact_digit_cps tl
+ (fun rec =>
+ dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in
+ dlet carry' := (fst rec + snd sum_carry)%RT in
+ f (carry', fst sum_carry))
+ end.
+ End compact_digit_cps.
Definition compact_digit n digit := compact_digit_cps n digit id.
Lemma compact_digit_id n digit: forall {T} f,
- @compact_digit_cps n digit T f = f (compact_digit n digit).
+ @compact_digit_cps n T digit f = f (compact_digit n digit).
Proof using Type.
induction digit; intros; cbv [compact_digit]; [reflexivity|];
simpl compact_digit_cps; break_match; rewrite ?IHdigit;