From 6312b9ac8f252dabc190b568c7716d1d3e492b6e Mon Sep 17 00:00:00 2001 From: jadep Date: Thu, 29 Jun 2017 09:43:09 -0400 Subject: new add/carry chain logic with admitted proofs --- src/Arithmetic/Saturated.v | 363 +++++++++++++++++++++++---------------------- 1 file changed, 183 insertions(+), 180 deletions(-) (limited to 'src/Arithmetic/Saturated.v') diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index a63208e9e..a7b9e6484 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -171,93 +171,6 @@ Hint Opaque Associational.mul Associational.multerm : uncps. Hint Rewrite @Associational.mul_id @Associational.multerm_id : uncps. Hint Rewrite @Associational.eval_mul @Associational.eval_map_multerm using (omega || assumption) : push_basesystem_eval. -Module Positional. - Section Positional. - Context (weight : nat->Z) {s:Z}. (* s is number at which to split *) - Section GenericOp. - Context {op : Z -> Z -> Z} - {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) - {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) - - Fixpoint chain_op'_cps {n}: - option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T := - match n with - | O => fun c p _ _ f => - let carry := match c with | None => 0 | Some x => x end in - f (carry,p) - | S n' => - fun c p q _ f => - (* for the first call, use op_get_carry, then op_with_carry *) - let op' := match c with - | None => op_get_carry - | Some x => op_with_carry x end in - dlet carry_result := op' (hd p) (hd q) in - chain_op'_cps (Some (fst carry_result)) (tl p) (tl q) _ - (fun carry_pq => - f (fst carry_pq, - append (snd carry_result) (snd carry_pq))) - end. - Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id. - Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f. - Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id. - - Lemma chain_op'_id {n} : forall c p q T f, - @chain_op'_cps n c p q T f = f (chain_op' c p q). - Proof. - cbv [chain_op']; induction n; intros; destruct c; - simpl chain_op'_cps; cbv [Let_In]; try reflexivity. - { etransitivity; rewrite IHn; reflexivity. } - { etransitivity; rewrite IHn; reflexivity. } - Qed. - - Lemma chain_op_id {n} p q T f : - @chain_op_cps n p q T f = f (chain_op p q). - Proof. apply chain_op'_id. Qed. - End GenericOp. - - Section AddSub. - Local Definition eval {n} := B.Positional.eval (n:=n) weight. - - Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.add_get_carry_full s) - (op_with_carry := Z.add_with_get_carry_full s) - p q f. - Definition sat_add {n} p q := @sat_add_cps n p q _ id. - - Lemma sat_add_id n p q T f : - @sat_add_cps n p q T f = f (sat_add p q). - Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed. - - Lemma sat_add_mod n p q : - eval (snd (@sat_add n p q)) = (eval p + eval q) mod s. - Admitted. - - Lemma sat_add_div n p q : - fst (@sat_add n p q) = (eval p + eval q) / s. - Admitted. - - Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.sub_get_borrow_full s) - (op_with_carry := Z.sub_with_get_borrow_full s) - p q f. - Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. - - Lemma sat_sub_id n p q T f : - @sat_sub_cps n p q T f = f (sat_sub p q). - Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed. - - Lemma sat_sub_mod n p q : - eval (snd (@sat_sub n p q)) = (eval p - eval q) mod s. - Admitted. - - Lemma sat_sub_div n p q : - fst (@sat_sub n p q) = - ((eval p - eval q) / s). - Admitted. - - End Add. - End Positional. -End Positional. - Module Columns. Section Columns. @@ -795,8 +708,108 @@ Section UniformWeight. ring. Qed. + Definition small {n} (p : Z^n) : Prop := + forall x, In x (to_list _ p) -> 0 <= x < bound. + End UniformWeight. +Module Positional. + Section Positional. + Context {s:Z}. (* s is bitwidth *) + Let small {n} := @small s n. + Section GenericOp. + Context {op : Z -> Z -> Z} + {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) + {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) + + Fixpoint chain_op'_cps {n}: + option Z->Z^n->Z^n->forall T, (Z*Z^n->T)->T := + match n with + | O => fun c p _ _ f => + let carry := match c with | None => 0 | Some x => x end in + f (carry,p) + | S n' => + fun c p q _ f => + (* for the first call, use op_get_carry, then op_with_carry *) + let op' := match c with + | None => op_get_carry + | Some x => op_with_carry x end in + dlet carry_result := op' (hd p) (hd q) in + chain_op'_cps (Some (fst carry_result)) (tl p) (tl q) _ + (fun carry_pq => + f (fst carry_pq, + append (snd carry_result) (snd carry_pq))) + end. + Definition chain_op' {n} c p q := @chain_op'_cps n c p q _ id. + Definition chain_op_cps {n} p q {T} f := @chain_op'_cps n None p q T f. + Definition chain_op {n} p q : Z * Z^n := chain_op_cps p q id. + + Lemma chain_op'_id {n} : forall c p q T f, + @chain_op'_cps n c p q T f = f (chain_op' c p q). + Proof. + cbv [chain_op']; induction n; intros; destruct c; + simpl chain_op'_cps; cbv [Let_In]; try reflexivity. + { etransitivity; rewrite IHn; reflexivity. } + { etransitivity; rewrite IHn; reflexivity. } + Qed. + + Lemma chain_op_id {n} p q T f : + @chain_op_cps n p q T f = f (chain_op p q). + Proof. apply chain_op'_id. Qed. + End GenericOp. + + Section AddSub. + Let eval {n} := B.Positional.eval (n:=n) (uweight s). + + Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := + chain_op_cps (op_get_carry := Z.add_get_carry_full s) + (op_with_carry := Z.add_with_get_carry_full s) + p q f. + Definition sat_add {n} p q := @sat_add_cps n p q _ id. + + Lemma sat_add_id n p q T f : + @sat_add_cps n p q T f = f (sat_add p q). + Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed. + + Lemma sat_add_mod n p q : + eval (snd (@sat_add n p q)) = (eval p + eval q) mod (uweight s n). + Admitted. + + Lemma sat_add_div n p q : + fst (@sat_add n p q) = (eval p + eval q) / (uweight s n). + Admitted. + + Lemma small_sat_add n p q : small (snd (@sat_add n p q)). + Admitted. + + Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := + chain_op_cps (op_get_carry := Z.sub_get_borrow_full s) + (op_with_carry := Z.sub_with_get_borrow_full s) + p q f. + Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. + + Lemma sat_sub_id n p q T f : + @sat_sub_cps n p q T f = f (sat_sub p q). + Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed. + + Lemma sat_sub_mod n p q : + eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n). + Admitted. + + Lemma sat_sub_div n p q : + fst (@sat_sub n p q) = - ((eval p - eval q) / uweight s n). + Admitted. + + Lemma small_sat_sub n p q : small (snd (@sat_sub n p q)). + Admitted. + + End AddSub. + End Positional. +End Positional. +Hint Opaque Positional.sat_sub Positional.sat_add Positional.chain_op Positional.chain_op' : uncps. +Hint Rewrite @Positional.sat_sub_id @Positional.sat_add_id @Positional.chain_op_id @Positional.chain_op' : uncps. +Hint Rewrite @Positional.sat_sub_mod @Positional.sat_sub_div @Positional.sat_add_mod @Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. + Section API. Context (bound : Z) {bound_pos : bound > 0}. Definition T : nat -> Type := tuple Z. @@ -804,8 +817,7 @@ Section API. (* lowest limb is less than its bound; this is required for [divmod] to simply separate the lowest limb from the rest and be equivalent to normal div/mod with [bound]. *) - Definition small {n} (p : T n) : Prop := - forall x, In x (to_list _ p) -> 0 <= x < bound. + Local Notation small := (@small bound). Definition zero {n:nat} : T n := B.Positional.zeros n. @@ -835,21 +847,33 @@ Section API. (fun carry_result =>f (snd carry_result)). Definition scmul {n} c p : T (S n) := @scmul_cps n c p _ id. - Definition add_cps {n m pred_nm} (p : T n) (q : T m) {R} (f:T (S pred_nm)->R) := - Columns.add_cps (uweight bound) p q - (fun carry_result => Tuple.left_append_cps (fst carry_result) (snd carry_result) - f). - Definition add {n m pred_nm} p q : T (S pred_nm) := @add_cps n m pred_nm p q _ id. - - Definition sub_then_maybe_add_cps {n} mask (p q r : T n) {R} (f:T n -> R) := - Columns.unbalanced_sub_cps (n3:=n) (uweight bound) p q + Definition add_cps {n} (p q: T n) {R} (f:T n->R) := + Positional.sat_add_cps (s:=bound) p q _ + (* drops last carry, this relies on bounds *) + (fun carry_result => f (snd carry_result)). + Definition add {n} p q : T n := @add_cps n p q _ id. + + (* Wrappers for additions with slightly uneven limb counts *) + Definition add_S1_cps {n} (p: T (S n)) (q: T n) {R} (f:T (S n)->R) := + join0_cps q (fun Q => add_cps p Q f). + Definition add_S1 {n} p q := @add_S1_cps n p q _ id. + Definition add_S2_cps {n} (p: T n) (q: T (S n)) {R} (f:T (S n)->R) := + join0_cps p (fun P => add_cps P q f). + Definition add_S2 {n} p q := @add_S2_cps n p q _ id. + + Definition sub_then_maybe_add_cps {n} mask (p q r : T n) + {R} (f:T n -> R) := + Positional.sat_sub_cps (s:=bound) p q _ (* the carry will be 0 unless we underflow--we do the addition only in the underflow case *) (fun carry_result => - Columns.conditional_add_cps (uweight bound) mask (fst carry_result) (left_append (fst carry_result) (snd carry_result)) r - (* We can now safely discard the carry. This relies on the - precondition that p - q + r < bound^n. *) - (fun carry_result' => f (snd carry_result'))). + B.Positional.select_cps mask (fst carry_result) r + (fun selected => join0_cps selected + (fun selected' => + Positional.sat_sub_cps (s:=bound) (left_append (fst carry_result) (snd carry_result)) selected' _ + (* We can now safely discard the carry and the highest digit. + This relies on the precondition that p - q + r < bound^n. *) + (fun carry_result' => drop_high_cps (snd carry_result') f)))). Definition sub_then_maybe_add {n} mask (p q r : T n) := sub_then_maybe_add_cps mask p q r id. @@ -857,14 +881,15 @@ Section API. that 0 <= p < 2*q and q < bound^n (this ensures the output is less than bound^n). *) Definition conditional_sub_cps {n} (p:Z^S n) (q:Z^n) R (f:Z^n->R) := - Columns.unbalanced_sub_cps (n3:=S n) (uweight bound) p q + join0_cps q + (fun qq => Positional.sat_sub_cps (s:=bound) p qq _ (* if carry is zero, we select the result of the subtraction, otherwise the first input *) (fun carry_result => Tuple.map2_cps (Z.zselect (fst carry_result)) (snd carry_result) p (* in either case, since our result must be < q and therefore < bound^n, we can drop the high digit *) - (fun r => drop_high_cps r f)). + (fun r => drop_high_cps r f))). Definition conditional_sub {n} p q := @conditional_sub_cps n p q _ id. Hint Opaque join0 divmod drop_high scmul add sub_then_maybe_add conditional_sub : uncps. @@ -894,17 +919,26 @@ Section API. @scmul_cps n c p R f = f (scmul c p). Proof. cbv [scmul_cps scmul]. prove_id. Qed. - Lemma add_id n m pred_nm p q R f : - @add_cps n m pred_nm p q R f = f (add p q). + Lemma add_id n p q R f : + @add_cps n p q R f = f (add p q). Proof. cbv [add_cps add Let_In]. prove_id. Qed. + Hint Rewrite add_id : uncps. + + Lemma add_S1_id n p q R f : + @add_S1_cps n p q R f = f (add_S1 p q). + Proof. cbv [add_S1_cps add_S1 join0_cps]. prove_id. Qed. + + Lemma add_S2_id n p q R f : + @add_S2_cps n p q R f = f (add_S2 p q). + Proof. cbv [add_S2_cps add_S2 join0_cps]. prove_id. Qed. Lemma sub_then_maybe_add_id n mask p q r R f : @sub_then_maybe_add_cps n mask p q r R f = f (sub_then_maybe_add mask p q r). - Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add Let_In]. prove_id. Qed. + Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add join0_cps Let_In]. prove_id. Qed. Lemma conditional_sub_id n p q R f : @conditional_sub_cps n p q R f = f (conditional_sub p q). - Proof. cbv [conditional_sub_cps conditional_sub Let_In]. prove_id. Qed. + Proof. cbv [conditional_sub_cps conditional_sub join0_cps Let_In]. prove_id. Qed. End CPSProofs. Hint Rewrite nonzero_id join0_id divmod_id drop_high_id scmul_id add_id sub_then_maybe_add_id conditional_sub_id : uncps. @@ -1005,13 +1039,13 @@ Section API. pose proof Z.mul_split_div; pose proof Z.mul_split_mod; pose proof div_correct; pose proof modulo_correct. - Lemma eval_add_nz n m pred_nm p q : - pred_nm <> 0%nat -> - eval (@add n m pred_nm p q) = eval p + eval q. + Lemma eval_add_nz n p q : + n <> 0%nat -> (0 <= eval p + eval q < uweight bound n) -> + eval (@add n p q) = eval p + eval q. Proof. intros. pose_all. repeat match goal with - | _ => progress (cbv [add_cps add eval Let_In]; repeat autounfold) + | _ => progress (cbv [add_cps add eval Let_In] in *; repeat autounfold) | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval | _ => rewrite B.Positional.eval_left_append @@ -1019,33 +1053,36 @@ Section API. (rewrite <-!from_list_default_eq with (d:=0); erewrite !length_to_list, !from_list_default_eq, from_list_to_list) - | _ => rewrite Z.mul_div_eq by auto; - omega + | _ => apply Z.mod_small; omega end. Qed. - Lemma eval_add_z n m pred_nm p q : - pred_nm = 0%nat -> + Lemma eval_add_z n p q : n = 0%nat -> - m = 0%nat -> - eval (@add n m pred_nm p q) = eval p + eval q. + eval (@add n p q) = eval p + eval q. Proof. intros; subst; reflexivity. Qed. - Lemma eval_add n m pred_nm p q (Hpred_nm : pred_nm = 0%nat -> n = 0%nat /\ m = 0%nat) - : eval (@add n m pred_nm p q) = eval p + eval q. + Lemma eval_add n p q (H:0 <= eval p + eval q < uweight bound n) + : eval (@add n p q) = eval p + eval q. Proof. - destruct (Nat.eq_dec pred_nm 0%nat); intuition auto using eval_add_z, eval_add_nz. + destruct (Nat.eq_dec n 0%nat); intuition auto using eval_add_z, eval_add_nz. Qed. - Lemma eval_add_same n p q - : eval (@add n n n p q) = eval p + eval q. - Proof. apply eval_add; omega. Qed. - Lemma eval_add_S1 n p q - : eval (@add (S n) n (S n) p q) = eval p + eval q. + Lemma eval_add_same n p q (H:0 <= eval p + eval q < uweight bound n) + : eval (@add n p q) = eval p + eval q. Proof. apply eval_add; omega. Qed. - Lemma eval_add_S2 n p q - : eval (@add n (S n) (S n) p q) = eval p + eval q. - Proof. apply eval_add; omega. Qed. - Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 : push_basesystem_eval. + Lemma eval_add_S1 n p q (H:0 <= eval p + eval q < uweight bound (S n)) + : eval (@add_S1 n p q) = eval p + eval q. + Proof. + cbv [add_S1 add_S1_cps]. autorewrite with uncps push_id. + rewrite eval_add; rewrite eval_join0; [reflexivity|assumption]. + Qed. + Lemma eval_add_S2 n p q (H:0 <= eval p + eval q < uweight bound (S n)) + : eval (@add_S2 n p q) = eval p + eval q. + Proof. + cbv [add_S2 add_S2_cps]. autorewrite with uncps push_id. + rewrite eval_add; rewrite eval_join0; [reflexivity|assumption]. + Qed. + Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval. Lemma uweight_le_mono n m : (n <= m)%nat -> uweight bound n <= uweight bound m. @@ -1104,43 +1141,15 @@ Section API. apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } Qed. - Lemma small_add n m pred_nm a b : - (2 <= bound) -> (max n m <= pred_nm)%nat -> - small a -> small b -> small (@add n m pred_nm a b). + Lemma small_add n a b : + (2 <= bound) -> + small a -> small b -> small (@add n a b). Proof. intros. pose_all. - cbv [small add_cps add Let_In]. repeat autounfold. + cbv [add_cps add Let_In]. autorewrite with uncps push_id. - rewrite Tuple.to_list_left_append. - let H := fresh "H" in intros x H; apply in_app_or in H; - destruct H as [H | H]; - [apply (small_compact _ _ H) - |destruct H; [|exfalso; assumption] ]. - subst x. - rewrite Columns.compact_div by assumption. - repeat match goal with - H : small ?x |- _ => apply eval_small in H; cbv [eval] in H - end. - destruct pred_nm as [|pred_pred_nm]; autorewrite with push_basesystem_eval; - repeat match goal with - | [ H : (max ?x ?y <= 0)%nat |- _ ] - => assert (x = 0%nat) by omega *; - assert (y = 0%nat) by omega *; - clear H - | _ => progress subst - end. - { destruct bound; cbv -[Z.le Z.lt]; lia. } - split; [ solve [ unfold uweight in *; Z.zero_bounds ] | ]. - apply Zdiv_lt_upper_bound; [ solve [ unfold uweight in *; Z.zero_bounds ] | ]. - apply Z.lt_le_trans with (m:=uweight bound n + uweight bound m); - [omega|]. - apply Z.le_trans with (m:=uweight bound (max n m) + uweight bound (max n m)); auto using Z.add_le_mono, uweight_le_mono, Max.le_max_l, Max.le_max_r. - rewrite Z.add_diag. - pose proof (uweight_le_mono (max n m) (S pred_pred_nm)). - specialize_by_assumption. - apply Z.mul_le_mono_nonneg; try omega. - apply Max.max_case_strong; omega. - Qed. + apply Positional.small_sat_add. + Admitted. Lemma small_left_tl n (v:T (S n)) : small v -> small (left_tl v). Proof. cbv [small]. auto using Tuple.In_to_list_left_tl. Qed. @@ -1190,29 +1199,22 @@ Section API. | _ => progress (cbv [sub_then_maybe_add sub_then_maybe_add_cps eval] in *; intros) | _ => progress autounfold | _ => progress autorewrite with uncps push_id push_basesystem_eval + | _ => rewrite eval_drop_high + | _ => rewrite eval_join0 | H : small _ |- _ => apply eval_small in H | _ => progress break_match | _ => (rewrite Z.add_opp_r in * ) - | _ => progress autorewrite with zsimplify; [ ] | H : _ |- _ => rewrite Z.ltb_lt in H; rewrite <-div_nonzero_neg_iff with (y:=uweight bound n) in H by (auto; omega) | H : _ |- _ => rewrite Z.ltb_ge in H | _ => rewrite Z.mod_small by omega | _ => omega - end; - repeat match goal with - | H : _ |- _ => rewrite div_nonzero_neg_iff in H - by (auto; omega) - | |- context [-?x + ?y mod ?x] => - replace (-x + y mod x) with y - by (rewrite Z.mod_eq, Z.div_small_neg; omega) - | _ => apply Z.mod_small; omega - | _ => omega + | _ => progress autorewrite with zsimplify; [ ] end. - Qed. + Admitted. - Lemma eval_sub_then_maybe_add n mask p q r: + Lemma eval_sub_then_maybe_add n mask p q r : small p -> small q -> small r -> (map (Z.land mask) r = r) -> (0 <= eval p < eval r) -> (0 <= eval q < eval r) -> @@ -1227,7 +1229,7 @@ Section API. Proof. cbv [sub_then_maybe_add_cps sub_then_maybe_add]; intros. repeat progress autounfold. autorewrite with uncps push_id. - apply small_compact. + apply small_drop_high, Positional.small_sat_sub. Qed. (* TODO : remove if unneeded when all admits are proven @@ -1277,12 +1279,13 @@ Section API. | _ => progress cbv [eval] | H : (_ <=? _) = true |- _ => apply Z.leb_le in H | H : (_ <=? _) = false |- _ => apply Z.leb_gt in H - | _ => rewrite eval_drop_high by auto using small_compact + | _ => rewrite eval_drop_high by auto using Positional.small_sat_sub + | _ => (rewrite eval_join0 in * ) | _ => progress autorewrite with uncps push_id push_basesystem_eval - | _ => rewrite Z.mod_small by (rewrite ?Z.mod_small; omega) + | _ => repeat rewrite Z.mod_small; omega | _ => omega end. - Qed. + Admitted. Lemma eval_conditional_sub n (p:T (S n)) (q:T n) (psmall : small p) (qsmall : small q) : @@ -1391,4 +1394,4 @@ Definition base2pow25p5 i := Eval compute in 2^(25*Z.of_nat i + ((Z.of_nat i + 1 Time Eval cbv -[runtime_add runtime_mul Let_In] in (fun adc a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 => Columns.mul_cps (weight := base2pow25p5) (n:=10) (a9,a8,a7,a6,a5,a4,a3,a2,a1,a0) (b9,b8,b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=19) (add_get_carry:=adc) (weight:=base2pow25p5) ab)). (* Finished transaction in 97.341 secs *) -*) +*) \ No newline at end of file -- cgit v1.2.3