diff options
author | jadep <jade.philipoom@gmail.com> | 2017-06-24 10:44:02 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2017-06-24 10:44:02 -0400 |
commit | aea636f7399e5388e75fc8116baf465ac09160dd (patch) | |
tree | 337dacb51e2f5c235614ac0bb239611e9a901ef2 /src/Arithmetic/Saturated.v | |
parent | cd1f0f68f16ce26f5de7e895d6cb6d801e251d84 (diff) |
made conditional_add a wrapper, defined and proved sub_then_maybe_add
Diffstat (limited to 'src/Arithmetic/Saturated.v')
-rw-r--r-- | src/Arithmetic/Saturated.v | 147 |
1 files changed, 90 insertions, 57 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index 411f8a558..df080b116 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -515,59 +515,17 @@ Module Columns. (fun PQ => from_associational_cps weight n3 PQ (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). - End Wrappers. - Hint Unfold add_cps unbalanced_sub_cps mul_cps. - - (* These come after the wrapper definitions because they depend on - e.g. unbalanced_sub and add. *) - Section Conditionals. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0} - . - - Definition conditional_add_cps {n} mask cond (p q : Z^n) + Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2) {T} (f:_->T) := B.Positional.select_cps mask cond q - (fun qq => add_cps (n3:=n) weight p qq f). - Definition conditional_add {n} mask cond p q := - @conditional_add_cps n mask cond p q _ id. - Lemma conditional_add_id {n} mask cond p q T f: - @conditional_add_cps n mask cond p q T f - = f (conditional_add mask cond p q). - Proof. - cbv [conditional_add_cps conditional_add]; autounfold; - autorewrite with push_id uncps; reflexivity. - Qed. - Hint Opaque conditional_add : uncps. - Hint Rewrite @conditional_add_id : uncps. + (fun qq => add_cps (n3:=n3) p qq f). + + End Wrappers. + Hint Unfold add_cps unbalanced_sub_cps mul_cps conditional_add_cps. - Lemma eval_conditional_add {n} mask cond p q (n_nonzero:n<>0%nat) - (H:Tuple.map (Z.land mask) q = q) : - B.Positional.eval weight (snd (@conditional_add n mask cond p q)) - = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add mask cond p q)). - Proof. - cbv [conditional_add_cps conditional_add]; - repeat progress autounfold in *. cbv [add_cps]. - pose proof Z.add_get_carry_full_mod. - pose proof Z.add_get_carry_full_div. - pose proof div_correct. pose proof modulo_correct. - autorewrite with uncps push_id push_basesystem_eval. - break_match; - match goal with - |- context [weight ?n * (?x / weight ?n)] => - pose proof (Z.div_mod x (weight n) (weight_nonzero n)) - end; omega. - Qed. - Hint Rewrite @eval_conditional_add using (omega || assumption) - : push_basesystem_eval. - End Conditionals. - End Columns. Hint Unfold + Columns.conditional_add_cps Columns.add_cps Columns.unbalanced_sub_cps Columns.mul_cps. @@ -577,7 +535,6 @@ Hint Rewrite @Columns.compact_id @Columns.cons_to_nth_id @Columns.from_associational_id - @Columns.conditional_add_id : uncps. Hint Rewrite @Columns.compact_mod @@ -585,7 +542,6 @@ Hint Rewrite @Columns.eval_cons_to_nth @Columns.eval_from_associational @Columns.eval_nils - @Columns.eval_conditional_add using (assumption || omega): push_basesystem_eval. Section Freeze. @@ -614,8 +570,8 @@ Section Freeze. the carry in step 3 was -1, so they cancel out. *) Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := - Columns.unbalanced_sub_cps weight p m - (fun carry_p => Columns.conditional_add_cps weight mask (fst carry_p) (snd carry_p) m + Columns.unbalanced_sub_cps (n3:=n) weight p m + (fun carry_p => Columns.conditional_add_cps (n3:=n) weight mask (fst carry_p) (snd carry_p) m (fun carry_r => f (snd carry_r))) . @@ -669,7 +625,7 @@ Section Freeze. (B.Positional.eval weight (@freeze n mask m p)) (B.Positional.eval weight p). Proof. - cbv [freeze_cps freeze Columns.conditional_add_cps]. + cbv [freeze_cps freeze]. repeat progress autounfold. pose proof Z.add_get_carry_full_mod. pose proof Z.add_get_carry_full_div. @@ -751,9 +707,21 @@ Section API. f). Definition add {n m pred_nm} p q : T (S pred_nm) := @add_cps n m pred_nm p q _ id. + Definition sub_then_maybe_add_cps {n} mask (p q r : T n) {R} (f:T n -> R) := + Columns.unbalanced_sub_cps (n3:=n) (uweight bound) p q + (* the carry will be 0 unless we underflow--we do the addition only + in the underflow case *) + (fun carry_result => + Columns.conditional_add_cps (uweight bound) mask (fst carry_result) (left_append (fst carry_result) (snd carry_result)) r + (* We can now safely discard the carry. This relies on the + precondition that p - q + r < bound^n. *) + (fun carry_result' => f (snd carry_result'))). + Definition sub_then_maybe_add {n} mask (p q r : T n) := + sub_then_maybe_add_cps mask p q r id. + (* Subtract q if and only if p < bound^n. The correctness of this lemma depends on the precondition that p < q + bound^n. *) - Definition conditional_sub_cps {n} (p : Z^S n) (q : Z ^ n) + Definition conditional_sub_cps {n} mask (p : Z^S n) (q : Z ^ n) {T} (f:Z^n->T) := (* we pass the highest digit of p into select_cps as the conditional argument; if it is 0, the subtraction will not @@ -768,7 +736,7 @@ Section API. Definition conditional_sub {n} mask p q := @conditional_sub_cps n mask p q _ id. - Hint Opaque join0 divmod drop_high scmul add conditional_sub : uncps. + Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps. Section CPSProofs. @@ -796,12 +764,16 @@ Section API. @add_cps n m pred_nm p q R f = f (add p q). Proof. cbv [add_cps add Let_In]. prove_id. Qed. + Lemma sub_then_maybe_add_id n mask p q r R f : + @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r). + Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add Let_In]. prove_id. Qed. + Lemma conditional_sub_id n mask p q R f : @conditional_sub_cps n mask p q R f = f (conditional_sub mask p q). Proof. cbv [conditional_sub_cps conditional_sub Let_In]. prove_id. Qed. End CPSProofs. - Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id conditional_sub_id : uncps. + Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps. Section Proofs. @@ -1038,6 +1010,67 @@ Section API. apply small_left_tl. Qed. + Lemma div_nonzero_neg_iff x y : x < y -> 0 < y -> x / y <> 0 <-> x < 0. + Proof. + repeat match goal with + | _ => progress intros + | _ => rewrite Z.div_small_iff by omega + | _ => split + | _ => omega + end. + Qed. + + Lemma eval_sub_then_maybe_add_nz n mask p q r: + small p -> small q -> small r -> (n<>0)%nat -> + (map (Z.land mask) r = r) -> + (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> + eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0). + Proof. + pose_all. + repeat match goal with + | _ => progress (cbv [sub_then_maybe_add sub_then_maybe_add_cps eval] in *; intros) + | _ => progress autounfold + | _ => progress autorewrite with uncps push_id push_basesystem_eval + | H : small _ |- _ => apply eval_small in H + | _ => progress break_match + | _ => (rewrite Z.add_opp_r in * ) + | _ => progress autorewrite with zsimplify; [ ] + | H : _ |- _ => rewrite Z.ltb_lt in H; + rewrite <-div_nonzero_neg_iff with + (y:=uweight bound n) in H by (auto; omega) + | H : _ |- _ => rewrite Z.ltb_ge in H + | _ => rewrite Z.mod_small by omega + | _ => omega + end; + repeat match goal with + | H : _ |- _ => rewrite div_nonzero_neg_iff in H + by (auto; omega) + | |- context [-?x + ?y mod ?x] => + replace (-x + y mod x) with y + by (rewrite Z.mod_eq, Z.div_small_neg; omega) + | _ => apply Z.mod_small; omega + | _ => omega + end. + Qed. + + Lemma eval_sub_then_maybe_add n mask p q r: + small p -> small q -> small r -> + (map (Z.land mask) r = r) -> + (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> + eval (@sub_then_maybe_add n mask p q r) = eval p - eval q + (if eval p - eval q <? 0 then eval r else 0). + Proof. + destruct n; [|solve[auto using eval_sub_then_maybe_add_nz]]. + destruct p, q, r; reflexivity. + Qed. + + Lemma small_sub_then_maybe_add n mask (p q r : T n) : + small (sub_then_maybe_add mask p q r). + Proof. + cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros. + repeat progress autounfold. autorewrite with uncps push_id. + apply small_compact. + Qed. + Lemma small_highest_zero_iff {n} (p: T (S n)) (Hsmall : small p) : (left_hd p = 0 <-> eval p < uweight bound n). Proof. @@ -1148,7 +1181,7 @@ Section API. End Proofs. End API. -Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id conditional_sub_id : uncps. +Hint Rewrite join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps. (* (* Just some pretty-printing *) |