aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-09 17:17:08 +0200
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-30 06:14:51 -0400
commiteed8884a1b60e644af820ff8cacdfc6c0670f327 (patch)
treec495d6c5f1cfd70a239888476a171dd0e2a9003e /src/Experiments/SimplyTypedArithmetic.v
parentc1e78a56f5f86f27257c8f79766ceaa9dab18faa (diff)
First stab at generating code for saturated solinas modular
multiplication (currently produces way too many expressions because 1*x and -1*x are not simplified for two-output mul)
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v415
1 files changed, 353 insertions, 62 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index a6eda9643..760a5146c 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -646,7 +646,7 @@ Module Saturated.
[(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)].
Definition sat_mul s (p q : list (Z * Z)) : list (Z * Z) :=
- flat_map (fun t => flat_map (sat_multerm s t) q) p.
+ flat_map (fun t => flat_map (fun t' => sat_multerm s t t') q) p.
Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0):
Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * Associational.eval q.
@@ -978,14 +978,6 @@ Module Columns.
(from_associational n p).
Proof. reflexivity. Qed.
End FromAssociational.
-
- Section mul.
- Definition mul s n m (p q : list Z) : list Z :=
- let p_a := Positional.to_associational weight n p in
- let q_a := Positional.to_associational weight n q in
- let pq_a := Associational.sat_mul s p_a q_a in
- fst (flatten (from_associational m pq_a)).
- End mul.
End Columns.
End Columns.
@@ -1536,6 +1528,28 @@ Module Rows.
break_match; omega. }
Qed.
+ Lemma partition_step n x :
+ partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n].
+ Proof.
+ cbv [partition]. rewrite seq_snoc.
+ autorewrite with natsimplify push_map. reflexivity.
+ Qed.
+
+ Lemma length_partition n x : length (partition n x) = n.
+ Proof. cbv [partition]; distr_length. Qed.
+ Hint Rewrite length_partition : distr_length.
+
+ Lemma eval_partition n x :
+ Positional.eval weight n (partition n x) = x mod (weight n).
+ Proof.
+ induction n; intros.
+ { cbn. rewrite (weight_0); auto with zarith. }
+ { rewrite (Z.div_mod (x mod weight (S n)) (weight n)) by auto.
+ rewrite <-Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto).
+ rewrite partition_step, Positional.eval_snoc with (n:=n) by distr_length.
+ omega. }
+ Qed.
+
Lemma flatten_partitions' inp n :
(forall row, In row inp -> length row = n) ->
fst (flatten n inp) = partition n (eval n inp).
@@ -1554,25 +1568,39 @@ Module Rows.
Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval.
+ Definition mul base n m (p q : list Z) :=
+ let p_a := Positional.to_associational weight n p in
+ let q_a := Positional.to_associational weight n q in
+ let pq_a := Associational.sat_mul base p_a q_a in
+ flatten m (from_associational m pq_a).
+
+ Definition mulmod base s c n m (p q : list Z) :=
+ let p_a := Positional.to_associational weight n p in
+ let q_a := Positional.to_associational weight n q in
+ let pq_a := Associational.sat_mul base p_a q_a in
+ let lo_hi := Associational.split s pq_a in
+ let r_a := fst lo_hi ++ (Associational.sat_mul base c (snd lo_hi)) in (* reduce, but using sat_mul *)
+ flatten m (from_associational m r_a).
+
+ Hint Rewrite Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval.
+ Hint Rewrite eval_from_associational using solve [auto] : push_eval.
+ Hint Rewrite eval_partition using solve [auto] : push_eval.
+ Ltac solver :=
+ intros; cbv [sub add mul mulmod];
+ rewrite ?flatten_partitions' by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
+ rewrite ?flatten_div by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
+ autorewrite with push_eval; ring_simplify_subterms;
+ try reflexivity.
+
Lemma add_partitions n p q :
n <> 0%nat -> length p = n -> length q = n ->
fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q).
- Proof.
- intros; cbv [add].
- rewrite flatten_partitions' by (intros; In_cases; subst; distr_length).
- autorewrite with push_eval. ring_simplify_subterms.
- reflexivity.
- Qed.
+ Proof. solver. Qed.
Lemma add_div n p q :
n <> 0%nat -> length p = n -> length q = n ->
snd (add n p q) = (Positional.eval weight n p + Positional.eval weight n q) / weight n.
- Proof.
- intros; cbv [add].
- rewrite flatten_div by (intros; In_cases; subst; distr_length).
- autorewrite with push_eval. ring_simplify_subterms.
- reflexivity.
- Qed.
+ Proof. solver. Qed.
Lemma eval_map_opp q :
forall n, length q = n ->
@@ -1590,21 +1618,29 @@ Module Rows.
Lemma sub_partitions n p q :
n <> 0%nat -> length p = n -> length q = n ->
fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q).
- Proof.
- intros; cbv [sub].
- rewrite flatten_partitions' by (intros; In_cases; subst; distr_length).
- autorewrite with push_eval. ring_simplify_subterms.
- reflexivity.
- Qed.
+ Proof. solver. Qed.
Lemma sub_div n p q :
n <> 0%nat -> length p = n -> length q = n ->
snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n.
+ Proof. solver. Qed.
+
+ Lemma mul_partitions base n m p q :
+ base <> 0 -> n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n ->
+ fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q).
+ Proof. solver. Qed.
+
+ Lemma eval_mulmod base s c n m p q :
+ base <> 0 -> s <> 0 -> s - Associational.eval c <> 0 ->
+ n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n ->
+ (Positional.eval weight m (fst (mulmod base s c n m p q))
+ + weight m * (snd (mulmod base s c n m p q))) mod (s - Associational.eval c)
+ = (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c).
Proof.
- intros; cbv [sub].
- rewrite flatten_div by (intros; In_cases; subst; distr_length).
- autorewrite with push_eval. ring_simplify_subterms.
- reflexivity.
+ solver.
+ rewrite <-Z.div_mod'' by auto.
+ rewrite <-Associational.reduction_rule by omega.
+ autorewrite with push_eval; auto.
Qed.
End Ops.
End Rows.
@@ -6893,7 +6929,6 @@ Proof.
try (rewrite <- Z.log2_up_le_pow2_full in *; omega).
Qed.
-
(** XXX TODO: Translate Jade's python script *)
Section rcarry_mul.
Context (n : nat)
@@ -7418,7 +7453,6 @@ Goal False.
Abort.
*)
-
Time Compute
(Pipeline.BoundsPipeline_full
true (relax_zrange_gen [64; 128])
@@ -7592,11 +7626,9 @@ Module X25519_32.
Import PrintingNotations.
Print base_25p5_carry_mul.
(*
-base_25p5_carry_mul =
+base_25p5_carry_mul =
fun var : type -> Type =>
-(λ x : var
- (type.list (type.type_primitive type.Z) *
- type.list (type.type_primitive type.Z))%ctype,
+(λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype,
expr_let x0 := x₁ [[0]] *₆₄ x₂ [[0]] +₆₄
((uint64)(x₁ [[1]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1) +₆₄
(x₁ [[2]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄
@@ -7605,8 +7637,7 @@ fun var : type -> Type =>
((uint64)(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[5]])) << 1) +₆₄
(x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[4]])) +₆₄
((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[3]])) << 1) +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[2]])) +₆₄
- (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[1]])) << 1))))))))) in
+ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[2]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[1]])) << 1))))))))) in
expr_let x1 := (uint64)(x0 >> 26) +₆₄
(x₁ [[0]] *₆₄ x₂ [[1]] +₆₄
(x₁ [[1]] *₆₄ x₂ [[0]] +₆₄
@@ -7616,8 +7647,7 @@ fun var : type -> Type =>
(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄
(x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[5]])) +₆₄
(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[4]])) +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[3]])) +₆₄
- x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[2]]))))))))))) in
+ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[3]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[2]]))))))))))) in
expr_let x2 := (uint64)(x1 >> 25) +₆₄
(x₁ [[0]] *₆₄ x₂ [[2]] +₆₄
((uint64)(x₁ [[1]] *₆₄ x₂ [[1]] << 1) +₆₄
@@ -7627,8 +7657,7 @@ fun var : type -> Type =>
((uint64)(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1) +₆₄
(x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄
((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[5]])) << 1) +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[4]])) +₆₄
- (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[3]])) << 1)))))))))) in
+ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[4]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[3]])) << 1)))))))))) in
expr_let x3 := (uint64)(x2 >> 26) +₆₄
(x₁ [[0]] *₆₄ x₂ [[3]] +₆₄
(x₁ [[1]] *₆₄ x₂ [[2]] +₆₄
@@ -7638,8 +7667,7 @@ fun var : type -> Type =>
(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄
(x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[7]])) +₆₄
(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[5]])) +₆₄
- x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[4]]))))))))))) in
+ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[5]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[4]]))))))))))) in
expr_let x4 := (uint64)(x3 >> 25) +₆₄
(x₁ [[0]] *₆₄ x₂ [[4]] +₆₄
((uint64)(x₁ [[1]] *₆₄ x₂ [[3]] << 1) +₆₄
@@ -7649,8 +7677,7 @@ fun var : type -> Type =>
((uint64)(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1) +₆₄
(x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄
((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1) +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄
- (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[5]])) << 1)))))))))) in
+ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[5]])) << 1)))))))))) in
expr_let x5 := (uint64)(x4 >> 26) +₆₄
(x₁ [[0]] *₆₄ x₂ [[5]] +₆₄
(x₁ [[1]] *₆₄ x₂ [[4]] +₆₄
@@ -7660,8 +7687,7 @@ fun var : type -> Type =>
(x₁ [[5]] *₆₄ x₂ [[0]] +₆₄
(x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[9]])) +₆₄
(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[7]])) +₆₄
- x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[6]]))))))))))) in
+ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[7]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[6]]))))))))))) in
expr_let x6 := (uint64)(x5 >> 25) +₆₄
(x₁ [[0]] *₆₄ x₂ [[6]] +₆₄
((uint64)(x₁ [[1]] *₆₄ x₂ [[5]] << 1) +₆₄
@@ -7671,8 +7697,7 @@ fun var : type -> Type =>
((uint64)(x₁ [[5]] *₆₄ x₂ [[1]] << 1) +₆₄
(x₁ [[6]] *₆₄ x₂ [[0]] +₆₄
((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1) +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄
- (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1)))))))))) in
+ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1)))))))))) in
expr_let x7 := (uint64)(x6 >> 26) +₆₄
(x₁ [[0]] *₆₄ x₂ [[7]] +₆₄
(x₁ [[1]] *₆₄ x₂ [[6]] +₆₄
@@ -7681,9 +7706,7 @@ fun var : type -> Type =>
(x₁ [[4]] *₆₄ x₂ [[3]] +₆₄
(x₁ [[5]] *₆₄ x₂ [[2]] +₆₄
(x₁ [[6]] *₆₄ x₂ [[1]] +₆₄
- (x₁ [[7]] *₆₄ x₂ [[0]] +₆₄
- (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[9]])) +₆₄
- x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[8]]))))))))))) in
+ (x₁ [[7]] *₆₄ x₂ [[0]] +₆₄ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[9]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[8]]))))))))))) in
expr_let x8 := (uint64)(x7 >> 25) +₆₄
(x₁ [[0]] *₆₄ x₂ [[8]] +₆₄
((uint64)(x₁ [[1]] *₆₄ x₂ [[7]] << 1) +₆₄
@@ -7693,8 +7716,7 @@ fun var : type -> Type =>
((uint64)(x₁ [[5]] *₆₄ x₂ [[3]] << 1) +₆₄
(x₁ [[6]] *₆₄ x₂ [[2]] +₆₄
((uint64)(x₁ [[7]] *₆₄ x₂ [[1]] << 1) +₆₄
- (x₁ [[8]] *₆₄ x₂ [[0]] +₆₄
- (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1)))))))))) in
+ (x₁ [[8]] *₆₄ x₂ [[0]] +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1)))))))))) in
expr_let x9 := (uint64)(x8 >> 26) +₆₄
(x₁ [[0]] *₆₄ x₂ [[9]] +₆₄
(x₁ [[1]] *₆₄ x₂ [[8]] +₆₄
@@ -7702,9 +7724,7 @@ fun var : type -> Type =>
(x₁ [[3]] *₆₄ x₂ [[6]] +₆₄
(x₁ [[4]] *₆₄ x₂ [[5]] +₆₄
(x₁ [[5]] *₆₄ x₂ [[4]] +₆₄
- (x₁ [[6]] *₆₄ x₂ [[3]] +₆₄
- (x₁ [[7]] *₆₄ x₂ [[2]] +₆₄
- (x₁ [[8]] *₆₄ x₂ [[1]] +₆₄ x₁ [[9]] *₆₄ x₂ [[0]]))))))))) in
+ (x₁ [[6]] *₆₄ x₂ [[3]] +₆₄ (x₁ [[7]] *₆₄ x₂ [[2]] +₆₄ (x₁ [[8]] *₆₄ x₂ [[1]] +₆₄ x₁ [[9]] *₆₄ x₂ [[0]]))))))))) in
expr_let x10 := ((uint32)(x0) & 67108863) +₆₄ 19 *₆₄ (uint64)(x9 >> 25) in
expr_let x11 := (uint32)(x10 >> 26) +₃₂ ((uint32)(x1) & 33554431) in
((uint32)(x10) & 67108863)
@@ -7714,9 +7734,7 @@ fun var : type -> Type =>
:: ((uint32)(x4) & 67108863)
:: ((uint32)(x5) & 33554431)
:: ((uint32)(x6) & 67108863)
- :: ((uint32)(x7) & 33554431)
- :: ((uint32)(x8) & 67108863)
- :: ((uint32)(x9) & 33554431) :: [])%expr
+ :: ((uint32)(x7) & 33554431) :: ((uint32)(x8) & 67108863) :: ((uint32)(x9) & 33554431) :: [])%expr
: Expr
(type.uncurry
(type.list (type.type_primitive type.Z) ->
@@ -8357,6 +8375,279 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.
End Barrett256.
+Module SaturatedSolinas.
+ Section MulMod.
+ Context (s : Z) (c : list (Z * Z))
+ (s_nz : s <> 0) (modulus_nz : s - Associational.eval c <> 0).
+ Context (log2base : Z) (log2base_pos : 0 < log2base)
+ (n m : nat) (n_nz : n <> 0%nat) (m_nz : m <> 0%nat).
+
+ Let weight := weight log2base 1.
+ Let props : @weight_properties weight := wprops log2base 1 ltac:(omega).
+ Local Lemma base_nz : 2 ^ log2base <> 0. Proof. auto with zarith. Qed.
+
+ Derive mulmod
+ SuchThat (forall (f g : list Z)
+ (Hf : length f = n)
+ (Hg : length g = n),
+ (eval weight m (fst (mulmod f g)) + weight m * (snd (mulmod f g))) mod (s - Associational.eval c)
+ = (eval weight n f * eval weight n g) mod (s - Associational.eval c))
+ As eval_mulmod.
+ Proof.
+ intros.
+ rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) (m:=m) by auto using base_nz.
+ eapply f_equal2; [|trivial].
+ (* expand_lists (). *) (* uncommenting this line removes some unused multiplications but also inlines a bunch of carry stuff at the end *)
+ subst mulmod. reflexivity.
+ Qed.
+ End MulMod.
+
+ Derive mulmod_gen
+ SuchThat (forall (log2base s : Z) (c : list (Z * Z)) (n m : nat)
+ (f g : list Z),
+ Interp (t:=type.reify_type_of mulmod)
+ mulmod_gen s c log2base n m f g
+ = mulmod s c log2base n m f g)
+ As mulmod_gen_correct.
+ Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed.
+ Module Export ReifyHints.
+ Global Hint Extern 1 (_ = mulmod _ _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache.
+ End ReifyHints.
+
+ Section rmulmod.
+ Context (n m : nat)
+ (s : Z)
+ (c : list (Z * Z))
+ (machine_wordsize : Z).
+
+ Definition relax_zrange_of_machine_wordsize
+ := relax_zrange_gen [1; machine_wordsize]%Z.
+
+ Let relax_zrange := relax_zrange_of_machine_wordsize.
+ Let bound := Some r[0 ~> (2^machine_wordsize - 1)]%zrange.
+ Let boundsn : list (ZRange.type.option.interp type.Z)
+ := List.repeat bound n.
+ Let boundsm : list (ZRange.type.option.interp type.Z)
+ := List.repeat bound m.
+
+ Definition check_args {T} (res : Pipeline.ErrorT T)
+ : Pipeline.ErrorT T
+ := if (negb (0 <? s - Associational.eval c))%Z
+ then Pipeline.Error (Pipeline.Value_not_lt "s - Associational.eval c ≤ 0" 0 (s - Associational.eval c))
+ else if (s =? 0)%Z
+ then Pipeline.Error (Pipeline.Values_not_provably_distinct "s ≠ 0" s 0)
+ else if (n =? 0)%nat
+ then Pipeline.Error (Pipeline.Values_not_provably_distinct "n ≠ 0" n 0%nat)
+ else if (negb (0 <? machine_wordsize))
+ then Pipeline.Error (Pipeline.Value_not_lt "0 < machine_wordsize" 0 machine_wordsize)
+ else res.
+
+ Notation BoundsPipeline rop in_bounds out_bounds
+ := (Pipeline.BoundsPipeline
+ (*false*) true
+ relax_zrange
+ rop%Expr in_bounds out_bounds).
+
+ Notation BoundsPipeline_correct in_bounds out_bounds op
+ := (fun rv (rop : Expr (type.reify_type_of op)) Hrop
+ => @Pipeline.BoundsPipeline_correct_trans
+ (*false*) true
+ relax_zrange
+ (relax_zrange_gen_good _)
+ _
+ rop
+ in_bounds
+ out_bounds
+ op
+ Hrop rv)
+ (only parsing).
+
+ Definition rmulmod_correct
+ := BoundsPipeline_correct
+ (Some boundsn, Some boundsn)
+ (Some boundsm, bound)
+ (mulmod s c machine_wordsize n m).
+
+ Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _).
+ Definition rmulmod_correctT rv : Prop
+ := type_of_strip_3arrow (@rmulmod_correct rv).
+ End rmulmod.
+End SaturatedSolinas.
+
+Ltac solve_rmulmod := solve_rop SaturatedSolinas.rmulmod_correct.
+Ltac solve_rmulmod_nocache := solve_rop_nocache SaturatedSolinas.rmulmod_correct.
+
+Module P192_64.
+ Definition n := 3%nat.
+ Definition m := 6%nat.
+ Definition s := 2^192.
+ Definition c := [(2^64, 1); (1,1)].
+ Definition machine_wordsize := 64.
+
+ Derive mulmod
+ SuchThat (SaturatedSolinas.rmulmod_correctT n m s c machine_wordsize mulmod)
+ As mulmod_correct.
+ Proof. Time solve_rmulmod machine_wordsize. Time Qed.
+
+ Print mulmod.
+ Import PrintingNotations.
+ Open Scope expr_scope.
+ Set Printing Width 100000.
+ Set Printing Depth 100000.
+
+ Local Notation "'mul64' '(' x ',' y ')'" :=
+ (Z.cast2 (uint64, _)%core @@ (Z.mul_split_concrete 18446744073709551616 @@ (x , y)))%expr (at level 50) : expr_scope.
+ Local Notation "'add64' '(' x ',' y ')'" :=
+ (Z.cast2 (uint64, bool)%core @@ (Z.add_get_carry_concrete 18446744073709551616 @@ (x , y)))%expr (at level 50) : expr_scope.
+ Local Notation "'adc64' '(' c ',' x ',' y ')'" :=
+ (Z.cast2 (uint64, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (c, x , y)))%expr (at level 50) : expr_scope.
+
+ Print mulmod.
+ (*
+ expr_let x0 := mul64 ((uint64)(x₁[[2]]), (uint64)(x₂[[2]])) in
+ expr_let x1 := mul64 ((uint64)(x₁[[2]]), (uint64)(x₂[[1]])) in
+ expr_let x2 := mul64 ((uint64)(x₁[[2]]), (uint64)(x₂[[0]])) in
+ expr_let x3 := mul64 ((uint64)(x₁[[1]]), (uint64)(x₂[[2]])) in
+ expr_let x4 := mul64 ((uint64)(x₁[[1]]), (uint64)(x₂[[1]])) in
+ expr_let x5 := mul64 ((uint64)(x₁[[1]]), (uint64)(x₂[[0]])) in
+ expr_let x6 := mul64 ((uint64)(x₁[[0]]), (uint64)(x₂[[2]])) in
+ expr_let x7 := mul64 ((uint64)(x₁[[0]]), (uint64)(x₂[[1]])) in
+ expr_let x8 := mul64 ((uint64)(x₁[[0]]), (uint64)(x₂[[0]])) in
+ expr_let _ := mul64 (1, x0₂) in
+ expr_let _ := mul64 (1, x0₁) in
+ expr_let _ := mul64 (1, x1₂) in
+ expr_let _ := mul64 (1, x1₁) in
+ expr_let _ := mul64 (1, x2₂) in
+ expr_let _ := mul64 (1, x3₂) in
+ expr_let _ := mul64 (1, x3₁) in
+ expr_let _ := mul64 (1, x4₂) in
+ expr_let _ := mul64 (1, x6₂) in
+ expr_let _ := mul64 (1, x0₂) in
+ expr_let _ := mul64 (1, x0₁) in
+ expr_let _ := mul64 (1, x1₂) in
+ expr_let _ := mul64 (1, x1₁) in
+ expr_let _ := mul64 (1, x2₂) in
+ expr_let _ := mul64 (1, x3₂) in
+ expr_let _ := mul64 (1, x3₁) in
+ expr_let _ := mul64 (1, x4₂) in
+ expr_let _ := mul64 (1, x6₂) in
+ expr_let x27 := mul64 (1, x0₂) in
+ expr_let x28 := mul64 (1, x0₁) in
+ expr_let x29 := mul64 (1, x1₂) in
+ expr_let x30 := mul64 (1, x1₁) in
+ expr_let x31 := mul64 (1, x2₂) in
+ expr_let x32 := mul64 (1, x3₂) in
+ expr_let x33 := mul64 (1, x3₁) in
+ expr_let x34 := mul64 (1, x4₂) in
+ expr_let x35 := mul64 (1, x6₂) in
+ expr_let x36 := mul64 (1, x0₂) in
+ expr_let x37 := mul64 (1, x0₁) in
+ expr_let x38 := mul64 (1, x1₂) in
+ expr_let x39 := mul64 (1, x1₁) in
+ expr_let x40 := mul64 (1, x2₂) in
+ expr_let x41 := mul64 (1, x3₂) in
+ expr_let x42 := mul64 (1, x3₁) in
+ expr_let x43 := mul64 (1, x4₂) in
+ expr_let x44 := mul64 (1, x6₂) in
+ expr_let x45 := add64 (x35₁, x8₁) in
+ expr_let x46 := adc64 (x45₂, x7₁, x8₂) in
+ expr_let x47 := adc64 (x46₂, x6₁, x7₂) in
+ expr_let x48 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x47₂, x38₂, x41₂)) in
+ expr_let x49 := add64 (x34₁, x45₁) in
+ expr_let x50 := adc64 (x49₂, x5₁, x46₁) in
+ expr_let x51 := adc64 (x50₂, x5₂, x47₁) in
+ expr_let x52 := adc64 (x51₂, x37₂, x48₁) in
+ expr_let x53 := add64 (x33₁, x49₁) in
+ expr_let x54 := adc64 (x53₂, x44₁, x50₁) in
+ expr_let x55 := adc64 (x54₂, x4₁, x51₁) in
+ expr_let x56 := adc64 (x55₂, x36₁, x52₁) in
+ expr_let x57 := add64 (x31₁, x53₁) in
+ expr_let x58 := adc64 (x57₂, x43₁, x54₁) in
+ expr_let x59 := adc64 (x58₂, x2₁, x55₁) in
+ expr_let x60 := adc64 (x59₂, x27₂, x56₁) in
+ expr_let x61 := add64 (x30₁, x57₁) in
+ expr_let x62 := adc64 (x61₂, x42₁, x58₁) in
+ expr_let x63 := adc64 (x62₂, x44₂, x59₁) in
+ expr_let x64 := add64 (x40₁, x62₁) in
+ expr_let x65 := adc64 (x64₂, x43₂, x63₁) in
+ expr_let x66 := add64 (x39₁, x64₁) in
+ expr_let x67 := adc64 (x66₂, x42₂, x65₁) in
+ expr_let x68 := add64 (x35₂, x66₁) in
+ expr_let x69 := adc64 (x68₂, x41₁, x67₁) in
+ expr_let x70 := add64 (x34₂, x68₁) in
+ expr_let x71 := adc64 (x70₂, x40₂, x69₁) in
+ expr_let x72 := add64 (x33₂, x70₁) in
+ expr_let x73 := adc64 (x72₂, x39₂, x71₁) in
+ expr_let x74 := add64 (x32₁, x72₁) in
+ expr_let x75 := adc64 (x74₂, x38₁, x73₁) in
+ expr_let x76 := add64 (x31₂, x74₁) in
+ expr_let x77 := adc64 (x76₂, x37₁, x75₁) in
+ expr_let x78 := add64 (x30₂, x76₁) in
+ expr_let x79 := adc64 (x78₂, x32₂, x77₁) in
+ expr_let x80 := add64 (x29₁, x78₁) in
+ expr_let x81 := adc64 (x80₂, x29₂, x79₁) in
+ expr_let x82 := add64 (x28₁, x80₁) in
+ expr_let x83 := adc64 (x82₂, x28₂, x81₁) in
+ expr_let x84 := add64 (x27₁, x83₁) in
+ (x61₁ :: x82₁ :: x84₁ :: x60₁ :: x36₂ :: 0 :: [], 0)
+ : Expr (type.uncurry (type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) * type.type_primitive type.Z))
+*)
+
+End P192_64.
+
+(* TODO: suuuuper slow--need to replace (Z.mul_split _ 1 x) with (x, 0) to reduce term size
+Module P192_32.
+ Definition n := 6%nat.
+ Definition m := 12%nat.
+ Definition s := 2^192.
+ Definition c := [(2^64, 1); (1,1)].
+ Definition machine_wordsize := 32.
+
+ Derive mulmod
+ SuchThat (SaturatedSolinas.rmulmod_correctT n m s c machine_wordsize mulmod)
+ As mulmod_correct.
+ Proof. Time solve_rmulmod machine_wordsize. Time Qed.
+ (* 165.87s for solve, 163.404 for Qed *)
+
+ Import PrintingNotations.
+ Open Scope expr_scope.
+ Set Printing Width 100000.
+ Set Printing Depth 100000.
+
+ Eval compute in (2^32).
+ Local Notation "'mul32' '(' x ',' y ')'" :=
+ (Z.cast2 (uint32, _)%core @@ (Z.mul_split_concrete 4294967296 @@ (x , y)))%expr (at level 50) : expr_scope.
+ Local Notation "'add32' '(' x ',' y ')'" :=
+ (Z.cast2 (uint32, bool)%core @@ (Z.add_get_carry_concrete 4294967296 @@ (x , y)))%expr (at level 50) : expr_scope.
+ Local Notation "'adc32' '(' c ',' x ',' y ')'" :=
+ (Z.cast2 (uint32, bool)%core @@ (Z.add_with_get_carry_concrete 4294967296 @@ (c, x , y)))%expr (at level 50) : expr_scope.
+
+ Print mulmod.
+
+End P192_32.
+
+Module P256_32.
+ Definition n := 8%nat.
+ Definition m := 16%nat.
+ Definition s := 2^256.
+ Definition c := [(2^224, 1); (2^192, -1); (2^96, -1); (1,1)].
+ Definition machine_wordsize := 32.
+
+ Derive mulmod
+ SuchThat (SaturatedSolinas.rmulmod_correctT n m s c machine_wordsize mulmod)
+ As mulmod_correct.
+ Proof. Time solve_rmulmod machine_wordsize. Time Qed.
+
+ Import PrintingNotations.
+ Open Scope expr_scope.
+ Set Printing Width 100000.
+
+ Print mulmod.
+
+
+End P256_32.
+*)
+
Module MontgomeryReduction.
Section MontRed'.
Context (N R N' R' : Z).