diff options
Diffstat (limited to 'src/Arithmetic/Saturated/AddSub.v')
-rw-r--r-- | src/Arithmetic/Saturated/AddSub.v | 70 |
1 files changed, 43 insertions, 27 deletions
diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v index e34230904..e886c36de 100644 --- a/src/Arithmetic/Saturated/AddSub.v +++ b/src/Arithmetic/Saturated/AddSub.v @@ -6,8 +6,10 @@ Require Import Crypto.Arithmetic.Core. Require Import Crypto.Arithmetic.Saturated.Core. Require Import Crypto.Arithmetic.Saturated.UniformWeight. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.ZUtil.AddGetCarry. Require Import Crypto.Util.Tuple Crypto.Util.LetIn. +Require Import Crypto.Util.Tactics.BreakMatch. Local Notation "A ^ n" := (tuple A n) : type_scope. Module B. @@ -17,8 +19,15 @@ Module B. Let small {n} := @small s n. Section GenericOp. Context {op : Z -> Z -> Z} - {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) - {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) + {op_get_carry_cps : forall {T}, Z -> Z -> (Z * Z -> T) -> T} (* no carry in, carry out *) + {op_with_carry_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T}. (* carry in, carry out *) + Let op_get_carry x y := op_get_carry_cps _ x y id. + Let op_with_carry x y z := op_with_carry_cps _ x y z id. + Context {op_get_carry_id : forall {T} x y f, + @op_get_carry_cps T x y f = f (op_get_carry x y)} + {op_with_carry_id : forall {T} x y z f, + @op_with_carry_cps T x y z f = f (op_with_carry x y z)}. + Hint Rewrite @op_get_carry_id @op_with_carry_id : uncps. Section chain_op'_cps. Context (T : Type). @@ -32,14 +41,15 @@ Module B. | S n' => fun c p q f => (* for the first call, use op_get_carry, then op_with_carry *) - let op' := match c with - | None => op_get_carry - | Some x => op_with_carry x end in - dlet carry_result := op' (hd p) (hd q) in + let op'_cps := match c with + | None => op_get_carry_cps _ + | Some x => op_with_carry_cps _ x end in + op'_cps (hd p) (hd q) (fun carry_result => + dlet carry_result := carry_result in chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) (fun carry_pq => f (fst carry_pq, - append (fst carry_result) (snd carry_pq))) + append (fst carry_result) (snd carry_pq)))) end c p q. End chain_op'_cps. Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id. @@ -50,7 +60,8 @@ Module B. @chain_op'_cps T n c p q f = f (chain_op' c p q). Proof. cbv [chain_op']; induction n; intros; destruct c; - simpl chain_op'_cps; cbv [Let_In]; try reflexivity. + simpl chain_op'_cps; cbv [Let_In]; try reflexivity; + autorewrite with uncps. { etransitivity; rewrite IHn; reflexivity. } { etransitivity; rewrite IHn; reflexivity. } Qed. @@ -60,7 +71,7 @@ Module B. Proof. apply (@chain_op'_id n None). Qed. End GenericOp. Hint Opaque chain_op chain_op' : uncps. - Hint Rewrite @chain_op_id @chain_op'_id : uncps. + Hint Rewrite @chain_op_id @chain_op'_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Section AddSub. Create HintDb divmod discriminated. @@ -77,14 +88,14 @@ Module B. Let eval {n} := B.Positional.eval (n:=n) (uweight s). Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.add_get_carry_full s) - (op_with_carry := Z.add_with_get_carry_full s) + chain_op_cps (op_get_carry_cps := fun T => Z.add_get_carry_full_cps s) + (op_with_carry_cps := fun T => Z.add_with_get_carry_full_cps s) p q f. Definition sat_add {n} p q := @sat_add_cps n p q _ id. Lemma sat_add_id n p q T f : @sat_add_cps n p q T f = f (sat_add p q). - Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed. + Proof. cbv [sat_add sat_add_cps]. autorewrite with uncps. reflexivity. Qed. Lemma sat_add_mod_step n c d : c mod s + s * ((d + c / s) mod (uweight s n)) @@ -156,23 +167,25 @@ Module B. simpl In in H | H : _ \/ _ |- _ => destruct H | _ => contradiction - end. - { subst x. - destruct c; rewrite ?Z.add_with_get_carry_full_mod, - ?Z.add_get_carry_full_mod; - apply Z.mod_pos_bound; omega. } - { apply IHn in H. assumption. } + | _ => break_innermost_match_hyps_step + | _ => progress subst + | [ H : In _ (to_list _ (snd _)) |- _ ] + => apply IHn in H; assumption + end; + try solve [ rewrite ?Z.add_with_get_carry_full_mod, + ?Z.add_get_carry_full_mod; + apply Z.mod_pos_bound; omega ]. Qed. Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.sub_get_borrow_full s) - (op_with_carry := Z.sub_with_get_borrow_full s) + chain_op_cps (op_get_carry_cps := fun T => Z.sub_get_borrow_full_cps s) + (op_with_carry_cps := fun T => Z.sub_with_get_borrow_full_cps s) p q f. Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. Lemma sat_sub_id n p q T f : @sat_sub_cps n p q T f = f (sat_sub p q). - Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed. + Proof. cbv [sat_sub sat_sub_cps]. autorewrite with uncps. reflexivity. Qed. Lemma sat_sub_divmod n p q : eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n) /\ fst (@sat_sub n p q) = - ((eval p - eval q) / (uweight s n)). @@ -224,19 +237,22 @@ Module B. simpl In in H | H : _ \/ _ |- _ => destruct H | _ => contradiction - end. - { subst x. - destruct c; rewrite ?Z.sub_with_get_borrow_full_mod, + | _ => break_innermost_match_hyps_step + | _ => progress subst + | [ H : In _ (to_list _ (snd _)) |- _ ] + => apply IHn in H; assumption + end; + try solve [ rewrite ?Z.sub_with_get_borrow_full_mod, ?Z.sub_get_borrow_full_mod; - apply Z.mod_pos_bound; omega. } - { apply IHn in H. assumption. } + apply Z.mod_pos_bound; omega ]. Qed. End AddSub. End Positional. End Positional. End B. Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps. -Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps. +Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id : uncps. +Hint Rewrite @B.Positional.chain_op_id @B.Positional.chain_op' using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. Hint Unfold |