diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-08 15:51:21 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | ab4bea12d33f3c4225d0af577242c1bbae12d420 (patch) | |
tree | 25908bf5eb407a48f62285cd88bdc195cda7055d /src/Experiments/SimplyTypedArithmetic.v | |
parent | dda6031883d48cb9a775be561154f09d44e6303c (diff) |
move some shared lemmas between Columns/Rows into a Saturated module
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 239 |
1 files changed, 126 insertions, 113 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index da53fcf6c..62168b9a8 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -617,14 +617,51 @@ Module BaseConversion. End BaseConversion. End BaseConversion. -(* Non-CPS version of Arithmetic/Saturated/MulSplit.v *) -Module MulSplit. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Hints.PullPush. + +Module Saturated. + Section Weight. + Context (weight : nat->Z) + {weight_0 : weight 0%nat = 1} + {weight_nonzero : forall i, weight i <> 0} + {weight_positive : forall i, weight i > 0} + {weight_multiples : forall i, weight (S i) mod weight i = 0} + {weight_divides : forall i : nat, weight (S i) / weight i > 0}. + + Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0. + Proof. + induction j; intros; + repeat match goal with + | _ => rewrite Nat.add_succ_r + | _ => rewrite IHj + | |- context [weight (S ?x) mod weight _] => + rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto + | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast + | _ => reflexivity + end. + Qed. + + Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0. + Proof. + intros; replace j with (i + (j - i))%nat by omega. + apply weight_multiples_full'. + Qed. + + Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0. + Proof. auto using Z.div_positive_gt_0, weight_multiples_full. Qed. + + Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i). + Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed. + End Weight. + Module Associational. Section Associational. Definition sat_multerm s (t t' : (Z * Z)) : list (Z * Z) := dlet_nd xy := Z.mul_split s (snd t) (snd t') in - [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)]. + [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)]. Definition sat_mul s (p q : list (Z * Z)) : list (Z * Z) := flat_map (fun t => flat_map (sat_multerm s t) q) p. @@ -646,20 +683,90 @@ Module MulSplit. Lemma eval_sat_mul s p q (s_nonzero:s<>0): Associational.eval (sat_mul s p q) = Associational.eval p * Associational.eval q. Proof. - cbv [sat_mul]; induction p; [reflexivity|]. - repeat match goal with - | _ => progress (autorewrite with push_eval in * ) - | _ => progress simpl flat_map - | _ => rewrite IHp - | _ => ring_simplify; omega - end. + cbv [sat_mul]; induction p; [reflexivity|]. + repeat match goal with + | _ => progress (autorewrite with push_eval in * ) + | _ => progress simpl flat_map + | _ => rewrite IHp + | _ => ring_simplify; omega + end. Qed. Hint Rewrite eval_sat_mul : push_eval. End Associational. End Associational. -End MulSplit. + + Section DivMod. + Lemma mod_step a b c d: 0 < a -> 0 < b -> + c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b). + Proof. + intros; rewrite Z.rem_mul_r by omega. push_Zmod. + autorewrite with zsimplify pull_Zmod. repeat (f_equal; try ring). + Qed. + + Lemma div_step a b c d : 0 < a -> 0 < b -> + (c / a + d) / b = (a * d + c) / (a * b). + Proof. + intros. rewrite <- Z.div_add' by omega. + autorewrite with pull_Zdiv. repeat (f_equal; try ring ). + Qed. + + Lemma add_mod_div_multiple a b n m: + n > 0 -> + 0 <= m / n -> + m mod n = 0 -> + (a / n + b) mod (m / n) = (a + n * b) mod m / n. + Proof. + intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero. + rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt. + repeat (f_equal; try omega). + Qed. + + Lemma add_mod_l_multiple a b n m: + 0 < n / m -> m <> 0 -> n mod m = 0 -> + (a mod n + b) mod m = (a + b) mod m. + Proof. + intros. + rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto. + rewrite Z.rem_mul_r by auto. + push_Zmod. autorewrite with zsimplify. + pull_Zmod. reflexivity. + Qed. + + Definition is_div_mod {T} (evalf : T -> Z) dm y n := + evalf (fst dm) = y mod n /\ snd dm = y / n. + + Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x : + n1 > 0 -> + 0 < n2 / n1 -> + n2 mod n1 = 0 -> + evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) -> + snd dm2 = (snd dm1 + x) / (n2 / n1) -> + y2 = y1 + n1 * x -> + @is_div_mod T evalf1 dm1 y1 n1 -> + @is_div_mod T evalf2 dm2 y2 n2. + Proof. + intros; subst y2; cbv [is_div_mod] in *. + repeat match goal with + | H: _ /\ _ |- _ => destruct H + | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end + | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end + | _ => rewrite mod_step by omega + | _ => rewrite div_step by omega + | _ => rewrite Z.mul_div_eq_full by omega + end. + split; f_equal; omega. + Qed. + + Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n : + y1 = y2 -> + @is_div_mod T evalf dm y1 n -> + @is_div_mod T evalf dm y2 n. + Proof. congruence. Qed. + End DivMod. +End Saturated. Module Columns. + Import Saturated. Section Columns. Context (weight : nat->Z) {weight_0 : weight 0%nat = 1} @@ -679,30 +786,7 @@ Module Columns. apply Positional.eval_snoc; distr_length. Qed. Hint Rewrite eval_snoc using (solve [distr_length]) : push_eval. - (* TODO: move out of Columns? *) - Section Weight. - Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0. - Proof. - induction j; intros; - repeat match goal with - | _ => rewrite Nat.add_succ_r - | _ => rewrite IHj - | |- context [weight (S ?x) mod weight _] => - rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto - | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast - | _ => reflexivity - end. - Qed. - Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0. - Proof. - intros; replace j with (i + (j - i))%nat by omega. - apply weight_multiples_full'. - Qed. - - Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i). - Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed. - End Weight. - + (* TODO: move to ListUtil *) Lemma list_rect_to_match A (P:list A -> Type) (Pnil: P []) (PS: forall a tl, P (a :: tl)) ls : @list_rect A P Pnil (fun a tl _ => PS a tl) ls = match ls with | cons a tl => PS a tl @@ -794,15 +878,6 @@ Module Columns. end. Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod. - (* helper for some of the modular logic in flatten *) - Lemma flatten_mod_step a b c d: 0 < a -> 0 < b -> - c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b). - Proof. intros; rewrite Z.rem_mul_r by omega. push_Zmod. push. Qed. - - Lemma flatten_div_step a b c d : 0 < a -> 0 < b -> - (c / a + d) / b = (a * d + c) / (a * b). - Proof. intros; push. Qed. - Hint Rewrite Positional.eval_nil : push_eval. Hint Resolve Z.gt_lt. @@ -829,7 +904,7 @@ Module Columns. | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity) | H: _ = _ mod (weight _) |- _ => rewrite H | H: _ = _ / (weight _) |- _ => rewrite H - | _ => progress rewrite ?flatten_mod_step, ?flatten_div_step by auto + | _ => progress rewrite ?mod_step, ?div_step by auto | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval | _ => progress (distr_length; push_fast) end. @@ -860,7 +935,7 @@ Module Columns. induction inp using rev_ind; intros; destruct n; distr_length. rewrite flatten_snoc. push; distr_length; - [rewrite IHinp with (n:=n) by omega; rewrite (weight_div_mod n (S i)) by omega; push_Zmod; push |]. + [rewrite IHinp with (n:=n) by omega; rewrite weight_div_mod with (j:=n) (i:=S i) by (eauto; omega); push_Zmod; push |]. repeat match goal with | _ => progress replace (length inp) with n by omega | _ => progress replace i with n by omega @@ -940,47 +1015,14 @@ Module Columns. Definition mul s n m (p q : list Z) : list Z := let p_a := Positional.to_associational weight n p in let q_a := Positional.to_associational weight n q in - let pq_a := MulSplit.Associational.sat_mul s p_a q_a in + let pq_a := Associational.sat_mul s p_a q_a in fst (flatten (from_associational m pq_a)). End mul. End Columns. End Columns. -Module DivMod. - Definition is_div_mod {T} (evalf : T -> Z) dm y n := - evalf (fst dm) = y mod n /\ snd dm = y / n. - - Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x : - n1 > 0 -> - 0 < n2 / n1 -> - n2 mod n1 = 0 -> - evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) -> - snd dm2 = (snd dm1 + x) / (n2 / n1) -> - y2 = y1 + n1 * x -> - @is_div_mod T evalf1 dm1 y1 n1 -> - @is_div_mod T evalf2 dm2 y2 n2. - Proof. - intros; subst y2; cbv [is_div_mod] in *. - repeat match goal with - | H: _ /\ _ |- _ => destruct H - | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end - | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end - | _ => rewrite (Columns.flatten_mod_step (fun _ => 0)) by omega - | _ => rewrite (Columns.flatten_div_step (fun _ => 0)) by omega - | _ => rewrite Z.mul_div_eq_full by omega - end. - split; f_equal; omega. - Qed. - - Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n : - y1 = y2 -> - @is_div_mod T evalf dm y1 n -> - @is_div_mod T evalf dm y2 n. - Proof. congruence. Qed. -End DivMod. - Module Rows. - Import DivMod. + Import Saturated. Section Rows. Context (weight : nat->Z) {weight_0 : weight 0%nat = 1} @@ -1272,22 +1314,6 @@ Module Rows. = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n). Proof. apply sum_rows_div_mod. Qed. - (* TODO: figure out where to put this and weight_multiples_full *) - Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0. - Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed. - - (* TODO: move to ZUtil *) - Lemma add_mod_div_multiple a b n m: - n > 0 -> - 0 <= m / n -> - m mod n = 0 -> - (a / n + b) mod (m / n) = (a + n * b) mod m / n. - Proof. - intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero. - rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt. - repeat (f_equal; try omega). - Qed. - Lemma sum_rows'_partitions row1 : forall nm start_state row2 row1' row2', let m := length (fst start_state) in @@ -1314,9 +1340,9 @@ Module Rows. | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega | _ => rewrite <-(Z.add_assoc _ x1 x2) end. - { rewrite (@Columns.flatten_div_step weight) by auto using Z.gt_lt. + { rewrite div_step by auto using Z.gt_lt. rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples. push. } - { erewrite @Columns.weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (eauto; omega). + { rewrite weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (auto; omega). push_Zmod. autorewrite with zsimplify_fast. reflexivity. } { push. replace (length (fst start_state)) with j in * by omega. push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl. @@ -1447,18 +1473,6 @@ Module Rows. subst row; distr_length; auto. Qed. Hint Rewrite length_flatten : distr_length. - (* TODO: move to ZUtil *) - Lemma add_mod_l_multiple a b n m: - 0 < n / m -> m <> 0 -> n mod m = 0 -> - (a mod n + b) mod m = (a + b) mod m. - Proof. - intros. - rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto. - rewrite Z.rem_mul_r by auto. - push_Zmod. autorewrite with zsimplify. - pull_Zmod. reflexivity. - Qed. - Lemma flatten'_partitions n inp : forall start_state, length (fst start_state) = n -> (forall row, In row inp -> length row = n) -> @@ -1471,7 +1485,7 @@ Module Rows. destruct (dec (inp = nil)). { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. } { erewrite IHinp; push. - rewrite add_mod_l_multiple by auto using weight_divides_full, Columns.weight_multiples_full. + rewrite add_mod_l_multiple by auto using weight_divides_full, weight_multiples_full. repeat (f_equal; try ring). } Qed. @@ -8291,7 +8305,6 @@ Local Open Scope expr_scope. Print Montgomery256.montred256. (* -<<<<<<< HEAD c.ShiftR($x0, $x_lo, 128); c.Lower128($x1, $x_lo); c.Mul128x128($x2, Lower128{RegPinv}, $x0); |