aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2018-06-28 19:28:21 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-07-03 19:28:55 -0400
commitdbd5c4be0fa10cbfed1998ee09d702944d6f5c91 (patch)
tree71fbfb89f6b29c1b196c22895db741350e2de37c /src
parent16a39010f466bcc3471e5c9cea03fe8ec006232d (diff)
Add select
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v57
1 files changed, 47 insertions, 10 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index 87cc5c7b2..72d488d80 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -490,32 +490,58 @@ Module Positional. Section Positional.
Hint Rewrite @length_sub @length_opp : distr_length.
Section select.
- Definition select (mask cond:Z) (p:list Z) :=
+ Definition zselect (mask cond:Z) (p:list Z) :=
dlet t := Z.zselect cond 0 mask in List.map (Z.land t) p.
+ Definition select (cond:Z) (if_zero if_nonzero:list Z) :=
+ List.map (fun '(p, q) => Z.zselect cond p q) (List.combine if_zero if_nonzero).
+
Lemma map_and_0 n (p:list Z) : length p = n -> map (Z.land 0) p = zeros n.
Proof.
intro; subst; induction p as [|x xs IHxs]; [reflexivity | ].
cbn; f_equal; auto.
Qed.
- Lemma eval_select n mask cond p (H:List.map (Z.land mask) p = p) :
+ Lemma eval_zselect n mask cond p (H:List.map (Z.land mask) p = p) :
length p = n
- -> eval n (select mask cond p) =
+ -> eval n (zselect mask cond p) =
if dec (cond = 0) then 0 else eval n p.
Proof.
- cbv [select Let_In].
+ cbv [zselect Let_In].
rewrite Z.zselect_correct; break_match.
{ intros; erewrite map_and_0 by eassumption. apply eval_zeros. }
{ rewrite H; reflexivity. }
Qed.
- Lemma length_select mask cond p :
- length (select mask cond p) = length p.
- Proof using Type. clear dependent weight. cbv [select Let_In]; break_match; intros; distr_length. Qed.
+ Lemma length_zselect mask cond p :
+ length (zselect mask cond p) = length p.
+ Proof using Type. clear dependent weight. cbv [zselect Let_In]; break_match; intros; distr_length. Qed.
+ Lemma eval_select n cond p q :
+ length p = n -> length q = n
+ -> eval n (select cond p q) =
+ if dec (cond = 0) then eval n p else eval n q.
+ Proof.
+ cbv [select Let_In]; intro; subst.
+ rewrite <- (List.rev_involutive q), <- (List.rev_involutive p).
+ generalize (rev p) (rev q); clear p q; intros p q; revert q.
+ induction p as [|p ps IHps], q as [|q qs]; cbn [length map combine rev]; distr_length; rewrite ?Nat.add_1_r; try omega.
+ { break_match; reflexivity. }
+ { intro; rewrite !combine_snoc, !map_app by (distr_length; omega).
+ cbn [map].
+ rewrite !eval_snoc with (n:=length ps), IHps by (distr_length; omega* ).
+ rewrite !Z.zselect_correct; break_match; reflexivity. }
+ Qed.
+ Lemma length_select_min cond p q :
+ length (select cond p q) = Nat.min (length p) (length q).
+ Proof using Type. clear dependent weight. cbv [select Let_In]; distr_length. Qed.
+ Hint Rewrite length_select_min : distr_length.
+ Lemma length_select n cond p q :
+ length p = n -> length q = n ->
+ length (select cond p q) = n.
+ Proof using Type. clear dependent weight. distr_length; omega **. Qed.
End select.
End Positional.
(* Hint Rewrite disappears after the end of a section *)
-Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select : distr_length.
-Hint Rewrite @eval_select : push_eval.
+Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select @length_zselect @length_select_min : distr_length.
+Hint Rewrite @eval_select @eval_zselect : push_eval.
Section Positional_nonuniform.
Context (weight weight' : nat -> Z).
@@ -1759,6 +1785,7 @@ Module Rows.
fst (flatten n inp) = partition n (eval n inp).
Proof using wprops. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed.
End Flatten.
+ Hint Rewrite length_partition : distr_length.
Section Ops.
Definition add n p q := flatten n [p; q].
@@ -1773,9 +1800,14 @@ Module Rows.
Definition sub n p q := flatten n [p; map (fun x => dlet y := x in Z.opp y) q].
Definition conditional_add n mask cond (p q:list Z) :=
- let qq := Positional.select mask cond q in
+ let qq := Positional.zselect mask cond q in
add n p qq.
+ (* Subtract q if and only if p >= q. *)
+ Definition conditional_sub n (p q:list Z) :=
+ let '(v, c) := sub n p q in
+ Positional.select c v p.
+
Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval.
Definition mul base n m (p q : list Z) :=
@@ -1864,6 +1896,11 @@ Module Rows.
fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q).
Proof using wprops. solver. Qed.
+ Lemma length_mul base n m p q :
+ length p = n -> length q = n ->
+ length (fst (mul base n m p q)) = m.
+ Proof using wprops. solver; distr_length. Qed.
+
Lemma eval_sat_reduce base s c p :
base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 ->
Associational.eval (sat_reduce base s c p) mod (s - Associational.eval c)