diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Arithmetic.v')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 147 |
1 files changed, 135 insertions, 12 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index d263ca151..04f4bdd4d 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -2079,7 +2079,7 @@ Module Freeze. Definition freeze n mask (m p:list Z) : list Z := let '(p, carry) := Rows.sub weight n p m in - let '(r, carry) := Rows.conditional_add weight n mask carry p m in + let '(r, carry) := Rows.conditional_add weight n mask (-carry) p m in r. Lemma freezeZ m s c y : @@ -2087,13 +2087,14 @@ Module Freeze. 0 < c < s -> s <> 0 -> 0 <= y < 2*m -> - ((y - m) + (if (dec ((y - m) / s = 0)) then 0 else m)) mod s + ((y - m) + (if (dec (-((y - m) / s) = 0)) then 0 else m)) mod s = y mod m. Proof using Type. clear; intros. transitivity ((y - m) mod m); repeat first [ progress intros | progress subst + | rewrite Z.opp_eq_0_iff in * | break_innermost_match_step | progress autorewrite with zsimplify_fast | rewrite Z.div_small_iff in * by auto @@ -2117,7 +2118,7 @@ Module Freeze. (Hmlen : length m = n) : Positional.eval weight n (@freeze n mask m p) = (Positional.eval weight n p - Positional.eval weight n m + - (if dec ((Positional.eval weight n p - Positional.eval weight n m) / weight n = 0) then 0 else Positional.eval weight n m)) + (if dec (-((Positional.eval weight n p - Positional.eval weight n m) / weight n) = 0) then 0 else Positional.eval weight n m)) mod weight n. (*if dec ((Positional.eval weight n p - Positional.eval weight n m) / weight n = 0) then Positional.eval weight n p - Positional.eval weight n m @@ -2144,13 +2145,13 @@ Module Freeze. (n_nonzero:n<>0%nat) (Hc : 0 < Associational.eval c < weight n) (Hmask : List.map (Z.land mask) m = m) - modulus (Hm : Positional.eval weight n m = Z.pos modulus) - (Hp : 0 <= Positional.eval weight n p < 2*(Z.pos modulus)) - (Hsc : Z.pos modulus = weight n - Associational.eval c) + (modulus:=weight n - Associational.eval c) + (Hm : Positional.eval weight n m = modulus) + (Hp : 0 <= Positional.eval weight n p < 2*modulus) (Hplen : length p = n) (Hmlen : length m = n) : Positional.eval weight n (@freeze n mask m p) - = Positional.eval weight n p mod (Z.pos modulus). + = Positional.eval weight n p mod modulus. Proof using wprops. pose proof (@weight_positive weight wprops n). rewrite eval_freeze_eq by assumption. @@ -2162,16 +2163,17 @@ Module Freeze. (n_nonzero:n<>0%nat) (Hc : 0 < Associational.eval c < weight n) (Hmask : List.map (Z.land mask) m = m) - modulus (Hm : Positional.eval weight n m = Z.pos modulus) - (Hp : 0 <= Positional.eval weight n p < 2*(Z.pos modulus)) - (Hsc : Z.pos modulus = weight n - Associational.eval c) + (modulus:=weight n - Associational.eval c) + (Hm : Positional.eval weight n m = modulus) + (Hp : 0 <= Positional.eval weight n p < 2*modulus) (Hplen : length p = n) (Hmlen : length m = n) - : @freeze n mask m p = Rows.partition weight n (Positional.eval weight n p mod (Z.pos modulus)). + : @freeze n mask m p = Rows.partition weight n (Positional.eval weight n p mod modulus). Proof using wprops. pose proof (@weight_positive weight wprops n). pose proof (fun v => Z.mod_pos_bound v (weight n) ltac:(lia)). - pose proof (Z.mod_pos_bound (Positional.eval weight n p) (Z.pos modulus) ltac:(lia)). + pose proof (Z.mod_pos_bound (Positional.eval weight n p) modulus ltac:(lia)). + subst modulus. erewrite <- eval_freeze by eassumption. cbv [freeze]; eta_expand. rewrite Rows.conditional_add_partitions by (auto; rewrite Rows.sub_partitions; auto; distr_length). @@ -2182,3 +2184,124 @@ Module Freeze. Qed. End Freeze. End Freeze. +Hint Rewrite Freeze.length_freeze : distr_length. + +Section freeze_mod_ops. + Import Positional. + Import Freeze. + Local Coercion Z.of_nat : nat >-> Z. + Local Coercion QArith_base.inject_Z : Z >-> Q. + (* Design constraints: + - inputs must be [Z] (b/c reification does not support Q) + - internal structure must not match on the arguments (b/c reification does not support [positive]) *) + Context (limbwidth_num limbwidth_den : Z) + (limbwidth_good : 0 < limbwidth_den <= limbwidth_num) + (s : Z) + (c : list (Z*Z)) + (n : nat) + (bitwidth : Z) + (m_enc : list Z) + (m_nz:s - Associational.eval c <> 0) (s_nz:s <> 0) + (Hn_nz : n <> 0%nat). + Local Notation bytes_weight := (@weight 8 1). + Local Notation weight := (@weight limbwidth_num limbwidth_den). + Let m := (s - Associational.eval c). + + Context (Hs : s = weight n). + Context (c_small : 0 < Associational.eval c < weight n) + (m_enc_bounded : List.map (BinInt.Z.land (Z.ones bitwidth)) m_enc = m_enc) + (m_enc_correct : Positional.eval weight n m_enc = m) + (Hm_enc_len : length m_enc = n). + + Definition wprops_bytes := (@wprops 8 1 ltac:(lia)). + Local Notation wprops := (@wprops limbwidth_num limbwidth_den limbwidth_good). + + Local Hint Immediate (weight_0 wprops). + Local Hint Immediate (weight_positive wprops). + Local Hint Immediate (weight_multiples wprops). + Local Hint Immediate (weight_divides wprops). + Local Hint Immediate (weight_0 wprops_bytes). + Local Hint Immediate (weight_positive wprops_bytes). + Local Hint Immediate (weight_multiples wprops_bytes). + Local Hint Immediate (weight_divides wprops_bytes). + Local Hint Resolve Z.positive_is_nonzero Z.lt_gt. + + Definition bytes_n := (1 + (Z.to_nat (Z.log2_up (weight n) / 8)))%nat. + + Definition to_bytes' (v : list Z) + := BaseConversion.convert_bases weight bytes_weight n bytes_n v. + + Definition from_bytes (v : list Z) + := BaseConversion.convert_bases bytes_weight weight bytes_n n v. + + Definition to_bytesmod (f : list Z) : list Z + := to_bytes' (freeze weight n (Z.ones bitwidth) m_enc f). + + Definition from_bytesmod (f : list Z) : list Z + := from_bytes f. + + Lemma eval_to_bytesmod + : forall (f : list Z) + (Hf : length f = n) + (Hf_bounded : 0 <= eval weight n f < 2 * m), + (eval bytes_weight bytes_n (to_bytesmod f)) = (eval weight n f) mod m + /\ to_bytesmod f = to_bytes' (Rows.partition weight n (Positional.eval weight n f mod m)). + Proof. + intros; subst m s; split. + { erewrite <- eval_freeze with (weight := weight) (n:=n) (mask:=Z.ones bitwidth) (m:=m_enc) ; auto using wprops. + erewrite <- BaseConversion.eval_convert_bases with (sw:=weight) (dw:=bytes_weight) (sn:=n) (dn:=bytes_n) (p:=freeze _ _ _ _ _) + by (cbv [bytes_n]; auto using wprops_bytes; distr_length; auto using wprops). + reflexivity. } + { cbv [to_bytesmod]. + erewrite freeze_partitions by eauto using wprops. + reflexivity. } + Qed. + + Lemma eval_from_bytesmod + : forall (f : list Z) + (Hf : length f = bytes_n), + eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f. + Proof. + cbv [from_bytesmod from_bytes]; intros. + rewrite BaseConversion.eval_convert_bases by eauto using wprops. + reflexivity. + Qed. +End freeze_mod_ops. + +Section primitives. + Definition mulx (bitwidth : Z) := Eval cbv [Z.mul_split_at_bitwidth] in Z.mul_split_at_bitwidth bitwidth. + Definition addcarryx (bitwidth : Z) := Eval cbv [Z.add_with_get_carry Z.add_with_carry Z.get_carry] in Z.add_with_get_carry bitwidth. + Definition subborrowx (bitwidth : Z) := Eval cbv [Z.sub_with_get_borrow Z.sub_with_borrow Z.get_borrow Z.get_carry Z.add_with_carry] in Z.sub_with_get_borrow bitwidth. + Definition cmovznz (bitwidth : Z) (cond : Z) (z nz : Z) + := dlet t := (0 - Z.bneg (Z.bneg cond)) mod 2^bitwidth in Z.lor (Z.land t nz) (Z.land (Z.lnot_modulo t (2^bitwidth)) z). + + Lemma cmovznz_correct bitwidth cond z nz + : 0 <= z < 2^bitwidth + -> 0 <= nz < 2^bitwidth + -> cmovznz bitwidth cond z nz = Z.zselect cond z nz. + Proof. + intros. + assert (0 < 2^bitwidth) by omega. + assert (0 <= bitwidth) by auto with zarith. + assert (0 < bitwidth -> 1 < 2^bitwidth) by auto with zarith. + pose proof Z.log2_lt_pow2_alt. + assert (bitwidth = 0 \/ 0 < bitwidth) by omega. + repeat first [ progress cbv [cmovznz Z.zselect Z.bneg Let_In Z.lnot_modulo] + | progress split_iff + | progress subst + | progress Z.ltb_to_lt + | progress destruct_head'_or + | congruence + | omega + | progress break_innermost_match_step + | progress break_innermost_match_hyps_step + | progress autorewrite with zsimplify_const in * + | progress pull_Zmod + | progress intros + | rewrite !Z.sub_1_r, <- Z.ones_equiv, <- ?Z.sub_1_r + | rewrite Z_mod_nz_opp_full by (Z.rewrite_mod_small; omega) + | rewrite (Z.land_comm (Z.ones _)) + | rewrite Z.land_ones_low by auto with omega + | progress Z.rewrite_mod_small ]. + Qed. +End primitives. |