From 6f2493f77f61b3922f3bc01ce3ea613f2a70230c Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 30 May 2018 15:30:04 +0200 Subject: Define machine model, write prefancy->fancy pass, and prove Montgomery code correct --- src/Experiments/SimplyTypedArithmetic.v | 1372 +++++++++++++++++++++++++++---- 1 file changed, 1234 insertions(+), 138 deletions(-) (limited to 'src') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 814161335..76a885d70 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -126,6 +126,9 @@ Module Associational. push; [|rewrite IHp]; reflexivity. Qed. + Lemma eval_rev p : eval (rev p) = eval p. + Proof. induction p; cbn [rev]; push; lia. Qed. + Section Carries. Definition carryterm (w fw:Z) (t:Z * Z) := if (Z.eqb (fst t) w) @@ -1758,6 +1761,7 @@ Module BaseConversion. @Associational.eval_carry @Associational.eval_mul @Positional.eval_to_associational + Associational.eval_carryterm @eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. Ltac push_eval := intros; autorewrite with push_eval; auto with zarith. @@ -1767,14 +1771,30 @@ Module BaseConversion. let p' := convert_bases n m p in Positional.to_associational dw m p'. + (* TODO : move to Associational? *) + Section reorder. + Definition reordering_carry (w fw : Z) (p : list (Z * Z)) := + fold_right (fun t acc => + let r := Associational.carryterm w fw t in + if fst t =? w then acc ++ r else r ++ acc) nil p. + + Lemma eval_reordering_carry w fw p (_:fw<>0): + Associational.eval (reordering_carry w fw p) = Associational.eval p. + Proof. + cbv [reordering_carry]. induction p; [reflexivity |]. + autorewrite with push_fold_right. break_match; push_eval. + Qed. + End reorder. + Hint Rewrite eval_reordering_carry using solve [auto using Z.positive_is_nonzero] : push_eval. + (* carry at specified indices in dw, then use Rows.flatten to convert to Positional with sw *) Definition from_associational idxs n (p : list (Z * Z)) : list Z := (* important not to use Positional.carry here; we don't want to accumulate yet *) - let p' := fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) (Associational.bind_snd p) (rev idxs) in + let p' := fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) (Associational.bind_snd p) (rev idxs) in fst (Rows.flatten sw n (Rows.from_associational sw n p')). Lemma eval_carries p idxs : - Associational.eval (fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) p idxs) = + Associational.eval (fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) p idxs) = Associational.eval p. Proof. apply fold_right_invariant; push_eval. Qed. Hint Rewrite eval_carries: push_eval. @@ -1821,7 +1841,7 @@ Module BaseConversion. As from_associational_inlined_correct. Proof. intros. - cbv beta iota delta [from_associational Associational.carry Associational.carryterm]. + cbv beta iota delta [from_associational reordering_carry Associational.carryterm]. cbv beta iota delta [Let_In]. (* inlines all shifts/lands from carryterm *) cbv beta iota delta [from_associational Rows.from_associational Columns.from_associational]. cbv beta iota delta [Let_In]. (* inlines the shifts from place *) @@ -1908,9 +1928,6 @@ Module BaseConversion. Z.rewrite_mod_small. reflexivity. Qed. - (* For some reason, this is a universe inconsistency if not factored out *) - Lemma nout_nonzero : nout <> 0%nat. Proof. omega. Qed. - Derive widemul_inlined SuchThat (forall a b, 0 <= a * b < 2^log2base * 2^log2base -> @@ -1922,9 +1939,33 @@ Module BaseConversion. cbv beta iota delta [widemul mul_converted]. rewrite <-to_associational_inlined_correct with (p:=[a]). rewrite <-to_associational_inlined_correct with (p:=[b]). - rewrite <-from_associational_inlined_correct by (apply nout_nonzero || assumption). + rewrite <-from_associational_inlined_correct. subst widemul_inlined; reflexivity. Qed. + + Derive widemul_inlined_reverse + SuchThat (forall a b, + 0 <= a * b < 2^log2base * 2^log2base -> + widemul_inlined_reverse a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]) + As widemul_inlined_reverse_correct. + Proof. + intros. + rewrite <-widemul_inlined_correct by assumption. + cbv [widemul_inlined]. + match goal with |- _ = from_associational_inlined sw dw ?idxs ?n ?p => + transitivity (from_associational_inlined sw dw idxs n (rev p)); + [ | transitivity (from_associational sw dw idxs n p); [ | reflexivity ] ](* reverse to make addc chains line up *) + end. + Focus 2. { + rewrite from_associational_inlined_correct by (subst nout; auto). + cbv [from_associational]. + rewrite !Rows.flatten_partitions' by eauto using Rows.length_from_associational. + rewrite !Rows.eval_from_associational by (subst nout; auto). + f_equal. + rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto. + reflexivity. } Unfocus. + subst widemul_inlined_reverse; reflexivity. + Qed. End widemul. End BaseConversion. @@ -8187,8 +8228,8 @@ Module Straightline. Section interp. Context {ident : type -> type -> Type} {interp_ident : forall s d, ident s d -> type.interp s -> type.interp d}. + Context {interp_cast : zrange -> Z -> Z}. - Definition interp_cast (r : zrange) (x : Z) : Z := ident.cast ident.cast_outside_of_range r x. Definition interp_cast2 (r : zrange * zrange) (x : Z * Z) : Z * Z := (interp_cast (fst r) (fst x), interp_cast (snd r) (snd x)). @@ -8220,9 +8261,10 @@ Module Straightline. End interp. Section proofs. - Local Notation straightline_interp := (expr.interp (ident:=default.ident) (interp_ident:=@ident.interp)). + Local Notation straightline_interp := (expr.interp (ident:=default.ident) (interp_ident:=@ident.interp) (interp_cast:=ident.cast (@ident.cast_outside_of_range))). Local Notation uinterp := (Uncurried.expr.interp (@ident.interp)). Local Notation uexpr := (@Uncurried.expr.expr ident type.interp). + Local Notation interp_scalar := (interp_scalar (interp_cast:=ident.cast (@ident.cast_outside_of_range))). Inductive ok_scalar_ident : forall {s d}, ident.ident s d -> Prop := | ok_si_cast : forall r, ok_scalar_ident (ident.Z.cast r) @@ -8273,11 +8315,11 @@ Module Straightline. . Lemma interp_cast_correct r (x : uexpr type.Z) : - interp_cast r (uinterp x) = uinterp (AppIdent (ident.Z.cast r) x). + ident.cast ident.cast_outside_of_range r (uinterp x) = uinterp (AppIdent (ident.Z.cast r) x). Proof. reflexivity. Qed. Lemma interp_cast2_correct r (x : uexpr (type.prod type.Z type.Z)) : - interp_cast2 r (uinterp x) = uinterp (AppIdent (ident.Z.cast2 r) x). + @interp_cast2 (ident.cast ident.cast_outside_of_range) r (uinterp x) = uinterp (AppIdent (ident.Z.cast2 r) x). Proof. cbn; break_match; reflexivity. Qed. Ltac invert H := @@ -8542,6 +8584,28 @@ Module PreFancy. apply BinInt.Z.mul_div_le; omega. } Qed. + Lemma wordmax_half_bits_le_wordmax : wordmax_half_bits <= wordmax. + Proof. + subst wordmax half_bits wordmax_half_bits. + apply Z.pow_le_mono_r; [lia|]. + apply Z.div_le_upper_bound; lia. + Qed. + + Lemma ones_half_bits : wordmax_half_bits - 1 = Z.ones half_bits. + Proof. + subst wordmax_half_bits. cbv [Z.ones]. + rewrite Z.shiftl_mul_pow2, <-Z.sub_1_r by auto using half_bits_nonneg. + lia. + Qed. + + Lemma wordmax_half_bits_squared : wordmax_half_bits * wordmax_half_bits = wordmax. + Proof. + subst wordmax half_bits wordmax_half_bits. + rewrite <-Z.pow_add_r by Z.zero_bounds. + rewrite Z.add_diag, Z.mul_div_eq by omega. + f_equal; lia. + Qed. + Section with_var. Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) (consts : list Z). Local Notation Z := (type.type_primitive type.Z). @@ -8580,7 +8644,7 @@ Module PreFancy. option (@scalar var Z) := match e in scalar t return option (@scalar var Z) with | Cast r (Land n x) => - if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? 2^half_bits-1) + if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? wordmax_half_bits-1) then Some x else None | _ => None @@ -8764,6 +8828,9 @@ Module PreFancy. End with_var. Section interp. + Context {interp_cast : zrange -> Z -> Z}. + Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)). + Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)). Local Notation low x := (Z.land x (wordmax_half_bits - 1)). Local Notation high x := (x >> half_bits). Local Notation shift x imm := ((x << imm) mod wordmax). @@ -8797,6 +8864,9 @@ Module PreFancy. Section proofs. Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z) (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1). + Context {interp_cast : zrange -> Z -> Z} {interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x}. + Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)). + Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)). Local Notation word_range := (r[0~>wordmax-1])%zrange. Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange. @@ -8979,7 +9049,7 @@ Module PreFancy. . Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop := - | ok_of_scalar : forall t s, @ok_expr t (Scalar s) + | ok_of_scalar : forall t s, ok_scalar s -> @ok_expr t (Scalar s) | ok_letin_z : forall s d r idc x f, ok_ident _ type.Z x r idc -> ok_scalar x -> @@ -9018,23 +9088,15 @@ Module PreFancy. Lemma interp_cast_noop x r : @has_range type.Z r x -> interp_cast r x = x. - Proof. - cbv [has_range interp_cast ident.cast]; intros. - break_match; - try match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end; - try match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end; - Z.ltb_to_lt; omega. - Qed. + Proof. cbv [has_range]; intros; auto. Qed. Lemma interp_cast2_noop x r : @has_range (type.prod type.Z type.Z) r x -> interp_cast2 r x = x. Proof. - cbv [has_range interp_cast2 interp_cast ident.cast]; intros. - break_match; destruct x; - repeat match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end; - repeat match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end; - Z.ltb_to_lt; try omega; reflexivity. + cbv [has_range interp_cast2]; intros. + rewrite !interp_cast_correct by tauto. + destruct x; reflexivity. Qed. Lemma has_range_shiftr n (x : scalar type.Z) : @@ -9440,8 +9502,8 @@ Module PreFancy. is_halved y -> ok_scalar (Pair x y) -> @has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) -> - interp (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) = - interp (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))). + @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) = + @interp interp_cast _ (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))). Proof. intros Hx Hy Hok ?; invert Hok; cbn [interp_scalar of_straightline_ident]. destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ]; @@ -9459,12 +9521,39 @@ Module PreFancy. intros. apply Z.mod_small; omega. Qed. + Lemma half_word_range_le_word_range r : + upper r = wordmax_half_bits - 1 -> + lower r = 0 -> + (r <=? word_range)%zrange = true. + Proof. + pose proof wordmax_half_bits_le_wordmax. + destruct r; cbv [is_tighter_than_bool ZRange.lower ZRange.upper]. + intros; subst. + apply andb_true_iff; split; Z.ltb_to_lt; lia. + Qed. + + Lemma and_shiftl_half_bits_eq x : + (x &' (wordmax_half_bits - 1)) << half_bits = x << half_bits mod wordmax. + Proof. + rewrite ones_half_bits. + rewrite Z.land_ones, !Z.shiftl_mul_pow2 by auto using half_bits_nonneg. + rewrite <-wordmax_half_bits_squared. + subst wordmax_half_bits. + rewrite Z.mul_mod_distr_r_full. + reflexivity. + Qed. + + Lemma in_word_range_word_range : in_word_range word_range. + Proof. + cbv [in_word_range is_tighter_than_bool]. + rewrite !Z.leb_refl; reflexivity. + Qed. + Lemma invert_shift_correct (s : scalar type.Z) x imm : ok_scalar s -> invert_shift consts s = Some (x, imm) -> interp_scalar s = (interp_scalar x << imm) mod wordmax. Proof. - (* intros Hok ?; invert Hok; try match goal with H : ok_scalar ?x, H' : context[Cast _ ?x] |- _ => invert H end; @@ -9472,44 +9561,34 @@ Module PreFancy. invert H end; try match goal with H : ok_scalar ?x, H' : context[Shiftl _ (Cast _ ?x)] |- _ => invert H end; - try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate). - { + try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate); repeat match goal with - | _ => progress (cbn [invert_shift invert_upper invert_upper' - interp_scalar fst snd] in * ) - | _ => rewrite interp_cast_noop by eauto using has_range_loosen - | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H - | H : context [if ?x then _ else _] |- _ => - let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H - | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt - | H : Some _ = Some _ |- _ => progress (invert H) - | _ => progress subst - | _ => reflexivity - | _ => discriminate - end. - rewrite has_word_range_mod_small. - 2:eauto using has_range_loosen. - } - { - repeat match goal with - | _ => progress (cbn [invert_shift invert_upper invert_upper' - invert_lower invert_lower' - interp_scalar fst snd] in * ) - | _ => rewrite interp_cast_noop by eauto using has_range_loosen + | _ => progress (cbn [invert_shift invert_lower invert_lower' invert_upper invert_upper' interp_scalar fst snd] in * ) + | _ => rewrite interp_cast_noop by eauto using has_half_word_range_land, has_half_word_range_shiftr, in_word_range_word_range, has_range_loosen | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H | H : ok_scalar (Shiftl _ _) |- _ => apply has_range_interp_scalar in H | H : ok_scalar (Land _ _) |- _ => apply has_range_interp_scalar in H | H : context [if ?x then _ else _] |- _ => let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H + | H : context [match @constant_to_scalar ?v ?consts ?x with _ => _ end] |- _ => + let Heq := fresh in + case_eq (@constant_to_scalar v consts x); intros until 0; intro Heq; rewrite Heq in *; [|discriminate]; + destruct (constant_to_scalar_cases _ _ Heq) as [ [? [? ?] ] | [? [? ?] ] ]; subst; + pose proof (ok_scalar_constant_to_scalar _ _ Heq) + | H : constant_to_scalar _ _ = Some _ |- _ => erewrite <-(constant_to_scalar_correct _ _ H) | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt | H : Some _ = Some _ |- _ => progress (invert H) + | _ => rewrite has_word_range_mod_small by eauto using has_range_loosen, half_word_range_le_word_range + | _ => rewrite has_word_range_mod_small by + (eapply has_range_loosen with (r1:=half_word_range); + [ eapply has_half_word_range_shiftr with (r:=word_range) | ]; + eauto using in_word_range_word_range, half_word_range_le_word_range) + | _ => rewrite and_shiftl_half_bits_eq | _ => progress subst | _ => reflexivity | _ => discriminate end. Qed. - *) - Admitted. Local Ltac solve_commutative_replace := match goal with @@ -9521,8 +9600,8 @@ Module PreFancy. Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g : ok_ident s d x r idc -> ok_scalar x -> - interp (of_straightline_ident dummy_arrow consts idc t r x g) = - interp (g (ident.interp idc (interp_scalar x))). + @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r x g) = + @interp interp_cast _ (g (ident.interp idc (interp_scalar x))). Proof. intros. pose proof wordmax_half_bits_pos. @@ -9559,7 +9638,8 @@ Module PreFancy. Lemma of_straightline_correct {t} (e : expr t) : ok_expr e -> - interp (of_straightline dummy_arrow consts e) = Straightline.expr.interp (interp_ident:=@ident.interp) e. + @interp interp_cast _ (of_straightline dummy_arrow consts e) + = Straightline.expr.interp (interp_ident:=@ident.interp) (interp_cast:=interp_cast) e. Proof. induction 1; cbn [of_straightline]; intros; repeat match goal with @@ -9572,12 +9652,79 @@ Module PreFancy. end. Qed. End proofs. + + Section no_interp_cast. + Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z) + (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1). + + Local Arguments interp _ {_} _. + Local Arguments interp_scalar _ {_} _. + + Local Ltac tighter_than_to_le := + repeat match goal with + | _ => progress (cbv [is_tighter_than_bool] in * ) + | _ => rewrite andb_true_iff in * + | H : _ /\ _ |- _ => destruct H + end; Z.ltb_to_lt. + + Lemma replace_interp_cast_scalar {t} (x : scalar t) interp_cast interp_cast' + (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x) + (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) : + ok_scalar x -> + interp_scalar interp_cast x = interp_scalar interp_cast' x. + Proof. + induction 1; cbn [interp_scalar Straightline.expr.interp_scalar]; + repeat match goal with + | _ => progress (cbv [has_range interp_cast2] in * ) + | _ => progress tighter_than_to_le + | H : ok_scalar _ |- _ => apply (has_range_interp_scalar (interp_cast_correct:=interp_cast_correct)) in H + | _ => rewrite <-IHok_scalar + | _ => rewrite interp_cast_correct by omega + | _ => rewrite interp_cast'_correct by omega + | _ => congruence + end. + Qed. + + Lemma replace_interp_cast {t} (e : expr t) interp_cast interp_cast' + (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x) + (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) : + ok_expr consts e -> + interp interp_cast (of_straightline dummy_arrow consts e) = + interp interp_cast' (of_straightline dummy_arrow consts e). + Proof. + induction 1; intros; cbn [of_straightline interp]. + { apply replace_interp_cast_scalar; auto. } + { rewrite !of_straightline_ident_correct by auto. + rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. + eauto using ident_interp_has_range. } + { rewrite !of_straightline_ident_correct by auto. + rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. + eauto using ident_interp_has_range. } + Qed. + End no_interp_cast. End with_wordmax. Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) (var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d := @of_straightline log2wordmax var dummy_arrow consts _ (Straightline.of_Expr e var x dummy_arrow). + Definition interp_cast_mod w r x := if (lower r =? 0) + then if (upper r =? 2^w - 1) + then x mod (2^w) + else if (upper r =? 1) + then x mod 2 + else x + else x. + + Lemma interp_cast_mod_correct w r x : + lower r <= x <= upper r -> + interp_cast_mod w r x = x. + Proof. + cbv [interp_cast_mod]. + intros; break_match; rewrite ?andb_true_iff in *; intuition; Z.ltb_to_lt; + apply Z.mod_small; omega. + Qed. + Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) (e' : (type.interp s -> Uncurried.expr.expr d)) (x : type.interp s) dummy_arrow : @@ -9589,11 +9736,14 @@ Module PreFancy. ok_expr log2wordmax consts (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e _)) (e' x)) -> (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e _))%nat -> - interp log2wordmax (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x. + @interp log2wordmax (interp_cast_mod log2wordmax) _ (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x. Proof. intro He'; intros; cbv [of_Expr Straightline.of_Expr]. rewrite He'; cbn [invert_Abs expr.interp]. - rewrite of_straightline_correct by assumption. + assert (forall r z, lower r <= z <= upper r -> ident.cast ident.cast_outside_of_range r z = z) as interp_cast_correct. + { cbv [ident.cast]; intros; break_match; rewrite ?andb_true_iff, ?andb_false_iff in *; intuition; Z.ltb_to_lt; omega. } + erewrite replace_interp_cast with (interp_cast':=ident.cast ident.cast_outside_of_range) by auto using interp_cast_mod_correct. + rewrite of_straightline_correct by auto. erewrite Straightline.expr.of_uncurried_correct by eassumption. reflexivity. Qed. @@ -9651,6 +9801,661 @@ Module PreFancy. End Notations. End PreFancy. +Module Fancy. + Import Straightline.expr. + + Module CC. + Inductive code : Type := + | C : code + | M : code + | L : code + | Z : code + . + + Record state := + { cc_c : bool; cc_m : bool; cc_l : bool; cc_z : bool }. + + Definition code_dec (x y : code) : {x = y} + {x <> y}. + Proof. destruct x, y; try apply (left eq_refl); right; congruence. Defined. + + Definition update (to_write : list code) (result : BinInt.Z) (cc_spec : code -> BinInt.Z -> bool) (old_state : state) + : state := + {| + cc_c := if (In_dec code_dec C to_write) + then cc_spec C result + else old_state.(cc_c); + cc_m := if (In_dec code_dec M to_write) + then cc_spec M result + else old_state.(cc_m); + cc_l := if (In_dec code_dec L to_write) + then cc_spec L result + else old_state.(cc_l); + cc_z := if (In_dec code_dec Z to_write) + then cc_spec Z result + else old_state.(cc_z) + |}. + + End CC. + + Record instruction := + { + num_source_regs : nat; + writes_conditions : list CC.code; + spec : tuple Z num_source_regs -> CC.state -> Z + }. + + Section expr. + Context {name : Type} (name_eqb : name -> name -> bool) (wordmax : Z) (cc_spec : CC.code -> Z -> bool). + + Inductive expr := + | Ret : name -> expr + | Instr (i : instruction) + (rd : name) (* destination register *) + (args : tuple name i.(num_source_regs)) (* source registers *) + (cont : expr) (* next line *) + : expr + . + + Fixpoint interp (e : expr) (cc : CC.state) (ctx : name -> Z) : Z := + match e with + | Ret n => ctx n + | Instr i rd args cont => + let result := i.(spec) (Tuple.map ctx args) cc in + let new_cc := CC.update i.(writes_conditions) result cc_spec cc in + let new_ctx := (fun n : name => if name_eqb n rd then result mod wordmax else ctx n) in + interp cont new_cc new_ctx + end. + End expr. + + Section ISA. + Import CC. + + (* For the C flag, we have to consider cases with a negative result (like the one returned by an underflowing borrow). + In these cases, we want to set the C flag to true. *) + Definition cc_spec (x : CC.code) (result : BinInt.Z) : bool := + match x with + | CC.C => if result Z.testbit result 255 + | CC.L => Z.testbit result 0 + | CC.Z => result =? 0 + end. + + Local Definition lower128 x := (Z.land x (Z.ones 128)). + Local Definition upper128 x := (Z.shiftr x 128). + Local Notation "x '[C]'" := (if x.(cc_c) then 1 else 0) (at level 20). + Local Notation "x '[M]'" := (if x.(cc_m) then 1 else 0) (at level 20). + Local Notation "x '[L]'" := (if x.(cc_l) then 1 else 0) (at level 20). + Local Notation "x '[Z]'" := (if x.(cc_z) then 1 else 0) (at level 20). + Local Notation "'int'" := (BinInt.Z). + Local Notation "x << y" := ((x << y) mod (2^256)) : Z_scope. (* truncating left shift *) + + + (* Note: In the specification document, argument order gets a bit + confusing. Like here, r0 is always the first argument "source 0" + and r1 the second. But the specification of MUL128LU is: + (R[RS1][127:0] * R[RS0][255:128]) + + while the specification of SUB is: + (R[RS0] - shift(R[RS1], imm)) + + In the SUB case, r0 is really treated the first argument, but in + MUL128LU the order seems to be reversed; rather than low-high, we + take the high part of the first argument r0 and the low parts of + r1. This is also true for MUL128UL. *) + + Definition ADD (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 + (r1 << imm)) + |}. + + Definition ADDC (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 + (r1 << imm) + cc[C]) + |}. + + Definition SUB (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 - (r1 << imm)) + |}. + + Definition MUL128LL : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (lower128 r0) * (lower128 r1)) + |}. + + Definition MUL128LU : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (lower128 r1) * (upper128 r0)) (* see note *) + |}. + + Definition MUL128UL : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (upper128 r1) * (lower128 r0)) (* see note *) + |}. + + Definition MUL128UU : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (upper128 r0) * (upper128 r1)) + |}. + + Definition RSHI (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (r0 + (r1 << 256)) >> imm) + |}. + + Definition SELC : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[C] =? 1 then r0 else r1) + |}. + + Definition SELM : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[M] =? 1 then r0 else r1) + |}. + + Definition SELL : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[L] =? 1 then r0 else r1) + |}. + + (* TODO : treat the MOD register specially, like CC *) + Definition ADDM : instruction := + {| + num_source_regs := 3; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1, MOD) cc => + let ra := r0 + r1 in + if ra >=? MOD + then ra - MOD + else ra) + |}. + + End ISA. + + Module Registers. + Inductive register : Type := + | r0 : register + | r1 : register + | r2 : register + | r3 : register + | r4 : register + | r5 : register + | r6 : register + | r7 : register + | r8 : register + | r9 : register + | r10 : register + | r11 : register + | r12 : register + | r13 : register + | r14 : register + | r15 : register + | r16 : register + | r17 : register + | r18 : register + | r19 : register + | r20 : register + | r21 : register + | r22 : register + | r23 : register + | r24 : register + | r25 : register + | r26 : register + | r27 : register + | r28 : register + | r29 : register + | r30 : register + | RegZero : register (* r31 *) + | RegMod : register + . + + Definition reg_dec (x y : register) : {x = y} + {x <> y}. + Proof. destruct x, y; try (apply left; congruence); right; congruence. Defined. + Definition reg_eqb x y := if reg_dec x y then true else false. + + Lemma reg_eqb_neq x y : x <> y -> reg_eqb x y = false. + Proof. cbv [reg_eqb]; break_match; congruence. Qed. + Lemma reg_eqb_refl x : reg_eqb x x = true. + Proof. cbv [reg_eqb]; break_match; congruence. Qed. + End Registers. + + Section of_prefancy. + Context (name : Type) (name_succ : name -> name) (error : name) (consts : Z -> option name). + + Fixpoint var (t : type) : Type := + match t with + | type.type_primitive type.Z => name + | type.prod a b => var a * var b + | _ => unit + end. + + Fixpoint of_prefancy_scalar {t} (s : @scalar var t) : var t := + match s with + | Var t v => v + | Pair a b x y => (of_prefancy_scalar x, of_prefancy_scalar y) + | Cast r x => of_prefancy_scalar x + | Cast2 r x => of_prefancy_scalar x + | Fst a b x => fst (of_prefancy_scalar x) + | Snd a b x => snd (of_prefancy_scalar x) + | Shiftr n x => error + | Shiftl n x => error + | Land n x => error + | CC_m n x => error + | @Primitive _ type.Z x => match consts x with + | Some n => n + | None => error + end + | @Primitive _ _ x => tt + | TT => tt + | Nil _ => tt + end. + + (* Note : some argument orders are reversed for MUL128LU, MUL128UL, SELC, SELM, and SELL *) + Definition of_prefancy_ident {s d} (idc : PreFancy.ident s d) + : @scalar var s -> {i : instruction & tuple name i.(num_source_regs) } := + match idc in PreFancy.ident s d return _ with + | PreFancy.add imm => fun args : @scalar var (type.Z * type.Z) => + existT _ (ADD imm) (of_prefancy_scalar args) + | PreFancy.addc imm => fun args : @scalar var (type.Z * type.Z * type.Z) => + existT _ (ADDC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args))) + | PreFancy.sub imm => fun args : @scalar var (type.Z * type.Z) => + existT _ (SUB imm) (of_prefancy_scalar args) + | PreFancy.mulll => fun args : @scalar var (type.Z * type.Z) => + existT _ MUL128LL (of_prefancy_scalar args) + | PreFancy.mullh => fun args : @scalar var (type.Z * type.Z) => + existT _ MUL128LU (of_prefancy_scalar (Pair (Snd args) (Fst args))) + | PreFancy.mulhl => fun args : @scalar var (type.Z * type.Z) => + existT _ MUL128UL (of_prefancy_scalar (Pair (Snd args) (Fst args))) + | PreFancy.mulhh => fun args : @scalar var (type.Z * type.Z) => + existT _ MUL128UU (of_prefancy_scalar args) + | PreFancy.rshi imm => fun args : @scalar var (type.Z * type.Z) => + existT _ (RSHI imm) (of_prefancy_scalar args) + | PreFancy.selc => fun args : @scalar var (type.Z * type.Z * type.Z) => + existT _ SELC (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args)))) + | PreFancy.selm => fun args : @scalar var (type.Z * type.Z * type.Z) => + existT _ SELM (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args)))) + | PreFancy.sell => fun args : @scalar var (type.Z * type.Z * type.Z) => + existT _ SELL (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args)))) + | PreFancy.addm => fun args : @scalar var (type.Z * type.Z * type.Z) => + existT _ ADDM (of_prefancy_scalar args) + end. + + Fixpoint of_prefancy (next_name : name) {t} (e : @Straightline.expr.expr var PreFancy.ident t) : expr := + match e with + | LetInAppIdentZ s d r idc x f => + let instr_args := @of_prefancy_ident s type.Z idc x in + let i : instruction := projT1 instr_args in + let args : tuple name i.(num_source_regs) := projT2 instr_args in + Instr i next_name args (of_prefancy (name_succ next_name) (f next_name)) + | LetInAppIdentZZ s d r idc x f => + let instr_args := @of_prefancy_ident s (type.Z * type.Z) idc x in + let i : instruction := projT1 instr_args in + let args : tuple name i.(num_source_regs) := projT2 instr_args in + Instr i next_name args (of_prefancy (name_succ next_name) (f (next_name, error))) (* we pass the error code as the carry register, because it cannot be read from directly. *) + | Scalar type.Z s => Ret (of_prefancy_scalar s) + | _ => Ret error + end. + End of_prefancy. + + Section allocate_registers. + Context (reg name : Type) (name_eqb : name -> name -> bool) (error : reg). + Fixpoint allocate (e : @expr name) (reg_list : list reg) (name_to_reg : name -> reg) : @expr reg := + match e with + | Ret n => Ret (name_to_reg n) + | Instr i rd args cont => + match reg_list with + | r :: reg_list' => Instr i r (Tuple.map name_to_reg args) (allocate cont reg_list' (fun n => if name_eqb n rd then r else name_to_reg n)) + | nil => Ret error + end + end. + End allocate_registers. + + Definition test_prog : @expr positive := + Instr (ADD (128)) 3%positive (1, 2)%positive + (Instr (ADDC 0) 4%positive (3,1)%positive + (Ret 4%positive)). + + Definition x1 := 2^256 - 1. + Definition x2 := 2^128 - 1. + Definition wordmax := 2^256. + Definition expected := + let r3' := (x1 + (x2 << 128)) in + let r3 := r3' mod wordmax in + let c := r3' / wordmax in + let r4' := (r3 + x1 + c) in + r4' mod wordmax. + Definition actual := + interp Pos.eqb + (2^256) cc_spec test_prog {|CC.cc_c:=false; CC.cc_m:=false; CC.cc_l:=false; CC.cc_z:=false|} + (fun n => if n =? 1%positive + then x1 + else if n =? 2%positive + then x2 + else 0). + Lemma test_prog_ok : expected = actual. + Proof. reflexivity. Qed. + + Definition of_Expr {s d} next_name (consts : Z -> option positive) (consts_list : list Z) (e : Expr (s -> d)) (x : var positive s) dummy_arrow : positive -> @expr positive := + fun error => + @of_prefancy positive Pos.succ error consts next_name _ (PreFancy.of_Expr 256 consts_list e (var positive) x dummy_arrow). + +End Fancy. + +Module Prod. + Import Fancy. Import Registers. + + Definition Mul256 (out src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := + Instr MUL128LL out (src1, src2) + (Instr MUL128UL tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr MUL128LU tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) cont)))). + Definition Mul256x256 (out outHigh src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := + Instr MUL128LL out (src1, src2) + (Instr MUL128UU outHigh (src1, src2) + (Instr MUL128UL tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) + (Instr MUL128LU tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont))))))). + + Definition MontRed256 lo hi y t1 t2 scratch RegPInv : @Fancy.expr register := + Mul256 y lo RegPInv t1 + (Mul256x256 t1 t2 y RegMod scratch + (Instr (ADD 0) lo (lo, t1) + (Instr (ADDC 0) hi (hi, t2) + (Instr SELC y (RegMod, RegZero) + (Instr (SUB 0) lo (hi, y) + (Instr ADDM lo (lo, RegZero, RegMod) + (Ret lo))))))). +End Prod. + +Module ProdEquiv. + Import Fancy. Import Registers. + + Definition interp256 := Fancy.interp reg_eqb (2^256) cc_spec. + Lemma interp_step i rd args cont cc ctx : + interp256 (Instr i rd args cont) cc ctx = + let result := spec i (Tuple.map ctx args) cc in + let new_cc := CC.update (writes_conditions i) result cc_spec cc in + let new_ctx := fun n => if reg_eqb n rd then result mod wordmax else ctx n in interp256 cont new_cc new_ctx. + Proof. reflexivity. Qed. + + (* TODO : move *) + Lemma tuple_map_ext {A B} (f g : A -> B) n (t : tuple A n) : + (forall x : A, f x = g x) -> + Tuple.map f t = Tuple.map g t. + Proof. + destruct n; [reflexivity|]; cbn in *. + induction n; cbn in *; intro H; auto; [ ]. + rewrite IHn by assumption. + rewrite H; reflexivity. + Qed. + + Lemma interp_state_equiv e : + forall cc ctx cc' ctx', + cc = cc' -> (forall r, ctx r = ctx' r) -> + interp256 e cc ctx = interp256 e cc' ctx'. + Proof. + induction e; intros; subst; cbn; [solve[auto]|]. + apply IHe; rewrite tuple_map_ext with (g:=ctx') by auto; + [reflexivity|]. + intros; break_match; auto. + Qed. + Lemma cc_overwrite_full x1 x2 l1 cc : + CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec (CC.update l1 x1 cc_spec cc) = CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec cc. + Proof. + cbv [CC.update]. cbn [CC.cc_c CC.cc_m CC.cc_l CC.cc_z]. + break_match; try match goal with H : ~ In _ _ |- _ => cbv [In] in H; tauto end. + reflexivity. + Qed. + + Lemma tuple_map_ext_In {A B} (f g : A -> B) n (t : tuple A n) : + (forall x, In x (to_list n t) -> f x = g x) -> + Tuple.map f t = Tuple.map g t. + Proof. + destruct n; [reflexivity|]; cbn in *. + induction n; cbn in *; intro H; auto; [ ]. + destruct t. + rewrite IHn by auto using in_cons. + rewrite H; auto using in_eq. + Qed. + + Definition value_unused r e : Prop := + forall x cc ctx, interp256 e cc ctx = interp256 e cc (fun r' => if reg_eqb r' r then x else ctx r'). + + Lemma value_unused_skip r i rd args cont (Hcont: value_unused r cont) : + r <> rd -> + (~ In r (Tuple.to_list _ args)) -> + value_unused r (Instr i rd args cont). + Proof. + cbv [value_unused] in *; intros. + rewrite !interp_step; cbv zeta. + rewrite Hcont with (x:=x). + match goal with |- ?lhs = ?rhs => + match lhs with context [Tuple.map ?f ?t] => + match rhs with context [Tuple.map ?g ?t] => + rewrite (tuple_map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) + end end end. + apply interp_state_equiv; [ congruence | ]. + { intros; cbv [reg_eqb] in *; break_match; congruence. } + Qed. + + Lemma value_unused_overwrite r i args cont : + (~ In r (Tuple.to_list _ args)) -> + value_unused r (Instr i r args cont). + Proof. + cbv [value_unused]; intros; rewrite !interp_step; cbv zeta. + match goal with |- ?lhs = ?rhs => + match lhs with context [Tuple.map ?f ?t] => + match rhs with context [Tuple.map ?g ?t] => + rewrite (tuple_map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) + end end end. + apply interp_state_equiv; [ congruence | ]. + { intros; cbv [reg_eqb] in *; break_match; congruence. } + Qed. + + Lemma value_unused_ret r r' : + r <> r' -> + value_unused r (Ret r'). + Proof. + cbv - [reg_dec]; intros. + break_match; congruence. + Qed. + + Ltac remember_results := + repeat match goal with |- context [(spec ?i ?args ?flags) mod ?w] => + let x := fresh "x" in + let y := fresh "y" in + let Heqx := fresh "Heqx" in + remember (spec i args flags) as x eqn:Heqx; + remember (x mod w) as y + end. + + Ltac do_interp_step := + rewrite interp_step; cbn - [interp spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + remember_results. + + Lemma interp_Mul256 out src1 src2 tmp tmp2 cont cc ctx: + out <> src1 -> + out <> src2 -> + out <> tmp -> + out <> tmp2 -> + src1 <> src2 -> + src1 <> tmp -> + src1 <> tmp2 -> + src2 <> tmp -> + src2 <> tmp2 -> + tmp <> tmp2 -> + value_unused tmp cont -> + value_unused tmp2 cont -> + interp256 (Prod.Mul256 out src1 src2 tmp cont) cc ctx = + interp256 ( + Instr MUL128LU tmp (src1, src2) + (Instr MUL128UL tmp2 (src1, src2) + (Instr MUL128LL out (src1, src2) + (Instr (ADD 128) out (out, tmp2) + (Instr (ADD 128) out (out, tmp) cont))))) cc ctx. + Proof. + intros; cbv [Prod.Mul256]. + repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU ADD] in * ). + + match goal with H : value_unused tmp _ |- _ => erewrite H end. + match goal with H : value_unused tmp2 _ |- _ => erewrite H end. + apply interp_state_equiv. + { rewrite !cc_overwrite_full. + f_equal. subst. lia. } + { intros; cbv [reg_eqb]. + repeat (break_match_step ltac:(fun _ => idtac); try congruence); reflexivity. } + Qed. + + Lemma interp_Mul256x256 out outHigh src1 src2 tmp tmp2 cont cc ctx: + out <> src1 -> + out <> outHigh -> + out <> src2 -> + out <> tmp -> + out <> tmp2 -> + outHigh <> src1 -> + outHigh <> src2 -> + outHigh <> tmp -> + outHigh <> tmp2 -> + src1 <> src2 -> + src1 <> tmp -> + src1 <> tmp2 -> + src2 <> tmp -> + src2 <> tmp2 -> + tmp <> tmp2 -> + value_unused tmp cont -> + value_unused tmp2 cont -> + interp256 (Prod.Mul256x256 out outHigh src1 src2 tmp cont) cc ctx = + interp256 ( + Instr MUL128LL out (src1, src2) + (Instr MUL128LU tmp (src1, src2) + (Instr MUL128UL tmp2 (src1, src2) + (Instr MUL128UU outHigh (src1, src2) + (Instr (ADD 128) out (out, tmp2) + (Instr (ADDC (-128)) outHigh (outHigh, tmp2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont)))))))) cc ctx. + Proof. + intros; cbv [Prod.Mul256x256]. + repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU MUL128UU ADD ADDC] in * ). + + match goal with H : value_unused tmp _ |- _ => erewrite H end. + match goal with H : value_unused tmp2 _ |- _ => erewrite H end. + apply interp_state_equiv. + { rewrite !cc_overwrite_full. + f_equal. + subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. + lia. } + { intros; cbv [reg_eqb]. + repeat (break_match_step ltac:(fun _ => idtac); try congruence); try reflexivity; [ ]. + subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. + lia. } + Qed. +End ProdEquiv. + +(* Lemmas to help prove that a fancy and prefancy expression have the +same meaning -- should be replaced eventually with a proof of fancy +passes in general. *) +Module Fancy_PreFancy_Equiv. + Import Fancy.Registers. + + Lemma interp_cast_mod_eq w u x: u = 2^w - 1 -> PreFancy.interp_cast_mod w r[0 ~> u] x = x mod 2^w. + Proof. + cbv [PreFancy.interp_cast_mod upper lower]; intros; subst. + rewrite !Z.eqb_refl. + reflexivity. + Qed. + Lemma interp_cast_mod_flag w x: PreFancy.interp_cast_mod w r[0 ~> 1] x = x mod 2. + Proof. + cbv [PreFancy.interp_cast_mod upper lower]. + break_match; Z.ltb_to_lt; subst; try omega. + f_equal; lia. + Qed. + + Lemma interp_equivZ {s} w u (Hu : u = 2^w-1) i rd regs e cc ctx idc args f : + (Fancy.spec i (Tuple.map ctx regs) cc + = PreFancy.interp_ident (d:=type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args)) -> + ( let r := Fancy.spec i (Tuple.map ctx regs) cc in + Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e + (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc) + (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) = + PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w (f (r mod 2 ^ w))) -> + Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx + = PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w + (@Straightline.expr.LetInAppIdentZ _ _ s _ (r[0~>2^w-1])%zrange idc args f). + Proof. + cbv zeta; intros spec_eq next_eq. + cbn [Fancy.interp PreFancy.interp]. + rewrite next_eq. + rewrite <-spec_eq. + rewrite interp_cast_mod_eq by omega. + reflexivity. + Qed. + + Lemma interp_equivZZ {s} w (Hw : 2 < 2 ^ w) u (Hu : u = 2^w - 1) i rd regs e cc ctx idc args f : + ((Fancy.spec i (Tuple.map ctx regs) cc) mod 2 ^ w + = fst (PreFancy.interp_ident (d:=type.Z*type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args))) -> + ((if Fancy.cc_spec Fancy.CC.C(Fancy.spec i (Tuple.map ctx regs) cc) then 1 else 0) + = snd (PreFancy.interp_ident (d:=type.Z*type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args)) mod 2) -> + ( let r := Fancy.spec i (Tuple.map ctx regs) cc in + Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e + (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc) + (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) = + PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w + (f (r mod 2 ^ w, if (Fancy.cc_spec Fancy.CC.C r) then 1 else 0))) -> + Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx + = PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w + (@Straightline.expr.LetInAppIdentZZ _ _ s _ (r[0~>u], r[0~>1])%zrange idc args f). + Proof. + cbv zeta; intros spec_eq1 spec_eq2 next_eq. + cbn [Fancy.interp PreFancy.interp]. + cbv [Straightline.expr.interp_cast2]. cbn [fst snd]. + rewrite next_eq. + rewrite interp_cast_mod_eq by omega. + rewrite interp_cast_mod_flag by omega. + rewrite <-spec_eq1, <-spec_eq2. + rewrite Z.mod_mod by omega. + reflexivity. + Qed. +End Fancy_PreFancy_Equiv. + Module BarrettReduction. (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) Section Generic. @@ -10201,6 +11006,7 @@ Module Barrett256. Proof. Time solve_rbarrett_red machine_wordsize. Time Qed. Import PrintingNotations. + Set Printing Width 1000. Open Scope expr_scope. Print barrett_red256. @@ -10209,21 +11015,16 @@ Module Barrett256. This is why their results are not cast (because the carry has range [-1~>0]). *) (* barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, - expr_let x0 := SELM (x₂, 0, - 26959946667150639793205513449348445388433292963828203772348655992835) in + expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in expr_let x1 := RSHI (0, x₂, 255) in expr_let x2 := RSHI (x₂, x₁, 255) in expr_let x3 := 79228162514264337589248983038 *₂₅₆ (uint128)(x2 >> 128) in - expr_let x4 := 79228162514264337589248983038 *₂₅₆ - ((uint128)(x2) & 340282366920938463463374607431768211455) in + expr_let x4 := 79228162514264337589248983038 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in expr_let x5 := 340282366841710300930663525764514709507 *₂₅₆ (uint128)(x2 >> 128) in - expr_let x6 := 340282366841710300930663525764514709507 *₂₅₆ - ((uint128)(x2) & 340282366920938463463374607431768211455) in - expr_let x7 := ADD_256 ((uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128), - x6) in - expr_let x8 := ADDC_256 (x7₂, x3, (uint128)(x5 >> 128)) in - expr_let x9 := ADD_256 ((uint256)(((uint128)(x5) & 340282366920938463463374607431768211455) << 128), - x7₁) in + expr_let x6 := 340282366841710300930663525764514709507 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in + expr_let x7 := ADD_256 ((uint256)(((uint128)(x5) & 340282366920938463463374607431768211455) << 128), x6) in + expr_let x8 := ADDC_256 (x7₂, (uint128)(x5 >> 128), x3) in + expr_let x9 := ADD_256 ((uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128), x7₁) in expr_let x10 := ADDC_256 (x9₂, (uint128)(x4 >> 128), x8₁) in expr_let x11 := ADD_256 (x2, x10₁) in expr_let x12 := ADDC_128 (x11₂, 0, x1) in @@ -10232,27 +11033,17 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type expr_let x15 := RSHI (x14₁, x13₁, 1) in expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x15 >> 128) in expr_let x17 := 79228162514264337593543950335 *₂₅₆ (uint128)(x15 >> 128) in - expr_let x18 := 340282366841710300967557013911933812736 *₂₅₆ - ((uint128)(x15) & 340282366920938463463374607431768211455) in - expr_let x19 := 79228162514264337593543950335 *₂₅₆ - ((uint128)(x15) & 340282366920938463463374607431768211455) in - expr_let x20 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), - x19) in - expr_let x21 := ADDC_256 (x20₂, x16, (uint128)(x18 >> 128)) in - expr_let x22 := ADD_256 ((uint256)(((uint128)(x18) & 340282366920938463463374607431768211455) << 128), - x20₁) in + expr_let x18 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in + expr_let x19 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in + expr_let x20 := ADD_256 ((uint256)(((uint128)(x18) & 340282366920938463463374607431768211455) << 128), x19) in + expr_let x21 := ADDC_256 (x20₂, (uint128)(x18 >> 128), x16) in + expr_let x22 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), x20₁) in expr_let x23 := ADDC_256 (x22₂, (uint128)(x17 >> 128), x21₁) in - expr_let x24 := Z.add_get_carry_concrete - 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ - (Z.opp @@ (fst @@ x22), x₁) in - expr_let x25 := Z.add_with_get_carry_concrete - 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ - (x24₂, Z.opp @@ (fst @@ x23), x₂) in - expr_let x26 := SELL (x25₁, 0, - 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let x24 := Z.add_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (Z.opp @@ (fst @@ x22), x₁) in + expr_let x25 := Z.add_with_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (x24₂, Z.opp @@ (fst @@ x23), x₂) in + expr_let x26 := SELL (x25₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in expr_let x27 := Z.cast uint256 @@ (fst @@ SUB_256 (x24₁, x26)) in - ADDM (x27, 0, - 115792089210356248762697446949407573530086143415290314195533631308867097853951) + ADDM (x27, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) *) @@ -10274,9 +11065,9 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type mulhl@(y3, RegMuLow, $y1); mullh@(y4, RegMuLow, $y1); mulll@(y5, RegMuLow, $y1); - add@(y6, $y5, $y3, 128); + add@(y6, $y5, $y4, 128); addc@(y7, carry{$y6}, $y2, $y4, -128); - add@(y8, $y6, $y4, 128); + add@(y8, $y6, $y3, 128); addc@(y9, carry{$y8}, $y7, $y3, -128); add@(y10, $y1, $y9, 0); addc@(y11, carry{$y10}, RegZero, $y0, 0); #128 @@ -10287,9 +11078,9 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type mullh@(y16, RegMod, $y14); mulhl@(y17, RegMod, $y14); mulll@(y18, RegMod, $y14); - add@(y19, $y18, $y16, 128); + add@(y19, $y18, $y17, 128); addc@(y20, carry{$y19}, $y15, $y17, -128); - add@(y21, $y19, $y17, 128); + add@(y21, $y19, $y16, 128); addc@(_, carry{$y21}, $y20, $y16, -128); Straightline.expr.Scalar (Straightline.expr.Primitive (-1)) *) @@ -10706,7 +11497,7 @@ Module MontgomeryReduction. Definition montred' (lo_hi : (Z * Z)) := dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout (fst lo_hi) N') 0 in - dlet_nd t1_t2 := (BaseConversion.widemul_inlined Zlog2R n nout y N) in + dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in dlet_nd y' := Z.zselect (snd sum_carry) 0 N in dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in @@ -10734,6 +11525,9 @@ Module MontgomeryReduction. Local Lemma eval2 x y : eval w 2 [x;y] = x + R * y. Proof. cbn. change_weight. ring. Qed. + Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct + using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul. + Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = reduce_via_partial N R N' T. @@ -10741,10 +11535,9 @@ Module MontgomeryReduction. rewrite <-reduce_via_partial_alt_eq by nia. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. rewrite Hlo, Hhi. - assert (0 <= T mod R * N' < w 2) by (solve_range). + assert (0 <= (T mod R) * N' < w 2) by (solve_range). - rewrite !BaseConversion.widemul_inlined_correct - by (rewrite ?BaseConversion.widemul_inlined_correct; autorewrite with push_nth_default; solve_range). + autorewrite with widemul. rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). rewrite R_two_pow. cbv [Rows.partition seq]. rewrite !eval2. @@ -10757,7 +11550,8 @@ Module MontgomeryReduction. let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end. autorewrite with zsimplify. - break_match; try reflexivity; autorewrite with ltb_to_lt in *; rewrite Z.div_small_iff in * by omega; + rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *. + break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega; repeat match goal with | _ => progress autorewrite with zsimplify_fast | |- context [?x mod (R * R)] => @@ -10842,6 +11636,7 @@ Module Montgomery256. Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1). Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849). Definition R := Eval lazy in (2^256). + Definition R' := 115792089183396302114378112356516095823261736990586219612555396166510339686400. Definition machine_wordsize := 256. Derive montred256 @@ -10856,6 +11651,7 @@ Module Montgomery256. As montred256_prefancy_eq. Proof. lazy - [type.interp]; reflexivity. Qed. + Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) : forall (lo hi : Z), 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> @@ -10933,14 +11729,14 @@ Module Montgomery256. | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' end; intros; cbn [Nat.max]. - Lemma montred256_prefancy_correct R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + Lemma montred256_prefancy_correct : forall (lo hi : Z) dummy_arrow, 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> - @PreFancy.interp machine_wordsize type.Z (montred256_prefancy (lo,hi) dummy_arrow) = ((lo + R * hi) * R') mod N. + @PreFancy.interp machine_wordsize (PreFancy.interp_cast_mod machine_wordsize) type.Z (montred256_prefancy (lo,hi) dummy_arrow) = ((lo + R * hi) * R') mod N. Proof. intros. rewrite montred256_prefancy_eq; cbv [montred256_prefancy']. erewrite PreFancy.of_Expr_correct. - { apply montred256_correct_full; assumption. } + { apply montred256_correct_full; try assumption; reflexivity. } { reflexivity. } { lazy; reflexivity. } { lazy; reflexivity. } @@ -10950,10 +11746,298 @@ Module Montgomery256. repeat (ok_expr_step; [ ]). ok_expr_step. lazy; congruence. + constructor. constructor. } { lazy. omega. } Qed. + Definition montred256_fancy' (lo hi RegMod RegPInv RegZero error : positive) := + Fancy.of_Expr 3%positive + (fun z => if z =? N then Some RegMod else if z =? N' then Some RegPInv else if z =? 0 then Some RegZero else None) + [N;N'] + montred256 + (lo, hi)%positive + (fun _ _ => tt) + error. + Derive montred256_fancy + SuchThat (forall RegMod RegPInv RegZero, + montred256_fancy RegMod RegPInv RegZero = montred256_fancy' RegMod RegPInv RegZero) + As montred256_fancy_eq. + Proof. + intros. + lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB + Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU + Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. + reflexivity. + Qed. + + Import Fancy.Registers. + + Definition montred256_alloc' lo hi RegPInv := + fun errorP errorR => + Fancy.allocate register + positive Pos.eqb + errorR + (montred256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) + [r2;r3;r4;r5;r6;r7;r8;r9;r10;r11;r12;r13;r14;r15;r16;r17;r18;r19;r20] + (fun n => if n =? 1000 then lo + else if n =? 1001 then hi + else if n =? 1002 then RegMod + else if n =? 1003 then RegPInv + else if n =? 1004 then RegZero + else errorR). + Derive montred256_alloc + SuchThat (montred256_alloc = montred256_alloc') + As montred256_alloc_eq. + Proof. + intros. + cbv [montred256_alloc' montred256_fancy]. + cbn. subst montred256_alloc. + reflexivity. + Qed. + + (* TODO : move *) + Lemma mulll_comm rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma mulhh_comm rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma mullh_mulhl rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma add_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx. + Proof. + intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm. + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. + Qed. + + Lemma addc_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx. + Proof. + intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)). + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. + Qed. + + SearchAbout reg_eqb. + + Local Ltac push_value_unused := + repeat match goal with + | |- ~ In _ _ => cbn; intuition; congruence + | _ => apply ProdEquiv.value_unused_overwrite + | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ] + | _ => apply ProdEquiv.value_unused_ret; congruence + end. + + Local Ltac remember_single_result := + match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] => + let x := fresh "x" in + let y := fresh "y" in + let Heqx := fresh "Heqx" in + remember (Fancy.spec i args cc) as x eqn:Heqx; + remember (x mod w) as y + end. + Local Ltac step_both_sides := + match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 => + rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2); + cbn - [Fancy.interp Fancy.spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + remember_single_result; + lazymatch goal with + | |- context [Fancy.spec i _ _] => + let Heqa1 := fresh in + let Heqa2 := fresh in + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1; + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2; + cbn in Heqa1; cbn in Heqa2; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence; + let a1 := match type of Heqa1 with _ = ?a1 => a1 end in + let a2 := match type of Heqa2 with _ = ?a2 => a2 end in + (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2) + | _ => idtac + end + end. + + Local Ltac solve_bounds := + match goal with + | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega + | _ => assumption + end. + + Lemma montred256_alloc_equivalent errorP errorR cc_start_state start_context : + forall lo hi y t1 t2 scratch RegPInv extra_reg, + NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> + 0 <= start_context lo < R -> + 0 <= start_context hi < R -> + 0 <= start_context RegPInv < R -> + ProdEquiv.interp256 (montred256_alloc r0 r1 r30 errorP errorR) cc_start_state + (fun r => if reg_eqb r r0 + then start_context lo + else if reg_eqb r r1 + then start_context hi + else if reg_eqb r r30 + then start_context RegPInv + else start_context r) + = ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context. + Proof. + intros. cbv [R] in *. + cbv [Prod.MontRed256 montred256_alloc]. + + (* Extract proofs that no registers are equal to each other *) + repeat match goal with + | H : NoDup _ |- _ => inversion H; subst; clear H + | H : ~ In _ _ |- _ => cbv [In] in H + | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H + | H : ~ False |- _ => clear H + end. + + rewrite ProdEquiv.interp_Mul256 with (tmp2:=extra_reg) by (congruence || push_value_unused). + + step_both_sides. + step_both_sides. + rewrite mulll_comm. step_both_sides. + step_both_sides. + step_both_sides. + + rewrite ProdEquiv.interp_Mul256x256 with (tmp2:=extra_reg) by (congruence || push_value_unused). + + rewrite mulll_comm. step_both_sides. + step_both_sides. + step_both_sides. + rewrite mulhh_comm. step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + + + rewrite add_comm by (cbn; solve_bounds). step_both_sides. + rewrite addc_comm by (cbn; solve_bounds). step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + + cbn; repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence. + reflexivity. + Qed. + + Import Fancy_PreFancy_Equiv. + + Definition interp_equivZZ_256 {s} := + @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). + Definition interp_equivZ_256 {s} := + @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). + + Local Ltac simplify_op_equiv start_ctx := + cbn - [Fancy.spec PreFancy.interp_ident Fancy.cc_spec]; + repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end; + cbv - [ + Z.add_with_get_carry_full + Z.add_get_carry_full Z.sub_get_borrow_full + Z.le Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl + Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones + fst snd]; cbn [fst snd]; + try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega); + autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega; + repeat match goal with + | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H + | |- context [?x rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds) + | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega) + | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] => + replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring + end. + + Local Ltac solve_nonneg ctx := + match goal with x := (Fancy.spec _ _ _) |- _ => subst x end; + simplify_op_equiv ctx; Z.zero_bounds. + + Local Ltac generalize_result := + let v := fresh "v" in intro v; generalize v; clear v; intro v. + + Local Ltac generalize_result_nonneg ctx := + let v := fresh "v" in + let v_nonneg := fresh "v_nonneg" in + intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg]. + + Local Ltac step ctx := + match goal with + | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ => + apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx] + | _ => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result] + | _ => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result] + end. + + (* TODO: move this lemma to ZUtil *) + Lemma testbit_neg_eq_if x y n : + 0 <= n -> + 0 <= x < 2 ^ n -> + 0 <= y < 2 ^ n -> + Z.b2z (if (x - y) Z) (* starting register values *) + (lo hi y t1 t2 scratch RegPInv extra_reg : register), (* registers to use in computation *) + NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> (* registers must be distinct *) + start_context RegPInv = N' -> (* RegPInv needs to hold the inverse of the modulus *) + start_context RegMod = N -> (* RegMod needs to hold the modulus *) + start_context RegZero = 0 -> (* RegZero needs to hold zero *) + (0 <= start_context lo < R) -> (* value in lo is in bounds (R=2^256) *) + (0 <= start_context hi < R) -> (* value in hi is in bounds (R=2^256) *) + let x := (start_context lo) + R * (start_context hi) in (* x is the input (split into two registers) *) + (0 <= x < R * N) -> (* input precondition *) + (ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context = (x * R') mod N). + Proof. + intros. subst x. cbv [N R N'] in *. + rewrite <-montred256_prefancy_correct with (dummy_arrow := fun s d _ => DefaultValue.type.default) by auto. + rewrite <-montred256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) + by (cbv [R]; auto with omega). + cbv [ProdEquiv.interp256]. + cbv [montred256_alloc montred256_prefancy]. + + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. + step start_context; [ reflexivity | | ]. + { + let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity. + rewrite !Z.shiftl_0_r, !Z.mod_mod by omega. + apply testbit_neg_eq_if; + let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity; + auto using Z.mod_pos_bound with omega. } + step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. + reflexivity. + Qed. + Import PrintingNotations. Set Printing Width 10000. @@ -10963,15 +12047,15 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in expr_let x2 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in - expr_let x3 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x2) in - expr_let x4 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x3₁) in - expr_let x5 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in + expr_let x3 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x2) in + expr_let x4 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x3₁) in + expr_let x5 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in expr_let x6 := 79228162514264337593543950335 *₂₅₆ (uint128)(x4₁ >> 128) in expr_let x7 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in - expr_let x8 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in - expr_let x9 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x8) in - expr_let x10 := ADDC_256 (x9₂, x5, (uint128)(x7 >> 128)) in - expr_let x11 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x9₁) in + expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in + expr_let x9 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x5) in + expr_let x10 := ADDC_256 (x9₂, (uint128)(x7 >> 128), x8) in + expr_let x11 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x9₁) in expr_let x12 := ADDC_256 (x11₂, (uint128)(x6 >> 128), x10₁) in expr_let x13 := ADD_256 (x11₁, x₁) in expr_let x14 := ADDC_256 (x13₂, x12₁, x₂) in @@ -10989,23 +12073,24 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * (* mulhl@(y0, RegPInv, $x₁); mulll@(y1, RegPInv, $x₁); - add@(y2, $y1, $y, 128); - add@(y3, $y2, $y0, 128); - mulhh@(y4, RegMod, $y3); + add@(y2, $y1, $y0, 128); + add@(y3, $y2, $y, 128); + mulll@(y4, RegMod, $y3); mullh@(y5, RegMod, $y3); mulhl@(y6, RegMod, $y3); - mulll@(y7, RegMod, $y3); - add@(y8, $y7, $y5, 128); - addc@(y9, carry{$y8}, $y4, $y6, -128); - add@(y10, $y8, $y6, 128); + mulhh@(y7, RegMod, $y3); + add@(y8, $y4, $y6, 128); + addc@(y9, carry{$y8}, $y7, $y6, -128); + add@(y10, $y8, $y5, 128); addc@(y11, carry{$y10}, $y9, $y5, -128); add@(y12, $y10, $x₁, 0); addc@(y13, carry{$y12}, $y11, $x₂, 0); selc@(y14, carry{$y13}, RegZero, RegMod); sub@(y15, $y13, $y14, 0); addm@(y16, $y15, RegZero, RegMod); - ret $y16 + ret $y16 *) + End Montgomery256. (* Extra-specialized ad-hoc pretty-printing *) @@ -11139,6 +12224,39 @@ Module FancyPrintingNotations. *) End FancyPrintingNotations. + +Local Notation "i rd x y ; cont" := (Fancy.Instr i rd (x, y) cont) (at level 40, cont at level 200, format "i rd x y ; '//' cont"). +Local Notation "i rd x y z ; cont" := (Fancy.Instr i rd (x, y, z) cont) (at level 40, cont at level 200, format "i rd x y z ; '//' cont"). + +Import Fancy.Registers. +Import Fancy. +Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.MontRed256. +(* + = fun lo hi y t1 t2 scratch RegPInv : register => + MUL128LL y lo RegPInv; + MUL128UL t1 lo RegPInv; + ADD 128 y y t1; + MUL128LU t1 lo RegPInv; + ADD 128 y y t1; + MUL128LL t1 y RegMod; + MUL128UU t2 y RegMod; + MUL128UL scratch y RegMod; + ADD 128 t1 t1 scratch; + ADDC (-128) t2 t2 scratch; + MUL128LU scratch y RegMod; + ADD 128 t1 t1 scratch; + ADDC (-128) t2 t2 scratch; + ADD 0 lo lo t1; + ADDC 0 hi hi t2; + SELC y RegMod RegZero; + SUB 0 lo hi y; + ADDM lo lo RegZero RegMod; + Ret lo + *) +Import Montgomery256. +Check Montgomery256.prod_montred256_correct. +(* Print Assumptions Montgomery256.prod_montred256_correct. *) + Import FancyPrintingNotations. Local Open Scope expr_scope. @@ -11177,26 +12295,4 @@ expr_let x25 := Z.add_with_get_carry_concrete c.Sell($x26, $x25_lo, RegZero, RegMod); c.Sub($x27, $x24_lo, $x26); c.AddM($ret, $x27, RegZero, RegMod); -*) - -Print Montgomery256.montred256. -(* -c.Mul128x128($x0, c.Lower(RegPinv), c.Upper($x_lo)); -c.Mul128x128($x1, c.Upper(RegPinv), c.Lower($x_lo)); -c.Mul128x128($x2, c.Lower(RegPinv), c.Lower($x_lo)); -c.Add256($x3, (c.Lower($x0) << 128), $x2); -c.Add256($x4, (c.Lower($x1) << 128), $x3_lo); -c.Mul128x128($x5, c.Upper(RegMod), c.Upper($x4_lo)); -c.Mul128x128($x6, c.Lower(RegMod), c.Upper($x4_lo)); -c.Mul128x128($x7, c.Upper(RegMod), c.Lower($x4_lo)); -c.Mul128x128($x8, c.Lower(RegMod), c.Lower($x4_lo)); -c.Add256($x9, (c.Lower($x6) << 128), $x8); -c.Addc256($x10, $x9_hi, $x5, c.Upper($x7)); -c.Add256($x11, (c.Lower($x7) << 128), $x9_lo); -c.Addc256($x12, $x11_hi, c.Upper($x6), $x10_lo); -c.Add256($x13, $x11_lo, $x_lo); -c.Addc256($x14, $x13_hi, $x12_lo, $x_hi); -c.Selc($x15, $x14_hi, RegZero, RegMod); -c.Sub($x16, $x14_lo, $x15); -c.AddM($ret, $x16_lo, RegZero, RegMod); - *) \ No newline at end of file + *) -- cgit v1.2.3