diff options
author | Jason Gross <jgross@mit.edu> | 2017-10-17 22:37:26 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-10-18 11:02:14 -0400 |
commit | a426187067726ecb0362aabe12c3877166e427a0 (patch) | |
tree | af46a607b5f63c5d6ecf034bef7d4d20a10007d2 | |
parent | b5c975c9b5a7cf522d9bd94a7843b96d91f64a9b (diff) |
Allow instantiating type arguments without reducing matches
-rw-r--r-- | src/Arithmetic/Core.v | 42 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/AddSub.v | 72 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Core.v | 36 | ||||
-rw-r--r-- | src/Util/CPSUtil.v | 12 |
4 files changed, 89 insertions, 73 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v index 542b7d72b..8e7f3cb23 100644 --- a/src/Arithmetic/Core.v +++ b/src/Arithmetic/Core.v @@ -348,21 +348,25 @@ Module B. Proof. cbv [mul mul_cps]; induction p; prove_eval. Qed. Hint Rewrite eval_mul : push_basesystem_eval. - Fixpoint split_cps (s:Z) (xs:list limb) - {T} (f :list limb*list limb->T) := - match xs with - | nil => f (nil, nil) - | cons x xs' => - split_cps s xs' - (fun sxs' => - if dec (fst x mod s = 0) - then f (fst sxs', cons (fst x / s, snd x) (snd sxs')) - else f (cons x (fst sxs'), snd sxs')) - end. + Section split_cps. + Context (s:Z) {T : Type}. + + Fixpoint split_cps (xs:list limb) + (f :list limb*list limb->T) := + match xs with + | nil => f (nil, nil) + | cons x xs' => + split_cps xs' + (fun sxs' => + if dec (fst x mod s = 0) + then f (fst sxs', cons (fst x / s, snd x) (snd sxs')) + else f (cons x (fst sxs'), snd sxs')) + end. + End split_cps. Definition split s xs := split_cps s xs id. Lemma split_cps_id s p: forall {T} f, - @split_cps s p T f = f (split s p). + @split_cps s T p f = f (split s p). Proof. induction p as [|?? IHp]; repeat match goal with @@ -576,14 +580,18 @@ Module B. intros; subst. autorewrite with uncps push_id. distr_length. Qed. Hint Rewrite @eval_add_to_nth using omega : push_basesystem_eval. - Fixpoint place_cps (t:limb) (i:nat) {T} (f:nat * Z->T) := - if dec (fst t mod weight i = 0) - then f (i, let c := fst t / weight i in (c * snd t)%RT) - else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end. + Section place_cps. + Context {T : Type}. + + Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) := + if dec (fst t mod weight i = 0) + then f (i, let c := fst t / weight i in (c * snd t)%RT) + else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end. + End place_cps. Definition place t i := place_cps t i id. Lemma place_cps_id t i {T} f : - @place_cps t i T f = f (place t i). + @place_cps T t i f = f (place t i). Proof using Type. cbv [place]; induction i; prove_id. Qed. Hint Opaque place : uncps. Hint Rewrite place_cps_id : uncps. diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v index b8b6ee31a..c1c701870 100644 --- a/src/Arithmetic/Saturated/AddSub.v +++ b/src/Arithmetic/Saturated/AddSub.v @@ -20,30 +20,34 @@ Module B. {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) - Fixpoint chain_op'_cps {n}: - option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T := - match n with - | O => fun c p _ _ f => - let carry := match c with | None => 0 | Some x => x end in - f (carry,p) - | S n' => - fun c p q _ f => - (* for the first call, use op_get_carry, then op_with_carry *) - let op' := match c with - | None => op_get_carry - | Some x => op_with_carry x end in - dlet carry_result := op' (hd p) (hd q) in - chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) _ - (fun carry_pq => - f (fst carry_pq, - append (fst carry_result) (snd carry_pq))) - end. - Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id. - Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f. + Section chain_op'_cps. + Context (T : Type). + + Fixpoint chain_op'_cps {n} (c:option Z) (p q:Z^n) + : (Z*Z^n->T)->T := + match n return option Z -> Z^n -> Z^n -> (Z*Z^n -> T) -> T with + | O => fun c p _ f => + let carry := match c with | None => 0 | Some x => x end in + f (carry,p) + | S n' => + fun c p q f => + (* for the first call, use op_get_carry, then op_with_carry *) + let op' := match c with + | None => op_get_carry + | Some x => op_with_carry x end in + dlet carry_result := op' (hd p) (hd q) in + chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) + (fun carry_pq => + f (fst carry_pq, + append (fst carry_result) (snd carry_pq))) + end c p q. + End chain_op'_cps. + Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id. + Definition chain_op_cps {n} p q {T} f := @chain_op'_cps T n None p q f. Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id. Lemma chain_op'_id {n} : forall c p q T f, - @chain_op'_cps n c p q T f = f (chain_op' c p q). + @chain_op'_cps T n c p q f = f (chain_op' c p q). Proof. cbv [chain_op']; induction n; intros; destruct c; simpl chain_op'_cps; cbv [Let_In]; try reflexivity. @@ -53,7 +57,7 @@ Module B. Lemma chain_op_id {n} p q T f : @chain_op_cps n p q T f = f (chain_op p q). - Proof. apply chain_op'_id. Qed. + Proof. apply (@chain_op'_id n None). Qed. End GenericOp. Hint Opaque chain_op chain_op' : uncps. Hint Rewrite @chain_op_id @chain_op'_id : uncps. @@ -117,9 +121,9 @@ Module B. | _ => progress autorewrite with uncps divmod push_id cancel_pair push_basesystem_eval | _ => rewrite uweight_0, ?Z.mod_1_r, ?Z.div_1_r | _ => rewrite uweight_succ - | _ => rewrite Z.sub_opp_r - | _ => rewrite sat_add_mod_step - | _ => rewrite sat_add_div_step + | _ => rewrite Z.sub_opp_r + | _ => rewrite sat_add_mod_step + | _ => rewrite sat_add_div_step | p : Z ^ 0 |- _ => destruct p | _ => rewrite uweight_eval_step, ?hd_append, ?tl_append | |- context[B.Positional.eval _ (snd (chain_op' ?c ?p ?q))] @@ -149,8 +153,8 @@ Module B. | _ => progress (simpl chain_op'_cps in * ) | _ => progress autorewrite with uncps push_id cancel_pair in H | H : _ |- _ => rewrite to_list_append in H; - simpl In in H - | H : _ \/ _ |- _ => destruct H + simpl In in H + | H : _ \/ _ |- _ => destruct H | _ => contradiction end. { subst x. @@ -185,9 +189,9 @@ Module B. | _ => progress autorewrite with uncps divmod push_id cancel_pair push_basesystem_eval | _ => rewrite uweight_0, ?Z.mod_1_r, ?Z.div_1_r | _ => rewrite uweight_succ - | _ => rewrite Z.sub_opp_r - | _ => rewrite sat_add_mod_step - | _ => rewrite sat_add_div_step + | _ => rewrite Z.sub_opp_r + | _ => rewrite sat_add_mod_step + | _ => rewrite sat_add_div_step | p : Z ^ 0 |- _ => destruct p | _ => rewrite uweight_eval_step, ?hd_append, ?tl_append | |- context[B.Positional.eval _ (snd (chain_op' ?c ?p ?q))] @@ -197,7 +201,7 @@ Module B. | _ => solve [split; repeat (f_equal; ring_simplify; try omega)] end. Qed. - + Lemma sat_sub_mod n p q : eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n). Proof. exact (proj1 (sat_sub_divmod n p q)). Qed. @@ -217,8 +221,8 @@ Module B. | _ => progress (simpl chain_op'_cps in * ) | _ => progress autorewrite with uncps push_id cancel_pair in H | H : _ |- _ => rewrite to_list_append in H; - simpl In in H - | H : _ \/ _ |- _ => destruct H + simpl In in H + | H : _ \/ _ |- _ => destruct H | _ => contradiction end. { subst x. @@ -233,4 +237,4 @@ Module B. End B. Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps. Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps. -Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval.
\ No newline at end of file +Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v index face608ab..ce9888b74 100644 --- a/src/Arithmetic/Saturated/Core.v +++ b/src/Arithmetic/Saturated/Core.v @@ -156,25 +156,29 @@ Module Columns. (* Sums a list of integers using carry bits. Output : carry, sum *) - Fixpoint compact_digit_cps n (digit : list Z) {T} (f:Z * Z->T) := - match digit with - | nil => f (0, 0) - | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n)) - | x :: y :: nil => - dlet sum_carry := add_get_carry (weight (S n) / weight n) x y in - dlet carry := snd sum_carry in - f (carry, fst sum_carry) - | x :: tl => - compact_digit_cps n tl - (fun rec => - dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in - dlet carry' := (fst rec + snd sum_carry)%RT in - f (carry', fst sum_carry)) - end. + Section compact_digit_cps. + Context (n : nat) {T : Type}. + + Fixpoint compact_digit_cps (digit : list Z) (f:Z * Z->T) := + match digit with + | nil => f (0, 0) + | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n)) + | x :: y :: nil => + dlet sum_carry := add_get_carry (weight (S n) / weight n) x y in + dlet carry := snd sum_carry in + f (carry, fst sum_carry) + | x :: tl => + compact_digit_cps tl + (fun rec => + dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in + dlet carry' := (fst rec + snd sum_carry)%RT in + f (carry', fst sum_carry)) + end. + End compact_digit_cps. Definition compact_digit n digit := compact_digit_cps n digit id. Lemma compact_digit_id n digit: forall {T} f, - @compact_digit_cps n digit T f = f (compact_digit n digit). + @compact_digit_cps n T digit f = f (compact_digit n digit). Proof using Type. induction digit; intros; cbv [compact_digit]; [reflexivity|]; simpl compact_digit_cps; break_match; rewrite ?IHdigit; diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v index f8d51e5b3..9b249b460 100644 --- a/src/Util/CPSUtil.v +++ b/src/Util/CPSUtil.v @@ -325,12 +325,12 @@ Module Tuple. := fun ts R => @mapi_with'_cps_specialized R T A B (fun n t a => @f n t a R) n i start ts. Definition mapi_with_cps {S A B n} - (f: nat->S->A->forall {T}, (S*B->T)->T) (start:S) - : tuple A n -> forall {T}, (S * tuple B n->T)->T := - match n as n0 return (tuple A n0 -> forall {T}, (S * tuple B n0->T)->T) with - | O => fun ys {T} ret => ret (start, tt) - | S n' => fun ys {T} ret => mapi_with'_cps 0%nat f start ys ret - end. + (f: nat->S->A->forall {T}, (S*B->T)->T) (start:S) (ys:tuple A n) {T} + : (S * tuple B n->T)->T := + match n as n0 return (tuple A n0 -> (S * tuple B n0->T)->T) with + | O => fun ys ret => ret (start, tt) + | S n' => fun ys ret => mapi_with'_cps 0%nat f start ys ret + end ys. Lemma unfold_mapi_with'_cps {T A B n} i (f: nat->T->A->forall {R}, (T*B->R)->R) (start:T) |