aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Saturated.v
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-06-24 10:44:02 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2017-06-24 10:44:02 -0400
commitaea636f7399e5388e75fc8116baf465ac09160dd (patch)
tree337dacb51e2f5c235614ac0bb239611e9a901ef2 /src/Arithmetic/Saturated.v
parentcd1f0f68f16ce26f5de7e895d6cb6d801e251d84 (diff)
made conditional_add a wrapper, defined and proved sub_then_maybe_add
Diffstat (limited to 'src/Arithmetic/Saturated.v')
-rw-r--r--src/Arithmetic/Saturated.v147
1 files changed, 90 insertions, 57 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index 411f8a558..df080b116 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -515,59 +515,17 @@ Module Columns.
(fun PQ => from_associational_cps weight n3 PQ
(fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
- End Wrappers.
- Hint Unfold add_cps unbalanced_sub_cps mul_cps.
-
- (* These come after the wrapper definitions because they depend on
- e.g. unbalanced_sub and add. *)
- Section Conditionals.
- Context (weight : nat->Z)
- {weight_0 : weight 0%nat = 1}
- {weight_nonzero : forall i, weight i <> 0}
- {weight_positive : forall i, weight i > 0}
- {weight_multiples : forall i, weight (S i) mod weight i = 0}
- {weight_divides : forall i : nat, weight (S i) / weight i > 0}
- .
-
- Definition conditional_add_cps {n} mask cond (p q : Z^n)
+ Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2)
{T} (f:_->T) :=
B.Positional.select_cps mask cond q
- (fun qq => add_cps (n3:=n) weight p qq f).
- Definition conditional_add {n} mask cond p q :=
- @conditional_add_cps n mask cond p q _ id.
- Lemma conditional_add_id {n} mask cond p q T f:
- @conditional_add_cps n mask cond p q T f
- = f (conditional_add mask cond p q).
- Proof.
- cbv [conditional_add_cps conditional_add]; autounfold;
- autorewrite with push_id uncps; reflexivity.
- Qed.
- Hint Opaque conditional_add : uncps.
- Hint Rewrite @conditional_add_id : uncps.
+ (fun qq => add_cps (n3:=n3) p qq f).
+
+ End Wrappers.
+ Hint Unfold add_cps unbalanced_sub_cps mul_cps conditional_add_cps.
- Lemma eval_conditional_add {n} mask cond p q (n_nonzero:n<>0%nat)
- (H:Tuple.map (Z.land mask) q = q) :
- B.Positional.eval weight (snd (@conditional_add n mask cond p q))
- = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add mask cond p q)).
- Proof.
- cbv [conditional_add_cps conditional_add];
- repeat progress autounfold in *. cbv [add_cps].
- pose proof Z.add_get_carry_full_mod.
- pose proof Z.add_get_carry_full_div.
- pose proof div_correct. pose proof modulo_correct.
- autorewrite with uncps push_id push_basesystem_eval.
- break_match;
- match goal with
- |- context [weight ?n * (?x / weight ?n)] =>
- pose proof (Z.div_mod x (weight n) (weight_nonzero n))
- end; omega.
- Qed.
- Hint Rewrite @eval_conditional_add using (omega || assumption)
- : push_basesystem_eval.
- End Conditionals.
-
End Columns.
Hint Unfold
+ Columns.conditional_add_cps
Columns.add_cps
Columns.unbalanced_sub_cps
Columns.mul_cps.
@@ -577,7 +535,6 @@ Hint Rewrite
@Columns.compact_id
@Columns.cons_to_nth_id
@Columns.from_associational_id
- @Columns.conditional_add_id
: uncps.
Hint Rewrite
@Columns.compact_mod
@@ -585,7 +542,6 @@ Hint Rewrite
@Columns.eval_cons_to_nth
@Columns.eval_from_associational
@Columns.eval_nils
- @Columns.eval_conditional_add
using (assumption || omega): push_basesystem_eval.
Section Freeze.
@@ -614,8 +570,8 @@ Section Freeze.
the carry in step 3 was -1, so they cancel out.
*)
Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
- Columns.unbalanced_sub_cps weight p m
- (fun carry_p => Columns.conditional_add_cps weight mask (fst carry_p) (snd carry_p) m
+ Columns.unbalanced_sub_cps (n3:=n) weight p m
+ (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m
(fun carry_r => f (snd carry_r)))
.
@@ -669,7 +625,7 @@ Section Freeze.
(B.Positional.eval weight (@freeze n mask m p))
(B.Positional.eval weight p).
Proof.
- cbv [freeze_cps freeze Columns.conditional_add_cps].
+ cbv [freeze_cps freeze].
repeat progress autounfold.
pose proof Z.add_get_carry_full_mod.
pose proof Z.add_get_carry_full_div.
@@ -751,9 +707,21 @@ Section API.
f).
Definition add {n m pred_nm} p q : T (S pred_nm) := @add_cps n m pred_nm p q _ id.
+ Definition sub_then_maybe_add_cps {n} mask (p q r : T n) {R} (f:T n -> R) :=
+ Columns.unbalanced_sub_cps (n3:=n) (uweight bound) p q
+ (* the carry will be 0 unless we underflow--we do the addition only
+ in the underflow case *)
+ (fun carry_result =>
+ Columns.conditional_add_cps (uweight bound) mask (fst carry_result) (left_append (fst carry_result) (snd carry_result)) r
+ (* We can now safely discard the carry. This relies on the
+ precondition that p - q + r < bound^n. *)
+ (fun carry_result' => f (snd carry_result'))).
+ Definition sub_then_maybe_add {n} mask (p q r : T n) :=
+ sub_then_maybe_add_cps mask p q r id.
+
(* Subtract q if and only if p < bound^n. The correctness of this
lemma depends on the precondition that p < q + bound^n. *)
- Definition conditional_sub_cps {n} (p : Z^S n) (q : Z ^ n)
+ Definition conditional_sub_cps {n} mask (p : Z^S n) (q : Z ^ n)
{T} (f:Z^n->T) :=
(* we pass the highest digit of p into select_cps as the
conditional argument; if it is 0, the subtraction will not
@@ -768,7 +736,7 @@ Section API.
Definition conditional_sub {n} mask p q := @conditional_sub_cps n mask p q _ id.
- Hint Opaque join0 divmod drop_high scmul add conditional_sub : uncps.
+ Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps.
Section CPSProofs.
@@ -796,12 +764,16 @@ Section API.
@add_cps n m pred_nm p q R f = f (add p q).
Proof. cbv [add_cps add Let_In]. prove_id. Qed.
+ Lemma sub_then_maybe_add_id n mask p q r R f :
+ @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r).
+ Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add Let_In]. prove_id. Qed.
+
Lemma conditional_sub_id n mask p q R f :
@conditional_sub_cps n mask p q R f = f (conditional_sub mask p q).
Proof. cbv [conditional_sub_cps conditional_sub Let_In]. prove_id. Qed.
End CPSProofs.
- Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id conditional_sub_id : uncps.
+ Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps.
Section Proofs.
@@ -1038,6 +1010,67 @@ Section API.
apply small_left_tl.
Qed.
+ Lemma div_nonzero_neg_iff x y : x < y -> 0 < y -> x / y <> 0 <-> x < 0.
+ Proof.
+ repeat match goal with
+ | _ => progress intros
+ | _ => rewrite Z.div_small_iff by omega
+ | _ => split
+ | _ => omega
+ end.
+ Qed.
+
+ Lemma eval_sub_then_maybe_add_nz n mask p q r:
+ small p -> small q -> small r -> (n<>0)%nat ->
+ (map (Z.land mask) r = r) ->
+ (0 <= eval p < eval r) -> (0 <= eval q < eval r) ->
+ eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0).
+ Proof.
+ pose_all.
+ repeat match goal with
+ | _ => progress (cbv [sub_then_maybe_add sub_then_maybe_add_cps eval] in *; intros)
+ | _ => progress autounfold
+ | _ => progress autorewrite with uncps push_id push_basesystem_eval
+ | H : small _ |- _ => apply eval_small in H
+ | _ => progress break_match
+ | _ => (rewrite Z.add_opp_r in * )
+ | _ => progress autorewrite with zsimplify; [ ]
+ | H : _ |- _ => rewrite Z.ltb_lt in H;
+ rewrite <-div_nonzero_neg_iff with
+ (y:=uweight bound n) in H by (auto; omega)
+ | H : _ |- _ => rewrite Z.ltb_ge in H
+ | _ => rewrite Z.mod_small by omega
+ | _ => omega
+ end;
+ repeat match goal with
+ | H : _ |- _ => rewrite div_nonzero_neg_iff in H
+ by (auto; omega)
+ | |- context [-?x + ?y mod ?x] =>
+ replace (-x + y mod x) with y
+ by (rewrite Z.mod_eq, Z.div_small_neg; omega)
+ | _ => apply Z.mod_small; omega
+ | _ => omega
+ end.
+ Qed.
+
+ Lemma eval_sub_then_maybe_add n mask p q r:
+ small p -> small q -> small r ->
+ (map (Z.land mask) r = r) ->
+ (0 <= eval p < eval r) -> (0 <= eval q < eval r) ->
+ eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0).
+ Proof.
+ destruct n; [|solve[auto using eval_sub_then_maybe_add_nz]].
+ destruct p, q, r; reflexivity.
+ Qed.
+
+ Lemma small_sub_then_maybe_add n mask (p q r : T n) :
+ small (sub_then_maybe_add mask p q r).
+ Proof.
+ cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros.
+ repeat progress autounfold. autorewrite with uncps push_id.
+ apply small_compact.
+ Qed.
+
Lemma small_highest_zero_iff {n} (p: T (S n)) (Hsmall : small p) :
(left_hd p = 0 <-> eval p < uweight bound n).
Proof.
@@ -1148,7 +1181,7 @@ Section API.
End Proofs.
End API.
-Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id conditional_sub_id : uncps.
+Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps.
(*
(* Just some pretty-printing *)