aboutsummaryrefslogtreecommitdiff
path: root/src/Util/ListUtil.v
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2016-07-06 15:41:14 -0700
committerGravatar Jason Gross <jagro@google.com>2016-07-06 15:41:14 -0700
commit56e58b21bb80e7b460b0010a8b307f97c3fefea4 (patch)
tree436ce6f5151b59b2452c39de699db695c33e2f4f /src/Util/ListUtil.v
parent6ddfe39affef2f47836b03d49fc6e4b9266e75c8 (diff)
Add [update_nth] to ListUtil, change [set_nth]
Define [set_nth] in terms of [update_nth]
Diffstat (limited to 'src/Util/ListUtil.v')
-rw-r--r--src/Util/ListUtil.v241
1 files changed, 195 insertions, 46 deletions
diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v
index 060372341..50927c2a4 100644
--- a/src/Util/ListUtil.v
+++ b/src/Util/ListUtil.v
@@ -1,8 +1,27 @@
Require Import Coq.Lists.List.
Require Import Coq.omega.Omega.
Require Import Coq.Arith.Peano_dec.
+Require Import Coq.Classes.Morphisms.
Require Import Crypto.Tactics.VerdiTactics.
Require Import Coq.Numbers.Natural.Peano.NPeano.
+Require Import Crypto.Util.NatUtil.
+
+Create HintDb distr_length discriminated.
+Create HintDb simpl_set_nth discriminated.
+Create HintDb simpl_update_nth discriminated.
+
+Hint Rewrite
+ @app_length
+ @rev_length
+ @map_length
+ @seq_length
+ @fold_left_length
+ @split_length_l
+ @split_length_r
+ @firstn_length
+ @combine_length
+ @prod_length
+ : distr_length.
Definition sum_firstn l n := fold_right Z.add 0%Z (firstn n l).
@@ -106,42 +125,130 @@ Proof.
induction xs; boring.
Qed.
-(* xs[n] := x *)
-Fixpoint set_nth {T} n x (xs:list T) {struct n} :=
+(* xs[n] := f xs[n] *)
+Fixpoint update_nth {T} n f (xs:list T) {struct n} :=
match n with
| O => match xs with
| nil => nil
- | x'::xs' => x::xs'
+ | x'::xs' => f x'::xs'
end
| S n' => match xs with
| nil => nil
- | x'::xs' => x'::set_nth n' x xs'
+ | x'::xs' => x'::update_nth n' f xs'
end
end.
-Lemma nth_set_nth : forall m {T} (xs:list T) (n:nat) (x x':T),
- nth_error (set_nth m x xs) n =
- if eq_nat_dec n m
- then (if lt_dec n (length xs) then Some x else None)
- else nth_error xs n.
+(* xs[n] := x *)
+Definition set_nth {T} n x (xs:list T)
+ := update_nth n (fun _ => x) xs.
+
+Lemma unfold_set_nth {T} n x
+ : forall xs,
+ @set_nth T n x xs
+ = match n with
+ | O => match xs with
+ | nil => nil
+ | x'::xs' => x::xs'
+ end
+ | S n' => match xs with
+ | nil => nil
+ | x'::xs' => x'::set_nth n' x xs'
+ end
+ end.
Proof.
- induction m.
+ induction n; destruct xs; reflexivity.
+Qed.
+
+Lemma simpl_set_nth_0 {T} x
+ : forall xs,
+ @set_nth T 0 x xs
+ = match xs with
+ | nil => nil
+ | x'::xs' => x::xs'
+ end.
+Proof. intro; rewrite unfold_set_nth; reflexivity. Qed.
+
+Lemma simpl_set_nth_S {T} x n
+ : forall xs,
+ @set_nth T (S n) x xs
+ = match xs with
+ | nil => nil
+ | x'::xs' => x'::set_nth n x xs'
+ end.
+Proof. intro; rewrite unfold_set_nth; reflexivity. Qed.
+
+Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_set_nth.
+
+Lemma update_nth_ext {T} f g n
+ : forall xs, (forall x, nth_error xs n = Some x -> f x = g x)
+ -> @update_nth T n f xs = @update_nth T n g xs.
+Proof.
+ induction n; destruct xs; simpl; intros H;
+ try rewrite IHn; try rewrite H;
+ try congruence; trivial.
+Qed.
+
+Global Instance update_nth_Proper {T}
+ : Proper (eq ==> pointwise_relation _ eq ==> eq ==> eq) (@update_nth T).
+Proof. repeat intro; subst; apply update_nth_ext; trivial. Qed.
+
+Lemma update_nth_id_eq_specific {T} f n
+ : forall (xs : list T) (H : forall x, nth_error xs n = Some x -> f x = x),
+ update_nth n f xs = xs.
+Proof.
+ induction n; destruct xs; simpl; intros;
+ try rewrite IHn; try rewrite H;
+ try congruence; assumption.
+Qed.
- destruct n, xs; auto.
+Hint Rewrite @update_nth_id_eq_specific using congruence : simpl_update_nth.
- intros; destruct xs, n; auto.
- simpl; unfold error; match goal with
- [ |- None = if ?x then None else None ] => destruct x
- end; auto.
+Lemma update_nth_id_eq : forall {T} f (H : forall x, f x = x) n (xs : list T),
+ update_nth n f xs = xs.
+Proof. intros; apply update_nth_id_eq_specific; trivial. Qed.
- simpl nth_error; erewrite IHm by auto; clear IHm.
- destruct (eq_nat_dec n m), (eq_nat_dec (S n) (S m)); nth_tac.
+Hint Rewrite @update_nth_id_eq using congruence : simpl_update_nth.
+
+Lemma update_nth_id : forall {T} n (xs : list T),
+ update_nth n (fun x => x) xs = xs.
+Proof. intros; apply update_nth_id_eq; trivial. Qed.
+
+Hint Rewrite @update_nth_id : simpl_update_nth.
+
+Lemma nth_update_nth : forall m {T} (xs:list T) (n:nat) (f:T -> T),
+ nth_error (update_nth m f xs) n =
+ if eq_nat_dec n m
+ then option_map f (nth_error xs n)
+ else nth_error xs n.
+Proof.
+ induction m.
+ { destruct n, xs; auto. }
+ { destruct xs, n; intros; simpl; auto;
+ [ | rewrite IHm ]; clear IHm;
+ edestruct eq_nat_dec; reflexivity. }
Qed.
-Lemma length_set_nth : forall {T} i (x:T) xs, length (set_nth i x xs) = length xs.
+Lemma length_update_nth : forall {T} i f (xs:list T), length (update_nth i f xs) = length xs.
+Proof.
induction i, xs; boring.
Qed.
+Lemma nth_set_nth : forall m {T} (xs:list T) (n:nat) x,
+ nth_error (set_nth m x xs) n =
+ if eq_nat_dec n m
+ then (if lt_dec n (length xs) then Some x else None)
+ else nth_error xs n.
+Proof.
+ intros; unfold set_nth; rewrite nth_update_nth.
+ destruct (nth_error xs n) eqn:?, (lt_dec n (length xs)) as [p|p];
+ rewrite <- nth_error_Some in p;
+ solve [ reflexivity
+ | exfalso; apply p; congruence ].
+Qed.
+
+Lemma length_set_nth : forall {T} i x (xs:list T), length (set_nth i x xs) = length xs.
+Proof. intros; apply length_update_nth. Qed.
+
Lemma nth_error_length_exists_value : forall {A} (i : nat) (xs : list A),
(i < length xs)%nat -> exists x, nth_error xs i = Some x.
Proof.
@@ -208,11 +315,50 @@ Lemma set_nth_equiv_splice_nth: forall {T} n x (xs:list T),
then splice_nth n x xs
else xs.
Proof.
- induction n; destruct xs; intros; simpl in *;
- try (rewrite IHn; clear IHn); auto.
+ induction n; destruct xs; intros;
+ autorewrite with simpl_set_nth in *; simpl in *;
+ try (rewrite IHn; clear IHn); auto.
break_if; break_if; auto; omega.
Qed.
+Lemma combine_update_nth : forall {A B} n f g (xs:list A) (ys:list B),
+ combine (update_nth n f xs) (update_nth n g ys) =
+ update_nth n (fun xy => (f (fst xy), g (snd xy))) (combine xs ys).
+Proof.
+ induction n; destruct xs, ys; simpl; try rewrite IHn; reflexivity.
+Qed.
+
+(* grumble, grumble, [rewrite] is bad at inferring the identity function, and constant functions *)
+Ltac rewrite_rev_combine_update_nth :=
+ let lem := match goal with
+ | [ |- appcontext[update_nth ?n (fun xy => (@?f xy, @?g xy)) (combine ?xs ?ys)] ]
+ => let f := match (eval cbv [fst] in (fun y x => f (x, y))) with
+ | fun _ => ?f => f
+ end in
+ let g := match (eval cbv [snd] in (fun x y => g (x, y))) with
+ | fun _ => ?g => g
+ end in
+ constr:(@combine_update_nth _ _ n f g xs ys)
+ end in
+ rewrite <- lem.
+
+Lemma combine_update_nth_l : forall {A B} n (f : A -> A) xs (ys:list B),
+ combine (update_nth n f xs) ys =
+ update_nth n (fun xy => (f (fst xy), snd xy)) (combine xs ys).
+Proof.
+ intros ??? f xs ys.
+ etransitivity; [ | apply combine_update_nth with (g := fun x => x) ].
+ rewrite update_nth_id; reflexivity.
+Qed.
+
+Lemma combine_update_nth_r : forall {A B} n (g : B -> B) (xs:list A) (ys:list B),
+ combine xs (update_nth n g ys) =
+ update_nth n (fun xy => (fst xy, g (snd xy))) (combine xs ys).
+Proof.
+ intros ??? g xs ys.
+ etransitivity; [ | apply combine_update_nth with (f := fun x => x) ].
+ rewrite update_nth_id; reflexivity.
+Qed.
Lemma combine_set_nth : forall {A B} n (x:A) xs (ys:list B),
combine (set_nth n x xs) ys =
@@ -221,12 +367,12 @@ Lemma combine_set_nth : forall {A B} n (x:A) xs (ys:list B),
| Some y => set_nth n (x,y) (combine xs ys)
end.
Proof.
- (* TODO(andreser): this proof can totally be automated, but requires writing ltac that vets multiple hypotheses at once *)
- induction n, xs, ys; nth_tac; try rewrite IHn; nth_tac;
- try (f_equal; specialize (IHn x xs ys ); rewrite H in IHn; rewrite <- IHn);
- try (specialize (nth_error_value_length _ _ _ _ H); omega).
- assert (Some b0=Some b1) as HA by (rewrite <-H, <-H0; auto).
- injection HA; intros; subst; auto.
+ intros; unfold set_nth; rewrite combine_update_nth_l.
+ nth_tac;
+ [ repeat rewrite_rev_combine_update_nth; apply f_equal2
+ | assert (nth_error (combine xs ys) n = None)
+ by (apply nth_error_None; rewrite combine_length; omega * ) ];
+ autorewrite with simpl_update_nth; reflexivity.
Qed.
Lemma nth_error_value_In : forall {T} n xs (x:T),
@@ -461,42 +607,40 @@ Proof.
reflexivity.
Qed.
+Lemma update_nth_cons : forall {T} f (u0 : T) us, update_nth 0 f (u0 :: us) = (f u0) :: us.
+Proof. reflexivity. Qed.
+
Lemma set_nth_cons : forall {T} (x u0 : T) us, set_nth 0 x (u0 :: us) = x :: us.
-Proof.
- auto.
-Qed.
+Proof. intros; apply update_nth_cons. Qed.
-Create HintDb distr_length discriminated.
Hint Rewrite
@nil_length0
@length_cons
- @app_length
- @rev_length
- @map_length
- @seq_length
- @fold_left_length
- @split_length_l
- @split_length_r
- @firstn_length
@skipn_length
- @combine_length
- @prod_length
+ @length_update_nth
@length_set_nth
: distr_length.
Ltac distr_length := autorewrite with distr_length in *;
try solve [simpl in *; omega].
-Lemma cons_set_nth : forall {T} n (x y : T) us,
- y :: set_nth n x us = set_nth (S n) x (y :: us).
+Lemma cons_update_nth : forall {T} n f (y : T) us,
+ y :: update_nth n f us = update_nth (S n) f (y :: us).
Proof.
induction n; boring.
Qed.
-Lemma set_nth_nil : forall {T} n (x : T), set_nth n x nil = nil.
+Lemma update_nth_nil : forall {T} n f, set_nth n f (@nil T) = @nil T.
Proof.
induction n; boring.
Qed.
+Lemma cons_set_nth : forall {T} n (x y : T) us,
+ y :: set_nth n x us = set_nth (S n) x (y :: us).
+Proof. intros; apply cons_update_nth. Qed.
+
+Lemma set_nth_nil : forall {T} n (x : T), set_nth n x nil = nil.
+Proof. intros; apply update_nth_nil. Qed.
+
Lemma nth_default_nil : forall {T} n (d : T), nth_default d nil n = d.
Proof.
induction n; boring.
@@ -629,15 +773,20 @@ Proof.
omega.
Qed.
-Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat ->
- nth_default d (set_nth n x l) i =
- if (eq_nat_dec i n) then x else nth_default d l i.
+Lemma update_nth_nth_default : forall {A} (d:A) n f l i, (0 <= i < length l)%nat ->
+ nth_default d (update_nth n f l) i =
+ if (eq_nat_dec i n) then f (nth_default d l i) else nth_default d l i.
Proof.
induction n; (destruct l; [intros; simpl in *; omega | ]); simpl;
destruct i; break_if; try omega; intros; try apply nth_default_cons;
rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity.
Qed.
+Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat ->
+ nth_default d (set_nth n x l) i =
+ if (eq_nat_dec i n) then x else nth_default d l i.
+Proof. intros; apply update_nth_nth_default; assumption. Qed.
+
Lemma nth_default_preserves_properties : forall {A} (P : A -> Prop) l n d,
(forall x, In x l -> P x) -> P d -> P (nth_default d l n).
Proof.