aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
m---------bbv0
m---------etc/coq-scripts0
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v1372
3 files changed, 1234 insertions, 138 deletions
diff --git a/bbv b/bbv
-Subproject 4dcd180f0605c6aa401097685593433d6806201
+Subproject 38fe6a40ea26ce738637e340d7f8e9f0eb85fbc
diff --git a/etc/coq-scripts b/etc/coq-scripts
-Subproject 7bd683da1fac8b5eb42de1e44a3274db4fd0ce4
+Subproject ef2d7f9e7e9530f05fb3b2362db787a2885c59b
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 814161335..76a885d70 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -126,6 +126,9 @@ Module Associational.
push; [|rewrite IHp]; reflexivity.
Qed.
+ Lemma eval_rev p : eval (rev p) = eval p.
+ Proof. induction p; cbn [rev]; push; lia. Qed.
+
Section Carries.
Definition carryterm (w fw:Z) (t:Z * Z) :=
if (Z.eqb (fst t) w)
@@ -1758,6 +1761,7 @@ Module BaseConversion.
@Associational.eval_carry
@Associational.eval_mul
@Positional.eval_to_associational
+ Associational.eval_carryterm
@eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval.
Ltac push_eval := intros; autorewrite with push_eval; auto with zarith.
@@ -1767,14 +1771,30 @@ Module BaseConversion.
let p' := convert_bases n m p in
Positional.to_associational dw m p'.
+ (* TODO : move to Associational? *)
+ Section reorder.
+ Definition reordering_carry (w fw : Z) (p : list (Z * Z)) :=
+ fold_right (fun t acc =>
+ let r := Associational.carryterm w fw t in
+ if fst t =? w then acc ++ r else r ++ acc) nil p.
+
+ Lemma eval_reordering_carry w fw p (_:fw<>0):
+ Associational.eval (reordering_carry w fw p) = Associational.eval p.
+ Proof.
+ cbv [reordering_carry]. induction p; [reflexivity |].
+ autorewrite with push_fold_right. break_match; push_eval.
+ Qed.
+ End reorder.
+ Hint Rewrite eval_reordering_carry using solve [auto using Z.positive_is_nonzero] : push_eval.
+
(* carry at specified indices in dw, then use Rows.flatten to convert to Positional with sw *)
Definition from_associational idxs n (p : list (Z * Z)) : list Z :=
(* important not to use Positional.carry here; we don't want to accumulate yet *)
- let p' := fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) (Associational.bind_snd p) (rev idxs) in
+ let p' := fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) (Associational.bind_snd p) (rev idxs) in
fst (Rows.flatten sw n (Rows.from_associational sw n p')).
Lemma eval_carries p idxs :
- Associational.eval (fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) p idxs) =
+ Associational.eval (fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) p idxs) =
Associational.eval p.
Proof. apply fold_right_invariant; push_eval. Qed.
Hint Rewrite eval_carries: push_eval.
@@ -1821,7 +1841,7 @@ Module BaseConversion.
As from_associational_inlined_correct.
Proof.
intros.
- cbv beta iota delta [from_associational Associational.carry Associational.carryterm].
+ cbv beta iota delta [from_associational reordering_carry Associational.carryterm].
cbv beta iota delta [Let_In]. (* inlines all shifts/lands from carryterm *)
cbv beta iota delta [from_associational Rows.from_associational Columns.from_associational].
cbv beta iota delta [Let_In]. (* inlines the shifts from place *)
@@ -1908,9 +1928,6 @@ Module BaseConversion.
Z.rewrite_mod_small. reflexivity.
Qed.
- (* For some reason, this is a universe inconsistency if not factored out *)
- Lemma nout_nonzero : nout <> 0%nat. Proof. omega. Qed.
-
Derive widemul_inlined
SuchThat (forall a b,
0 <= a * b < 2^log2base * 2^log2base ->
@@ -1922,9 +1939,33 @@ Module BaseConversion.
cbv beta iota delta [widemul mul_converted].
rewrite <-to_associational_inlined_correct with (p:=[a]).
rewrite <-to_associational_inlined_correct with (p:=[b]).
- rewrite <-from_associational_inlined_correct by (apply nout_nonzero || assumption).
+ rewrite <-from_associational_inlined_correct.
subst widemul_inlined; reflexivity.
Qed.
+
+ Derive widemul_inlined_reverse
+ SuchThat (forall a b,
+ 0 <= a * b < 2^log2base * 2^log2base ->
+ widemul_inlined_reverse a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base])
+ As widemul_inlined_reverse_correct.
+ Proof.
+ intros.
+ rewrite <-widemul_inlined_correct by assumption.
+ cbv [widemul_inlined].
+ match goal with |- _ = from_associational_inlined sw dw ?idxs ?n ?p =>
+ transitivity (from_associational_inlined sw dw idxs n (rev p));
+ [ | transitivity (from_associational sw dw idxs n p); [ | reflexivity ] ](* reverse to make addc chains line up *)
+ end.
+ Focus 2. {
+ rewrite from_associational_inlined_correct by (subst nout; auto).
+ cbv [from_associational].
+ rewrite !Rows.flatten_partitions' by eauto using Rows.length_from_associational.
+ rewrite !Rows.eval_from_associational by (subst nout; auto).
+ f_equal.
+ rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto.
+ reflexivity. } Unfocus.
+ subst widemul_inlined_reverse; reflexivity.
+ Qed.
End widemul.
End BaseConversion.
@@ -8187,8 +8228,8 @@ Module Straightline.
Section interp.
Context {ident : type -> type -> Type} {interp_ident : forall s d, ident s d -> type.interp s -> type.interp d}.
+ Context {interp_cast : zrange -> Z -> Z}.
- Definition interp_cast (r : zrange) (x : Z) : Z := ident.cast ident.cast_outside_of_range r x.
Definition interp_cast2 (r : zrange * zrange) (x : Z * Z) : Z * Z :=
(interp_cast (fst r) (fst x), interp_cast (snd r) (snd x)).
@@ -8220,9 +8261,10 @@ Module Straightline.
End interp.
Section proofs.
- Local Notation straightline_interp := (expr.interp (ident:=default.ident) (interp_ident:=@ident.interp)).
+ Local Notation straightline_interp := (expr.interp (ident:=default.ident) (interp_ident:=@ident.interp) (interp_cast:=ident.cast (@ident.cast_outside_of_range))).
Local Notation uinterp := (Uncurried.expr.interp (@ident.interp)).
Local Notation uexpr := (@Uncurried.expr.expr ident type.interp).
+ Local Notation interp_scalar := (interp_scalar (interp_cast:=ident.cast (@ident.cast_outside_of_range))).
Inductive ok_scalar_ident : forall {s d}, ident.ident s d -> Prop :=
| ok_si_cast : forall r, ok_scalar_ident (ident.Z.cast r)
@@ -8273,11 +8315,11 @@ Module Straightline.
.
Lemma interp_cast_correct r (x : uexpr type.Z) :
- interp_cast r (uinterp x) = uinterp (AppIdent (ident.Z.cast r) x).
+ ident.cast ident.cast_outside_of_range r (uinterp x) = uinterp (AppIdent (ident.Z.cast r) x).
Proof. reflexivity. Qed.
Lemma interp_cast2_correct r (x : uexpr (type.prod type.Z type.Z)) :
- interp_cast2 r (uinterp x) = uinterp (AppIdent (ident.Z.cast2 r) x).
+ @interp_cast2 (ident.cast ident.cast_outside_of_range) r (uinterp x) = uinterp (AppIdent (ident.Z.cast2 r) x).
Proof. cbn; break_match; reflexivity. Qed.
Ltac invert H :=
@@ -8542,6 +8584,28 @@ Module PreFancy.
apply BinInt.Z.mul_div_le; omega. }
Qed.
+ Lemma wordmax_half_bits_le_wordmax : wordmax_half_bits <= wordmax.
+ Proof.
+ subst wordmax half_bits wordmax_half_bits.
+ apply Z.pow_le_mono_r; [lia|].
+ apply Z.div_le_upper_bound; lia.
+ Qed.
+
+ Lemma ones_half_bits : wordmax_half_bits - 1 = Z.ones half_bits.
+ Proof.
+ subst wordmax_half_bits. cbv [Z.ones].
+ rewrite Z.shiftl_mul_pow2, <-Z.sub_1_r by auto using half_bits_nonneg.
+ lia.
+ Qed.
+
+ Lemma wordmax_half_bits_squared : wordmax_half_bits * wordmax_half_bits = wordmax.
+ Proof.
+ subst wordmax half_bits wordmax_half_bits.
+ rewrite <-Z.pow_add_r by Z.zero_bounds.
+ rewrite Z.add_diag, Z.mul_div_eq by omega.
+ f_equal; lia.
+ Qed.
+
Section with_var.
Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) (consts : list Z).
Local Notation Z := (type.type_primitive type.Z).
@@ -8580,7 +8644,7 @@ Module PreFancy.
option (@scalar var Z) :=
match e in scalar t return option (@scalar var Z) with
| Cast r (Land n x) =>
- if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? 2^half_bits-1)
+ if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? wordmax_half_bits-1)
then Some x
else None
| _ => None
@@ -8764,6 +8828,9 @@ Module PreFancy.
End with_var.
Section interp.
+ Context {interp_cast : zrange -> Z -> Z}.
+ Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)).
+ Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)).
Local Notation low x := (Z.land x (wordmax_half_bits - 1)).
Local Notation high x := (x >> half_bits).
Local Notation shift x imm := ((x << imm) mod wordmax).
@@ -8797,6 +8864,9 @@ Module PreFancy.
Section proofs.
Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z)
(consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1).
+ Context {interp_cast : zrange -> Z -> Z} {interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x}.
+ Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)).
+ Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)).
Local Notation word_range := (r[0~>wordmax-1])%zrange.
Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange.
@@ -8979,7 +9049,7 @@ Module PreFancy.
.
Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop :=
- | ok_of_scalar : forall t s, @ok_expr t (Scalar s)
+ | ok_of_scalar : forall t s, ok_scalar s -> @ok_expr t (Scalar s)
| ok_letin_z : forall s d r idc x f,
ok_ident _ type.Z x r idc ->
ok_scalar x ->
@@ -9018,23 +9088,15 @@ Module PreFancy.
Lemma interp_cast_noop x r :
@has_range type.Z r x ->
interp_cast r x = x.
- Proof.
- cbv [has_range interp_cast ident.cast]; intros.
- break_match;
- try match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end;
- try match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end;
- Z.ltb_to_lt; omega.
- Qed.
+ Proof. cbv [has_range]; intros; auto. Qed.
Lemma interp_cast2_noop x r :
@has_range (type.prod type.Z type.Z) r x ->
interp_cast2 r x = x.
Proof.
- cbv [has_range interp_cast2 interp_cast ident.cast]; intros.
- break_match; destruct x;
- repeat match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end;
- repeat match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end;
- Z.ltb_to_lt; try omega; reflexivity.
+ cbv [has_range interp_cast2]; intros.
+ rewrite !interp_cast_correct by tauto.
+ destruct x; reflexivity.
Qed.
Lemma has_range_shiftr n (x : scalar type.Z) :
@@ -9440,8 +9502,8 @@ Module PreFancy.
is_halved y ->
ok_scalar (Pair x y) ->
@has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) ->
- interp (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) =
- interp (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))).
+ @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) =
+ @interp interp_cast _ (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))).
Proof.
intros Hx Hy Hok ?; invert Hok; cbn [interp_scalar of_straightline_ident].
destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ];
@@ -9459,12 +9521,39 @@ Module PreFancy.
intros. apply Z.mod_small; omega.
Qed.
+ Lemma half_word_range_le_word_range r :
+ upper r = wordmax_half_bits - 1 ->
+ lower r = 0 ->
+ (r <=? word_range)%zrange = true.
+ Proof.
+ pose proof wordmax_half_bits_le_wordmax.
+ destruct r; cbv [is_tighter_than_bool ZRange.lower ZRange.upper].
+ intros; subst.
+ apply andb_true_iff; split; Z.ltb_to_lt; lia.
+ Qed.
+
+ Lemma and_shiftl_half_bits_eq x :
+ (x &' (wordmax_half_bits - 1)) << half_bits = x << half_bits mod wordmax.
+ Proof.
+ rewrite ones_half_bits.
+ rewrite Z.land_ones, !Z.shiftl_mul_pow2 by auto using half_bits_nonneg.
+ rewrite <-wordmax_half_bits_squared.
+ subst wordmax_half_bits.
+ rewrite Z.mul_mod_distr_r_full.
+ reflexivity.
+ Qed.
+
+ Lemma in_word_range_word_range : in_word_range word_range.
+ Proof.
+ cbv [in_word_range is_tighter_than_bool].
+ rewrite !Z.leb_refl; reflexivity.
+ Qed.
+
Lemma invert_shift_correct (s : scalar type.Z) x imm :
ok_scalar s ->
invert_shift consts s = Some (x, imm) ->
interp_scalar s = (interp_scalar x << imm) mod wordmax.
Proof.
- (*
intros Hok ?; invert Hok;
try match goal with H : ok_scalar ?x, H' : context[Cast _ ?x] |- _ =>
invert H end;
@@ -9472,44 +9561,34 @@ Module PreFancy.
invert H end;
try match goal with H : ok_scalar ?x, H' : context[Shiftl _ (Cast _ ?x)] |- _ =>
invert H end;
- try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate).
- {
+ try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate);
repeat match goal with
- | _ => progress (cbn [invert_shift invert_upper invert_upper'
- interp_scalar fst snd] in * )
- | _ => rewrite interp_cast_noop by eauto using has_range_loosen
- | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H
- | H : context [if ?x then _ else _] |- _ =>
- let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H
- | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt
- | H : Some _ = Some _ |- _ => progress (invert H)
- | _ => progress subst
- | _ => reflexivity
- | _ => discriminate
- end.
- rewrite has_word_range_mod_small.
- 2:eauto using has_range_loosen.
- }
- {
- repeat match goal with
- | _ => progress (cbn [invert_shift invert_upper invert_upper'
- invert_lower invert_lower'
- interp_scalar fst snd] in * )
- | _ => rewrite interp_cast_noop by eauto using has_range_loosen
+ | _ => progress (cbn [invert_shift invert_lower invert_lower' invert_upper invert_upper' interp_scalar fst snd] in * )
+ | _ => rewrite interp_cast_noop by eauto using has_half_word_range_land, has_half_word_range_shiftr, in_word_range_word_range, has_range_loosen
| H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H
| H : ok_scalar (Shiftl _ _) |- _ => apply has_range_interp_scalar in H
| H : ok_scalar (Land _ _) |- _ => apply has_range_interp_scalar in H
| H : context [if ?x then _ else _] |- _ =>
let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H
+ | H : context [match @constant_to_scalar ?v ?consts ?x with _ => _ end] |- _ =>
+ let Heq := fresh in
+ case_eq (@constant_to_scalar v consts x); intros until 0; intro Heq; rewrite Heq in *; [|discriminate];
+ destruct (constant_to_scalar_cases _ _ Heq) as [ [? [? ?] ] | [? [? ?] ] ]; subst;
+ pose proof (ok_scalar_constant_to_scalar _ _ Heq)
+ | H : constant_to_scalar _ _ = Some _ |- _ => erewrite <-(constant_to_scalar_correct _ _ H)
| H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt
| H : Some _ = Some _ |- _ => progress (invert H)
+ | _ => rewrite has_word_range_mod_small by eauto using has_range_loosen, half_word_range_le_word_range
+ | _ => rewrite has_word_range_mod_small by
+ (eapply has_range_loosen with (r1:=half_word_range);
+ [ eapply has_half_word_range_shiftr with (r:=word_range) | ];
+ eauto using in_word_range_word_range, half_word_range_le_word_range)
+ | _ => rewrite and_shiftl_half_bits_eq
| _ => progress subst
| _ => reflexivity
| _ => discriminate
end.
Qed.
- *)
- Admitted.
Local Ltac solve_commutative_replace :=
match goal with
@@ -9521,8 +9600,8 @@ Module PreFancy.
Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g :
ok_ident s d x r idc ->
ok_scalar x ->
- interp (of_straightline_ident dummy_arrow consts idc t r x g) =
- interp (g (ident.interp idc (interp_scalar x))).
+ @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r x g) =
+ @interp interp_cast _ (g (ident.interp idc (interp_scalar x))).
Proof.
intros.
pose proof wordmax_half_bits_pos.
@@ -9559,7 +9638,8 @@ Module PreFancy.
Lemma of_straightline_correct {t} (e : expr t) :
ok_expr e ->
- interp (of_straightline dummy_arrow consts e) = Straightline.expr.interp (interp_ident:=@ident.interp) e.
+ @interp interp_cast _ (of_straightline dummy_arrow consts e)
+ = Straightline.expr.interp (interp_ident:=@ident.interp) (interp_cast:=interp_cast) e.
Proof.
induction 1; cbn [of_straightline]; intros;
repeat match goal with
@@ -9572,12 +9652,79 @@ Module PreFancy.
end.
Qed.
End proofs.
+
+ Section no_interp_cast.
+ Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z)
+ (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1).
+
+ Local Arguments interp _ {_} _.
+ Local Arguments interp_scalar _ {_} _.
+
+ Local Ltac tighter_than_to_le :=
+ repeat match goal with
+ | _ => progress (cbv [is_tighter_than_bool] in * )
+ | _ => rewrite andb_true_iff in *
+ | H : _ /\ _ |- _ => destruct H
+ end; Z.ltb_to_lt.
+
+ Lemma replace_interp_cast_scalar {t} (x : scalar t) interp_cast interp_cast'
+ (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x)
+ (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) :
+ ok_scalar x ->
+ interp_scalar interp_cast x = interp_scalar interp_cast' x.
+ Proof.
+ induction 1; cbn [interp_scalar Straightline.expr.interp_scalar];
+ repeat match goal with
+ | _ => progress (cbv [has_range interp_cast2] in * )
+ | _ => progress tighter_than_to_le
+ | H : ok_scalar _ |- _ => apply (has_range_interp_scalar (interp_cast_correct:=interp_cast_correct)) in H
+ | _ => rewrite <-IHok_scalar
+ | _ => rewrite interp_cast_correct by omega
+ | _ => rewrite interp_cast'_correct by omega
+ | _ => congruence
+ end.
+ Qed.
+
+ Lemma replace_interp_cast {t} (e : expr t) interp_cast interp_cast'
+ (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x)
+ (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) :
+ ok_expr consts e ->
+ interp interp_cast (of_straightline dummy_arrow consts e) =
+ interp interp_cast' (of_straightline dummy_arrow consts e).
+ Proof.
+ induction 1; intros; cbn [of_straightline interp].
+ { apply replace_interp_cast_scalar; auto. }
+ { rewrite !of_straightline_ident_correct by auto.
+ rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto.
+ eauto using ident_interp_has_range. }
+ { rewrite !of_straightline_ident_correct by auto.
+ rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto.
+ eauto using ident_interp_has_range. }
+ Qed.
+ End no_interp_cast.
End with_wordmax.
Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d))
(var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d :=
@of_straightline log2wordmax var dummy_arrow consts _ (Straightline.of_Expr e var x dummy_arrow).
+ Definition interp_cast_mod w r x := if (lower r =? 0)
+ then if (upper r =? 2^w - 1)
+ then x mod (2^w)
+ else if (upper r =? 1)
+ then x mod 2
+ else x
+ else x.
+
+ Lemma interp_cast_mod_correct w r x :
+ lower r <= x <= upper r ->
+ interp_cast_mod w r x = x.
+ Proof.
+ cbv [interp_cast_mod].
+ intros; break_match; rewrite ?andb_true_iff in *; intuition; Z.ltb_to_lt;
+ apply Z.mod_small; omega.
+ Qed.
+
Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d))
(e' : (type.interp s -> Uncurried.expr.expr d))
(x : type.interp s) dummy_arrow :
@@ -9589,11 +9736,14 @@ Module PreFancy.
ok_expr log2wordmax consts
(of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e _)) (e' x)) ->
(depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e _))%nat ->
- interp log2wordmax (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x.
+ @interp log2wordmax (interp_cast_mod log2wordmax) _ (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x.
Proof.
intro He'; intros; cbv [of_Expr Straightline.of_Expr].
rewrite He'; cbn [invert_Abs expr.interp].
- rewrite of_straightline_correct by assumption.
+ assert (forall r z, lower r <= z <= upper r -> ident.cast ident.cast_outside_of_range r z = z) as interp_cast_correct.
+ { cbv [ident.cast]; intros; break_match; rewrite ?andb_true_iff, ?andb_false_iff in *; intuition; Z.ltb_to_lt; omega. }
+ erewrite replace_interp_cast with (interp_cast':=ident.cast ident.cast_outside_of_range) by auto using interp_cast_mod_correct.
+ rewrite of_straightline_correct by auto.
erewrite Straightline.expr.of_uncurried_correct by eassumption.
reflexivity.
Qed.
@@ -9651,6 +9801,661 @@ Module PreFancy.
End Notations.
End PreFancy.
+Module Fancy.
+ Import Straightline.expr.
+
+ Module CC.
+ Inductive code : Type :=
+ | C : code
+ | M : code
+ | L : code
+ | Z : code
+ .
+
+ Record state :=
+ { cc_c : bool; cc_m : bool; cc_l : bool; cc_z : bool }.
+
+ Definition code_dec (x y : code) : {x = y} + {x <> y}.
+ Proof. destruct x, y; try apply (left eq_refl); right; congruence. Defined.
+
+ Definition update (to_write : list code) (result : BinInt.Z) (cc_spec : code -> BinInt.Z -> bool) (old_state : state)
+ : state :=
+ {|
+ cc_c := if (In_dec code_dec C to_write)
+ then cc_spec C result
+ else old_state.(cc_c);
+ cc_m := if (In_dec code_dec M to_write)
+ then cc_spec M result
+ else old_state.(cc_m);
+ cc_l := if (In_dec code_dec L to_write)
+ then cc_spec L result
+ else old_state.(cc_l);
+ cc_z := if (In_dec code_dec Z to_write)
+ then cc_spec Z result
+ else old_state.(cc_z)
+ |}.
+
+ End CC.
+
+ Record instruction :=
+ {
+ num_source_regs : nat;
+ writes_conditions : list CC.code;
+ spec : tuple Z num_source_regs -> CC.state -> Z
+ }.
+
+ Section expr.
+ Context {name : Type} (name_eqb : name -> name -> bool) (wordmax : Z) (cc_spec : CC.code -> Z -> bool).
+
+ Inductive expr :=
+ | Ret : name -> expr
+ | Instr (i : instruction)
+ (rd : name) (* destination register *)
+ (args : tuple name i.(num_source_regs)) (* source registers *)
+ (cont : expr) (* next line *)
+ : expr
+ .
+
+ Fixpoint interp (e : expr) (cc : CC.state) (ctx : name -> Z) : Z :=
+ match e with
+ | Ret n => ctx n
+ | Instr i rd args cont =>
+ let result := i.(spec) (Tuple.map ctx args) cc in
+ let new_cc := CC.update i.(writes_conditions) result cc_spec cc in
+ let new_ctx := (fun n : name => if name_eqb n rd then result mod wordmax else ctx n) in
+ interp cont new_cc new_ctx
+ end.
+ End expr.
+
+ Section ISA.
+ Import CC.
+
+ (* For the C flag, we have to consider cases with a negative result (like the one returned by an underflowing borrow).
+ In these cases, we want to set the C flag to true. *)
+ Definition cc_spec (x : CC.code) (result : BinInt.Z) : bool :=
+ match x with
+ | CC.C => if result <? 0 then true else Z.testbit result 256
+ | CC.M => Z.testbit result 255
+ | CC.L => Z.testbit result 0
+ | CC.Z => result =? 0
+ end.
+
+ Local Definition lower128 x := (Z.land x (Z.ones 128)).
+ Local Definition upper128 x := (Z.shiftr x 128).
+ Local Notation "x '[C]'" := (if x.(cc_c) then 1 else 0) (at level 20).
+ Local Notation "x '[M]'" := (if x.(cc_m) then 1 else 0) (at level 20).
+ Local Notation "x '[L]'" := (if x.(cc_l) then 1 else 0) (at level 20).
+ Local Notation "x '[Z]'" := (if x.(cc_z) then 1 else 0) (at level 20).
+ Local Notation "'int'" := (BinInt.Z).
+ Local Notation "x << y" := ((x << y) mod (2^256)) : Z_scope. (* truncating left shift *)
+
+
+ (* Note: In the specification document, argument order gets a bit
+ confusing. Like here, r0 is always the first argument "source 0"
+ and r1 the second. But the specification of MUL128LU is:
+ (R[RS1][127:0] * R[RS0][255:128])
+
+ while the specification of SUB is:
+ (R[RS0] - shift(R[RS1], imm))
+
+ In the SUB case, r0 is really treated the first argument, but in
+ MUL128LU the order seems to be reversed; rather than low-high, we
+ take the high part of the first argument r0 and the low parts of
+ r1. This is also true for MUL128UL. *)
+
+ Definition ADD (imm : int) : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [C; M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ r0 + (r1 << imm))
+ |}.
+
+ Definition ADDC (imm : int) : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [C; M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ r0 + (r1 << imm) + cc[C])
+ |}.
+
+ Definition SUB (imm : int) : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [C; M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ r0 - (r1 << imm))
+ |}.
+
+ Definition MUL128LL : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ (lower128 r0) * (lower128 r1))
+ |}.
+
+ Definition MUL128LU : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ (lower128 r1) * (upper128 r0)) (* see note *)
+ |}.
+
+ Definition MUL128UL : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ (upper128 r1) * (lower128 r0)) (* see note *)
+ |}.
+
+ Definition MUL128UU : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ (upper128 r0) * (upper128 r1))
+ |}.
+
+ Definition RSHI (imm : int) : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ (r0 + (r1 << 256)) >> imm)
+ |}.
+
+ Definition SELC : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [];
+ spec := (fun '(r0, r1) cc =>
+ if cc[C] =? 1 then r0 else r1)
+ |}.
+
+ Definition SELM : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [];
+ spec := (fun '(r0, r1) cc =>
+ if cc[M] =? 1 then r0 else r1)
+ |}.
+
+ Definition SELL : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [];
+ spec := (fun '(r0, r1) cc =>
+ if cc[L] =? 1 then r0 else r1)
+ |}.
+
+ (* TODO : treat the MOD register specially, like CC *)
+ Definition ADDM : instruction :=
+ {|
+ num_source_regs := 3;
+ writes_conditions := [M; L; Z];
+ spec := (fun '(r0, r1, MOD) cc =>
+ let ra := r0 + r1 in
+ if ra >=? MOD
+ then ra - MOD
+ else ra)
+ |}.
+
+ End ISA.
+
+ Module Registers.
+ Inductive register : Type :=
+ | r0 : register
+ | r1 : register
+ | r2 : register
+ | r3 : register
+ | r4 : register
+ | r5 : register
+ | r6 : register
+ | r7 : register
+ | r8 : register
+ | r9 : register
+ | r10 : register
+ | r11 : register
+ | r12 : register
+ | r13 : register
+ | r14 : register
+ | r15 : register
+ | r16 : register
+ | r17 : register
+ | r18 : register
+ | r19 : register
+ | r20 : register
+ | r21 : register
+ | r22 : register
+ | r23 : register
+ | r24 : register
+ | r25 : register
+ | r26 : register
+ | r27 : register
+ | r28 : register
+ | r29 : register
+ | r30 : register
+ | RegZero : register (* r31 *)
+ | RegMod : register
+ .
+
+ Definition reg_dec (x y : register) : {x = y} + {x <> y}.
+ Proof. destruct x, y; try (apply left; congruence); right; congruence. Defined.
+ Definition reg_eqb x y := if reg_dec x y then true else false.
+
+ Lemma reg_eqb_neq x y : x <> y -> reg_eqb x y = false.
+ Proof. cbv [reg_eqb]; break_match; congruence. Qed.
+ Lemma reg_eqb_refl x : reg_eqb x x = true.
+ Proof. cbv [reg_eqb]; break_match; congruence. Qed.
+ End Registers.
+
+ Section of_prefancy.
+ Context (name : Type) (name_succ : name -> name) (error : name) (consts : Z -> option name).
+
+ Fixpoint var (t : type) : Type :=
+ match t with
+ | type.type_primitive type.Z => name
+ | type.prod a b => var a * var b
+ | _ => unit
+ end.
+
+ Fixpoint of_prefancy_scalar {t} (s : @scalar var t) : var t :=
+ match s with
+ | Var t v => v
+ | Pair a b x y => (of_prefancy_scalar x, of_prefancy_scalar y)
+ | Cast r x => of_prefancy_scalar x
+ | Cast2 r x => of_prefancy_scalar x
+ | Fst a b x => fst (of_prefancy_scalar x)
+ | Snd a b x => snd (of_prefancy_scalar x)
+ | Shiftr n x => error
+ | Shiftl n x => error
+ | Land n x => error
+ | CC_m n x => error
+ | @Primitive _ type.Z x => match consts x with
+ | Some n => n
+ | None => error
+ end
+ | @Primitive _ _ x => tt
+ | TT => tt
+ | Nil _ => tt
+ end.
+
+ (* Note : some argument orders are reversed for MUL128LU, MUL128UL, SELC, SELM, and SELL *)
+ Definition of_prefancy_ident {s d} (idc : PreFancy.ident s d)
+ : @scalar var s -> {i : instruction & tuple name i.(num_source_regs) } :=
+ match idc in PreFancy.ident s d return _ with
+ | PreFancy.add imm => fun args : @scalar var (type.Z * type.Z) =>
+ existT _ (ADD imm) (of_prefancy_scalar args)
+ | PreFancy.addc imm => fun args : @scalar var (type.Z * type.Z * type.Z) =>
+ existT _ (ADDC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args)))
+ | PreFancy.sub imm => fun args : @scalar var (type.Z * type.Z) =>
+ existT _ (SUB imm) (of_prefancy_scalar args)
+ | PreFancy.mulll => fun args : @scalar var (type.Z * type.Z) =>
+ existT _ MUL128LL (of_prefancy_scalar args)
+ | PreFancy.mullh => fun args : @scalar var (type.Z * type.Z) =>
+ existT _ MUL128LU (of_prefancy_scalar (Pair (Snd args) (Fst args)))
+ | PreFancy.mulhl => fun args : @scalar var (type.Z * type.Z) =>
+ existT _ MUL128UL (of_prefancy_scalar (Pair (Snd args) (Fst args)))
+ | PreFancy.mulhh => fun args : @scalar var (type.Z * type.Z) =>
+ existT _ MUL128UU (of_prefancy_scalar args)
+ | PreFancy.rshi imm => fun args : @scalar var (type.Z * type.Z) =>
+ existT _ (RSHI imm) (of_prefancy_scalar args)
+ | PreFancy.selc => fun args : @scalar var (type.Z * type.Z * type.Z) =>
+ existT _ SELC (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args))))
+ | PreFancy.selm => fun args : @scalar var (type.Z * type.Z * type.Z) =>
+ existT _ SELM (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args))))
+ | PreFancy.sell => fun args : @scalar var (type.Z * type.Z * type.Z) =>
+ existT _ SELL (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args))))
+ | PreFancy.addm => fun args : @scalar var (type.Z * type.Z * type.Z) =>
+ existT _ ADDM (of_prefancy_scalar args)
+ end.
+
+ Fixpoint of_prefancy (next_name : name) {t} (e : @Straightline.expr.expr var PreFancy.ident t) : expr :=
+ match e with
+ | LetInAppIdentZ s d r idc x f =>
+ let instr_args := @of_prefancy_ident s type.Z idc x in
+ let i : instruction := projT1 instr_args in
+ let args : tuple name i.(num_source_regs) := projT2 instr_args in
+ Instr i next_name args (of_prefancy (name_succ next_name) (f next_name))
+ | LetInAppIdentZZ s d r idc x f =>
+ let instr_args := @of_prefancy_ident s (type.Z * type.Z) idc x in
+ let i : instruction := projT1 instr_args in
+ let args : tuple name i.(num_source_regs) := projT2 instr_args in
+ Instr i next_name args (of_prefancy (name_succ next_name) (f (next_name, error))) (* we pass the error code as the carry register, because it cannot be read from directly. *)
+ | Scalar type.Z s => Ret (of_prefancy_scalar s)
+ | _ => Ret error
+ end.
+ End of_prefancy.
+
+ Section allocate_registers.
+ Context (reg name : Type) (name_eqb : name -> name -> bool) (error : reg).
+ Fixpoint allocate (e : @expr name) (reg_list : list reg) (name_to_reg : name -> reg) : @expr reg :=
+ match e with
+ | Ret n => Ret (name_to_reg n)
+ | Instr i rd args cont =>
+ match reg_list with
+ | r :: reg_list' => Instr i r (Tuple.map name_to_reg args) (allocate cont reg_list' (fun n => if name_eqb n rd then r else name_to_reg n))
+ | nil => Ret error
+ end
+ end.
+ End allocate_registers.
+
+ Definition test_prog : @expr positive :=
+ Instr (ADD (128)) 3%positive (1, 2)%positive
+ (Instr (ADDC 0) 4%positive (3,1)%positive
+ (Ret 4%positive)).
+
+ Definition x1 := 2^256 - 1.
+ Definition x2 := 2^128 - 1.
+ Definition wordmax := 2^256.
+ Definition expected :=
+ let r3' := (x1 + (x2 << 128)) in
+ let r3 := r3' mod wordmax in
+ let c := r3' / wordmax in
+ let r4' := (r3 + x1 + c) in
+ r4' mod wordmax.
+ Definition actual :=
+ interp Pos.eqb
+ (2^256) cc_spec test_prog {|CC.cc_c:=false; CC.cc_m:=false; CC.cc_l:=false; CC.cc_z:=false|}
+ (fun n => if n =? 1%positive
+ then x1
+ else if n =? 2%positive
+ then x2
+ else 0).
+ Lemma test_prog_ok : expected = actual.
+ Proof. reflexivity. Qed.
+
+ Definition of_Expr {s d} next_name (consts : Z -> option positive) (consts_list : list Z) (e : Expr (s -> d)) (x : var positive s) dummy_arrow : positive -> @expr positive :=
+ fun error =>
+ @of_prefancy positive Pos.succ error consts next_name _ (PreFancy.of_Expr 256 consts_list e (var positive) x dummy_arrow).
+
+End Fancy.
+
+Module Prod.
+ Import Fancy. Import Registers.
+
+ Definition Mul256 (out src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr :=
+ Instr MUL128LL out (src1, src2)
+ (Instr MUL128UL tmp (src1, src2)
+ (Instr (ADD 128) out (out, tmp)
+ (Instr MUL128LU tmp (src1, src2)
+ (Instr (ADD 128) out (out, tmp) cont)))).
+ Definition Mul256x256 (out outHigh src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr :=
+ Instr MUL128LL out (src1, src2)
+ (Instr MUL128UU outHigh (src1, src2)
+ (Instr MUL128UL tmp (src1, src2)
+ (Instr (ADD 128) out (out, tmp)
+ (Instr (ADDC (-128)) outHigh (outHigh, tmp)
+ (Instr MUL128LU tmp (src1, src2)
+ (Instr (ADD 128) out (out, tmp)
+ (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont))))))).
+
+ Definition MontRed256 lo hi y t1 t2 scratch RegPInv : @Fancy.expr register :=
+ Mul256 y lo RegPInv t1
+ (Mul256x256 t1 t2 y RegMod scratch
+ (Instr (ADD 0) lo (lo, t1)
+ (Instr (ADDC 0) hi (hi, t2)
+ (Instr SELC y (RegMod, RegZero)
+ (Instr (SUB 0) lo (hi, y)
+ (Instr ADDM lo (lo, RegZero, RegMod)
+ (Ret lo))))))).
+End Prod.
+
+Module ProdEquiv.
+ Import Fancy. Import Registers.
+
+ Definition interp256 := Fancy.interp reg_eqb (2^256) cc_spec.
+ Lemma interp_step i rd args cont cc ctx :
+ interp256 (Instr i rd args cont) cc ctx =
+ let result := spec i (Tuple.map ctx args) cc in
+ let new_cc := CC.update (writes_conditions i) result cc_spec cc in
+ let new_ctx := fun n => if reg_eqb n rd then result mod wordmax else ctx n in interp256 cont new_cc new_ctx.
+ Proof. reflexivity. Qed.
+
+ (* TODO : move *)
+ Lemma tuple_map_ext {A B} (f g : A -> B) n (t : tuple A n) :
+ (forall x : A, f x = g x) ->
+ Tuple.map f t = Tuple.map g t.
+ Proof.
+ destruct n; [reflexivity|]; cbn in *.
+ induction n; cbn in *; intro H; auto; [ ].
+ rewrite IHn by assumption.
+ rewrite H; reflexivity.
+ Qed.
+
+ Lemma interp_state_equiv e :
+ forall cc ctx cc' ctx',
+ cc = cc' -> (forall r, ctx r = ctx' r) ->
+ interp256 e cc ctx = interp256 e cc' ctx'.
+ Proof.
+ induction e; intros; subst; cbn; [solve[auto]|].
+ apply IHe; rewrite tuple_map_ext with (g:=ctx') by auto;
+ [reflexivity|].
+ intros; break_match; auto.
+ Qed.
+ Lemma cc_overwrite_full x1 x2 l1 cc :
+ CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec (CC.update l1 x1 cc_spec cc) = CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec cc.
+ Proof.
+ cbv [CC.update]. cbn [CC.cc_c CC.cc_m CC.cc_l CC.cc_z].
+ break_match; try match goal with H : ~ In _ _ |- _ => cbv [In] in H; tauto end.
+ reflexivity.
+ Qed.
+
+ Lemma tuple_map_ext_In {A B} (f g : A -> B) n (t : tuple A n) :
+ (forall x, In x (to_list n t) -> f x = g x) ->
+ Tuple.map f t = Tuple.map g t.
+ Proof.
+ destruct n; [reflexivity|]; cbn in *.
+ induction n; cbn in *; intro H; auto; [ ].
+ destruct t.
+ rewrite IHn by auto using in_cons.
+ rewrite H; auto using in_eq.
+ Qed.
+
+ Definition value_unused r e : Prop :=
+ forall x cc ctx, interp256 e cc ctx = interp256 e cc (fun r' => if reg_eqb r' r then x else ctx r').
+
+ Lemma value_unused_skip r i rd args cont (Hcont: value_unused r cont) :
+ r <> rd ->
+ (~ In r (Tuple.to_list _ args)) ->
+ value_unused r (Instr i rd args cont).
+ Proof.
+ cbv [value_unused] in *; intros.
+ rewrite !interp_step; cbv zeta.
+ rewrite Hcont with (x:=x).
+ match goal with |- ?lhs = ?rhs =>
+ match lhs with context [Tuple.map ?f ?t] =>
+ match rhs with context [Tuple.map ?g ?t] =>
+ rewrite (tuple_map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence)
+ end end end.
+ apply interp_state_equiv; [ congruence | ].
+ { intros; cbv [reg_eqb] in *; break_match; congruence. }
+ Qed.
+
+ Lemma value_unused_overwrite r i args cont :
+ (~ In r (Tuple.to_list _ args)) ->
+ value_unused r (Instr i r args cont).
+ Proof.
+ cbv [value_unused]; intros; rewrite !interp_step; cbv zeta.
+ match goal with |- ?lhs = ?rhs =>
+ match lhs with context [Tuple.map ?f ?t] =>
+ match rhs with context [Tuple.map ?g ?t] =>
+ rewrite (tuple_map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence)
+ end end end.
+ apply interp_state_equiv; [ congruence | ].
+ { intros; cbv [reg_eqb] in *; break_match; congruence. }
+ Qed.
+
+ Lemma value_unused_ret r r' :
+ r <> r' ->
+ value_unused r (Ret r').
+ Proof.
+ cbv - [reg_dec]; intros.
+ break_match; congruence.
+ Qed.
+
+ Ltac remember_results :=
+ repeat match goal with |- context [(spec ?i ?args ?flags) mod ?w] =>
+ let x := fresh "x" in
+ let y := fresh "y" in
+ let Heqx := fresh "Heqx" in
+ remember (spec i args flags) as x eqn:Heqx;
+ remember (x mod w) as y
+ end.
+
+ Ltac do_interp_step :=
+ rewrite interp_step; cbn - [interp spec];
+ repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence;
+ remember_results.
+
+ Lemma interp_Mul256 out src1 src2 tmp tmp2 cont cc ctx:
+ out <> src1 ->
+ out <> src2 ->
+ out <> tmp ->
+ out <> tmp2 ->
+ src1 <> src2 ->
+ src1 <> tmp ->
+ src1 <> tmp2 ->
+ src2 <> tmp ->
+ src2 <> tmp2 ->
+ tmp <> tmp2 ->
+ value_unused tmp cont ->
+ value_unused tmp2 cont ->
+ interp256 (Prod.Mul256 out src1 src2 tmp cont) cc ctx =
+ interp256 (
+ Instr MUL128LU tmp (src1, src2)
+ (Instr MUL128UL tmp2 (src1, src2)
+ (Instr MUL128LL out (src1, src2)
+ (Instr (ADD 128) out (out, tmp2)
+ (Instr (ADD 128) out (out, tmp) cont))))) cc ctx.
+ Proof.
+ intros; cbv [Prod.Mul256].
+ repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU ADD] in * ).
+
+ match goal with H : value_unused tmp _ |- _ => erewrite H end.
+ match goal with H : value_unused tmp2 _ |- _ => erewrite H end.
+ apply interp_state_equiv.
+ { rewrite !cc_overwrite_full.
+ f_equal. subst. lia. }
+ { intros; cbv [reg_eqb].
+ repeat (break_match_step ltac:(fun _ => idtac); try congruence); reflexivity. }
+ Qed.
+
+ Lemma interp_Mul256x256 out outHigh src1 src2 tmp tmp2 cont cc ctx:
+ out <> src1 ->
+ out <> outHigh ->
+ out <> src2 ->
+ out <> tmp ->
+ out <> tmp2 ->
+ outHigh <> src1 ->
+ outHigh <> src2 ->
+ outHigh <> tmp ->
+ outHigh <> tmp2 ->
+ src1 <> src2 ->
+ src1 <> tmp ->
+ src1 <> tmp2 ->
+ src2 <> tmp ->
+ src2 <> tmp2 ->
+ tmp <> tmp2 ->
+ value_unused tmp cont ->
+ value_unused tmp2 cont ->
+ interp256 (Prod.Mul256x256 out outHigh src1 src2 tmp cont) cc ctx =
+ interp256 (
+ Instr MUL128LL out (src1, src2)
+ (Instr MUL128LU tmp (src1, src2)
+ (Instr MUL128UL tmp2 (src1, src2)
+ (Instr MUL128UU outHigh (src1, src2)
+ (Instr (ADD 128) out (out, tmp2)
+ (Instr (ADDC (-128)) outHigh (outHigh, tmp2)
+ (Instr (ADD 128) out (out, tmp)
+ (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont)))))))) cc ctx.
+ Proof.
+ intros; cbv [Prod.Mul256x256].
+ repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU MUL128UU ADD ADDC] in * ).
+
+ match goal with H : value_unused tmp _ |- _ => erewrite H end.
+ match goal with H : value_unused tmp2 _ |- _ => erewrite H end.
+ apply interp_state_equiv.
+ { rewrite !cc_overwrite_full.
+ f_equal.
+ subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128].
+ lia. }
+ { intros; cbv [reg_eqb].
+ repeat (break_match_step ltac:(fun _ => idtac); try congruence); try reflexivity; [ ].
+ subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128].
+ lia. }
+ Qed.
+End ProdEquiv.
+
+(* Lemmas to help prove that a fancy and prefancy expression have the
+same meaning -- should be replaced eventually with a proof of fancy
+passes in general. *)
+Module Fancy_PreFancy_Equiv.
+ Import Fancy.Registers.
+
+ Lemma interp_cast_mod_eq w u x: u = 2^w - 1 -> PreFancy.interp_cast_mod w r[0 ~> u] x = x mod 2^w.
+ Proof.
+ cbv [PreFancy.interp_cast_mod upper lower]; intros; subst.
+ rewrite !Z.eqb_refl.
+ reflexivity.
+ Qed.
+ Lemma interp_cast_mod_flag w x: PreFancy.interp_cast_mod w r[0 ~> 1] x = x mod 2.
+ Proof.
+ cbv [PreFancy.interp_cast_mod upper lower].
+ break_match; Z.ltb_to_lt; subst; try omega.
+ f_equal; lia.
+ Qed.
+
+ Lemma interp_equivZ {s} w u (Hu : u = 2^w-1) i rd regs e cc ctx idc args f :
+ (Fancy.spec i (Tuple.map ctx regs) cc
+ = PreFancy.interp_ident (d:=type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args)) ->
+ ( let r := Fancy.spec i (Tuple.map ctx regs) cc in
+ Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e
+ (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc)
+ (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) =
+ PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w (f (r mod 2 ^ w))) ->
+ Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx
+ = PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w
+ (@Straightline.expr.LetInAppIdentZ _ _ s _ (r[0~>2^w-1])%zrange idc args f).
+ Proof.
+ cbv zeta; intros spec_eq next_eq.
+ cbn [Fancy.interp PreFancy.interp].
+ rewrite next_eq.
+ rewrite <-spec_eq.
+ rewrite interp_cast_mod_eq by omega.
+ reflexivity.
+ Qed.
+
+ Lemma interp_equivZZ {s} w (Hw : 2 < 2 ^ w) u (Hu : u = 2^w - 1) i rd regs e cc ctx idc args f :
+ ((Fancy.spec i (Tuple.map ctx regs) cc) mod 2 ^ w
+ = fst (PreFancy.interp_ident (d:=type.Z*type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args))) ->
+ ((if Fancy.cc_spec Fancy.CC.C(Fancy.spec i (Tuple.map ctx regs) cc) then 1 else 0)
+ = snd (PreFancy.interp_ident (d:=type.Z*type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args)) mod 2) ->
+ ( let r := Fancy.spec i (Tuple.map ctx regs) cc in
+ Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e
+ (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc)
+ (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) =
+ PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w
+ (f (r mod 2 ^ w, if (Fancy.cc_spec Fancy.CC.C r) then 1 else 0))) ->
+ Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx
+ = PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w
+ (@Straightline.expr.LetInAppIdentZZ _ _ s _ (r[0~>u], r[0~>1])%zrange idc args f).
+ Proof.
+ cbv zeta; intros spec_eq1 spec_eq2 next_eq.
+ cbn [Fancy.interp PreFancy.interp].
+ cbv [Straightline.expr.interp_cast2]. cbn [fst snd].
+ rewrite next_eq.
+ rewrite interp_cast_mod_eq by omega.
+ rewrite interp_cast_mod_flag by omega.
+ rewrite <-spec_eq1, <-spec_eq2.
+ rewrite Z.mod_mod by omega.
+ reflexivity.
+ Qed.
+End Fancy_PreFancy_Equiv.
+
Module BarrettReduction.
(* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *)
Section Generic.
@@ -10201,6 +11006,7 @@ Module Barrett256.
Proof. Time solve_rbarrett_red machine_wordsize. Time Qed.
Import PrintingNotations.
+ Set Printing Width 1000.
Open Scope expr_scope.
Print barrett_red256.
@@ -10209,21 +11015,16 @@ Module Barrett256.
This is why their results are not cast (because the carry has range [-1~>0]). *)
(*
barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
- expr_let x0 := SELM (x₂, 0,
- 26959946667150639793205513449348445388433292963828203772348655992835) in
+ expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in
expr_let x1 := RSHI (0, x₂, 255) in
expr_let x2 := RSHI (x₂, x₁, 255) in
expr_let x3 := 79228162514264337589248983038 *₂₅₆ (uint128)(x2 >> 128) in
- expr_let x4 := 79228162514264337589248983038 *₂₅₆
- ((uint128)(x2) & 340282366920938463463374607431768211455) in
+ expr_let x4 := 79228162514264337589248983038 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in
expr_let x5 := 340282366841710300930663525764514709507 *₂₅₆ (uint128)(x2 >> 128) in
- expr_let x6 := 340282366841710300930663525764514709507 *₂₅₆
- ((uint128)(x2) & 340282366920938463463374607431768211455) in
- expr_let x7 := ADD_256 ((uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128),
- x6) in
- expr_let x8 := ADDC_256 (x7₂, x3, (uint128)(x5 >> 128)) in
- expr_let x9 := ADD_256 ((uint256)(((uint128)(x5) & 340282366920938463463374607431768211455) << 128),
- x7₁) in
+ expr_let x6 := 340282366841710300930663525764514709507 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in
+ expr_let x7 := ADD_256 ((uint256)(((uint128)(x5) & 340282366920938463463374607431768211455) << 128), x6) in
+ expr_let x8 := ADDC_256 (x7₂, (uint128)(x5 >> 128), x3) in
+ expr_let x9 := ADD_256 ((uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128), x7₁) in
expr_let x10 := ADDC_256 (x9₂, (uint128)(x4 >> 128), x8₁) in
expr_let x11 := ADD_256 (x2, x10₁) in
expr_let x12 := ADDC_128 (x11₂, 0, x1) in
@@ -10232,27 +11033,17 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
expr_let x15 := RSHI (x14₁, x13₁, 1) in
expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x15 >> 128) in
expr_let x17 := 79228162514264337593543950335 *₂₅₆ (uint128)(x15 >> 128) in
- expr_let x18 := 340282366841710300967557013911933812736 *₂₅₆
- ((uint128)(x15) & 340282366920938463463374607431768211455) in
- expr_let x19 := 79228162514264337593543950335 *₂₅₆
- ((uint128)(x15) & 340282366920938463463374607431768211455) in
- expr_let x20 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128),
- x19) in
- expr_let x21 := ADDC_256 (x20₂, x16, (uint128)(x18 >> 128)) in
- expr_let x22 := ADD_256 ((uint256)(((uint128)(x18) & 340282366920938463463374607431768211455) << 128),
- x20₁) in
+ expr_let x18 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in
+ expr_let x19 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in
+ expr_let x20 := ADD_256 ((uint256)(((uint128)(x18) & 340282366920938463463374607431768211455) << 128), x19) in
+ expr_let x21 := ADDC_256 (x20₂, (uint128)(x18 >> 128), x16) in
+ expr_let x22 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), x20₁) in
expr_let x23 := ADDC_256 (x22₂, (uint128)(x17 >> 128), x21₁) in
- expr_let x24 := Z.add_get_carry_concrete
- 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@
- (Z.opp @@ (fst @@ x22), x₁) in
- expr_let x25 := Z.add_with_get_carry_concrete
- 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@
- (x24₂, Z.opp @@ (fst @@ x23), x₂) in
- expr_let x26 := SELL (x25₁, 0,
- 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
+ expr_let x24 := Z.add_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (Z.opp @@ (fst @@ x22), x₁) in
+ expr_let x25 := Z.add_with_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (x24₂, Z.opp @@ (fst @@ x23), x₂) in
+ expr_let x26 := SELL (x25₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
expr_let x27 := Z.cast uint256 @@ (fst @@ SUB_256 (x24₁, x26)) in
- ADDM (x27, 0,
- 115792089210356248762697446949407573530086143415290314195533631308867097853951)
+ ADDM (x27, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
: Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z))
*)
@@ -10274,9 +11065,9 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
mulhl@(y3, RegMuLow, $y1);
mullh@(y4, RegMuLow, $y1);
mulll@(y5, RegMuLow, $y1);
- add@(y6, $y5, $y3, 128);
+ add@(y6, $y5, $y4, 128);
addc@(y7, carry{$y6}, $y2, $y4, -128);
- add@(y8, $y6, $y4, 128);
+ add@(y8, $y6, $y3, 128);
addc@(y9, carry{$y8}, $y7, $y3, -128);
add@(y10, $y1, $y9, 0);
addc@(y11, carry{$y10}, RegZero, $y0, 0); #128
@@ -10287,9 +11078,9 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
mullh@(y16, RegMod, $y14);
mulhl@(y17, RegMod, $y14);
mulll@(y18, RegMod, $y14);
- add@(y19, $y18, $y16, 128);
+ add@(y19, $y18, $y17, 128);
addc@(y20, carry{$y19}, $y15, $y17, -128);
- add@(y21, $y19, $y17, 128);
+ add@(y21, $y19, $y16, 128);
addc@(_, carry{$y21}, $y20, $y16, -128);
Straightline.expr.Scalar (Straightline.expr.Primitive (-1))
*)
@@ -10706,7 +11497,7 @@ Module MontgomeryReduction.
Definition montred' (lo_hi : (Z * Z)) :=
dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout (fst lo_hi) N') 0 in
- dlet_nd t1_t2 := (BaseConversion.widemul_inlined Zlog2R n nout y N) in
+ dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in
dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in
dlet_nd y' := Z.zselect (snd sum_carry) 0 N in
dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in
@@ -10734,6 +11525,9 @@ Module MontgomeryReduction.
Local Lemma eval2 x y : eval w 2 [x;y] = x + R * y.
Proof. cbn. change_weight. ring. Qed.
+ Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct
+ using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul.
+
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
(Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R):
montred' lo_hi = reduce_via_partial N R N' T.
@@ -10741,10 +11535,9 @@ Module MontgomeryReduction.
rewrite <-reduce_via_partial_alt_eq by nia.
cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
rewrite Hlo, Hhi.
- assert (0 <= T mod R * N' < w 2) by (solve_range).
+ assert (0 <= (T mod R) * N' < w 2) by (solve_range).
- rewrite !BaseConversion.widemul_inlined_correct
- by (rewrite ?BaseConversion.widemul_inlined_correct; autorewrite with push_nth_default; solve_range).
+ autorewrite with widemul.
rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega).
rewrite R_two_pow.
cbv [Rows.partition seq]. rewrite !eval2.
@@ -10757,7 +11550,8 @@ Module MontgomeryReduction.
let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end.
autorewrite with zsimplify.
- break_match; try reflexivity; autorewrite with ltb_to_lt in *; rewrite Z.div_small_iff in * by omega;
+ rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *.
+ break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega;
repeat match goal with
| _ => progress autorewrite with zsimplify_fast
| |- context [?x mod (R * R)] =>
@@ -10842,6 +11636,7 @@ Module Montgomery256.
Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1).
Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849).
Definition R := Eval lazy in (2^256).
+ Definition R' := 115792089183396302114378112356516095823261736990586219612555396166510339686400.
Definition machine_wordsize := 256.
Derive montred256
@@ -10856,6 +11651,7 @@ Module Montgomery256.
As montred256_prefancy_eq.
Proof. lazy - [type.interp]; reflexivity. Qed.
+
Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) :
forall (lo hi : Z),
0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N ->
@@ -10933,14 +11729,14 @@ Module Montgomery256.
| |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step'
end; intros; cbn [Nat.max].
- Lemma montred256_prefancy_correct R' (R'_correct : Z.equiv_modulo N (R * R') 1) :
+ Lemma montred256_prefancy_correct :
forall (lo hi : Z) dummy_arrow,
0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N ->
- @PreFancy.interp machine_wordsize type.Z (montred256_prefancy (lo,hi) dummy_arrow) = ((lo + R * hi) * R') mod N.
+ @PreFancy.interp machine_wordsize (PreFancy.interp_cast_mod machine_wordsize) type.Z (montred256_prefancy (lo,hi) dummy_arrow) = ((lo + R * hi) * R') mod N.
Proof.
intros. rewrite montred256_prefancy_eq; cbv [montred256_prefancy'].
erewrite PreFancy.of_Expr_correct.
- { apply montred256_correct_full; assumption. }
+ { apply montred256_correct_full; try assumption; reflexivity. }
{ reflexivity. }
{ lazy; reflexivity. }
{ lazy; reflexivity. }
@@ -10950,10 +11746,298 @@ Module Montgomery256.
repeat (ok_expr_step; [ ]).
ok_expr_step.
lazy; congruence.
+ constructor.
constructor. }
{ lazy. omega. }
Qed.
+ Definition montred256_fancy' (lo hi RegMod RegPInv RegZero error : positive) :=
+ Fancy.of_Expr 3%positive
+ (fun z => if z =? N then Some RegMod else if z =? N' then Some RegPInv else if z =? 0 then Some RegZero else None)
+ [N;N']
+ montred256
+ (lo, hi)%positive
+ (fun _ _ => tt)
+ error.
+ Derive montred256_fancy
+ SuchThat (forall RegMod RegPInv RegZero,
+ montred256_fancy RegMod RegPInv RegZero = montred256_fancy' RegMod RegPInv RegZero)
+ As montred256_fancy_eq.
+ Proof.
+ intros.
+ lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB
+ Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU
+ Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM].
+ reflexivity.
+ Qed.
+
+ Import Fancy.Registers.
+
+ Definition montred256_alloc' lo hi RegPInv :=
+ fun errorP errorR =>
+ Fancy.allocate register
+ positive Pos.eqb
+ errorR
+ (montred256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP)
+ [r2;r3;r4;r5;r6;r7;r8;r9;r10;r11;r12;r13;r14;r15;r16;r17;r18;r19;r20]
+ (fun n => if n =? 1000 then lo
+ else if n =? 1001 then hi
+ else if n =? 1002 then RegMod
+ else if n =? 1003 then RegPInv
+ else if n =? 1004 then RegZero
+ else errorR).
+ Derive montred256_alloc
+ SuchThat (montred256_alloc = montred256_alloc')
+ As montred256_alloc_eq.
+ Proof.
+ intros.
+ cbv [montred256_alloc' montred256_fancy].
+ cbn. subst montred256_alloc.
+ reflexivity.
+ Qed.
+
+ (* TODO : move *)
+ Lemma mulll_comm rd x y cont cc ctx :
+ ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx.
+ Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
+
+ Lemma mulhh_comm rd x y cont cc ctx :
+ ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx.
+ Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
+
+ Lemma mullh_mulhl rd x y cont cc ctx :
+ ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx.
+ Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
+
+ Lemma add_comm rd x y cont cc ctx :
+ 0 <= ctx x < 2^256 ->
+ 0 <= ctx y < 2^256 ->
+ ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx.
+ Proof.
+ intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm.
+ rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity.
+ Qed.
+
+ Lemma addc_comm rd x y cont cc ctx :
+ 0 <= ctx x < 2^256 ->
+ 0 <= ctx y < 2^256 ->
+ ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx.
+ Proof.
+ intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)).
+ rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity.
+ Qed.
+
+ SearchAbout reg_eqb.
+
+ Local Ltac push_value_unused :=
+ repeat match goal with
+ | |- ~ In _ _ => cbn; intuition; congruence
+ | _ => apply ProdEquiv.value_unused_overwrite
+ | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ]
+ | _ => apply ProdEquiv.value_unused_ret; congruence
+ end.
+
+ Local Ltac remember_single_result :=
+ match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] =>
+ let x := fresh "x" in
+ let y := fresh "y" in
+ let Heqx := fresh "Heqx" in
+ remember (Fancy.spec i args cc) as x eqn:Heqx;
+ remember (x mod w) as y
+ end.
+ Local Ltac step_both_sides :=
+ match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 =>
+ rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2);
+ cbn - [Fancy.interp Fancy.spec];
+ repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence;
+ remember_single_result;
+ lazymatch goal with
+ | |- context [Fancy.spec i _ _] =>
+ let Heqa1 := fresh in
+ let Heqa2 := fresh in
+ remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1;
+ remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2;
+ cbn in Heqa1; cbn in Heqa2;
+ repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence;
+ repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence;
+ let a1 := match type of Heqa1 with _ = ?a1 => a1 end in
+ let a2 := match type of Heqa2 with _ = ?a2 => a2 end in
+ (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2)
+ | _ => idtac
+ end
+ end.
+
+ Local Ltac solve_bounds :=
+ match goal with
+ | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega
+ | _ => assumption
+ end.
+
+ Lemma montred256_alloc_equivalent errorP errorR cc_start_state start_context :
+ forall lo hi y t1 t2 scratch RegPInv extra_reg,
+ NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] ->
+ 0 <= start_context lo < R ->
+ 0 <= start_context hi < R ->
+ 0 <= start_context RegPInv < R ->
+ ProdEquiv.interp256 (montred256_alloc r0 r1 r30 errorP errorR) cc_start_state
+ (fun r => if reg_eqb r r0
+ then start_context lo
+ else if reg_eqb r r1
+ then start_context hi
+ else if reg_eqb r r30
+ then start_context RegPInv
+ else start_context r)
+ = ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context.
+ Proof.
+ intros. cbv [R] in *.
+ cbv [Prod.MontRed256 montred256_alloc].
+
+ (* Extract proofs that no registers are equal to each other *)
+ repeat match goal with
+ | H : NoDup _ |- _ => inversion H; subst; clear H
+ | H : ~ In _ _ |- _ => cbv [In] in H
+ | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H
+ | H : ~ False |- _ => clear H
+ end.
+
+ rewrite ProdEquiv.interp_Mul256 with (tmp2:=extra_reg) by (congruence || push_value_unused).
+
+ step_both_sides.
+ step_both_sides.
+ rewrite mulll_comm. step_both_sides.
+ step_both_sides.
+ step_both_sides.
+
+ rewrite ProdEquiv.interp_Mul256x256 with (tmp2:=extra_reg) by (congruence || push_value_unused).
+
+ rewrite mulll_comm. step_both_sides.
+ step_both_sides.
+ step_both_sides.
+ rewrite mulhh_comm. step_both_sides.
+ step_both_sides.
+ step_both_sides.
+ step_both_sides.
+ step_both_sides.
+
+
+ rewrite add_comm by (cbn; solve_bounds). step_both_sides.
+ rewrite addc_comm by (cbn; solve_bounds). step_both_sides.
+ step_both_sides.
+ step_both_sides.
+ step_both_sides.
+
+ cbn; repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence.
+ reflexivity.
+ Qed.
+
+ Import Fancy_PreFancy_Equiv.
+
+ Definition interp_equivZZ_256 {s} :=
+ @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity).
+ Definition interp_equivZ_256 {s} :=
+ @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity).
+
+ Local Ltac simplify_op_equiv start_ctx :=
+ cbn - [Fancy.spec PreFancy.interp_ident Fancy.cc_spec];
+ repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end;
+ cbv - [
+ Z.add_with_get_carry_full
+ Z.add_get_carry_full Z.sub_get_borrow_full
+ Z.le Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl
+ Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones
+ fst snd]; cbn [fst snd];
+ try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega);
+ autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega;
+ repeat match goal with
+ | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H
+ | |- context [?x <? 0] => rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds)
+ | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega)
+ | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] =>
+ replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring
+ end.
+
+ Local Ltac solve_nonneg ctx :=
+ match goal with x := (Fancy.spec _ _ _) |- _ => subst x end;
+ simplify_op_equiv ctx; Z.zero_bounds.
+
+ Local Ltac generalize_result :=
+ let v := fresh "v" in intro v; generalize v; clear v; intro v.
+
+ Local Ltac generalize_result_nonneg ctx :=
+ let v := fresh "v" in
+ let v_nonneg := fresh "v_nonneg" in
+ intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg].
+
+ Local Ltac step ctx :=
+ match goal with
+ | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ =>
+ apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx]
+ | _ => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result]
+ | _ => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result]
+ end.
+
+ (* TODO: move this lemma to ZUtil *)
+ Lemma testbit_neg_eq_if x y n :
+ 0 <= n ->
+ 0 <= x < 2 ^ n ->
+ 0 <= y < 2 ^ n ->
+ Z.b2z (if (x - y) <? 0 then true else Z.testbit (x - y) n) = - ((x - y) / 2 ^ n) mod 2.
+ Proof.
+ intros. rewrite Z.sub_pos_bound_div_eq by omega.
+ break_innermost_match; Z.ltb_to_lt; try lia; try reflexivity; [ ].
+ rewrite Z.testbit_eqb, Z.div_between_0_if by omega.
+ break_innermost_match; Z.ltb_to_lt; try lia; reflexivity.
+ Qed.
+
+ Lemma prod_montred256_correct :
+ forall (cc_start_state : Fancy.CC.state) (* starting carry flags can be anything *)
+ (start_context : register -> Z) (* starting register values *)
+ (lo hi y t1 t2 scratch RegPInv extra_reg : register), (* registers to use in computation *)
+ NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> (* registers must be distinct *)
+ start_context RegPInv = N' -> (* RegPInv needs to hold the inverse of the modulus *)
+ start_context RegMod = N -> (* RegMod needs to hold the modulus *)
+ start_context RegZero = 0 -> (* RegZero needs to hold zero *)
+ (0 <= start_context lo < R) -> (* value in lo is in bounds (R=2^256) *)
+ (0 <= start_context hi < R) -> (* value in hi is in bounds (R=2^256) *)
+ let x := (start_context lo) + R * (start_context hi) in (* x is the input (split into two registers) *)
+ (0 <= x < R * N) -> (* input precondition *)
+ (ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context = (x * R') mod N).
+ Proof.
+ intros. subst x. cbv [N R N'] in *.
+ rewrite <-montred256_prefancy_correct with (dummy_arrow := fun s d _ => DefaultValue.type.default) by auto.
+ rewrite <-montred256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg)
+ by (cbv [R]; auto with omega).
+ cbv [ProdEquiv.interp256].
+ cbv [montred256_alloc montred256_prefancy].
+
+ step start_context; [ reflexivity | ].
+ step start_context; [ reflexivity | ].
+ step start_context; [ reflexivity | ].
+ step start_context; [ reflexivity | reflexivity | ].
+ step start_context; [ reflexivity | reflexivity | ].
+ step start_context; [ reflexivity | ].
+ step start_context; [ reflexivity | ].
+ step start_context; [ reflexivity | ].
+ step start_context; [ reflexivity | ].
+ step start_context; [ reflexivity | reflexivity | ].
+ step start_context; [ reflexivity | reflexivity | ].
+ step start_context; [ reflexivity | reflexivity | ].
+ step start_context; [ reflexivity | reflexivity | ].
+
+ step start_context; [ reflexivity | reflexivity | ].
+ step start_context; [ reflexivity | reflexivity | ].
+ step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ].
+ step start_context; [ reflexivity | | ].
+ {
+ let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity.
+ rewrite !Z.shiftl_0_r, !Z.mod_mod by omega.
+ apply testbit_neg_eq_if;
+ let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity;
+ auto using Z.mod_pos_bound with omega. }
+ step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ].
+ reflexivity.
+ Qed.
+
Import PrintingNotations.
Set Printing Width 10000.
@@ -10963,15 +12047,15 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z *
expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in
expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in
expr_let x2 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in
- expr_let x3 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x2) in
- expr_let x4 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x3₁) in
- expr_let x5 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in
+ expr_let x3 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x2) in
+ expr_let x4 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x3₁) in
+ expr_let x5 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in
expr_let x6 := 79228162514264337593543950335 *₂₅₆ (uint128)(x4₁ >> 128) in
expr_let x7 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in
- expr_let x8 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in
- expr_let x9 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x8) in
- expr_let x10 := ADDC_256 (x9₂, x5, (uint128)(x7 >> 128)) in
- expr_let x11 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x9₁) in
+ expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in
+ expr_let x9 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x5) in
+ expr_let x10 := ADDC_256 (x9₂, (uint128)(x7 >> 128), x8) in
+ expr_let x11 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x9₁) in
expr_let x12 := ADDC_256 (x11₂, (uint128)(x6 >> 128), x10₁) in
expr_let x13 := ADD_256 (x11₁, x₁) in
expr_let x14 := ADDC_256 (x13₂, x12₁, x₂) in
@@ -10989,23 +12073,24 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z *
(*
mulhl@(y0, RegPInv, $x₁);
mulll@(y1, RegPInv, $x₁);
- add@(y2, $y1, $y, 128);
- add@(y3, $y2, $y0, 128);
- mulhh@(y4, RegMod, $y3);
+ add@(y2, $y1, $y0, 128);
+ add@(y3, $y2, $y, 128);
+ mulll@(y4, RegMod, $y3);
mullh@(y5, RegMod, $y3);
mulhl@(y6, RegMod, $y3);
- mulll@(y7, RegMod, $y3);
- add@(y8, $y7, $y5, 128);
- addc@(y9, carry{$y8}, $y4, $y6, -128);
- add@(y10, $y8, $y6, 128);
+ mulhh@(y7, RegMod, $y3);
+ add@(y8, $y4, $y6, 128);
+ addc@(y9, carry{$y8}, $y7, $y6, -128);
+ add@(y10, $y8, $y5, 128);
addc@(y11, carry{$y10}, $y9, $y5, -128);
add@(y12, $y10, $x₁, 0);
addc@(y13, carry{$y12}, $y11, $x₂, 0);
selc@(y14, carry{$y13}, RegZero, RegMod);
sub@(y15, $y13, $y14, 0);
addm@(y16, $y15, RegZero, RegMod);
- ret $y16
+ ret $y16
*)
+
End Montgomery256.
(* Extra-specialized ad-hoc pretty-printing *)
@@ -11139,6 +12224,39 @@ Module FancyPrintingNotations.
*)
End FancyPrintingNotations.
+
+Local Notation "i rd x y ; cont" := (Fancy.Instr i rd (x, y) cont) (at level 40, cont at level 200, format "i rd x y ; '//' cont").
+Local Notation "i rd x y z ; cont" := (Fancy.Instr i rd (x, y, z) cont) (at level 40, cont at level 200, format "i rd x y z ; '//' cont").
+
+Import Fancy.Registers.
+Import Fancy.
+Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.MontRed256.
+(*
+ = fun lo hi y t1 t2 scratch RegPInv : register =>
+ MUL128LL y lo RegPInv;
+ MUL128UL t1 lo RegPInv;
+ ADD 128 y y t1;
+ MUL128LU t1 lo RegPInv;
+ ADD 128 y y t1;
+ MUL128LL t1 y RegMod;
+ MUL128UU t2 y RegMod;
+ MUL128UL scratch y RegMod;
+ ADD 128 t1 t1 scratch;
+ ADDC (-128) t2 t2 scratch;
+ MUL128LU scratch y RegMod;
+ ADD 128 t1 t1 scratch;
+ ADDC (-128) t2 t2 scratch;
+ ADD 0 lo lo t1;
+ ADDC 0 hi hi t2;
+ SELC y RegMod RegZero;
+ SUB 0 lo hi y;
+ ADDM lo lo RegZero RegMod;
+ Ret lo
+ *)
+Import Montgomery256.
+Check Montgomery256.prod_montred256_correct.
+(* Print Assumptions Montgomery256.prod_montred256_correct. *)
+
Import FancyPrintingNotations.
Local Open Scope expr_scope.
@@ -11177,26 +12295,4 @@ expr_let x25 := Z.add_with_get_carry_concrete
c.Sell($x26, $x25_lo, RegZero, RegMod);
c.Sub($x27, $x24_lo, $x26);
c.AddM($ret, $x27, RegZero, RegMod);
-*)
-
-Print Montgomery256.montred256.
-(*
-c.Mul128x128($x0, c.Lower(RegPinv), c.Upper($x_lo));
-c.Mul128x128($x1, c.Upper(RegPinv), c.Lower($x_lo));
-c.Mul128x128($x2, c.Lower(RegPinv), c.Lower($x_lo));
-c.Add256($x3, (c.Lower($x0) << 128), $x2);
-c.Add256($x4, (c.Lower($x1) << 128), $x3_lo);
-c.Mul128x128($x5, c.Upper(RegMod), c.Upper($x4_lo));
-c.Mul128x128($x6, c.Lower(RegMod), c.Upper($x4_lo));
-c.Mul128x128($x7, c.Upper(RegMod), c.Lower($x4_lo));
-c.Mul128x128($x8, c.Lower(RegMod), c.Lower($x4_lo));
-c.Add256($x9, (c.Lower($x6) << 128), $x8);
-c.Addc256($x10, $x9_hi, $x5, c.Upper($x7));
-c.Add256($x11, (c.Lower($x7) << 128), $x9_lo);
-c.Addc256($x12, $x11_hi, c.Upper($x6), $x10_lo);
-c.Add256($x13, $x11_lo, $x_lo);
-c.Addc256($x14, $x13_hi, $x12_lo, $x_hi);
-c.Selc($x15, $x14_hi, RegZero, RegMod);
-c.Sub($x16, $x14_lo, $x15);
-c.AddM($ret, $x16_lo, RegZero, RegMod);
- *) \ No newline at end of file
+ *)