aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-08 14:33:04 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commit25137d3f403a612adb8b948278b1f35ca3682638 (patch)
treeb203e61f3e9d04b4d400c9731beff039e746ee28 /src
parenta4fd850b6d41a20e0282685186216ea82c707ce8 (diff)
automate some Rows proofs
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v318
1 files changed, 169 insertions, 149 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 4ca2e89be..38a779f8c 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -945,7 +945,41 @@ Module Columns.
End Columns.
End Columns.
+Module DivMod.
+ Definition is_div_mod {T} (evalf : T -> Z) dm y n :=
+ evalf (fst dm) = y mod n /\ snd dm = y / n.
+
+ Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x :
+ n1 > 0 ->
+ 0 < n2 / n1 ->
+ n2 mod n1 = 0 ->
+ evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) ->
+ snd dm2 = (snd dm1 + x) / (n2 / n1) ->
+ y2 = y1 + n1 * x ->
+ @is_div_mod T evalf1 dm1 y1 n1 ->
+ @is_div_mod T evalf2 dm2 y2 n2.
+ Proof.
+ intros; subst y2; cbv [is_div_mod] in *.
+ repeat match goal with
+ | H: _ /\ _ |- _ => destruct H
+ | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end
+ | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end
+ | _ => rewrite (Columns.flatten_mod_step (fun _ => 0)) by omega
+ | _ => rewrite (Columns.flatten_div_step (fun _ => 0)) by omega
+ | _ => rewrite Z.mul_div_eq_full by omega
+ end.
+ split; f_equal; omega.
+ Qed.
+
+ Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n :
+ y1 = y2 ->
+ @is_div_mod T evalf dm y1 n ->
+ @is_div_mod T evalf dm y2 n.
+ Proof. congruence. Qed.
+End DivMod.
+
Module Rows.
+ Import DivMod.
Section Rows.
Context (weight : nat->Z)
{weight_0 : weight 0%nat = 1}
@@ -960,6 +994,7 @@ Module Rows.
Hint Rewrite Positional.eval_nil Positional.eval0 @Positional.eval_snoc
Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval.
Hint Resolve in_eq in_cons.
+ Hint Resolve Z.gt_lt.
Definition eval n (inp : rows) :=
sum (map (Positional.eval weight n) inp).
@@ -976,6 +1011,16 @@ Module Rows.
Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed.
Hint Rewrite eval_app : push_eval.
+ Ltac In_cases :=
+ repeat match goal with
+ | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H
+ | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H
+ | H: In _ nil |- _ => contradiction H
+ | H: forall x, In x (?y :: ?ls) -> ?P |- _ =>
+ unique pose proof (H y ltac:(apply in_eq));
+ unique assert (forall x, In x ls -> P) by auto
+ end.
+
Section FromAssociational.
(* extract row *)
Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp).
@@ -1042,13 +1087,6 @@ Module Rows.
Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])).
- Local Ltac In_cases :=
- repeat match goal with
- | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H
- | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H
- | H: In _ nil |- _ => contradiction H
- end.
-
Lemma eval_from_columns'_with_length m st n:
(length (fst st) = n) ->
length (fst (from_columns' m st)) = n /\
@@ -1157,62 +1195,66 @@ Module Rows.
(fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)).
Definition sum_rows := sum_rows' (nil,0).
+ Ltac push :=
+ repeat match goal with
+ | _ => progress cbv [Let_In]
+ | _ => rewrite Nat.add_1_r
+ | _ => erewrite Positional.eval_snoc by eauto
+ | H : length _ = _ |- _ => rewrite H
+ | H: 0%nat = _ |- _ => rewrite <-H
+ | p := _ |- _ => subst p
+ | _ => progress autorewrite with cancel_pair natsimplify push_sum_rows list
+ | _ => progress distr_length
+ | _ => ring
+ | _ => solve [ repeat (f_equal; try ring) ]
+ | _ => tauto
+ | _ => solve [eauto]
+ end.
+
+ Lemma sum_rows'_cons state x1 row1 x2 row2 :
+ sum_rows' state (x1 :: row1) (x2 :: row2) =
+ sum_rows' (fst state ++ [(snd state + x1 + x2) mod (fw (length (fst state)))], (snd state + x1 + x2) / fw (length (fst state))) row1 row2.
+ Proof.
+ cbv [sum_rows' Let_In]; autorewrite with push_combine.
+ rewrite !fold_left_rev_right. cbn [fold_left].
+ autorewrite with cancel_pair to_div_mod. congruence.
+ Qed.
+
+ Lemma sum_rows'_nil state :
+ sum_rows' state nil nil = state.
+ Proof. reflexivity. Qed.
+
+ Hint Rewrite sum_rows'_cons sum_rows'_nil : push_sum_rows.
+
Lemma sum_rows'_div_mod_length row1 :
- forall n m start_state row2 row1' row2',
- length (fst start_state) = m -> length row1 = n -> length row2 = n ->
- length row1' = m -> length row2' = m ->
- let nm : nat := (n + m)%nat in
+ forall nm start_state row2 row1' row2',
+ let m := length (fst start_state) in
+ let n := length row1 in
+ length row2 = n ->
+ length row1' = m ->
+ length row2' = m ->
+ (nm = n + m)%nat ->
let eval := Positional.eval weight in
- eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) ->
- snd start_state = (eval m row1' + eval m row2') / weight m ->
- length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat
- /\ eval nm (fst (sum_rows' start_state row1 row2))
- = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm)
- /\ snd (sum_rows' start_state row1 row2)
- = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm).
+ is_div_mod (eval m) start_state (eval m row1' + eval m row2') (weight m) ->
+ length (fst (sum_rows' start_state row1 row2)) = nm
+ /\ is_div_mod (eval nm) (sum_rows' start_state row1 row2)
+ (eval nm (row1' ++ row1) + eval nm (row2' ++ row2))
+ (weight nm).
Proof.
- cbv [sum_rows' Let_In].
- induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *;
- destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
- repeat match goal with
- | _ => progress autorewrite with natsimplify list
- | _ => progress cbn [fold_left combine]
- | _ => omega
- end.
-
- specialize (IHrow1 (pred n) (S m)).
- replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
+ induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push.
rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
- rewrite <-fold_left_rev_right.
- apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
- repeat match goal with
- | H : ?LHS = _ |- _ =>
- match LHS with context [start_state] => rewrite H end
- | H : length _ = _ |- _ => rewrite H
- | _ => rewrite <-Z.div_add by auto
- | _ => rewrite Z.div_div by auto using Z.gt_lt
- | _ => rewrite Z.mul_div_eq by auto
- | _ => rewrite weight_multiples
- | _ => erewrite Positional.eval_snoc by eauto
- | _ => progress autorewrite with cancel_pair distr_length to_div_mod in *
- | |- context [ ?x mod ?m + ?m * (((?x + ?a * ?m + ?b * ?m)/ ?m) mod ?c) ] =>
- replace (x mod m) with ((x + a * m + b * m) mod m) by
- (autorewrite with zsimplify; ring);
- rewrite <-Z.rem_mul_r by auto using Z.gt_lt
- | _ => f_equal; ring
- end.
+ apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega.
+ eapply is_div_mod_step with (x := x1 + x2); try eassumption; push.
Qed.
Lemma sum_rows_div_mod n row1 row2 :
length row1 = n -> length row2 = n ->
let eval := Positional.eval weight in
- eval n (fst (sum_rows row1 row2)) = (eval n row1 + eval n row2) mod (weight n)
- /\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n).
+ is_div_mod (eval n) (sum_rows row1 row2) (eval n row1 + eval n row2) (weight n).
Proof.
- cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
- edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ];
- try erewrite Hmod, Hdiv; auto using nil_length0;
- autorewrite with cancel_pair push_eval zsimplify_fast; distr_length.
+ cbv [sum_rows]; intros.
+ apply sum_rows'_div_mod_length with (row1':=nil) (row2':=nil);
+ cbv [is_div_mod]; autorewrite with cancel_pair push_eval zsimplify; distr_length.
Qed.
Lemma sum_rows_mod n row1 row2 :
@@ -1295,21 +1337,23 @@ Module Rows.
rewrite Z.div_0_l by auto; omega.
Qed.
- Lemma length_sum_rows row1 row2 n :
+ Lemma length_sum_rows row1 row2 n:
length row1 = n -> length row2 = n ->
length (fst (sum_rows row1 row2)) = n.
Proof.
- cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
- eapply sum_rows'_div_mod_length; auto using nil_length0.
+ cbv [sum_rows]; intros.
+ eapply sum_rows'_div_mod_length; cbv [is_div_mod];
+ autorewrite with cancel_pair; distr_length; auto using nil_length0.
Qed. Hint Rewrite length_sum_rows : distr_length.
End SumRows.
+ Hint Resolve length_sum_rows.
Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval.
Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z :=
fold_right (fun next_row (state : list Z * Z)=>
let out_carry := sum_rows next_row (fst state) in
- (fst out_carry, snd state + snd out_carry)) start_state inp.
+ (fst out_carry, snd state + snd out_carry)) start_state (rev inp).
(* For correctness if there is only one row, we add a row of
zeroes with the same length so that the add loop still happens. *)
@@ -1317,51 +1361,54 @@ Module Rows.
let first_row := hd nil inp in
flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)).
+ (* TODO : move to ListUtil *)
+ Lemma rev_cons {A} x ls : @rev A (x :: ls) = rev ls ++ [x]. Proof. reflexivity. Qed.
+ Hint Rewrite @rev_cons : list.
+
(* TODO: move to ListUtil *)
- (* connect state to remaining input *)
- Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A):
- forall bs a0
- (Pnil : P nil a0)
- (IHfold:
- forall bs' b a',
- In b bs ->
- a' = fold_right f a0 bs' ->
- P bs' a' ->
- P (b :: bs') (f b a')),
- P bs (fold_right f a0 bs).
- Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed.
-
- Lemma flatten'_div_mod_length n start_state inp:
+ Lemma fold_right_snoc {A B} f a x ls:
+ @fold_right A B f a (ls ++ [x]) = fold_right f (f x a) ls.
+ Proof.
+ rewrite <-(rev_involutive ls), <-rev_cons.
+ rewrite !fold_left_rev_right; reflexivity.
+ Qed.
+ Hint Rewrite @fold_right_snoc : push_fold_right.
+
+ Lemma flatten'_cons state r inp :
+ flatten' state (r :: inp) = flatten' (fst (sum_rows r (fst state)), snd state + snd (sum_rows r (fst state))) inp.
+ Proof. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed.
+ Lemma flatten'_nil state : flatten' state [] = state. Proof. reflexivity. Qed.
+ Hint Rewrite flatten'_cons flatten'_nil : push_flatten.
+
+ Ltac push :=
+ repeat match goal with
+ | _ => progress intros
+ | H: length ?x = ?n |- context [snd (sum_rows ?x _)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto)
+ | H: length ?x = ?n |- context [snd (sum_rows _ ?x)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto)
+ | H: length _ = _ |- _ => rewrite H
+ | _ => progress autorewrite with cancel_pair push_flatten push_eval distr_length zsimplify_fast
+ | _ => progress In_cases
+ | |- _ /\ _ => split
+ | |- context [?x mod ?y] => unique pose proof (Z.mul_div_eq_full x y ltac:(auto)); lia
+ | _ => solve [repeat (f_equal; try ring)]
+ | _ => congruence
+ | _ => solve [eauto]
+ end.
+
+ Lemma flatten'_div_mod_length n inp : forall start_state,
length (fst start_state) = n ->
(forall row, In row inp -> length row = n) ->
length (fst (flatten' start_state inp)) = n
/\ (inp <> nil ->
- Positional.eval weight n (fst (flatten' start_state inp)) = (Positional.eval weight n (fst start_state) + eval n inp) mod (weight n)
- /\ snd (flatten' start_state inp) = snd start_state + (Positional.eval weight n (fst start_state) + eval n inp) / weight n).
+ is_div_mod (Positional.eval weight n) (flatten' start_state inp)
+ (Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state)
+ (weight n)).
Proof.
- intro.
- cbv [flatten'].
- apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ];
- repeat match goal with
- | _ => subst a'
- | _ => subst bs'
- | _ => rewrite sum_rows_div with (n:=n) by auto
- | _ => rewrite @fold_right_nil in *
- | _ => progress autorewrite with cancel_pair push_eval in *
- | H : _ -> _ /\ _ |- _ =>
- let X := fresh in let Y := fresh in
- destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ]
- | _ => split
- | _ => progress autorewrite with pull_Zmod zsimplify_fast
- | _ => solve [ auto using length_sum_rows ]
- | _ => solve [ ring_simplify; repeat (f_equal; try ring) ]
- end.
- apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |].
- autorewrite with push_Zmul.
- rewrite !Z.mul_div_eq_full by apply weight_nonzero.
- autorewrite with pull_Zmod.
- ring_simplify_subterms. ring_simplify.
- repeat (f_equal; try ring).
+ induction inp; push; [apply IHinp; push|].
+ destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod]
+ | eapply is_div_mod_result_equal; try apply IHinp]; push.
+ { autorewrite with zsimplify; push. }
+ { autorewrite with zsimplify; push. }
Qed.
Hint Rewrite (@Positional.length_zeros weight) : distr_length.
@@ -1369,22 +1416,11 @@ Module Rows.
Lemma flatten_div_mod inp n :
(forall row, In row inp -> length row = n) ->
- Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)
- /\ snd (flatten inp) = (eval n inp) / weight n.
+ is_div_mod (Positional.eval weight n) (flatten inp) (eval n inp) (weight n).
Proof.
intros; cbv [flatten].
- destruct inp; [|destruct inp]; cbn [hd tl].
- { cbn. autorewrite with push_eval. tauto. }
- { cbn.
- match goal with H : forall r, In r [?x] -> _ |- _ =>
- specialize (H x ltac:(auto)); rewrite H
- end.
- rewrite sum_rows_div with (n:=n) by distr_length.
- autorewrite with push_eval.
- split; f_equal; ring. }
- { autorewrite with push_eval.
- apply flatten'_div_mod_length; auto.
- congruence. }
+ destruct inp; [|destruct inp]; cbn [hd tl]; try solve [cbv [is_div_mod]; push].
+ eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push.
Qed.
Lemma flatten_mod inp n :
@@ -1408,25 +1444,24 @@ Module Rows.
inp <> nil ->
length (fst (flatten inp)) = n.
Proof.
- intros.
- destruct inp; [congruence |]; destruct inp;
- repeat match goal with
- | _ => progress intros
- | _ => progress cbn [hd tl] in *
- | _ => progress autorewrite with cancel_pair
- | _ => apply flatten'_div_mod_length
- | H: _ |- _ => apply in_inv in H; destruct H
- | _ => solve [auto]
- end;
- subst row; distr_length; auto.
+ intros. apply flatten'_div_mod_length; push;
+ destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push.
+ subst row; distr_length; auto.
Qed. Hint Rewrite length_flatten : distr_length.
- Lemma flatten'_cons state x inp :
- flatten' state (x :: inp)
- = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))).
- Proof. reflexivity. Qed.
+ (* TODO: move to ZUtil *)
+ Lemma add_mod_l_multiple a b n m:
+ 0 < n / m -> m <> 0 -> n mod m = 0 ->
+ (a mod n + b) mod m = (a + b) mod m.
+ Proof.
+ intros.
+ rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto.
+ rewrite Z.rem_mul_r by auto.
+ push_Zmod. autorewrite with zsimplify.
+ pull_Zmod. reflexivity.
+ Qed.
- Lemma flatten'_partitions n start_state inp:
+ Lemma flatten'_partitions n inp : forall start_state,
length (fst start_state) = n ->
(forall row, In row inp -> length row = n) ->
inp <> nil ->
@@ -1434,18 +1469,11 @@ Module Rows.
nth_default 0 (fst (flatten' start_state inp)) i
= ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i).
Proof.
- intro.
- induction inp; intros; [congruence|].
- rewrite flatten'_cons.
- autorewrite with push_fold_right cancel_pair.
- rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto).
+ induction inp; push.
destruct (dec (inp = nil)).
- { subst inp. cbn. repeat (f_equal; try ring). }
- { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto.
- rewrite Hmod. autorewrite with push_eval.
- rewrite Z.add_mod_full. erewrite @Columns.weight_div_mod with (j:=n) (i:=S i) by eauto.
- rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full.
- autorewrite with zsimplify. pull_Zmod.
+ { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. }
+ { erewrite IHinp; push.
+ rewrite add_mod_l_multiple by auto using weight_divides_full, Columns.weight_multiples_full.
repeat (f_equal; try ring). }
Qed.
@@ -1455,18 +1483,10 @@ Module Rows.
nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i).
Proof.
intros; cbv [flatten].
- destruct inp; [|destruct inp]; cbn [hd tl].
- { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. }
- { cbn.
- match goal with H : forall r, In r [?x] -> _ |- _ =>
- specialize (H x ltac:(auto)); rewrite H
- end.
- rewrite sum_rows_partitions with (n:=n) by distr_length.
- autorewrite with push_eval.
- repeat (f_equal; try ring). }
- { autorewrite with push_eval.
- apply flatten'_partitions; auto.
- congruence. }
+ intros; destruct inp as [| ? [| ? ?] ]; try congruence; cbn [hd tl] in *; try solve [push].
+ { cbn. autorewrite with push_nth_default. reflexivity. }
+ { push. rewrite sum_rows_partitions with (n:=n) by distr_length; push. }
+ { rewrite flatten'_partitions with (n:=n); push. }
Qed.
End Flatten.