aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/NewPipeline
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2018-09-15 21:06:28 -0400
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-09-17 21:34:36 -0400
commit663864c832e2e94d87b8d19bd8163bbd4c4293a3 (patch)
tree62cf56584f002411c5e65956ccd4250e7f90ca14 /src/Experiments/NewPipeline
parent49395ae4814a31abc055f0bd9d026a3daa3e33f4 (diff)
redo all Rows correctness proofs using partition and sanity, remove now-unused Saturated.DivMod
Diffstat (limited to 'src/Experiments/NewPipeline')
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v281
1 files changed, 68 insertions, 213 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index 8461e13d4..07f2b3a79 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -1288,72 +1288,6 @@ Module Saturated.
Hint Rewrite eval_sat_mul_const : push_eval.
End Associational.
End Associational.
-
- Section DivMod.
- Lemma mod_step a b c d: 0 < a -> 0 < b ->
- c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b).
- Proof using Type.
- intros; rewrite Z.rem_mul_r by omega. push_Zmod.
- autorewrite with zsimplify pull_Zmod. repeat (f_equal; try ring).
- Qed.
-
- Lemma div_step a b c d : 0 < a -> 0 < b ->
- (c / a + d) / b = (a * d + c) / (a * b).
- Proof using Type. intros; Z.div_mod_to_quot_rem_in_goal; nia. Qed.
-
- Lemma add_mod_div_multiple a b n m:
- n > 0 ->
- 0 <= m / n ->
- m mod n = 0 ->
- (a / n + b) mod (m / n) = (a + n * b) mod m / n.
- Proof using Type.
- intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero.
- rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt.
- repeat (f_equal; try omega).
- Qed.
-
- Lemma add_mod_l_multiple a b n m:
- 0 < n / m -> m <> 0 -> n mod m = 0 ->
- (a mod n + b) mod m = (a + b) mod m.
- Proof using Type.
- intros.
- rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto.
- rewrite Z.rem_mul_r by auto.
- push_Zmod. autorewrite with zsimplify.
- pull_Zmod. reflexivity.
- Qed.
-
- Definition is_div_mod {T} (evalf : T -> Z) dm y n :=
- evalf (fst dm) = y mod n /\ snd dm = y / n.
-
- Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x :
- n1 > 0 ->
- 0 < n2 / n1 ->
- n2 mod n1 = 0 ->
- evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) ->
- snd dm2 = (snd dm1 + x) / (n2 / n1) ->
- y2 = y1 + n1 * x ->
- @is_div_mod T evalf1 dm1 y1 n1 ->
- @is_div_mod T evalf2 dm2 y2 n2.
- Proof using Type.
- intros; subst y2; cbv [is_div_mod] in *.
- repeat match goal with
- | H: _ /\ _ |- _ => destruct H
- | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end
- | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end
- | _ => rewrite mod_step by omega
- | _ => rewrite div_step by omega
- | _ => rewrite Z.mul_div_eq_full by omega
- end.
- split; f_equal; omega.
- Qed.
-
- Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n :
- y1 = y2 ->
- @is_div_mod T evalf dm y1 n ->
- @is_div_mod T evalf dm y2 n.
- Proof using Type. congruence. Qed.
- End DivMod.
End Saturated.
Module Partition.
@@ -1496,6 +1430,7 @@ Module Partition.
Qed.
End Partition.
Hint Rewrite length_partition length_recursive_partition : distr_length.
+ Hint Rewrite eval_partition using auto : push_eval.
End Partition.
Module Columns.
@@ -2040,55 +1975,8 @@ Module Rows.
Hint Rewrite sum_rows'_cons sum_rows'_nil : push_sum_rows.
- Lemma sum_rows'_div_mod_length row1 :
- forall nm start_state row2 row1' row2',
- let m := snd start_state in
- let n := length row1 in
- length row2 = n ->
- length row1' = m ->
- length row2' = m ->
- length (fst (fst start_state)) = m ->
- (nm = n + m)%nat ->
- let eval := Positional.eval weight in
- is_div_mod (eval m) (fst start_state) (eval m row1' + eval m row2') (weight m) ->
- length (fst (fst (sum_rows' start_state row1 row2))) = nm
- /\ is_div_mod (eval nm) (fst (sum_rows' start_state row1 row2))
- (eval nm (row1' ++ row1) + eval nm (row2' ++ row2))
- (weight nm).
- Proof using wprops.
- induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [ ].
- rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
- apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega.
- eapply is_div_mod_step with (x := x1 + x2); try eassumption; push.
- Qed.
-
- (* TODO : examine is_div_mod to see if it can be improved,
- especially since partition gives the "mod" part for
- free. Maybe put the remainder on the end, prove that this is a
- partition, and then split them? *)
- Lemma sum_rows_div_mod n row1 row2 :
- length row1 = n -> length row2 = n ->
- let eval := Positional.eval weight in
- is_div_mod (eval n) (sum_rows row1 row2) (eval n row1 + eval n row2) (weight n).
- Proof using wprops.
- cbv [sum_rows]; intros.
- apply sum_rows'_div_mod_length with (row1':=nil) (row2':=nil);
- cbv [is_div_mod]; autorewrite with cancel_pair push_eval zsimplify; distr_length.
- Qed.
-
- Lemma sum_rows_mod n row1 row2 :
- length row1 = n -> length row2 = n ->
- Positional.eval weight n (fst (sum_rows row1 row2))
- = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n).
- Proof using wprops. apply sum_rows_div_mod. Qed.
- Lemma sum_rows_div row1 row2 n:
- length row1 = n -> length row2 = n ->
- snd (sum_rows row1 row2)
- = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n).
- Proof using wprops. apply sum_rows_div_mod. Qed.
-
- Lemma sum_rows'_partitions row1 :
- forall nm start_state row2 row1' row2',
+ Lemma sum_rows'_correct row1 :
+ forall start_state nm row2 row1' row2',
let m := snd start_state in
let n := length row1 in
length row2 = n ->
@@ -2099,48 +1987,60 @@ Module Rows.
let eval := Positional.eval weight in
snd (fst start_state) = (eval m row1' + eval m row2') / weight m ->
(fst (fst start_state) = partition weight m (eval m row1' + eval m row2')) ->
- fst (fst (sum_rows' start_state row1 row2))
- = partition weight nm (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)).
+ let sum := eval nm (row1' ++ row1) + eval nm (row2' ++ row2) in
+ sum_rows' start_state row1 row2
+ = (partition weight nm sum, sum / weight nm, nm) .
Proof using wprops.
- induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [].
+ destruct start_state as [ [acc rem] m].
+ cbn [fst snd]. revert acc rem m.
+ induction row1 as [|x1 row1];
+ destruct row2 as [|x2 row2]; intros;
+ subst nm; push; [ congruence | ].
rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ subst rem acc.
apply IHrow1; clear IHrow1;
repeat match goal with
- | H : ?LHS = _ |- _ =>
- match LHS with context [start_state] => rewrite H end
| _ => rewrite <-(Z.add_assoc _ x1 x2)
| _ => rewrite div_step by auto using Z.gt_lt
| _ => rewrite Z.mul_div_eq_full by auto
| _ => rewrite weight_multiples by auto
| _ => rewrite partition_step by auto
- | _ => rewrite add_mod_div_multiple by auto using Z.lt_le_incl
+ | _ => rewrite weight_div_pull_div by auto
| _ => rewrite weight_mod_pull_div by auto
+ | _ => rewrite <-Z.div_add' by auto
| _ => progress push
end.
- f_equal; push; [ ].
- apply (@partition_eq_mod _ wprops).
- push_Zmod.
- autorewrite with zsimplify_fast; reflexivity.
+ f_equal; push; [ ].
+ apply (@partition_eq_mod _ wprops).
+ push_Zmod.
+ autorewrite with zsimplify_fast; reflexivity.
Qed.
- Lemma sum_rows_partitions row1: forall row2 n,
+ Lemma sum_rows_correct row1: forall row2 n,
length row1 = n -> length row2 = n ->
- fst (sum_rows row1 row2)
- = partition weight n (Positional.eval weight n row1 + Positional.eval weight n row2).
+ let sum := Positional.eval weight n row1 + Positional.eval weight n row2 in
+ sum_rows row1 row2 = (partition weight n sum, sum / weight n).
+ Proof using wprops.
+ cbv [sum_rows]; intros.
+ erewrite sum_rows'_correct with (nm:=n) (row1':=nil) (row2':=nil)by (cbn; distr_length; reflexivity).
+ reflexivity.
+ Qed.
+
+ Lemma sum_rows_mod n row1 row2 :
+ length row1 = n -> length row2 = n ->
+ Positional.eval weight n (fst (sum_rows row1 row2))
+ = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n).
Proof using wprops.
- cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
- rewrite <-(app_nil_l row1), <-(app_nil_l row2).
- apply sum_rows'_partitions; intros;
- autorewrite with cancel_pair push_eval zsimplify_fast; distr_length; reflexivity.
+ intros; erewrite sum_rows_correct by eauto.
+ cbn [fst]. auto using eval_partition.
Qed.
Lemma length_sum_rows row1 row2 n:
length row1 = n -> length row2 = n ->
length (fst (sum_rows row1 row2)) = n.
Proof using wprops.
- cbv [sum_rows]; intros.
- eapply sum_rows'_div_mod_length; cbv [is_div_mod];
- autorewrite with cancel_pair; distr_length; auto using nil_length0.
+ intros; erewrite sum_rows_correct by eauto.
+ cbn [fst]. distr_length.
Qed. Hint Rewrite length_sum_rows : distr_length.
End SumRows.
Hint Resolve length_sum_rows.
@@ -2169,8 +2069,8 @@ Module Rows.
Ltac push :=
repeat match goal with
| _ => progress intros
- | H: length ?x = ?n |- context [snd (sum_rows ?x _)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto)
- | H: length ?x = ?n |- context [snd (sum_rows _ ?x)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto)
+ | _ => erewrite sum_rows_correct by (eassumption || distr_length; reflexivity)
+ | _ => rewrite eval_partition by auto
| H: length _ = _ |- _ => rewrite H
| _ => progress autorewrite with cancel_pair push_flatten push_eval distr_length zsimplify_fast
| _ => progress In_cases
@@ -2182,93 +2082,50 @@ Module Rows.
| _ => solve [eauto]
end.
- Lemma flatten'_div_mod_length n inp : forall start_state,
+ Lemma flatten'_correct n inp : forall start_state,
length (fst start_state) = n ->
(forall row, In row inp -> length row = n) ->
- length (fst (flatten' start_state inp)) = n
- /\ (inp <> nil ->
- is_div_mod (Positional.eval weight n) (flatten' start_state inp)
- (Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state)
- (weight n)).
+ inp <> nil ->
+ let sum := Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state in
+ flatten' start_state inp = (partition weight n sum, sum / weight n).
Proof using wprops.
- induction inp using rev_ind; push; [apply IHinp; push|].
- destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod]
- | eapply is_div_mod_result_equal; try apply IHinp]; push.
- { autorewrite with zsimplify; push. }
- { rewrite Z.div_add' by auto; push. }
+ induction inp using rev_ind; push. subst sum.
+ destruct (dec (inp = nil)); [ subst inp; cbn | ];
+ repeat match goal with
+ | _ => rewrite IHinp by push; clear IHinp
+ | |- pair _ _ = pair _ _ => f_equal
+ | _ => apply (@partition_eq_mod _ wprops)
+ | _ => rewrite <-Z.div_add_l' by auto
+ | _ => rewrite Z.mod_add'_full by omega
+ | _ => rewrite Z.mul_div_eq_full by auto
+ | _ => progress (push_Zmod; pull_Zmod)
+ | _ => progress push
+ end.
Qed.
Hint Rewrite (@Positional.length_zeros) : distr_length.
Hint Rewrite (@Positional.eval_zeros) using auto : push_eval.
- Lemma flatten_div_mod inp n :
+ Lemma flatten_correct inp n :
(forall row, In row inp -> length row = n) ->
- is_div_mod (Positional.eval weight n) (flatten n inp) (eval n inp) (weight n).
+ flatten n inp = (partition weight n (eval n inp), eval n inp / weight n).
Proof using wprops.
intros; cbv [flatten].
- destruct inp; [|destruct inp]; cbn [hd tl].
- { cbv [is_div_mod]; push.
- erewrite sum_rows_div by (distr_length; reflexivity).
- push. }
- { cbv [is_div_mod]; push. }
- { eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push. }
+ destruct inp; [|destruct inp]; cbn [hd tl];
+ [ | | erewrite ?flatten'_correct ]; push.
Qed.
Lemma flatten_mod inp n :
(forall row, In row inp -> length row = n) ->
Positional.eval weight n (fst (flatten n inp)) = (eval n inp) mod (weight n).
- Proof using wprops. apply flatten_div_mod. Qed.
- Lemma flatten_div inp n :
- (forall row, In row inp -> length row = n) ->
- snd (flatten n inp) = (eval n inp) / (weight n).
- Proof using wprops. apply flatten_div_mod. Qed.
-
- Lemma length_flatten' n start_state inp :
- length (fst start_state) = n ->
- (forall row, In row inp -> length row = n) ->
- length (fst (flatten' start_state inp)) = n.
- Proof using wprops. apply flatten'_div_mod_length. Qed.
- Hint Rewrite length_flatten' : distr_length.
+ Proof using wprops. intros; rewrite flatten_correct; push. Qed.
Lemma length_flatten n inp :
(forall row, In row inp -> length row = n) ->
length (fst (flatten n inp)) = n.
- Proof using wprops.
- intros.
- apply length_flatten'; push;
- destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push;
- subst row; distr_length.
- Qed. Hint Rewrite length_flatten : distr_length.
-
- Lemma flatten'_partitions n inp : forall start_state,
- inp <> nil ->
- length (fst start_state) = n ->
- (forall row, In row inp -> length row = n) ->
- fst (flatten' start_state inp)
- = partition weight n (Positional.eval weight n (fst start_state) + eval n inp).
- Proof using wprops.
- induction inp using rev_ind; push.
- destruct (dec (inp = nil)).
- { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. f_equal; ring. }
- { erewrite IHinp; push.
- apply (@partition_eq_mod _ wprops).
- pull_Zmod. f_equal; ring. }
- Qed.
-
- Lemma flatten_partitions' inp n :
- (forall row, In row inp -> length row = n) ->
- fst (flatten n inp) = partition weight n (eval n inp).
- Proof using wprops.
- intros; cbv [flatten].
- intros; destruct inp as [| ? [| ? ?] ]; cbn [hd tl] in *;
- repeat match goal with
- | _ => erewrite flatten'_partitions by push
- | _ => erewrite sum_rows_partitions by (distr_length; reflexivity)
- | _ => progress push
- end.
- Qed.
+ Proof using wprops. intros; rewrite flatten_correct by assumption; push. Qed.
End Flatten.
- Hint Rewrite length_partition : distr_length.
+ Hint Rewrite length_flatten : distr_length.
Section Ops.
Definition add n p q := flatten n [p; q].
@@ -2324,11 +2181,9 @@ Module Rows.
Hint Rewrite Associational.eval_sat_mul_const Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval.
Hint Rewrite eval_from_associational using solve [auto] : push_eval.
- Hint Rewrite eval_partition using solve [auto] : push_eval.
Ltac solver :=
intros; cbv [sub add mul mulmod sat_reduce];
- rewrite ?flatten_partitions' by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
- rewrite ?flatten_div by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
+ rewrite ?flatten_correct by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
autorewrite with push_eval; ring_simplify_subterms;
try reflexivity.
@@ -2395,7 +2250,7 @@ Module Rows.
Lemma length_mul base n m p q :
length p = n -> length q = n ->
length (fst (mul base n m p q)) = m.
- Proof using wprops. solver; distr_length. Qed.
+ Proof using wprops. solver; cbn [fst snd]; distr_length. Qed.
Lemma eval_sat_reduce base s c p :
base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 ->
@@ -2426,7 +2281,8 @@ Module Rows.
+ weight n * (snd (mulmod base s c n nreductions p q))) mod (s - Associational.eval c)
= (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c).
Proof using wprops.
- solver.
+ solver. cbn [fst snd].
+ rewrite eval_partition by auto.
rewrite <-Z.div_mod'' by auto.
autorewrite with push_eval; reflexivity.
Qed.
@@ -2438,8 +2294,7 @@ Module Rows.
End Rows.
Hint Rewrite length_from_columns using eassumption : distr_length.
Hint Rewrite length_sum_rows using solve [ reflexivity | eassumption | distr_length; eauto ] : distr_length.
- Hint Rewrite length_fst_extract_row length_snd_extract_row length_flatten length_flatten' length_partition length_fst_from_columns' length_snd_from_columns' : distr_length.
- Hint Rewrite @eval_partition using auto : push_eval.
+ Hint Rewrite length_fst_extract_row length_snd_extract_row length_flatten length_fst_from_columns' length_snd_from_columns' : distr_length.
End Rows.
Module BaseConversion.
@@ -2527,7 +2382,7 @@ Module BaseConversion.
from_associational idxs n p = partition sw n (Associational.eval p).
Proof using dwprops swprops.
intros. cbv [from_associational].
- rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational.
+ rewrite Rows.flatten_correct with (n:=n) by eauto using Rows.length_from_associational.
rewrite Associational.bind_snd_correct.
push_eval.
Qed.
@@ -2657,7 +2512,7 @@ Module BaseConversion.
{ subst widemul_inlined_reverse; reflexivity. }
{ rewrite from_associational_inlined_correct by (subst nout; auto).
cbv [from_associational].
- rewrite !Rows.flatten_partitions' by eauto using Rows.length_from_associational.
+ rewrite !Rows.flatten_correct by eauto using Rows.length_from_associational.
rewrite !Rows.eval_from_associational by (subst nout; auto).
f_equal.
rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto.
@@ -2911,7 +2766,7 @@ Section freeze_mod_ops.
Proof using Hn_nz limbwidth_good.
clear -Hn_nz limbwidth_good.
intros; cbv [to_bytes].
- rewrite Rows.flatten_partitions' by eauto using wprops, Rows.length_from_associational.
+ rewrite Rows.flatten_correct by eauto using wprops, Rows.length_from_associational.
rewrite Rows.eval_from_associational by eauto using bytes_nz with omega.
rewrite eval_to_associational.
cbv [to_bytes'].