aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-09 13:06:01 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-04-09 13:06:01 +0200
commite2bf39c696a8ecc35c6d9b31011f54854ec3142a (patch)
tree9d81b6d756e5054c170d9f4201674e4698f5eecc /src
parent0fc77cb8ae1d0c85ac6d5ed59aaca85472735c4a (diff)
package properties of weight functions into a record
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v198
1 files changed, 64 insertions, 134 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 95c6929d2..4c0ff4e95 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -427,6 +427,15 @@ Module Positional. Section Positional.
Hint Rewrite @length_sub @length_opp : distr_length.
End Positional. End Positional.
+Record weight_properties {weight : nat -> Z} :=
+ {
+ weight_0 : weight 0%nat = 1;
+ weight_positive : forall i, 0 < weight i;
+ weight_multiples : forall i, weight (S i) mod weight i = 0;
+ weight_divides : forall i : nat, 0 < weight (S i) / weight i;
+ }.
+Hint Resolve weight_0 weight_positive weight_multiples weight_divides.
+
Section mod_ops.
Import Positional.
Local Coercion Z.of_nat : nat >-> Z.
@@ -463,34 +472,26 @@ Section mod_ops.
try reflexivity; try lia.
Qed.
- Lemma weight_0 : weight 0 = 1.
- Proof.
- clear.
- cbv [weight Z.of_nat]; autorewrite with zsimplify_fast; reflexivity.
- Qed.
- Local Hint Immediate weight_0.
-
Local Ltac t_weight_with lem :=
clear -limbwidth_good;
intros; rewrite !weight_ZQ_correct;
apply lem;
try omega; Q_cbv; destruct limbwidth_den; cbn; try lia.
- Local Lemma weight_nz : forall i, weight i <> 0.
- Proof. t_weight_with (@pow_ceil_mul_nat_nonzero 2). Qed.
- Local Hint Immediate weight_nz.
-
- Local Lemma weight_div_nz : forall i : nat, weight (S i) / weight i <> 0.
- Proof. t_weight_with (@pow_ceil_mul_nat_divide_nonzero 2). Qed.
- Local Hint Immediate weight_div_nz.
-
- (* lemmas for montred *)
- Local Lemma weight_divides : forall i, weight (S i) / weight i > 0.
- Proof. t_weight_with (@pow_ceil_mul_nat_divide 2). Qed.
- Local Lemma weight_positive : forall i, weight i > 0.
- Proof. t_weight_with (@pow_ceil_mul_nat_pos 2). Qed.
- Local Lemma weight_multiples : forall i, weight (S i) mod weight i = 0.
- Proof. t_weight_with (@pow_ceil_mul_nat_multiples 2). Qed.
+ Definition wprops : @weight_properties weight.
+ Proof.
+ constructor.
+ { cbv [weight Z.of_nat]; autorewrite with zsimplify_fast; reflexivity. }
+ { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_pos 2). }
+ { t_weight_with (@pow_ceil_mul_nat_multiples 2). }
+ { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_divide 2). }
+ Defined.
+ Local Hint Immediate (weight_0 wprops).
+ Local Hint Immediate (weight_positive wprops).
+ Local Hint Immediate (weight_multiples wprops).
+ Local Hint Immediate (weight_divides wprops).
+ Local Hint Resolve Z.positive_is_nonzero Z.lt_gt.
+
Local Lemma weight_1_gt_1 : weight 1 > 1.
Proof.
clear -limbwidth_good.
@@ -596,13 +597,11 @@ Section mod_ops.
End mod_ops.
Module Saturated.
+ Hint Resolve weight_positive weight_0 weight_multiples weight_divides.
+ Hint Resolve Z.positive_is_nonzero Z.lt_gt Nat2Z.is_nonneg.
+
Section Weight.
- Context (weight : nat->Z)
- {weight_0 : weight 0%nat = 1}
- {weight_nonzero : forall i, weight i <> 0}
- {weight_positive : forall i, weight i > 0}
- {weight_multiples : forall i, weight (S i) mod weight i = 0}
- {weight_divides : forall i : nat, weight (S i) / weight i > 0}.
+ Context weight {wprops : @weight_properties weight}.
Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0.
Proof.
@@ -623,8 +622,8 @@ Module Saturated.
apply weight_multiples_full'.
Qed.
- Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0.
- Proof. auto using Z.div_positive_gt_0, weight_multiples_full. Qed.
+ Lemma weight_divides_full j i : (i <= j)%nat -> 0 < weight j / weight i.
+ Proof. auto using Z.gt_lt, Z.div_positive_gt_0, weight_multiples_full. Qed.
Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i).
Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed.
@@ -738,12 +737,7 @@ End Saturated.
Module Columns.
Import Saturated.
Section Columns.
- Context (weight : nat->Z)
- {weight_0 : weight 0%nat = 1}
- {weight_nonzero : forall i, weight i <> 0}
- {weight_positive : forall i, weight i > 0}
- {weight_multiples : forall i, weight (S i) mod weight i = 0}
- {weight_divides : forall i : nat, weight (S i) / weight i > 0}.
+ Context weight {wprops : @weight_properties weight}.
Definition eval n (x : list (list Z)) : Z := Positional.eval weight n (map sum x).
@@ -953,7 +947,7 @@ Module Columns.
eval n (from_associational n p) = Associational.eval p.
Proof.
erewrite <-Positional.eval_from_associational by eauto.
- induction p; [ autorewrite with push_eval; congruence |].
+ induction p; [ autorewrite with push_eval; solve [auto] |].
cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right.
fold (from_associational n p); fold (Positional.from_associational weight n p).
cbv [Let_In].
@@ -989,12 +983,7 @@ End Columns.
Module Rows.
Import Saturated.
Section Rows.
- Context (weight : nat->Z)
- {weight_0 : weight 0%nat = 1}
- {weight_nonzero : forall i, weight i <> 0}
- {weight_positive : forall i, weight i > 0}
- {weight_multiples : forall i, weight (S i) mod weight i = 0}
- {weight_divides : forall i : nat, weight (S i) / weight i > 0}.
+ Context weight {wprops : @weight_properties weight}.
Local Notation rows := (list (list Z)) (only parsing).
Local Notation cols := (list (list Z)) (only parsing).
@@ -1003,7 +992,6 @@ Module Rows.
Positional.eval_to_associational
Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval.
Hint Resolve in_eq in_cons.
- Hint Resolve Z.gt_lt.
Definition eval n (inp : rows) :=
sum (map (Positional.eval weight n) inp).
@@ -1359,7 +1347,7 @@ Module Rows.
| _ => rewrite <-(Z.add_assoc _ x1 x2)
end.
{ rewrite div_step by auto using Z.gt_lt.
- rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples. push. }
+ rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples by auto. push. }
{ rewrite weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (auto; omega).
push_Zmod. autorewrite with zsimplify_fast. reflexivity. }
{ push. replace (length (fst start_state)) with j in * by omega.
@@ -1418,7 +1406,7 @@ Module Rows.
| _ => progress In_cases
| |- _ /\ _ => split
| |- context [?x mod ?y] => unique pose proof (Z.mul_div_eq_full x y ltac:(auto)); lia
- | _ => solve [repeat (f_equal; try ring)]
+ | _ => solve [repeat (ring_simplify; f_equal; try ring)]
| _ => congruence
| _ => solve [eauto]
end.
@@ -1436,7 +1424,7 @@ Module Rows.
destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod]
| eapply is_div_mod_result_equal; try apply IHinp]; push.
{ autorewrite with zsimplify; push. }
- { autorewrite with zsimplify; push. }
+ { rewrite Z.div_add' by auto; push. }
Qed.
Hint Rewrite (@Positional.length_zeros weight) : distr_length.
@@ -1494,7 +1482,7 @@ Module Rows.
{ subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. }
{ erewrite IHinp; push.
rewrite add_mod_l_multiple by auto using weight_divides_full, weight_multiples_full.
- repeat (f_equal; try ring). }
+ push. }
Qed.
Lemma flatten_partitions inp n :
@@ -1574,15 +1562,10 @@ End Rows.
Module BaseConversion.
Import Positional.
Section BaseConversion.
+ Hint Resolve Z.gt_lt.
Context (sw dw : nat -> Z) (* source/destination weight functions *)
- (dw_0 : dw 0%nat = 1)
- (sw_0 : sw 0%nat = 1)
- (dw_nz : forall i, dw i <> 0)
- (sw_nz : forall i, sw i <> 0)
- (sw_pos : forall i, sw i > 0)
- (sw_multiples : forall i, sw (S i) mod sw i = 0)
- (sw_divides : forall i, sw (S i) / sw i > 0)
- (dw_divides : forall i, dw (S i) / dw i > 0).
+ {swprops : @weight_properties sw}
+ {dwprops : @weight_properties dw}.
Definition convert_bases (sn dn : nat) (p : list Z) : list Z :=
let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in
@@ -1691,41 +1674,31 @@ Module BaseConversion.
(* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *)
Section widemul.
- Context (dw : nat -> Z) (n : nat) (n_nz : n <> 0%nat).
- Context (dw_0 : dw 0%nat = 1)
- (dw_pos : forall i, dw i > 0)
- (dw_multiples : forall i, dw (S i) mod dw i = 0)
- (dw_divides : forall i, dw (S i) / dw i > 0).
- Context (nout : nat) (nout_nz : nout <> 0%nat). (* this is always 2, but reification has trouble if it's a constant *)
- Let sw i := (dw i) ^ Z.of_nat n.
+ Context (log2base : Z) (log2base_pos : 0 < log2base).
+ Context (n : nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base)
+ (nout : nat) (nout_2 : nout = 2%nat). (* nout is always 2, but partial evaluation is overeager if it's a constant *)
+ Let dw : nat -> Z := weight (log2base / Z.of_nat n) 1.
+ Let sw : nat -> Z := weight log2base 1.
+
+ Local Lemma base_bounds : 0 < 1 <= log2base. Proof. auto with zarith. Qed.
+ Local Lemma dbase_bounds : 0 < 1 <= log2base / Z.of_nat n. Proof. auto with zarith. Qed.
+ Let dwprops : @weight_properties dw := wprops (log2base / Z.of_nat n) 1 dbase_bounds.
+ Let swprops : @weight_properties sw := wprops log2base 1 base_bounds.
Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg.
- (* TODO : There has got to be a cleaner way to do weight functions *)
- Lemma sw_0 : sw 0 = 1.
- Proof. subst sw; cbv beta. rewrite dw_0. auto using Z.pow_1_l. Qed.
- Lemma sw_pos : forall i, sw i > 0.
- Proof. subst sw; intros; apply Z.lt_gt. Z.zero_bounds. Qed.
- Lemma sw_nz : forall i, sw i <> 0.
- Proof. auto using sw_pos. Qed.
- Lemma sw_multiples : forall i, sw (S i) mod sw i = 0.
- Proof.
- subst sw; cbv beta; intros.
- rewrite Saturated.weight_div_mod with (weight := dw) (j:=S i) (i:=i) by auto.
- rewrite Z.pow_mul_l.
- push_Zmod. autorewrite with zsimplify. reflexivity.
- Qed.
- Lemma sw_divides : forall i, sw (S i) / sw i > 0.
- Proof. intros; apply Z.div_positive_gt_0; auto using sw_pos, sw_multiples. Qed.
- Hint Resolve sw_0 sw_pos sw_nz sw_multiples sw_divides.
Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b].
- Lemma widemul_partitions a b :
- widemul a b = Rows.partition sw nout (a * b).
+ Lemma widemul_correct a b :
+ 0 <= a * b < 2^log2base * 2^log2base ->
+ widemul a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base].
Proof.
- cbv [widemul].
+ cbv [widemul]; intros.
rewrite mul_converted_partitions by auto with zarith.
- cbn; rewrite sw_0; ring_simplify_subterms. reflexivity.
+ subst nout sw; cbv [weight]; cbn.
+ autorewrite with zsimplify.
+ rewrite Z.pow_mul_r, Z.pow_2_r by omega.
+ Z.rewrite_mod_small. reflexivity.
Qed.
End widemul.
End BaseConversion.
@@ -7789,45 +7762,11 @@ Module MontgomeryReduction.
Context (R_big_enough : n <= Zlog2R)
(R_two_pow : 2^Zlog2R = R).
Let w_mul : nat -> Z := weight (Zlog2R / n) 1.
- Local Lemma w_mul_pown : forall i, (w_mul i) ^ n = w i.
- Proof.
- cbv [w_mul w weight]; intro.
- autorewrite with pull_Zpow zsimplify.
- rewrite <-Z.pow_mul_r by Z.zero_bounds. apply f_equal.
- rewrite (Z.div_mod Zlog2R n) at 2 by Z.zero_bounds.
- rewrite n_good. lia.
- Qed.
- Local Lemma log2R_good : 0 < 1 <= Zlog2R.
- Proof. clear -R_big_enough Hn_nz; lia. Qed.
- Local Lemma half_log2R_good : 0 < 1 <= Zlog2R / n.
- Proof. clear -R_big_enough Hn_nz; Z.div_mod_to_quot_rem; nia. Qed.
- Let w_mul_0 : w_mul 0%nat = 1 := weight_0 _ _.
- Let w_mul_nonzero : forall i, w_mul i <> 0
- := weight_nz _ _ half_log2R_good.
- Let w_mul_positive : forall i, w_mul i > 0
- := weight_positive _ _ half_log2R_good.
- Let w_mul_multiples : forall i, w_mul (S i) mod w_mul i = 0
- := weight_multiples _ _ half_log2R_good.
- Let w_mul_divides : forall i : nat, w_mul (S i) / w_mul i > 0
- := weight_divides _ _ half_log2R_good.
-(*
- Let w_0 : w 0%nat = 1 := weight_0 _ _.
- Let w_nonzero : forall i, w i <> 0
- := weight_nz _ _ log2R_good.
- Let w_positive : forall i, w i > 0
- := weight_positive _ _ log2R_good.
- Let w_multiples : forall i, w (S i) mod w i = 0
- := weight_multiples _ _ log2R_good.
- Let w_divides : forall i : nat, w (S i) / w i > 0
- := weight_divides _ _ log2R_good.
- Let w_1_gt1 : w 1 > 1 := weight_1_gt_1 _ _ log2R_good.
-*)
- Let w_mul_1_gt1 : w_mul 1 > 1 := weight_1_gt_1 _ _ half_log2R_good.
Context (nout : nat) (Hnout : nout = 2%nat).
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (BaseConversion.widemul w_mul n nout (fst lo_hi) N') 0 in
- dlet_nd t1_t2 := (BaseConversion.widemul w_mul n nout y N) in
+ dlet_nd y := nth_default 0 (BaseConversion.widemul Zlog2R n nout (fst lo_hi) N') 0 in
+ dlet_nd t1_t2 := (BaseConversion.widemul Zlog2R n nout y N) in
dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
@@ -7841,7 +7780,7 @@ Module MontgomeryReduction.
rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity.
Qed.
- Local Ltac change_weight := rewrite ?w_mul_pown, !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r in *.
+ Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r, ?Z.pow_1_l in *.
Local Ltac solve_range :=
repeat match goal with
| _ => progress change_weight
@@ -7850,20 +7789,9 @@ Module MontgomeryReduction.
| |- 0 <= _ * _ < _ * _ =>
split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ]
| _ => solve [auto]
- | _ => cbn
- | _ => nia
+ | _ => omega
end.
- Lemma widemul_correct x y :
- 0 <= x * y < w 2 -> BaseConversion.widemul w_mul n nout x y = [(x * y) mod R; (x * y) / R].
- Proof.
- intros.
- rewrite BaseConversion.widemul_partitions by (auto; omega).
- subst nout. cbn. change_weight.
- autorewrite with zsimplify.
- Z.rewrite_mod_small. reflexivity.
- Qed.
-
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
(Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R):
montred' lo_hi = reduce_via_partial N R N' T.
@@ -7873,7 +7801,9 @@ Module MontgomeryReduction.
rewrite Hlo, Hhi.
assert (0 <= T mod R * N' < w 2) by (solve_range).
- rewrite !widemul_correct by (rewrite ?widemul_correct; autorewrite with push_nth_default; solve_range).
+ rewrite !BaseConversion.widemul_correct
+ by (rewrite ?BaseConversion.widemul_correct; autorewrite with push_nth_default; solve_range).
+ rewrite R_two_pow.
autorewrite with push_nth_default.
autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.