aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/NewPipeline/Arithmetic.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Experiments/NewPipeline/Arithmetic.v')
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v147
1 files changed, 135 insertions, 12 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index d263ca151..04f4bdd4d 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -2079,7 +2079,7 @@ Module Freeze.
Definition freeze n mask (m p:list Z) : list Z :=
let '(p, carry) := Rows.sub weight n p m in
- let '(r, carry) := Rows.conditional_add weight n mask carry p m in
+ let '(r, carry) := Rows.conditional_add weight n mask (-carry) p m in
r.
Lemma freezeZ m s c y :
@@ -2087,13 +2087,14 @@ Module Freeze.
0 < c < s ->
s <> 0 ->
0 <= y < 2*m ->
- ((y - m) + (if (dec ((y - m) / s = 0)) then 0 else m)) mod s
+ ((y - m) + (if (dec (-((y - m) / s) = 0)) then 0 else m)) mod s
= y mod m.
Proof using Type.
clear; intros.
transitivity ((y - m) mod m);
repeat first [ progress intros
| progress subst
+ | rewrite Z.opp_eq_0_iff in *
| break_innermost_match_step
| progress autorewrite with zsimplify_fast
| rewrite Z.div_small_iff in * by auto
@@ -2117,7 +2118,7 @@ Module Freeze.
(Hmlen : length m = n)
: Positional.eval weight n (@freeze n mask m p)
= (Positional.eval weight n p - Positional.eval weight n m +
- (if dec ((Positional.eval weight n p - Positional.eval weight n m) / weight n = 0) then 0 else Positional.eval weight n m))
+ (if dec (-((Positional.eval weight n p - Positional.eval weight n m) / weight n) = 0) then 0 else Positional.eval weight n m))
mod weight n.
(*if dec ((Positional.eval weight n p - Positional.eval weight n m) / weight n = 0)
then Positional.eval weight n p - Positional.eval weight n m
@@ -2144,13 +2145,13 @@ Module Freeze.
(n_nonzero:n<>0%nat)
(Hc : 0 < Associational.eval c < weight n)
(Hmask : List.map (Z.land mask) m = m)
- modulus (Hm : Positional.eval weight n m = Z.pos modulus)
- (Hp : 0 <= Positional.eval weight n p < 2*(Z.pos modulus))
- (Hsc : Z.pos modulus = weight n - Associational.eval c)
+ (modulus:=weight n - Associational.eval c)
+ (Hm : Positional.eval weight n m = modulus)
+ (Hp : 0 <= Positional.eval weight n p < 2*modulus)
(Hplen : length p = n)
(Hmlen : length m = n)
: Positional.eval weight n (@freeze n mask m p)
- = Positional.eval weight n p mod (Z.pos modulus).
+ = Positional.eval weight n p mod modulus.
Proof using wprops.
pose proof (@weight_positive weight wprops n).
rewrite eval_freeze_eq by assumption.
@@ -2162,16 +2163,17 @@ Module Freeze.
(n_nonzero:n<>0%nat)
(Hc : 0 < Associational.eval c < weight n)
(Hmask : List.map (Z.land mask) m = m)
- modulus (Hm : Positional.eval weight n m = Z.pos modulus)
- (Hp : 0 <= Positional.eval weight n p < 2*(Z.pos modulus))
- (Hsc : Z.pos modulus = weight n - Associational.eval c)
+ (modulus:=weight n - Associational.eval c)
+ (Hm : Positional.eval weight n m = modulus)
+ (Hp : 0 <= Positional.eval weight n p < 2*modulus)
(Hplen : length p = n)
(Hmlen : length m = n)
- : @freeze n mask m p = Rows.partition weight n (Positional.eval weight n p mod (Z.pos modulus)).
+ : @freeze n mask m p = Rows.partition weight n (Positional.eval weight n p mod modulus).
Proof using wprops.
pose proof (@weight_positive weight wprops n).
pose proof (fun v => Z.mod_pos_bound v (weight n) ltac:(lia)).
- pose proof (Z.mod_pos_bound (Positional.eval weight n p) (Z.pos modulus) ltac:(lia)).
+ pose proof (Z.mod_pos_bound (Positional.eval weight n p) modulus ltac:(lia)).
+ subst modulus.
erewrite <- eval_freeze by eassumption.
cbv [freeze]; eta_expand.
rewrite Rows.conditional_add_partitions by (auto; rewrite Rows.sub_partitions; auto; distr_length).
@@ -2182,3 +2184,124 @@ Module Freeze.
Qed.
End Freeze.
End Freeze.
+Hint Rewrite Freeze.length_freeze : distr_length.
+
+Section freeze_mod_ops.
+ Import Positional.
+ Import Freeze.
+ Local Coercion Z.of_nat : nat >-> Z.
+ Local Coercion QArith_base.inject_Z : Z >-> Q.
+ (* Design constraints:
+ - inputs must be [Z] (b/c reification does not support Q)
+ - internal structure must not match on the arguments (b/c reification does not support [positive]) *)
+ Context (limbwidth_num limbwidth_den : Z)
+ (limbwidth_good : 0 < limbwidth_den <= limbwidth_num)
+ (s : Z)
+ (c : list (Z*Z))
+ (n : nat)
+ (bitwidth : Z)
+ (m_enc : list Z)
+ (m_nz:s - Associational.eval c <> 0) (s_nz:s <> 0)
+ (Hn_nz : n <> 0%nat).
+ Local Notation bytes_weight := (@weight 8 1).
+ Local Notation weight := (@weight limbwidth_num limbwidth_den).
+ Let m := (s - Associational.eval c).
+
+ Context (Hs : s = weight n).
+ Context (c_small : 0 < Associational.eval c < weight n)
+ (m_enc_bounded : List.map (BinInt.Z.land (Z.ones bitwidth)) m_enc = m_enc)
+ (m_enc_correct : Positional.eval weight n m_enc = m)
+ (Hm_enc_len : length m_enc = n).
+
+ Definition wprops_bytes := (@wprops 8 1 ltac:(lia)).
+ Local Notation wprops := (@wprops limbwidth_num limbwidth_den limbwidth_good).
+
+ Local Hint Immediate (weight_0 wprops).
+ Local Hint Immediate (weight_positive wprops).
+ Local Hint Immediate (weight_multiples wprops).
+ Local Hint Immediate (weight_divides wprops).
+ Local Hint Immediate (weight_0 wprops_bytes).
+ Local Hint Immediate (weight_positive wprops_bytes).
+ Local Hint Immediate (weight_multiples wprops_bytes).
+ Local Hint Immediate (weight_divides wprops_bytes).
+ Local Hint Resolve Z.positive_is_nonzero Z.lt_gt.
+
+ Definition bytes_n := (1 + (Z.to_nat (Z.log2_up (weight n) / 8)))%nat.
+
+ Definition to_bytes' (v : list Z)
+ := BaseConversion.convert_bases weight bytes_weight n bytes_n v.
+
+ Definition from_bytes (v : list Z)
+ := BaseConversion.convert_bases bytes_weight weight bytes_n n v.
+
+ Definition to_bytesmod (f : list Z) : list Z
+ := to_bytes' (freeze weight n (Z.ones bitwidth) m_enc f).
+
+ Definition from_bytesmod (f : list Z) : list Z
+ := from_bytes f.
+
+ Lemma eval_to_bytesmod
+ : forall (f : list Z)
+ (Hf : length f = n)
+ (Hf_bounded : 0 <= eval weight n f < 2 * m),
+ (eval bytes_weight bytes_n (to_bytesmod f)) = (eval weight n f) mod m
+ /\ to_bytesmod f = to_bytes' (Rows.partition weight n (Positional.eval weight n f mod m)).
+ Proof.
+ intros; subst m s; split.
+ { erewrite <- eval_freeze with (weight := weight) (n:=n) (mask:=Z.ones bitwidth) (m:=m_enc) ; auto using wprops.
+ erewrite <- BaseConversion.eval_convert_bases with (sw:=weight) (dw:=bytes_weight) (sn:=n) (dn:=bytes_n) (p:=freeze _ _ _ _ _)
+ by (cbv [bytes_n]; auto using wprops_bytes; distr_length; auto using wprops).
+ reflexivity. }
+ { cbv [to_bytesmod].
+ erewrite freeze_partitions by eauto using wprops.
+ reflexivity. }
+ Qed.
+
+ Lemma eval_from_bytesmod
+ : forall (f : list Z)
+ (Hf : length f = bytes_n),
+ eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f.
+ Proof.
+ cbv [from_bytesmod from_bytes]; intros.
+ rewrite BaseConversion.eval_convert_bases by eauto using wprops.
+ reflexivity.
+ Qed.
+End freeze_mod_ops.
+
+Section primitives.
+ Definition mulx (bitwidth : Z) := Eval cbv [Z.mul_split_at_bitwidth] in Z.mul_split_at_bitwidth bitwidth.
+ Definition addcarryx (bitwidth : Z) := Eval cbv [Z.add_with_get_carry Z.add_with_carry Z.get_carry] in Z.add_with_get_carry bitwidth.
+ Definition subborrowx (bitwidth : Z) := Eval cbv [Z.sub_with_get_borrow Z.sub_with_borrow Z.get_borrow Z.get_carry Z.add_with_carry] in Z.sub_with_get_borrow bitwidth.
+ Definition cmovznz (bitwidth : Z) (cond : Z) (z nz : Z)
+ := dlet t := (0 - Z.bneg (Z.bneg cond)) mod 2^bitwidth in Z.lor (Z.land t nz) (Z.land (Z.lnot_modulo t (2^bitwidth)) z).
+
+ Lemma cmovznz_correct bitwidth cond z nz
+ : 0 <= z < 2^bitwidth
+ -> 0 <= nz < 2^bitwidth
+ -> cmovznz bitwidth cond z nz = Z.zselect cond z nz.
+ Proof.
+ intros.
+ assert (0 < 2^bitwidth) by omega.
+ assert (0 <= bitwidth) by auto with zarith.
+ assert (0 < bitwidth -> 1 < 2^bitwidth) by auto with zarith.
+ pose proof Z.log2_lt_pow2_alt.
+ assert (bitwidth = 0 \/ 0 < bitwidth) by omega.
+ repeat first [ progress cbv [cmovznz Z.zselect Z.bneg Let_In Z.lnot_modulo]
+ | progress split_iff
+ | progress subst
+ | progress Z.ltb_to_lt
+ | progress destruct_head'_or
+ | congruence
+ | omega
+ | progress break_innermost_match_step
+ | progress break_innermost_match_hyps_step
+ | progress autorewrite with zsimplify_const in *
+ | progress pull_Zmod
+ | progress intros
+ | rewrite !Z.sub_1_r, <- Z.ones_equiv, <- ?Z.sub_1_r
+ | rewrite Z_mod_nz_opp_full by (Z.rewrite_mod_small; omega)
+ | rewrite (Z.land_comm (Z.ones _))
+ | rewrite Z.land_ones_low by auto with omega
+ | progress Z.rewrite_mod_small ].
+ Qed.
+End primitives.