aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-08 15:51:21 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commitab4bea12d33f3c4225d0af577242c1bbae12d420 (patch)
tree25908bf5eb407a48f62285cd88bdc195cda7055d /src/Experiments/SimplyTypedArithmetic.v
parentdda6031883d48cb9a775be561154f09d44e6303c (diff)
move some shared lemmas between Columns/Rows into a Saturated module
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v239
1 files changed, 126 insertions, 113 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index da53fcf6c..62168b9a8 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -617,14 +617,51 @@ Module BaseConversion.
End BaseConversion.
End BaseConversion.
-(* Non-CPS version of Arithmetic/Saturated/MulSplit.v *)
-Module MulSplit.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core.
+Require Import Crypto.Util.ZUtil.Hints.PullPush.
+
+Module Saturated.
+ Section Weight.
+ Context (weight : nat->Z)
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_positive : forall i, weight i > 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ {weight_divides : forall i : nat, weight (S i) / weight i > 0}.
+
+ Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0.
+ Proof.
+ induction j; intros;
+ repeat match goal with
+ | _ => rewrite Nat.add_succ_r
+ | _ => rewrite IHj
+ | |- context [weight (S ?x) mod weight _] =>
+ rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto
+ | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast
+ | _ => reflexivity
+ end.
+ Qed.
+
+ Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0.
+ Proof.
+ intros; replace j with (i + (j - i))%nat by omega.
+ apply weight_multiples_full'.
+ Qed.
+
+ Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0.
+ Proof. auto using Z.div_positive_gt_0, weight_multiples_full. Qed.
+
+ Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i).
+ Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed.
+ End Weight.
+
Module Associational.
Section Associational.
Definition sat_multerm s (t t' : (Z * Z)) : list (Z * Z) :=
dlet_nd xy := Z.mul_split s (snd t) (snd t') in
- [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)].
+ [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)].
Definition sat_mul s (p q : list (Z * Z)) : list (Z * Z) :=
flat_map (fun t => flat_map (sat_multerm s t) q) p.
@@ -646,20 +683,90 @@ Module MulSplit.
Lemma eval_sat_mul s p q (s_nonzero:s<>0):
Associational.eval (sat_mul s p q) = Associational.eval p * Associational.eval q.
Proof.
- cbv [sat_mul]; induction p; [reflexivity|].
- repeat match goal with
- | _ => progress (autorewrite with push_eval in * )
- | _ => progress simpl flat_map
- | _ => rewrite IHp
- | _ => ring_simplify; omega
- end.
+ cbv [sat_mul]; induction p; [reflexivity|].
+ repeat match goal with
+ | _ => progress (autorewrite with push_eval in * )
+ | _ => progress simpl flat_map
+ | _ => rewrite IHp
+ | _ => ring_simplify; omega
+ end.
Qed.
Hint Rewrite eval_sat_mul : push_eval.
End Associational.
End Associational.
-End MulSplit.
+
+ Section DivMod.
+ Lemma mod_step a b c d: 0 < a -> 0 < b ->
+ c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b).
+ Proof.
+ intros; rewrite Z.rem_mul_r by omega. push_Zmod.
+ autorewrite with zsimplify pull_Zmod. repeat (f_equal; try ring).
+ Qed.
+
+ Lemma div_step a b c d : 0 < a -> 0 < b ->
+ (c / a + d) / b = (a * d + c) / (a * b).
+ Proof.
+ intros. rewrite <- Z.div_add' by omega.
+ autorewrite with pull_Zdiv. repeat (f_equal; try ring ).
+ Qed.
+
+ Lemma add_mod_div_multiple a b n m:
+ n > 0 ->
+ 0 <= m / n ->
+ m mod n = 0 ->
+ (a / n + b) mod (m / n) = (a + n * b) mod m / n.
+ Proof.
+ intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero.
+ rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt.
+ repeat (f_equal; try omega).
+ Qed.
+
+ Lemma add_mod_l_multiple a b n m:
+ 0 < n / m -> m <> 0 -> n mod m = 0 ->
+ (a mod n + b) mod m = (a + b) mod m.
+ Proof.
+ intros.
+ rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto.
+ rewrite Z.rem_mul_r by auto.
+ push_Zmod. autorewrite with zsimplify.
+ pull_Zmod. reflexivity.
+ Qed.
+
+ Definition is_div_mod {T} (evalf : T -> Z) dm y n :=
+ evalf (fst dm) = y mod n /\ snd dm = y / n.
+
+ Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x :
+ n1 > 0 ->
+ 0 < n2 / n1 ->
+ n2 mod n1 = 0 ->
+ evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) ->
+ snd dm2 = (snd dm1 + x) / (n2 / n1) ->
+ y2 = y1 + n1 * x ->
+ @is_div_mod T evalf1 dm1 y1 n1 ->
+ @is_div_mod T evalf2 dm2 y2 n2.
+ Proof.
+ intros; subst y2; cbv [is_div_mod] in *.
+ repeat match goal with
+ | H: _ /\ _ |- _ => destruct H
+ | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end
+ | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end
+ | _ => rewrite mod_step by omega
+ | _ => rewrite div_step by omega
+ | _ => rewrite Z.mul_div_eq_full by omega
+ end.
+ split; f_equal; omega.
+ Qed.
+
+ Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n :
+ y1 = y2 ->
+ @is_div_mod T evalf dm y1 n ->
+ @is_div_mod T evalf dm y2 n.
+ Proof. congruence. Qed.
+ End DivMod.
+End Saturated.
Module Columns.
+ Import Saturated.
Section Columns.
Context (weight : nat->Z)
{weight_0 : weight 0%nat = 1}
@@ -679,30 +786,7 @@ Module Columns.
apply Positional.eval_snoc; distr_length.
Qed. Hint Rewrite eval_snoc using (solve [distr_length]) : push_eval.
- (* TODO: move out of Columns? *)
- Section Weight.
- Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0.
- Proof.
- induction j; intros;
- repeat match goal with
- | _ => rewrite Nat.add_succ_r
- | _ => rewrite IHj
- | |- context [weight (S ?x) mod weight _] =>
- rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto
- | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast
- | _ => reflexivity
- end.
- Qed.
- Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0.
- Proof.
- intros; replace j with (i + (j - i))%nat by omega.
- apply weight_multiples_full'.
- Qed.
-
- Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i).
- Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed.
- End Weight.
-
+ (* TODO: move to ListUtil *)
Lemma list_rect_to_match A (P:list A -> Type) (Pnil: P []) (PS: forall a tl, P (a :: tl)) ls :
@list_rect A P Pnil (fun a tl _ => PS a tl) ls = match ls with
| cons a tl => PS a tl
@@ -794,15 +878,6 @@ Module Columns.
end.
Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod.
- (* helper for some of the modular logic in flatten *)
- Lemma flatten_mod_step a b c d: 0 < a -> 0 < b ->
- c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b).
- Proof. intros; rewrite Z.rem_mul_r by omega. push_Zmod. push. Qed.
-
- Lemma flatten_div_step a b c d : 0 < a -> 0 < b ->
- (c / a + d) / b = (a * d + c) / (a * b).
- Proof. intros; push. Qed.
-
Hint Rewrite Positional.eval_nil : push_eval.
Hint Resolve Z.gt_lt.
@@ -829,7 +904,7 @@ Module Columns.
| _ => erewrite Positional.eval_snoc by (distr_length; reflexivity)
| H: _ = _ mod (weight _) |- _ => rewrite H
| H: _ = _ / (weight _) |- _ => rewrite H
- | _ => progress rewrite ?flatten_mod_step, ?flatten_div_step by auto
+ | _ => progress rewrite ?mod_step, ?div_step by auto
| _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval
| _ => progress (distr_length; push_fast)
end.
@@ -860,7 +935,7 @@ Module Columns.
induction inp using rev_ind; intros; destruct n; distr_length.
rewrite flatten_snoc.
push; distr_length;
- [rewrite IHinp with (n:=n) by omega; rewrite (weight_div_mod n (S i)) by omega; push_Zmod; push |].
+ [rewrite IHinp with (n:=n) by omega; rewrite weight_div_mod with (j:=n) (i:=S i) by (eauto; omega); push_Zmod; push |].
repeat match goal with
| _ => progress replace (length inp) with n by omega
| _ => progress replace i with n by omega
@@ -940,47 +1015,14 @@ Module Columns.
Definition mul s n m (p q : list Z) : list Z :=
let p_a := Positional.to_associational weight n p in
let q_a := Positional.to_associational weight n q in
- let pq_a := MulSplit.Associational.sat_mul s p_a q_a in
+ let pq_a := Associational.sat_mul s p_a q_a in
fst (flatten (from_associational m pq_a)).
End mul.
End Columns.
End Columns.
-Module DivMod.
- Definition is_div_mod {T} (evalf : T -> Z) dm y n :=
- evalf (fst dm) = y mod n /\ snd dm = y / n.
-
- Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x :
- n1 > 0 ->
- 0 < n2 / n1 ->
- n2 mod n1 = 0 ->
- evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) ->
- snd dm2 = (snd dm1 + x) / (n2 / n1) ->
- y2 = y1 + n1 * x ->
- @is_div_mod T evalf1 dm1 y1 n1 ->
- @is_div_mod T evalf2 dm2 y2 n2.
- Proof.
- intros; subst y2; cbv [is_div_mod] in *.
- repeat match goal with
- | H: _ /\ _ |- _ => destruct H
- | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end
- | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end
- | _ => rewrite (Columns.flatten_mod_step (fun _ => 0)) by omega
- | _ => rewrite (Columns.flatten_div_step (fun _ => 0)) by omega
- | _ => rewrite Z.mul_div_eq_full by omega
- end.
- split; f_equal; omega.
- Qed.
-
- Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n :
- y1 = y2 ->
- @is_div_mod T evalf dm y1 n ->
- @is_div_mod T evalf dm y2 n.
- Proof. congruence. Qed.
-End DivMod.
-
Module Rows.
- Import DivMod.
+ Import Saturated.
Section Rows.
Context (weight : nat->Z)
{weight_0 : weight 0%nat = 1}
@@ -1272,22 +1314,6 @@ Module Rows.
= (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n).
Proof. apply sum_rows_div_mod. Qed.
- (* TODO: figure out where to put this and weight_multiples_full *)
- Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0.
- Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed.
-
- (* TODO: move to ZUtil *)
- Lemma add_mod_div_multiple a b n m:
- n > 0 ->
- 0 <= m / n ->
- m mod n = 0 ->
- (a / n + b) mod (m / n) = (a + n * b) mod m / n.
- Proof.
- intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero.
- rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt.
- repeat (f_equal; try omega).
- Qed.
-
Lemma sum_rows'_partitions row1 :
forall nm start_state row2 row1' row2',
let m := length (fst start_state) in
@@ -1314,9 +1340,9 @@ Module Rows.
| H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega
| _ => rewrite <-(Z.add_assoc _ x1 x2)
end.
- { rewrite (@Columns.flatten_div_step weight) by auto using Z.gt_lt.
+ { rewrite div_step by auto using Z.gt_lt.
rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples. push. }
- { erewrite @Columns.weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (eauto; omega).
+ { rewrite weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (auto; omega).
push_Zmod. autorewrite with zsimplify_fast. reflexivity. }
{ push. replace (length (fst start_state)) with j in * by omega.
push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl.
@@ -1447,18 +1473,6 @@ Module Rows.
subst row; distr_length; auto.
Qed. Hint Rewrite length_flatten : distr_length.
- (* TODO: move to ZUtil *)
- Lemma add_mod_l_multiple a b n m:
- 0 < n / m -> m <> 0 -> n mod m = 0 ->
- (a mod n + b) mod m = (a + b) mod m.
- Proof.
- intros.
- rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto.
- rewrite Z.rem_mul_r by auto.
- push_Zmod. autorewrite with zsimplify.
- pull_Zmod. reflexivity.
- Qed.
-
Lemma flatten'_partitions n inp : forall start_state,
length (fst start_state) = n ->
(forall row, In row inp -> length row = n) ->
@@ -1471,7 +1485,7 @@ Module Rows.
destruct (dec (inp = nil)).
{ subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. }
{ erewrite IHinp; push.
- rewrite add_mod_l_multiple by auto using weight_divides_full, Columns.weight_multiples_full.
+ rewrite add_mod_l_multiple by auto using weight_divides_full, weight_multiples_full.
repeat (f_equal; try ring). }
Qed.
@@ -8291,7 +8305,6 @@ Local Open Scope expr_scope.
Print Montgomery256.montred256.
(*
-<<<<<<< HEAD
c.ShiftR($x0, $x_lo, 128);
c.Lower128($x1, $x_lo);
c.Mul128x128($x2, Lower128{RegPinv}, $x0);