diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Arithmetic.v')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 281 |
1 files changed, 68 insertions, 213 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index 8461e13d4..07f2b3a79 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -1288,72 +1288,6 @@ Module Saturated. Hint Rewrite eval_sat_mul_const : push_eval. End Associational. End Associational. - - Section DivMod. - Lemma mod_step a b c d: 0 < a -> 0 < b -> - c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b). - Proof using Type. - intros; rewrite Z.rem_mul_r by omega. push_Zmod. - autorewrite with zsimplify pull_Zmod. repeat (f_equal; try ring). - Qed. - - Lemma div_step a b c d : 0 < a -> 0 < b -> - (c / a + d) / b = (a * d + c) / (a * b). - Proof using Type. intros; Z.div_mod_to_quot_rem_in_goal; nia. Qed. - - Lemma add_mod_div_multiple a b n m: - n > 0 -> - 0 <= m / n -> - m mod n = 0 -> - (a / n + b) mod (m / n) = (a + n * b) mod m / n. - Proof using Type. - intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero. - rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt. - repeat (f_equal; try omega). - Qed. - - Lemma add_mod_l_multiple a b n m: - 0 < n / m -> m <> 0 -> n mod m = 0 -> - (a mod n + b) mod m = (a + b) mod m. - Proof using Type. - intros. - rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto. - rewrite Z.rem_mul_r by auto. - push_Zmod. autorewrite with zsimplify. - pull_Zmod. reflexivity. - Qed. - - Definition is_div_mod {T} (evalf : T -> Z) dm y n := - evalf (fst dm) = y mod n /\ snd dm = y / n. - - Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x : - n1 > 0 -> - 0 < n2 / n1 -> - n2 mod n1 = 0 -> - evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) -> - snd dm2 = (snd dm1 + x) / (n2 / n1) -> - y2 = y1 + n1 * x -> - @is_div_mod T evalf1 dm1 y1 n1 -> - @is_div_mod T evalf2 dm2 y2 n2. - Proof using Type. - intros; subst y2; cbv [is_div_mod] in *. - repeat match goal with - | H: _ /\ _ |- _ => destruct H - | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end - | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end - | _ => rewrite mod_step by omega - | _ => rewrite div_step by omega - | _ => rewrite Z.mul_div_eq_full by omega - end. - split; f_equal; omega. - Qed. - - Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n : - y1 = y2 -> - @is_div_mod T evalf dm y1 n -> - @is_div_mod T evalf dm y2 n. - Proof using Type. congruence. Qed. - End DivMod. End Saturated. Module Partition. @@ -1496,6 +1430,7 @@ Module Partition. Qed. End Partition. Hint Rewrite length_partition length_recursive_partition : distr_length. + Hint Rewrite eval_partition using auto : push_eval. End Partition. Module Columns. @@ -2040,55 +1975,8 @@ Module Rows. Hint Rewrite sum_rows'_cons sum_rows'_nil : push_sum_rows. - Lemma sum_rows'_div_mod_length row1 : - forall nm start_state row2 row1' row2', - let m := snd start_state in - let n := length row1 in - length row2 = n -> - length row1' = m -> - length row2' = m -> - length (fst (fst start_state)) = m -> - (nm = n + m)%nat -> - let eval := Positional.eval weight in - is_div_mod (eval m) (fst start_state) (eval m row1' + eval m row2') (weight m) -> - length (fst (fst (sum_rows' start_state row1 row2))) = nm - /\ is_div_mod (eval nm) (fst (sum_rows' start_state row1 row2)) - (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) - (weight nm). - Proof using wprops. - induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [ ]. - rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). - apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega. - eapply is_div_mod_step with (x := x1 + x2); try eassumption; push. - Qed. - - (* TODO : examine is_div_mod to see if it can be improved, - especially since partition gives the "mod" part for - free. Maybe put the remainder on the end, prove that this is a - partition, and then split them? *) - Lemma sum_rows_div_mod n row1 row2 : - length row1 = n -> length row2 = n -> - let eval := Positional.eval weight in - is_div_mod (eval n) (sum_rows row1 row2) (eval n row1 + eval n row2) (weight n). - Proof using wprops. - cbv [sum_rows]; intros. - apply sum_rows'_div_mod_length with (row1':=nil) (row2':=nil); - cbv [is_div_mod]; autorewrite with cancel_pair push_eval zsimplify; distr_length. - Qed. - - Lemma sum_rows_mod n row1 row2 : - length row1 = n -> length row2 = n -> - Positional.eval weight n (fst (sum_rows row1 row2)) - = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n). - Proof using wprops. apply sum_rows_div_mod. Qed. - Lemma sum_rows_div row1 row2 n: - length row1 = n -> length row2 = n -> - snd (sum_rows row1 row2) - = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n). - Proof using wprops. apply sum_rows_div_mod. Qed. - - Lemma sum_rows'_partitions row1 : - forall nm start_state row2 row1' row2', + Lemma sum_rows'_correct row1 : + forall start_state nm row2 row1' row2', let m := snd start_state in let n := length row1 in length row2 = n -> @@ -2099,48 +1987,60 @@ Module Rows. let eval := Positional.eval weight in snd (fst start_state) = (eval m row1' + eval m row2') / weight m -> (fst (fst start_state) = partition weight m (eval m row1' + eval m row2')) -> - fst (fst (sum_rows' start_state row1 row2)) - = partition weight nm (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)). + let sum := eval nm (row1' ++ row1) + eval nm (row2' ++ row2) in + sum_rows' start_state row1 row2 + = (partition weight nm sum, sum / weight nm, nm) . Proof using wprops. - induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; []. + destruct start_state as [ [acc rem] m]. + cbn [fst snd]. revert acc rem m. + induction row1 as [|x1 row1]; + destruct row2 as [|x2 row2]; intros; + subst nm; push; [ congruence | ]. rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). + subst rem acc. apply IHrow1; clear IHrow1; repeat match goal with - | H : ?LHS = _ |- _ => - match LHS with context [start_state] => rewrite H end | _ => rewrite <-(Z.add_assoc _ x1 x2) | _ => rewrite div_step by auto using Z.gt_lt | _ => rewrite Z.mul_div_eq_full by auto | _ => rewrite weight_multiples by auto | _ => rewrite partition_step by auto - | _ => rewrite add_mod_div_multiple by auto using Z.lt_le_incl + | _ => rewrite weight_div_pull_div by auto | _ => rewrite weight_mod_pull_div by auto + | _ => rewrite <-Z.div_add' by auto | _ => progress push end. - f_equal; push; [ ]. - apply (@partition_eq_mod _ wprops). - push_Zmod. - autorewrite with zsimplify_fast; reflexivity. + f_equal; push; [ ]. + apply (@partition_eq_mod _ wprops). + push_Zmod. + autorewrite with zsimplify_fast; reflexivity. Qed. - Lemma sum_rows_partitions row1: forall row2 n, + Lemma sum_rows_correct row1: forall row2 n, length row1 = n -> length row2 = n -> - fst (sum_rows row1 row2) - = partition weight n (Positional.eval weight n row1 + Positional.eval weight n row2). + let sum := Positional.eval weight n row1 + Positional.eval weight n row2 in + sum_rows row1 row2 = (partition weight n sum, sum / weight n). + Proof using wprops. + cbv [sum_rows]; intros. + erewrite sum_rows'_correct with (nm:=n) (row1':=nil) (row2':=nil)by (cbn; distr_length; reflexivity). + reflexivity. + Qed. + + Lemma sum_rows_mod n row1 row2 : + length row1 = n -> length row2 = n -> + Positional.eval weight n (fst (sum_rows row1 row2)) + = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n). Proof using wprops. - cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - rewrite <-(app_nil_l row1), <-(app_nil_l row2). - apply sum_rows'_partitions; intros; - autorewrite with cancel_pair push_eval zsimplify_fast; distr_length; reflexivity. + intros; erewrite sum_rows_correct by eauto. + cbn [fst]. auto using eval_partition. Qed. Lemma length_sum_rows row1 row2 n: length row1 = n -> length row2 = n -> length (fst (sum_rows row1 row2)) = n. Proof using wprops. - cbv [sum_rows]; intros. - eapply sum_rows'_div_mod_length; cbv [is_div_mod]; - autorewrite with cancel_pair; distr_length; auto using nil_length0. + intros; erewrite sum_rows_correct by eauto. + cbn [fst]. distr_length. Qed. Hint Rewrite length_sum_rows : distr_length. End SumRows. Hint Resolve length_sum_rows. @@ -2169,8 +2069,8 @@ Module Rows. Ltac push := repeat match goal with | _ => progress intros - | H: length ?x = ?n |- context [snd (sum_rows ?x _)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto) - | H: length ?x = ?n |- context [snd (sum_rows _ ?x)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto) + | _ => erewrite sum_rows_correct by (eassumption || distr_length; reflexivity) + | _ => rewrite eval_partition by auto | H: length _ = _ |- _ => rewrite H | _ => progress autorewrite with cancel_pair push_flatten push_eval distr_length zsimplify_fast | _ => progress In_cases @@ -2182,93 +2082,50 @@ Module Rows. | _ => solve [eauto] end. - Lemma flatten'_div_mod_length n inp : forall start_state, + Lemma flatten'_correct n inp : forall start_state, length (fst start_state) = n -> (forall row, In row inp -> length row = n) -> - length (fst (flatten' start_state inp)) = n - /\ (inp <> nil -> - is_div_mod (Positional.eval weight n) (flatten' start_state inp) - (Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state) - (weight n)). + inp <> nil -> + let sum := Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state in + flatten' start_state inp = (partition weight n sum, sum / weight n). Proof using wprops. - induction inp using rev_ind; push; [apply IHinp; push|]. - destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod] - | eapply is_div_mod_result_equal; try apply IHinp]; push. - { autorewrite with zsimplify; push. } - { rewrite Z.div_add' by auto; push. } + induction inp using rev_ind; push. subst sum. + destruct (dec (inp = nil)); [ subst inp; cbn | ]; + repeat match goal with + | _ => rewrite IHinp by push; clear IHinp + | |- pair _ _ = pair _ _ => f_equal + | _ => apply (@partition_eq_mod _ wprops) + | _ => rewrite <-Z.div_add_l' by auto + | _ => rewrite Z.mod_add'_full by omega + | _ => rewrite Z.mul_div_eq_full by auto + | _ => progress (push_Zmod; pull_Zmod) + | _ => progress push + end. Qed. Hint Rewrite (@Positional.length_zeros) : distr_length. Hint Rewrite (@Positional.eval_zeros) using auto : push_eval. - Lemma flatten_div_mod inp n : + Lemma flatten_correct inp n : (forall row, In row inp -> length row = n) -> - is_div_mod (Positional.eval weight n) (flatten n inp) (eval n inp) (weight n). + flatten n inp = (partition weight n (eval n inp), eval n inp / weight n). Proof using wprops. intros; cbv [flatten]. - destruct inp; [|destruct inp]; cbn [hd tl]. - { cbv [is_div_mod]; push. - erewrite sum_rows_div by (distr_length; reflexivity). - push. } - { cbv [is_div_mod]; push. } - { eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push. } + destruct inp; [|destruct inp]; cbn [hd tl]; + [ | | erewrite ?flatten'_correct ]; push. Qed. Lemma flatten_mod inp n : (forall row, In row inp -> length row = n) -> Positional.eval weight n (fst (flatten n inp)) = (eval n inp) mod (weight n). - Proof using wprops. apply flatten_div_mod. Qed. - Lemma flatten_div inp n : - (forall row, In row inp -> length row = n) -> - snd (flatten n inp) = (eval n inp) / (weight n). - Proof using wprops. apply flatten_div_mod. Qed. - - Lemma length_flatten' n start_state inp : - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - length (fst (flatten' start_state inp)) = n. - Proof using wprops. apply flatten'_div_mod_length. Qed. - Hint Rewrite length_flatten' : distr_length. + Proof using wprops. intros; rewrite flatten_correct; push. Qed. Lemma length_flatten n inp : (forall row, In row inp -> length row = n) -> length (fst (flatten n inp)) = n. - Proof using wprops. - intros. - apply length_flatten'; push; - destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push; - subst row; distr_length. - Qed. Hint Rewrite length_flatten : distr_length. - - Lemma flatten'_partitions n inp : forall start_state, - inp <> nil -> - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - fst (flatten' start_state inp) - = partition weight n (Positional.eval weight n (fst start_state) + eval n inp). - Proof using wprops. - induction inp using rev_ind; push. - destruct (dec (inp = nil)). - { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. f_equal; ring. } - { erewrite IHinp; push. - apply (@partition_eq_mod _ wprops). - pull_Zmod. f_equal; ring. } - Qed. - - Lemma flatten_partitions' inp n : - (forall row, In row inp -> length row = n) -> - fst (flatten n inp) = partition weight n (eval n inp). - Proof using wprops. - intros; cbv [flatten]. - intros; destruct inp as [| ? [| ? ?] ]; cbn [hd tl] in *; - repeat match goal with - | _ => erewrite flatten'_partitions by push - | _ => erewrite sum_rows_partitions by (distr_length; reflexivity) - | _ => progress push - end. - Qed. + Proof using wprops. intros; rewrite flatten_correct by assumption; push. Qed. End Flatten. - Hint Rewrite length_partition : distr_length. + Hint Rewrite length_flatten : distr_length. Section Ops. Definition add n p q := flatten n [p; q]. @@ -2324,11 +2181,9 @@ Module Rows. Hint Rewrite Associational.eval_sat_mul_const Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval. Hint Rewrite eval_from_associational using solve [auto] : push_eval. - Hint Rewrite eval_partition using solve [auto] : push_eval. Ltac solver := intros; cbv [sub add mul mulmod sat_reduce]; - rewrite ?flatten_partitions' by (intros; In_cases; subst; distr_length; eauto using length_from_associational); - rewrite ?flatten_div by (intros; In_cases; subst; distr_length; eauto using length_from_associational); + rewrite ?flatten_correct by (intros; In_cases; subst; distr_length; eauto using length_from_associational); autorewrite with push_eval; ring_simplify_subterms; try reflexivity. @@ -2395,7 +2250,7 @@ Module Rows. Lemma length_mul base n m p q : length p = n -> length q = n -> length (fst (mul base n m p q)) = m. - Proof using wprops. solver; distr_length. Qed. + Proof using wprops. solver; cbn [fst snd]; distr_length. Qed. Lemma eval_sat_reduce base s c p : base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 -> @@ -2426,7 +2281,8 @@ Module Rows. + weight n * (snd (mulmod base s c n nreductions p q))) mod (s - Associational.eval c) = (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c). Proof using wprops. - solver. + solver. cbn [fst snd]. + rewrite eval_partition by auto. rewrite <-Z.div_mod'' by auto. autorewrite with push_eval; reflexivity. Qed. @@ -2438,8 +2294,7 @@ Module Rows. End Rows. Hint Rewrite length_from_columns using eassumption : distr_length. Hint Rewrite length_sum_rows using solve [ reflexivity | eassumption | distr_length; eauto ] : distr_length. - Hint Rewrite length_fst_extract_row length_snd_extract_row length_flatten length_flatten' length_partition length_fst_from_columns' length_snd_from_columns' : distr_length. - Hint Rewrite @eval_partition using auto : push_eval. + Hint Rewrite length_fst_extract_row length_snd_extract_row length_flatten length_fst_from_columns' length_snd_from_columns' : distr_length. End Rows. Module BaseConversion. @@ -2527,7 +2382,7 @@ Module BaseConversion. from_associational idxs n p = partition sw n (Associational.eval p). Proof using dwprops swprops. intros. cbv [from_associational]. - rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational. + rewrite Rows.flatten_correct with (n:=n) by eauto using Rows.length_from_associational. rewrite Associational.bind_snd_correct. push_eval. Qed. @@ -2657,7 +2512,7 @@ Module BaseConversion. { subst widemul_inlined_reverse; reflexivity. } { rewrite from_associational_inlined_correct by (subst nout; auto). cbv [from_associational]. - rewrite !Rows.flatten_partitions' by eauto using Rows.length_from_associational. + rewrite !Rows.flatten_correct by eauto using Rows.length_from_associational. rewrite !Rows.eval_from_associational by (subst nout; auto). f_equal. rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto. @@ -2911,7 +2766,7 @@ Section freeze_mod_ops. Proof using Hn_nz limbwidth_good. clear -Hn_nz limbwidth_good. intros; cbv [to_bytes]. - rewrite Rows.flatten_partitions' by eauto using wprops, Rows.length_from_associational. + rewrite Rows.flatten_correct by eauto using wprops, Rows.length_from_associational. rewrite Rows.eval_from_associational by eauto using bytes_nz with omega. rewrite eval_to_associational. cbv [to_bytes']. |