diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-07 15:51:49 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | 8690fc7b8054901178a0756d1fb6f47342a2bd55 (patch) | |
tree | 192844dc24b5b0ae0c67e07ab98c209a3830be4f /src | |
parent | 0be8c543e56273ba58d919836e4053e52e87b90a (diff) |
organize proofs into sections
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 427 |
1 files changed, 216 insertions, 211 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 61df7b12f..48661f4c6 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -679,35 +679,29 @@ Module Columns. apply Positional.eval_snoc; distr_length. Qed. Hint Rewrite eval_snoc using (solve [distr_length]) : push_eval. - Section flatten_column. - Context (fw : Z). (* maximum size of the result *) - - (* Outputs (sum, carry) *) - Definition flatten_column (digit: list Z) : (Z * Z) := - list_rect (fun _ => (Z * Z)%type) (0,0) - (fun xx tl flatten_column_tl => - list_rect - (fun _ => (Z * Z)%type) (xx mod fw, xx / fw) - (fun yy tl' _ => - list_rect - (fun _ => (Z * Z)%type) (dlet_nd x := xx in dlet_nd y := yy in Z.add_get_carry_full fw x y) - (fun _ _ _ => - dlet_nd x := xx in - dlet_nd rec := flatten_column_tl in (* recursively get the sum and carry *) - dlet_nd sum_carry := Z.add_get_carry_full fw x (fst rec) in (* add the new value to the sum *) - dlet_nd carry' := snd sum_carry + snd rec in (* add the two carries together *) - (fst sum_carry, carry')) - tl') - tl) - digit. - End flatten_column. - - Definition flatten_step (digit:list Z) (acc_carry:list Z * Z) : list Z * Z := - dlet sum_carry := flatten_column (weight (S (length (fst acc_carry))) / weight (length (fst acc_carry))) (snd acc_carry::digit) in - (fst acc_carry ++ fst sum_carry :: nil, snd sum_carry). - - Definition flatten (xs : list (list Z)) : list Z * Z := - fold_right (fun a b => flatten_step a b) (nil,0) (rev xs). + (* TODO: move out of Columns? *) + Section Weight. + Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0. + Proof. + induction j; intros; + repeat match goal with + | _ => rewrite Nat.add_succ_r + | _ => rewrite IHj + | |- context [weight (S ?x) mod weight _] => + rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto + | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast + | _ => reflexivity + end. + Qed. + Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0. + Proof. + intros; replace j with (i + (j - i))%nat by omega. + apply weight_multiples_full'. + Qed. + + Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i). + Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed. + End Weight. Lemma list_rect_to_match A (P:list A -> Type) (Pnil: P []) (PS: forall a tl, P (a :: tl)) ls : @list_rect A P Pnil (fun a tl _ => PS a tl) ls = match ls with @@ -718,17 +712,47 @@ Module Columns. Hint Rewrite <- Z.div_add' using omega : pull_Zdiv. - - Local Ltac cases := + Ltac cases := match goal with | |- _ /\ _ => split | H: _ /\ _ |- _ => destruct H | H: _ \/ _ |- _ => destruct H | _ => progress break_match; try discriminate end. - - Local Ltac push_fast := - repeat match goal with + + Section Flatten. + Section flatten_column. + Context (fw : Z). (* maximum size of the result *) + + (* Outputs (sum, carry) *) + Definition flatten_column (digit: list Z) : (Z * Z) := + list_rect (fun _ => (Z * Z)%type) (0,0) + (fun xx tl flatten_column_tl => + list_rect + (fun _ => (Z * Z)%type) (xx mod fw, xx / fw) + (fun yy tl' _ => + list_rect + (fun _ => (Z * Z)%type) (dlet_nd x := xx in dlet_nd y := yy in Z.add_get_carry_full fw x y) + (fun _ _ _ => + dlet_nd x := xx in + dlet_nd rec := flatten_column_tl in (* recursively get the sum and carry *) + dlet_nd sum_carry := Z.add_get_carry_full fw x (fst rec) in (* add the new value to the sum *) + dlet_nd carry' := snd sum_carry + snd rec in (* add the two carries together *) + (fst sum_carry, carry')) + tl') + tl) + digit. + End flatten_column. + + Definition flatten_step (digit:list Z) (acc_carry:list Z * Z) : list Z * Z := + dlet sum_carry := flatten_column (weight (S (length (fst acc_carry))) / weight (length (fst acc_carry))) (snd acc_carry::digit) in + (fst acc_carry ++ fst sum_carry :: nil, snd sum_carry). + + Definition flatten (xs : list (list Z)) : list Z * Z := + fold_right (fun a b => flatten_step a b) (nil,0) (rev xs). + + Ltac push_fast := + repeat match goal with | _ => progress cbv [Let_In] | |- context [list_rect _ _ _ ?ls] => rewrite list_rect_to_match; destruct ls | _ => progress (unfold flatten_step in *; fold flatten_step in * ) @@ -740,195 +764,176 @@ Module Columns. | _ => congruence | _ => progress cases end. - Local Ltac push := - repeat match goal with - | _ => progress push_fast - | _ => progress autorewrite with cancel_pair to_div_mod - | _ => progress autorewrite with push_sum push_fold_right push_nth_default in * - | _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast - | _ => progress autorewrite with list distr_length push_eval - end. - - - - Lemma flatten_column_mod fw (xs : list Z) : - fst (flatten_column fw xs) = sum xs mod fw. - Proof. - induction xs; simpl flatten_column; cbv [Let_In]; + Ltac push := repeat match goal with - | _ => rewrite IHxs - | _ => progress push + | _ => progress push_fast + | _ => progress autorewrite with cancel_pair to_div_mod + | _ => progress autorewrite with push_sum push_fold_right push_nth_default in * + | _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast + | _ => progress autorewrite with list distr_length push_eval end. - Qed. Hint Rewrite flatten_column_mod : to_div_mod. + + Lemma flatten_column_mod fw (xs : list Z) : + fst (flatten_column fw xs) = sum xs mod fw. + Proof. + induction xs; simpl flatten_column; cbv [Let_In]; + repeat match goal with + | _ => rewrite IHxs + | _ => progress push + end. + Qed. Hint Rewrite flatten_column_mod : to_div_mod. - Lemma flatten_column_div fw (xs : list Z) (fw_nz : fw <> 0) : - snd (flatten_column fw xs) = sum xs / fw. - Proof. - induction xs; simpl flatten_column; cbv [Let_In]; - repeat match goal with - | _ => rewrite IHxs - | _ => rewrite Z.mul_div_eq_full by omega - | _ => progress push - end. - Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod. - - (* helper for some of the modular logic in flatten *) - Lemma flatten_mod_step a b c d: 0 < a -> 0 < b -> - c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b). - Proof. intros; rewrite Z.rem_mul_r by omega. push_Zmod. push. Qed. - - Lemma flatten_div_step a b c d : 0 < a -> 0 < b -> - (c / a + d) / b = (a * d + c) / (a * b). - Proof. intros; push. Qed. - - Hint Rewrite Positional.eval_nil : push_eval. - Hint Resolve Z.gt_lt. - - Lemma length_flatten_step digit state : - length (fst (flatten_step digit state)) = S (length (fst state)). - Proof. cbv [flatten_step]; push. Qed. - Hint Rewrite length_flatten_step : distr_length. - Lemma length_flatten inp : length (fst (flatten inp)) = length inp. - Proof. cbv [flatten]. induction inp using rev_ind; push. Qed. - Hint Rewrite length_flatten : distr_length. - - Lemma flatten_div_mod n inp : - length inp = n -> - (Positional.eval weight n (fst (flatten inp)) - = (eval n inp) mod (weight n)) - /\ (snd (flatten inp) = eval n inp / weight n). - Proof. - (* to make the invariant take the right form, we make everything depend on output length, not input length *) - intro. subst n. rewrite <-(length_flatten inp). cbv [flatten]. - induction inp using rev_ind; intros; [push|]. - repeat match goal with - | _ => rewrite Nat.add_1_r - | _ => progress (fold (flatten inp) in * ) - | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity) - | H: _ = _ mod (weight _) |- _ => rewrite H - | H: _ = _ / (weight _) |- _ => rewrite H - | _ => progress rewrite ?flatten_mod_step, ?flatten_div_step by auto - | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval - | _ => progress (distr_length; push_fast) - end. - Qed. + Lemma flatten_column_div fw (xs : list Z) (fw_nz : fw <> 0) : + snd (flatten_column fw xs) = sum xs / fw. + Proof. + induction xs; simpl flatten_column; cbv [Let_In]; + repeat match goal with + | _ => rewrite IHxs + | _ => rewrite Z.mul_div_eq_full by omega + | _ => progress push + end. + Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod. - Lemma flatten_mod {n} inp : - length inp = n -> - (Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)). - Proof. apply flatten_div_mod. Qed. - Hint Rewrite @flatten_mod : push_eval. + (* helper for some of the modular logic in flatten *) + Lemma flatten_mod_step a b c d: 0 < a -> 0 < b -> + c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b). + Proof. intros; rewrite Z.rem_mul_r by omega. push_Zmod. push. Qed. - Lemma flatten_div {n} inp : - length inp = n -> snd (flatten inp) = eval n inp / weight n. - Proof. apply flatten_div_mod. Qed. - Hint Rewrite @flatten_div : push_eval. + Lemma flatten_div_step a b c d : 0 < a -> 0 < b -> + (c / a + d) / b = (a * d + c) / (a * b). + Proof. intros; push. Qed. - (* nils *) - Definition nils n : list (list Z) := List.repeat nil n. - Lemma length_nils n : length (nils n) = n. Proof. cbv [nils]. distr_length. Qed. - Hint Rewrite length_nils : distr_length. - Lemma eval_nils n : eval n (nils n) = 0. - Proof. - erewrite <-Positional.eval_zeros by eauto. - cbv [eval nils]; rewrite List.map_repeat; reflexivity. - Qed. Hint Rewrite eval_nils : push_eval. - - (* cons_to_nth *) - Definition cons_to_nth i x (xs : list (list Z)) : list (list Z) := - ListUtil.update_nth i (fun y => cons x y) xs. - Lemma length_cons_to_nth i x xs : length (cons_to_nth i x xs) = length xs. - Proof. cbv [cons_to_nth]. distr_length. Qed. - Hint Rewrite length_cons_to_nth : distr_length. - Lemma cons_to_nth_add_to_nth xs : forall i x, - map sum (cons_to_nth i x xs) = Positional.add_to_nth i x (map sum xs). - Proof. - cbv [cons_to_nth]; induction xs as [|? ? IHxs]; - intros i x; destruct i; simpl; rewrite ?IHxs; reflexivity. - Qed. - Lemma eval_cons_to_nth n i x xs : (i < length xs)%nat -> length xs = n -> - eval n (cons_to_nth i x xs) = weight i * x + eval n xs. - Proof using Type. - cbv [eval]; intros. rewrite cons_to_nth_add_to_nth. - apply Positional.eval_add_to_nth; distr_length. - Qed. Hint Rewrite eval_cons_to_nth using (solve [distr_length]) : push_eval. - - Hint Rewrite Positional.eval_zeros : push_eval. - Hint Rewrite Positional.length_from_associational : distr_length. - Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval. - - (* from_associational *) - Definition from_associational n (p:list (Z*Z)) : list (list Z) := - List.fold_right (fun t ls => - let p := Positional.place weight t (pred n) in - cons_to_nth (fst p) (snd p) ls ) (nils n) p. - Lemma length_from_associational n p : length (from_associational n p) = n. - Proof. cbv [from_associational]. apply fold_right_invariant; intros; distr_length. Qed. - Hint Rewrite length_from_associational: distr_length. - Lemma eval_from_associational n p (n_nonzero:n<>0%nat\/p=nil): - eval n (from_associational n p) = Associational.eval p. - Proof. - erewrite <-Positional.eval_from_associational by eauto. - induction p; push. - cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right. - fold (from_associational n p); fold (Positional.from_associational weight n p). - match goal with |- context [Positional.place _ ?x ?n] => - pose proof (Positional.place_in_range weight x n) end. - repeat match goal with - | _ => rewrite Nat.succ_pred in * by auto - | _ => rewrite IHp by auto - | _ => progress push - end. - Qed. + Hint Rewrite Positional.eval_nil : push_eval. + Hint Resolve Z.gt_lt. - Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp). - Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed. + Lemma length_flatten_step digit state : + length (fst (flatten_step digit state)) = S (length (fst state)). + Proof. cbv [flatten_step]; push. Qed. + Hint Rewrite length_flatten_step : distr_length. + Lemma length_flatten inp : length (fst (flatten inp)) = length inp. + Proof. cbv [flatten]. induction inp using rev_ind; push. Qed. + Hint Rewrite length_flatten : distr_length. - Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0. - Proof. - induction j; intros; + Lemma flatten_div_mod n inp : + length inp = n -> + (Positional.eval weight n (fst (flatten inp)) + = (eval n inp) mod (weight n)) + /\ (snd (flatten inp) = eval n inp / weight n). + Proof. + (* to make the invariant take the right form, we make everything depend on output length, not input length *) + intro. subst n. rewrite <-(length_flatten inp). cbv [flatten]. + induction inp using rev_ind; intros; [push|]. repeat match goal with - | _ => rewrite Nat.add_succ_r - | _ => rewrite IHj - | |- context [weight (S ?x) mod weight _] => - rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto - | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast - | _ => reflexivity + | _ => rewrite Nat.add_1_r + | _ => progress (fold (flatten inp) in * ) + | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity) + | H: _ = _ mod (weight _) |- _ => rewrite H + | H: _ = _ / (weight _) |- _ => rewrite H + | _ => progress rewrite ?flatten_mod_step, ?flatten_div_step by auto + | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval + | _ => progress (distr_length; push_fast) end. - Qed. - - Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0. - Proof. - intros; replace j with (i + (j - i))%nat by omega. - apply weight_multiples_full'. - Qed. + Qed. - Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i). - Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed. + Lemma flatten_mod {n} inp : + length inp = n -> + (Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)). + Proof. apply flatten_div_mod. Qed. + Hint Rewrite @flatten_mod : push_eval. - (* TODO: move to ZUtil *) - Lemma Z_divide_div_mul_exact' a b c : b <> 0 -> (b | a) -> a * c / b = c * (a / b). - Proof. intros. rewrite Z.mul_comm. auto using Z.divide_div_mul_exact. Qed. + Lemma flatten_div {n} inp : + length inp = n -> snd (flatten inp) = eval n inp / weight n. + Proof. apply flatten_div_mod. Qed. + Hint Rewrite @flatten_div : push_eval. - Lemma flatten_partitions inp: - forall n i, length inp = n -> (i < n)%nat -> - nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i. - Proof. - induction inp using rev_ind; intros; destruct n; distr_length. - rewrite flatten_snoc. - push; distr_length; - [rewrite IHinp with (n:=n) by omega; rewrite (weight_div_mod n (S i)) by omega; push_Zmod; push |]. - repeat match goal with - | _ => progress replace (length inp) with n by omega - | _ => progress replace i with n by omega - | _ => progress push - | _ => erewrite flatten_div by eauto - | _ => rewrite <-Z.div_add' by auto - | _ => rewrite Z.mul_div_eq' by auto - | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl - | _ => progress autorewrite with push_nth_default natsimplify - end. - Qed. + Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp). + Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed. + + (* TODO: move to ZUtil *) + Lemma Z_divide_div_mul_exact' a b c : b <> 0 -> (b | a) -> a * c / b = c * (a / b). + Proof. intros. rewrite Z.mul_comm. auto using Z.divide_div_mul_exact. Qed. + + Lemma flatten_partitions inp: + forall n i, length inp = n -> (i < n)%nat -> + nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i. + Proof. + induction inp using rev_ind; intros; destruct n; distr_length. + rewrite flatten_snoc. + push; distr_length; + [rewrite IHinp with (n:=n) by omega; rewrite (weight_div_mod n (S i)) by omega; push_Zmod; push |]. + repeat match goal with + | _ => progress replace (length inp) with n by omega + | _ => progress replace i with n by omega + | _ => progress push + | _ => erewrite flatten_div by eauto + | _ => rewrite <-Z.div_add' by auto + | _ => rewrite Z.mul_div_eq' by auto + | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl + | _ => progress autorewrite with push_nth_default natsimplify + end. + Qed. + End Flatten. + + Section FromAssociational. + (* nils *) + Definition nils n : list (list Z) := List.repeat nil n. + Lemma length_nils n : length (nils n) = n. Proof. cbv [nils]. distr_length. Qed. + Hint Rewrite length_nils : distr_length. + Lemma eval_nils n : eval n (nils n) = 0. + Proof. + erewrite <-Positional.eval_zeros by eauto. + cbv [eval nils]; rewrite List.map_repeat; reflexivity. + Qed. Hint Rewrite eval_nils : push_eval. + + (* cons_to_nth *) + Definition cons_to_nth i x (xs : list (list Z)) : list (list Z) := + ListUtil.update_nth i (fun y => cons x y) xs. + Lemma length_cons_to_nth i x xs : length (cons_to_nth i x xs) = length xs. + Proof. cbv [cons_to_nth]. distr_length. Qed. + Hint Rewrite length_cons_to_nth : distr_length. + Lemma cons_to_nth_add_to_nth xs : forall i x, + map sum (cons_to_nth i x xs) = Positional.add_to_nth i x (map sum xs). + Proof. + cbv [cons_to_nth]; induction xs as [|? ? IHxs]; + intros i x; destruct i; simpl; rewrite ?IHxs; reflexivity. + Qed. + Lemma eval_cons_to_nth n i x xs : (i < length xs)%nat -> length xs = n -> + eval n (cons_to_nth i x xs) = weight i * x + eval n xs. + Proof using Type. + cbv [eval]; intros. rewrite cons_to_nth_add_to_nth. + apply Positional.eval_add_to_nth; distr_length. + Qed. Hint Rewrite eval_cons_to_nth using (solve [distr_length]) : push_eval. + + Hint Rewrite Positional.eval_zeros : push_eval. + Hint Rewrite Positional.length_from_associational : distr_length. + Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval. + + (* from_associational *) + Definition from_associational n (p:list (Z*Z)) : list (list Z) := + List.fold_right (fun t ls => + let p := Positional.place weight t (pred n) in + cons_to_nth (fst p) (snd p) ls ) (nils n) p. + Lemma length_from_associational n p : length (from_associational n p) = n. + Proof. cbv [from_associational]. apply fold_right_invariant; intros; distr_length. Qed. + Hint Rewrite length_from_associational: distr_length. + Lemma eval_from_associational n p (n_nonzero:n<>0%nat\/p=nil): + eval n (from_associational n p) = Associational.eval p. + Proof. + erewrite <-Positional.eval_from_associational by eauto. + induction p; [ autorewrite with push_eval; congruence |]. + cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right. + fold (from_associational n p); fold (Positional.from_associational weight n p). + match goal with |- context [Positional.place _ ?x ?n] => + pose proof (Positional.place_in_range weight x n) end. + repeat match goal with + | _ => rewrite Nat.succ_pred in * by auto + | _ => rewrite IHp by auto + | _ => progress autorewrite with push_eval + | _ => progress cases + | _ => congruence + end. + Qed. + End FromAssociational. Section mul. Definition mul s n m (p q : list Z) : list Z := @@ -951,9 +956,9 @@ Module Rows. Local Notation rows := (list (list Z)) (only parsing). Local Notation cols := (list (list Z)) (only parsing). + Hint Rewrite Positional.eval_nil Positional.eval0 @Positional.eval_snoc Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval. - Hint Resolve in_eq in_cons. Definition eval n (inp : rows) := |