diff options
author | Jason Gross <jagro@google.com> | 2018-06-28 19:28:21 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-07-03 19:28:55 -0400 |
commit | dbd5c4be0fa10cbfed1998ee09d702944d6f5c91 (patch) | |
tree | 71fbfb89f6b29c1b196c22895db741350e2de37c /src | |
parent | 16a39010f466bcc3471e5c9cea03fe8ec006232d (diff) |
Add select
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 57 |
1 files changed, 47 insertions, 10 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index 87cc5c7b2..72d488d80 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -490,32 +490,58 @@ Module Positional. Section Positional. Hint Rewrite @length_sub @length_opp : distr_length. Section select. - Definition select (mask cond:Z) (p:list Z) := + Definition zselect (mask cond:Z) (p:list Z) := dlet t := Z.zselect cond 0 mask in List.map (Z.land t) p. + Definition select (cond:Z) (if_zero if_nonzero:list Z) := + List.map (fun '(p, q) => Z.zselect cond p q) (List.combine if_zero if_nonzero). + Lemma map_and_0 n (p:list Z) : length p = n -> map (Z.land 0) p = zeros n. Proof. intro; subst; induction p as [|x xs IHxs]; [reflexivity | ]. cbn; f_equal; auto. Qed. - Lemma eval_select n mask cond p (H:List.map (Z.land mask) p = p) : + Lemma eval_zselect n mask cond p (H:List.map (Z.land mask) p = p) : length p = n - -> eval n (select mask cond p) = + -> eval n (zselect mask cond p) = if dec (cond = 0) then 0 else eval n p. Proof. - cbv [select Let_In]. + cbv [zselect Let_In]. rewrite Z.zselect_correct; break_match. { intros; erewrite map_and_0 by eassumption. apply eval_zeros. } { rewrite H; reflexivity. } Qed. - Lemma length_select mask cond p : - length (select mask cond p) = length p. - Proof using Type. clear dependent weight. cbv [select Let_In]; break_match; intros; distr_length. Qed. + Lemma length_zselect mask cond p : + length (zselect mask cond p) = length p. + Proof using Type. clear dependent weight. cbv [zselect Let_In]; break_match; intros; distr_length. Qed. + Lemma eval_select n cond p q : + length p = n -> length q = n + -> eval n (select cond p q) = + if dec (cond = 0) then eval n p else eval n q. + Proof. + cbv [select Let_In]; intro; subst. + rewrite <- (List.rev_involutive q), <- (List.rev_involutive p). + generalize (rev p) (rev q); clear p q; intros p q; revert q. + induction p as [|p ps IHps], q as [|q qs]; cbn [length map combine rev]; distr_length; rewrite ?Nat.add_1_r; try omega. + { break_match; reflexivity. } + { intro; rewrite !combine_snoc, !map_app by (distr_length; omega). + cbn [map]. + rewrite !eval_snoc with (n:=length ps), IHps by (distr_length; omega* ). + rewrite !Z.zselect_correct; break_match; reflexivity. } + Qed. + Lemma length_select_min cond p q : + length (select cond p q) = Nat.min (length p) (length q). + Proof using Type. clear dependent weight. cbv [select Let_In]; distr_length. Qed. + Hint Rewrite length_select_min : distr_length. + Lemma length_select n cond p q : + length p = n -> length q = n -> + length (select cond p q) = n. + Proof using Type. clear dependent weight. distr_length; omega **. Qed. End select. End Positional. (* Hint Rewrite disappears after the end of a section *) -Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select : distr_length. -Hint Rewrite @eval_select : push_eval. +Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select @length_zselect @length_select_min : distr_length. +Hint Rewrite @eval_select @eval_zselect : push_eval. Section Positional_nonuniform. Context (weight weight' : nat -> Z). @@ -1759,6 +1785,7 @@ Module Rows. fst (flatten n inp) = partition n (eval n inp). Proof using wprops. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed. End Flatten. + Hint Rewrite length_partition : distr_length. Section Ops. Definition add n p q := flatten n [p; q]. @@ -1773,9 +1800,14 @@ Module Rows. Definition sub n p q := flatten n [p; map (fun x => dlet y := x in Z.opp y) q]. Definition conditional_add n mask cond (p q:list Z) := - let qq := Positional.select mask cond q in + let qq := Positional.zselect mask cond q in add n p qq. + (* Subtract q if and only if p >= q. *) + Definition conditional_sub n (p q:list Z) := + let '(v, c) := sub n p q in + Positional.select c v p. + Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval. Definition mul base n m (p q : list Z) := @@ -1864,6 +1896,11 @@ Module Rows. fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q). Proof using wprops. solver. Qed. + Lemma length_mul base n m p q : + length p = n -> length q = n -> + length (fst (mul base n m p q)) = m. + Proof using wprops. solver; distr_length. Qed. + Lemma eval_sat_reduce base s c p : base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 -> Associational.eval (sat_reduce base s c p) mod (s - Associational.eval c) |