aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Saturated/AddSub.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Arithmetic/Saturated/AddSub.v')
-rw-r--r--src/Arithmetic/Saturated/AddSub.v70
1 files changed, 43 insertions, 27 deletions
diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v
index e34230904..e886c36de 100644
--- a/src/Arithmetic/Saturated/AddSub.v
+++ b/src/Arithmetic/Saturated/AddSub.v
@@ -6,8 +6,10 @@ Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Arithmetic.Saturated.Core.
Require Import Crypto.Arithmetic.Saturated.UniformWeight.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.CPS.
Require Import Crypto.Util.ZUtil.AddGetCarry.
Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
+Require Import Crypto.Util.Tactics.BreakMatch.
Local Notation "A ^ n" := (tuple A n) : type_scope.
Module B.
@@ -17,8 +19,15 @@ Module B.
Let small {n} := @small s n.
Section GenericOp.
Context {op : Z -> Z -> Z}
- {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *)
- {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *)
+ {op_get_carry_cps : forall {T}, Z -> Z -> (Z * Z -> T) -> T} (* no carry in, carry out *)
+ {op_with_carry_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T}. (* carry in, carry out *)
+ Let op_get_carry x y := op_get_carry_cps _ x y id.
+ Let op_with_carry x y z := op_with_carry_cps _ x y z id.
+ Context {op_get_carry_id : forall {T} x y f,
+ @op_get_carry_cps T x y f = f (op_get_carry x y)}
+ {op_with_carry_id : forall {T} x y z f,
+ @op_with_carry_cps T x y z f = f (op_with_carry x y z)}.
+ Hint Rewrite @op_get_carry_id @op_with_carry_id : uncps.
Section chain_op'_cps.
Context (T : Type).
@@ -32,14 +41,15 @@ Module B.
| S n' =>
fun c p q f =>
(* for the first call, use op_get_carry, then op_with_carry *)
- let op' := match c with
- | None => op_get_carry
- | Some x => op_with_carry x end in
- dlet carry_result := op' (hd p) (hd q) in
+ let op'_cps := match c with
+ | None => op_get_carry_cps _
+ | Some x => op_with_carry_cps _ x end in
+ op'_cps (hd p) (hd q) (fun carry_result =>
+ dlet carry_result := carry_result in
chain_op'_cps (Some (snd carry_result)) (tl p) (tl q)
(fun carry_pq =>
f (fst carry_pq,
- append (fst carry_result) (snd carry_pq)))
+ append (fst carry_result) (snd carry_pq))))
end c p q.
End chain_op'_cps.
Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id.
@@ -50,7 +60,8 @@ Module B.
@chain_op'_cps T n c p q f = f (chain_op' c p q).
Proof.
cbv [chain_op']; induction n; intros; destruct c;
- simpl chain_op'_cps; cbv [Let_In]; try reflexivity.
+ simpl chain_op'_cps; cbv [Let_In]; try reflexivity;
+ autorewrite with uncps.
{ etransitivity; rewrite IHn; reflexivity. }
{ etransitivity; rewrite IHn; reflexivity. }
Qed.
@@ -60,7 +71,7 @@ Module B.
Proof. apply (@chain_op'_id n None). Qed.
End GenericOp.
Hint Opaque chain_op chain_op' : uncps.
- Hint Rewrite @chain_op_id @chain_op'_id : uncps.
+ Hint Rewrite @chain_op_id @chain_op'_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps.
Section AddSub.
Create HintDb divmod discriminated.
@@ -77,14 +88,14 @@ Module B.
Let eval {n} := B.Positional.eval (n:=n) (uweight s).
Definition sat_add_cps {n} p q T (f:Z*Z^n->T) :=
- chain_op_cps (op_get_carry := Z.add_get_carry_full s)
- (op_with_carry := Z.add_with_get_carry_full s)
+ chain_op_cps (op_get_carry_cps := fun T => Z.add_get_carry_full_cps s)
+ (op_with_carry_cps := fun T => Z.add_with_get_carry_full_cps s)
p q f.
Definition sat_add {n} p q := @sat_add_cps n p q _ id.
Lemma sat_add_id n p q T f :
@sat_add_cps n p q T f = f (sat_add p q).
- Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed.
+ Proof. cbv [sat_add sat_add_cps]. autorewrite with uncps. reflexivity. Qed.
Lemma sat_add_mod_step n c d :
c mod s + s * ((d + c / s) mod (uweight s n))
@@ -156,23 +167,25 @@ Module B.
simpl In in H
| H : _ \/ _ |- _ => destruct H
| _ => contradiction
- end.
- { subst x.
- destruct c; rewrite ?Z.add_with_get_carry_full_mod,
- ?Z.add_get_carry_full_mod;
- apply Z.mod_pos_bound; omega. }
- { apply IHn in H. assumption. }
+ | _ => break_innermost_match_hyps_step
+ | _ => progress subst
+ | [ H : In _ (to_list _ (snd _)) |- _ ]
+ => apply IHn in H; assumption
+ end;
+ try solve [ rewrite ?Z.add_with_get_carry_full_mod,
+ ?Z.add_get_carry_full_mod;
+ apply Z.mod_pos_bound; omega ].
Qed.
Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) :=
- chain_op_cps (op_get_carry := Z.sub_get_borrow_full s)
- (op_with_carry := Z.sub_with_get_borrow_full s)
+ chain_op_cps (op_get_carry_cps := fun T => Z.sub_get_borrow_full_cps s)
+ (op_with_carry_cps := fun T => Z.sub_with_get_borrow_full_cps s)
p q f.
Definition sat_sub {n} p q := @sat_sub_cps n p q _ id.
Lemma sat_sub_id n p q T f :
@sat_sub_cps n p q T f = f (sat_sub p q).
- Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed.
+ Proof. cbv [sat_sub sat_sub_cps]. autorewrite with uncps. reflexivity. Qed.
Lemma sat_sub_divmod n p q :
eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n)
/\ fst (@sat_sub n p q) = - ((eval p - eval q) / (uweight s n)).
@@ -224,19 +237,22 @@ Module B.
simpl In in H
| H : _ \/ _ |- _ => destruct H
| _ => contradiction
- end.
- { subst x.
- destruct c; rewrite ?Z.sub_with_get_borrow_full_mod,
+ | _ => break_innermost_match_hyps_step
+ | _ => progress subst
+ | [ H : In _ (to_list _ (snd _)) |- _ ]
+ => apply IHn in H; assumption
+ end;
+ try solve [ rewrite ?Z.sub_with_get_borrow_full_mod,
?Z.sub_get_borrow_full_mod;
- apply Z.mod_pos_bound; omega. }
- { apply IHn in H. assumption. }
+ apply Z.mod_pos_bound; omega ].
Qed.
End AddSub.
End Positional.
End Positional.
End B.
Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps.
-Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps.
+Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id : uncps.
+Hint Rewrite @B.Positional.chain_op_id @B.Positional.chain_op' using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps.
Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval.
Hint Unfold