diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-08 14:33:04 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | 25137d3f403a612adb8b948278b1f35ca3682638 (patch) | |
tree | b203e61f3e9d04b4d400c9731beff039e746ee28 /src | |
parent | a4fd850b6d41a20e0282685186216ea82c707ce8 (diff) |
automate some Rows proofs
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 318 |
1 files changed, 169 insertions, 149 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 4ca2e89be..38a779f8c 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -945,7 +945,41 @@ Module Columns. End Columns. End Columns. +Module DivMod. + Definition is_div_mod {T} (evalf : T -> Z) dm y n := + evalf (fst dm) = y mod n /\ snd dm = y / n. + + Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x : + n1 > 0 -> + 0 < n2 / n1 -> + n2 mod n1 = 0 -> + evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) -> + snd dm2 = (snd dm1 + x) / (n2 / n1) -> + y2 = y1 + n1 * x -> + @is_div_mod T evalf1 dm1 y1 n1 -> + @is_div_mod T evalf2 dm2 y2 n2. + Proof. + intros; subst y2; cbv [is_div_mod] in *. + repeat match goal with + | H: _ /\ _ |- _ => destruct H + | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end + | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end + | _ => rewrite (Columns.flatten_mod_step (fun _ => 0)) by omega + | _ => rewrite (Columns.flatten_div_step (fun _ => 0)) by omega + | _ => rewrite Z.mul_div_eq_full by omega + end. + split; f_equal; omega. + Qed. + + Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n : + y1 = y2 -> + @is_div_mod T evalf dm y1 n -> + @is_div_mod T evalf dm y2 n. + Proof. congruence. Qed. +End DivMod. + Module Rows. + Import DivMod. Section Rows. Context (weight : nat->Z) {weight_0 : weight 0%nat = 1} @@ -960,6 +994,7 @@ Module Rows. Hint Rewrite Positional.eval_nil Positional.eval0 @Positional.eval_snoc Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval. Hint Resolve in_eq in_cons. + Hint Resolve Z.gt_lt. Definition eval n (inp : rows) := sum (map (Positional.eval weight n) inp). @@ -976,6 +1011,16 @@ Module Rows. Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed. Hint Rewrite eval_app : push_eval. + Ltac In_cases := + repeat match goal with + | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H + | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H + | H: In _ nil |- _ => contradiction H + | H: forall x, In x (?y :: ?ls) -> ?P |- _ => + unique pose proof (H y ltac:(apply in_eq)); + unique assert (forall x, In x ls -> P) by auto + end. + Section FromAssociational. (* extract row *) Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp). @@ -1042,13 +1087,6 @@ Module Rows. Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])). - Local Ltac In_cases := - repeat match goal with - | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H - | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H - | H: In _ nil |- _ => contradiction H - end. - Lemma eval_from_columns'_with_length m st n: (length (fst st) = n) -> length (fst (from_columns' m st)) = n /\ @@ -1157,62 +1195,66 @@ Module Rows. (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)). Definition sum_rows := sum_rows' (nil,0). + Ltac push := + repeat match goal with + | _ => progress cbv [Let_In] + | _ => rewrite Nat.add_1_r + | _ => erewrite Positional.eval_snoc by eauto + | H : length _ = _ |- _ => rewrite H + | H: 0%nat = _ |- _ => rewrite <-H + | p := _ |- _ => subst p + | _ => progress autorewrite with cancel_pair natsimplify push_sum_rows list + | _ => progress distr_length + | _ => ring + | _ => solve [ repeat (f_equal; try ring) ] + | _ => tauto + | _ => solve [eauto] + end. + + Lemma sum_rows'_cons state x1 row1 x2 row2 : + sum_rows' state (x1 :: row1) (x2 :: row2) = + sum_rows' (fst state ++ [(snd state + x1 + x2) mod (fw (length (fst state)))], (snd state + x1 + x2) / fw (length (fst state))) row1 row2. + Proof. + cbv [sum_rows' Let_In]; autorewrite with push_combine. + rewrite !fold_left_rev_right. cbn [fold_left]. + autorewrite with cancel_pair to_div_mod. congruence. + Qed. + + Lemma sum_rows'_nil state : + sum_rows' state nil nil = state. + Proof. reflexivity. Qed. + + Hint Rewrite sum_rows'_cons sum_rows'_nil : push_sum_rows. + Lemma sum_rows'_div_mod_length row1 : - forall n m start_state row2 row1' row2', - length (fst start_state) = m -> length row1 = n -> length row2 = n -> - length row1' = m -> length row2' = m -> - let nm : nat := (n + m)%nat in + forall nm start_state row2 row1' row2', + let m := length (fst start_state) in + let n := length row1 in + length row2 = n -> + length row1' = m -> + length row2' = m -> + (nm = n + m)%nat -> let eval := Positional.eval weight in - eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) -> - snd start_state = (eval m row1' + eval m row2') / weight m -> - length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat - /\ eval nm (fst (sum_rows' start_state row1 row2)) - = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm) - /\ snd (sum_rows' start_state row1 row2) - = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm). + is_div_mod (eval m) start_state (eval m row1' + eval m row2') (weight m) -> + length (fst (sum_rows' start_state row1 row2)) = nm + /\ is_div_mod (eval nm) (sum_rows' start_state row1 row2) + (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) + (weight nm). Proof. - cbv [sum_rows' Let_In]. - induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *; - destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; - repeat match goal with - | _ => progress autorewrite with natsimplify list - | _ => progress cbn [fold_left combine] - | _ => omega - end. - - specialize (IHrow1 (pred n) (S m)). - replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. + induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push. rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). - rewrite <-fold_left_rev_right. - apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; - repeat match goal with - | H : ?LHS = _ |- _ => - match LHS with context [start_state] => rewrite H end - | H : length _ = _ |- _ => rewrite H - | _ => rewrite <-Z.div_add by auto - | _ => rewrite Z.div_div by auto using Z.gt_lt - | _ => rewrite Z.mul_div_eq by auto - | _ => rewrite weight_multiples - | _ => erewrite Positional.eval_snoc by eauto - | _ => progress autorewrite with cancel_pair distr_length to_div_mod in * - | |- context [ ?x mod ?m + ?m * (((?x + ?a * ?m + ?b * ?m)/ ?m) mod ?c) ] => - replace (x mod m) with ((x + a * m + b * m) mod m) by - (autorewrite with zsimplify; ring); - rewrite <-Z.rem_mul_r by auto using Z.gt_lt - | _ => f_equal; ring - end. + apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega. + eapply is_div_mod_step with (x := x1 + x2); try eassumption; push. Qed. Lemma sum_rows_div_mod n row1 row2 : length row1 = n -> length row2 = n -> let eval := Positional.eval weight in - eval n (fst (sum_rows row1 row2)) = (eval n row1 + eval n row2) mod (weight n) - /\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n). + is_div_mod (eval n) (sum_rows row1 row2) (eval n row1 + eval n row2) (weight n). Proof. - cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ]; - try erewrite Hmod, Hdiv; auto using nil_length0; - autorewrite with cancel_pair push_eval zsimplify_fast; distr_length. + cbv [sum_rows]; intros. + apply sum_rows'_div_mod_length with (row1':=nil) (row2':=nil); + cbv [is_div_mod]; autorewrite with cancel_pair push_eval zsimplify; distr_length. Qed. Lemma sum_rows_mod n row1 row2 : @@ -1295,21 +1337,23 @@ Module Rows. rewrite Z.div_0_l by auto; omega. Qed. - Lemma length_sum_rows row1 row2 n : + Lemma length_sum_rows row1 row2 n: length row1 = n -> length row2 = n -> length (fst (sum_rows row1 row2)) = n. Proof. - cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - eapply sum_rows'_div_mod_length; auto using nil_length0. + cbv [sum_rows]; intros. + eapply sum_rows'_div_mod_length; cbv [is_div_mod]; + autorewrite with cancel_pair; distr_length; auto using nil_length0. Qed. Hint Rewrite length_sum_rows : distr_length. End SumRows. + Hint Resolve length_sum_rows. Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval. Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z := fold_right (fun next_row (state : list Z * Z)=> let out_carry := sum_rows next_row (fst state) in - (fst out_carry, snd state + snd out_carry)) start_state inp. + (fst out_carry, snd state + snd out_carry)) start_state (rev inp). (* For correctness if there is only one row, we add a row of zeroes with the same length so that the add loop still happens. *) @@ -1317,51 +1361,54 @@ Module Rows. let first_row := hd nil inp in flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)). + (* TODO : move to ListUtil *) + Lemma rev_cons {A} x ls : @rev A (x :: ls) = rev ls ++ [x]. Proof. reflexivity. Qed. + Hint Rewrite @rev_cons : list. + (* TODO: move to ListUtil *) - (* connect state to remaining input *) - Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A): - forall bs a0 - (Pnil : P nil a0) - (IHfold: - forall bs' b a', - In b bs -> - a' = fold_right f a0 bs' -> - P bs' a' -> - P (b :: bs') (f b a')), - P bs (fold_right f a0 bs). - Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed. - - Lemma flatten'_div_mod_length n start_state inp: + Lemma fold_right_snoc {A B} f a x ls: + @fold_right A B f a (ls ++ [x]) = fold_right f (f x a) ls. + Proof. + rewrite <-(rev_involutive ls), <-rev_cons. + rewrite !fold_left_rev_right; reflexivity. + Qed. + Hint Rewrite @fold_right_snoc : push_fold_right. + + Lemma flatten'_cons state r inp : + flatten' state (r :: inp) = flatten' (fst (sum_rows r (fst state)), snd state + snd (sum_rows r (fst state))) inp. + Proof. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed. + Lemma flatten'_nil state : flatten' state [] = state. Proof. reflexivity. Qed. + Hint Rewrite flatten'_cons flatten'_nil : push_flatten. + + Ltac push := + repeat match goal with + | _ => progress intros + | H: length ?x = ?n |- context [snd (sum_rows ?x _)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto) + | H: length ?x = ?n |- context [snd (sum_rows _ ?x)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto) + | H: length _ = _ |- _ => rewrite H + | _ => progress autorewrite with cancel_pair push_flatten push_eval distr_length zsimplify_fast + | _ => progress In_cases + | |- _ /\ _ => split + | |- context [?x mod ?y] => unique pose proof (Z.mul_div_eq_full x y ltac:(auto)); lia + | _ => solve [repeat (f_equal; try ring)] + | _ => congruence + | _ => solve [eauto] + end. + + Lemma flatten'_div_mod_length n inp : forall start_state, length (fst start_state) = n -> (forall row, In row inp -> length row = n) -> length (fst (flatten' start_state inp)) = n /\ (inp <> nil -> - Positional.eval weight n (fst (flatten' start_state inp)) = (Positional.eval weight n (fst start_state) + eval n inp) mod (weight n) - /\ snd (flatten' start_state inp) = snd start_state + (Positional.eval weight n (fst start_state) + eval n inp) / weight n). + is_div_mod (Positional.eval weight n) (flatten' start_state inp) + (Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state) + (weight n)). Proof. - intro. - cbv [flatten']. - apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ]; - repeat match goal with - | _ => subst a' - | _ => subst bs' - | _ => rewrite sum_rows_div with (n:=n) by auto - | _ => rewrite @fold_right_nil in * - | _ => progress autorewrite with cancel_pair push_eval in * - | H : _ -> _ /\ _ |- _ => - let X := fresh in let Y := fresh in - destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ] - | _ => split - | _ => progress autorewrite with pull_Zmod zsimplify_fast - | _ => solve [ auto using length_sum_rows ] - | _ => solve [ ring_simplify; repeat (f_equal; try ring) ] - end. - apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |]. - autorewrite with push_Zmul. - rewrite !Z.mul_div_eq_full by apply weight_nonzero. - autorewrite with pull_Zmod. - ring_simplify_subterms. ring_simplify. - repeat (f_equal; try ring). + induction inp; push; [apply IHinp; push|]. + destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod] + | eapply is_div_mod_result_equal; try apply IHinp]; push. + { autorewrite with zsimplify; push. } + { autorewrite with zsimplify; push. } Qed. Hint Rewrite (@Positional.length_zeros weight) : distr_length. @@ -1369,22 +1416,11 @@ Module Rows. Lemma flatten_div_mod inp n : (forall row, In row inp -> length row = n) -> - Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n) - /\ snd (flatten inp) = (eval n inp) / weight n. + is_div_mod (Positional.eval weight n) (flatten inp) (eval n inp) (weight n). Proof. intros; cbv [flatten]. - destruct inp; [|destruct inp]; cbn [hd tl]. - { cbn. autorewrite with push_eval. tauto. } - { cbn. - match goal with H : forall r, In r [?x] -> _ |- _ => - specialize (H x ltac:(auto)); rewrite H - end. - rewrite sum_rows_div with (n:=n) by distr_length. - autorewrite with push_eval. - split; f_equal; ring. } - { autorewrite with push_eval. - apply flatten'_div_mod_length; auto. - congruence. } + destruct inp; [|destruct inp]; cbn [hd tl]; try solve [cbv [is_div_mod]; push]. + eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push. Qed. Lemma flatten_mod inp n : @@ -1408,25 +1444,24 @@ Module Rows. inp <> nil -> length (fst (flatten inp)) = n. Proof. - intros. - destruct inp; [congruence |]; destruct inp; - repeat match goal with - | _ => progress intros - | _ => progress cbn [hd tl] in * - | _ => progress autorewrite with cancel_pair - | _ => apply flatten'_div_mod_length - | H: _ |- _ => apply in_inv in H; destruct H - | _ => solve [auto] - end; - subst row; distr_length; auto. + intros. apply flatten'_div_mod_length; push; + destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push. + subst row; distr_length; auto. Qed. Hint Rewrite length_flatten : distr_length. - Lemma flatten'_cons state x inp : - flatten' state (x :: inp) - = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))). - Proof. reflexivity. Qed. + (* TODO: move to ZUtil *) + Lemma add_mod_l_multiple a b n m: + 0 < n / m -> m <> 0 -> n mod m = 0 -> + (a mod n + b) mod m = (a + b) mod m. + Proof. + intros. + rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto. + rewrite Z.rem_mul_r by auto. + push_Zmod. autorewrite with zsimplify. + pull_Zmod. reflexivity. + Qed. - Lemma flatten'_partitions n start_state inp: + Lemma flatten'_partitions n inp : forall start_state, length (fst start_state) = n -> (forall row, In row inp -> length row = n) -> inp <> nil -> @@ -1434,18 +1469,11 @@ Module Rows. nth_default 0 (fst (flatten' start_state inp)) i = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i). Proof. - intro. - induction inp; intros; [congruence|]. - rewrite flatten'_cons. - autorewrite with push_fold_right cancel_pair. - rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto). + induction inp; push. destruct (dec (inp = nil)). - { subst inp. cbn. repeat (f_equal; try ring). } - { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto. - rewrite Hmod. autorewrite with push_eval. - rewrite Z.add_mod_full. erewrite @Columns.weight_div_mod with (j:=n) (i:=S i) by eauto. - rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full. - autorewrite with zsimplify. pull_Zmod. + { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. } + { erewrite IHinp; push. + rewrite add_mod_l_multiple by auto using weight_divides_full, Columns.weight_multiples_full. repeat (f_equal; try ring). } Qed. @@ -1455,18 +1483,10 @@ Module Rows. nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i). Proof. intros; cbv [flatten]. - destruct inp; [|destruct inp]; cbn [hd tl]. - { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. } - { cbn. - match goal with H : forall r, In r [?x] -> _ |- _ => - specialize (H x ltac:(auto)); rewrite H - end. - rewrite sum_rows_partitions with (n:=n) by distr_length. - autorewrite with push_eval. - repeat (f_equal; try ring). } - { autorewrite with push_eval. - apply flatten'_partitions; auto. - congruence. } + intros; destruct inp as [| ? [| ? ?] ]; try congruence; cbn [hd tl] in *; try solve [push]. + { cbn. autorewrite with push_nth_default. reflexivity. } + { push. rewrite sum_rows_partitions with (n:=n) by distr_length; push. } + { rewrite flatten'_partitions with (n:=n); push. } Qed. End Flatten. |