From 456e29884c4995157a318f176153d4b5f5836959 Mon Sep 17 00:00:00 2001 From: jadep Date: Mon, 14 Jan 2019 14:50:23 -0500 Subject: separate toplevel2 into several files; fix up final barrett proof --- _CoqProject | 6 +- src/Fancy/Barrett256.v | 413 ++++++ src/Fancy/Montgomery256.v | 508 ++++++++ src/Fancy/Prod.v | 395 ++++++ src/Fancy/Spec.v | 348 +++++ src/Fancy/Translation.v | 1246 ++++++++++++++++++ src/Toplevel2.v | 3168 +-------------------------------------------- 7 files changed, 2917 insertions(+), 3167 deletions(-) create mode 100644 src/Fancy/Barrett256.v create mode 100644 src/Fancy/Montgomery256.v create mode 100644 src/Fancy/Prod.v create mode 100644 src/Fancy/Spec.v create mode 100644 src/Fancy/Translation.v diff --git a/_CoqProject b/_CoqProject index ded3c9adc..ef3ad2d27 100644 --- a/_CoqProject +++ b/_CoqProject @@ -45,7 +45,6 @@ src/RewriterWf2.v src/SlowPrimeSynthesisExamples.v src/StandaloneHaskellMain.v src/StandaloneOCamlMain.v -src/Toplevel2.v src/UnderLets.v src/UnderLetsProofs.v src/Algebra/Field.v @@ -86,6 +85,11 @@ src/ExtractionHaskell/word_by_word_montgomery.v src/ExtractionOCaml/saturated_solinas.v src/ExtractionOCaml/unsaturated_solinas.v src/ExtractionOCaml/word_by_word_montgomery.v +src/Fancy/Barrett256.v +src/Fancy/Montgomery256.v +src/Fancy/Prod.v +src/Fancy/Spec.v +src/Fancy/Translation.v src/Primitives/EdDSARepChange.v src/Primitives/MxDHRepChange.v src/PushButtonSynthesis/ReificationCache.v diff --git a/src/Fancy/Barrett256.v b/src/Fancy/Barrett256.v new file mode 100644 index 000000000..8d3319519 --- /dev/null +++ b/src/Fancy/Barrett256.v @@ -0,0 +1,413 @@ +(* TODO: prune all these dependencies *) +Require Import Coq.ZArith.ZArith Coq.micromega.Lia. +Require Import Coq.derive.Derive. +Require Import Coq.Bool.Bool. +Require Import Coq.Strings.String. +Require Import Coq.Lists.List. +Require Crypto.Util.Strings.String. +Require Import Crypto.Util.Strings.Decimal. +Require Import Crypto.Util.Strings.HexString. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil Coq.Lists.List. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.GetGoal. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC. +Require Import Crypto.Util.ZUtil.Modulo. +Require Import Crypto.Util.ZUtil.Notations. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.ZUtil.Tactics.SplitMinMax. +Require Import Crypto.Util.ErrorT. +Require Import Crypto.Util.Strings.Show. +Require Import Crypto.Util.ZRange.Operations. +Require Import Crypto.Util.ZRange.BasicLemmas. +Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Arithmetic. +Require Import Crypto.Fancy.PrintingNotations. +Require Import Crypto.Fancy.Prod. +Require Import Crypto.Fancy.Spec. +Require Import Crypto.Fancy.Translation. +Require Crypto.Language. +Require Crypto.UnderLets. +Require Crypto.AbstractInterpretation. +Require Crypto.AbstractInterpretationProofs. +Require Crypto.Rewriter. +Require Crypto.MiscCompilerPasses. +Require Crypto.CStringification. +Require Export Crypto.PushButtonSynthesis. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope Z_scope. + +Import Associational Positional. + +Import + Crypto.Language + Crypto.UnderLets + Crypto.AbstractInterpretation + Crypto.AbstractInterpretationProofs + Crypto.Rewriter + Crypto.MiscCompilerPasses + Crypto.CStringification. + +Import + Language.Compilers + UnderLets.Compilers + AbstractInterpretation.Compilers + AbstractInterpretationProofs.Compilers + Rewriter.Compilers + MiscCompilerPasses.Compilers + CStringification.Compilers. + +Import Compilers.defaults. +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion QArith_base.inject_Z : Z >-> Q. + +Import Spec.Fancy. +Import ProdEquiv. + +Module Barrett256. + Import LanguageWf.Compilers. + + Definition M := Eval lazy in (2^256-2^224+2^192+2^96-1). + Definition machine_wordsize := 256. + + Derive barrett_red256 + SuchThat (BarrettReduction.rbarrett_red_correctT M machine_wordsize barrett_red256) + As barrett_red256_correct. + Proof. Time solve_rbarrett_red_nocache machine_wordsize. Time Qed. + + Definition muLow := Eval lazy in (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize). + + Lemma barrett_reduce_correct_specialized : + forall (xLow xHigh : Z), + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. + apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *; + try omega; + try match goal with + | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega + end; lazy; try split; congruence. + Qed. + + (* + (* TODO: delete if unneeded *) + (* Note: If this is not factored out, then for some reason Qed takes forever in barrett_red256_correct_full. *) + Lemma barrett_red256_correct_proj2 : + forall x y, + ZRange.type.option.is_bounded_by + (t:=base.type.prod base.type.Z base.type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + (x, y) = true -> + type.app_curried + (expr.Interp (@ident.gen_interp ident.cast_outside_of_range) + barrett_red256) (x, (y, tt)) = + BarrettReduction.barrett_reduce machine_wordsize M + ((2 ^ (2 * machine_wordsize) / M) + mod 2 ^ machine_wordsize) 2 2 x y. + Proof. + intros. + destruct ((proj1 barrett_red256_correct) (x, (y, tt)) (x, (y, tt))). + { cbn; tauto. } + { cbn in *. rewrite andb_true_r. auto. } + { auto. } + Qed. + Lemma barrett_red256_correct_proj2' : + forall x y, + ZRange.type.option.is_bounded_by + (t:=base.type.prod base.type.Z base.type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + (x, y) = true -> + expr.Interp (@ident.interp) barrett_red256 x y = + BarrettReduction.barrett_reduce machine_wordsize M + ((2 ^ (2 * machine_wordsize) / M) + mod 2 ^ machine_wordsize) 2 2 x y. + Proof. + intros. + erewrite <-barrett_red256_correct_proj2 by assumption. + unfold type.app_curried. exact eq_refl. + Qed. +*) + Strategy -100 [type.app_curried]. + Local Arguments is_bounded_by_bool / . + Lemma barrett_red256_correct_full : + forall (xLow xHigh : Z), + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + expr.Interp (@ident.interp) barrett_red256 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. + rewrite <-barrett_reduce_correct_specialized by assumption. + destruct (proj1 barrett_red256_correct (xLow, (xHigh, tt)) (xLow, (xHigh, tt))) as [H1 H2]. + { repeat split. } + { cbn -[Z.pow]. + rewrite !andb_true_iff. + assert (M < 2^machine_wordsize) by (vm_compute; reflexivity). + repeat apply conj; Z.ltb_to_lt; trivial; omega. } + { etransitivity; [ eapply H2 | ]. (* need Strategy -100 [type.app_curried]. for this to be fast *) + generalize BarrettReduction.barrett_reduce; vm_compute; reflexivity. } + Qed. + + Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) := + of_Expr 6%positive + (make_consts [(RegMuLow, muLow); (RegMod, M); (RegZero, 0)]) + barrett_red256 + (xLow, (xHigh, tt)) + error. + Derive barrett_red256_fancy + SuchThat (forall xLow xHigh RegMuLow RegMod RegZero, + barrett_red256_fancy xLow xHigh RegMuLow RegMod RegZero = barrett_red256_fancy' xLow xHigh RegMuLow RegMod RegZero) + As barrett_red256_fancy_eq. + Proof. + intros. + lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB Fancy.SUBC + Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU + Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. + reflexivity. + Qed. + + Local Ltac wf_subgoal := + repeat match goal with + | _ => progress cbn [fst snd] + | |- LanguageWf.Compilers.expr.wf _ _ _ => + econstructor; try solve [econstructor]; [ ] + | |- LanguageWf.Compilers.expr.wf _ _ _ => + solve [econstructor] + | |- In _ _ => auto 50 using in_eq, in_cons + end. + Local Ltac valid_expr_subgoal := + repeat match goal with + | _ => progress intros + | |- context [valid_ident] => econstructor + | |- context[valid_scalar] => econstructor + | |- context [valid_carry] => econstructor + | _ => reflexivity + | |- _ <> None => cbn; congruence + | |- of_prefancy_scalar _ _ _ _ = _ => cbn; solve [eauto] + end. + + (* TODO: don't rely on the C, M, and L flags *) + Lemma barrett_red256_fancy_correct : + forall xLow xHigh error, + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + let RegZero := 1%positive in + let RegMod := 2%positive in + let RegMuLow := 3%positive in + let RegxHigh := 4%positive in + let RegxLow := 5%positive in + let consts_list := [(RegMuLow, muLow); (RegMod, M); (RegZero, 0)] in + let arg_list := [(RegxHigh, xHigh); (RegxLow, xLow)] in + let ctx := make_ctx (consts_list ++ arg_list) in + let carry_flag := false in (* TODO: don't rely on this value, given it's unused *) + let last_wrote := (fun x : Fancy.CC.code => + match x with + | Fancy.CC.C => RegZero + | _ => RegxHigh (* xHigh needs to have written M; others unused *) + end) in + let cc := make_cc last_wrote ctx carry_flag in + interp Pos.eqb wordmax Fancy.cc_spec (barrett_red256_fancy RegxLow RegxHigh RegMuLow RegMod RegZero error) cc ctx = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. + rewrite barrett_red256_fancy_eq. + cbv [barrett_red256_fancy']. + rewrite <-barrett_red256_correct_full by auto. + eapply of_Expr_correct with (x2 := (xLow, (xHigh, tt))). + { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. + intuition; Prod.inversion_prod; subst; cbv. break_innermost_match; congruence. } + { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. + intuition; Prod.inversion_prod; subst; cbv; congruence. } + { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. tauto. } + { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. + intuition; Prod.inversion_prod; subst; cbv; congruence. } + { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. + match goal with |- context [_ mod ?m] => change m with (2 ^ machine_wordsize) end. + assert (M < 2 ^ machine_wordsize) by (cbv; congruence). + assert (0 <= muLow < 2 ^ machine_wordsize) by (split; cbv; congruence). + intuition; Prod.inversion_prod; subst; apply Z.mod_small; omega. } + { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. + match goal with |- context [_ mod ?m] => change m with (2 ^ machine_wordsize) end. + assert (M < 2 ^ machine_wordsize) by (cbv; congruence). + assert (0 <= muLow < 2 ^ machine_wordsize) by (split; cbv; congruence). + intuition; Prod.inversion_prod; subst; apply Z.mod_small; omega. } + { cbn. + repeat match goal with + | _ => apply expr.WfLetIn + | _ => progress wf_subgoal + | _ => econstructor + end. } + { cbn. cbv [muLow M]. + repeat (econstructor; [ solve [valid_expr_subgoal] | intros ]). + econstructor. valid_expr_subgoal. } + { cbn - [barrett_red256]. cbv [id]. + f_equal. + (* TODO(jgross): switch out casts *) + (* might need to use CheckCasts.interp_eqv_without_casts? *) + replace (@ident.gen_interp cast_oor) with (@ident.interp) by admit. + reflexivity. } + Admitted. + + Import Fancy.Registers. + + Definition barrett_red256_alloc' xLow xHigh RegMuLow := + fun errorP errorR => + allocate register + positive Pos.eqb + errorR + (barrett_red256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) + [r2;r3;r4;r5;r6;r7;r8;r9;r10;r5;r11;r6;r12;r13;r14;r15;r16;r17;r18;r19;r20;r21;r22;r23;r24;r25;r26;r27;r28;r29] + (fun n => if n =? 1000 then xLow + else if n =? 1001 then xHigh + else if n =? 1002 then RegMuLow + else if n =? 1003 then RegMod + else if n =? 1004 then RegZero + else errorR). + Derive barrett_red256_alloc + SuchThat (barrett_red256_alloc = barrett_red256_alloc') + As barrett_red256_alloc_eq. + Proof. + intros. + cbv [barrett_red256_alloc' barrett_red256_fancy]. + cbn. subst barrett_red256_alloc. + reflexivity. + Qed. + + Local Ltac solve_bounds := + match goal with + | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega + | _ => assumption + end. + + Lemma barrett_red256_alloc_equivalent errorP errorR cc_start_state start_context : + forall x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg, + NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> + 0 <= start_context x < 2^machine_wordsize -> + 0 <= start_context xHigh < 2^machine_wordsize -> + 0 <= start_context RegMuLow < 2^machine_wordsize -> + ProdEquiv.interp256 (barrett_red256_alloc r0 r1 r30 errorP errorR) cc_start_state + (fun r => if reg_eqb r r0 + then start_context x + else if reg_eqb r r1 + then start_context xHigh + else if reg_eqb r r30 + then start_context RegMuLow + else start_context r) + = ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context. + Proof. + intros. + let r := eval compute in (2^machine_wordsize) in + replace (2^machine_wordsize) with r in * by reflexivity. + cbv [Prod.MulMod barrett_red256_alloc]. + + (* Extract proofs that no registers are equal to each other *) + repeat match goal with + | H : NoDup _ |- _ => inversion H; subst; clear H + | H : ~ In _ _ |- _ => cbv [In] in H + | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H + | H : ~ False |- _ => clear H + end. + + step_both_sides. + + (* TODO: To prove equivalence between these two, we need to either relocate the RSHI instructions so they're in the same places or use instruction commutativity to push them down. *) + + Admitted. + + Local Ltac results_equiv := + match goal with + |- ?lhs = ?rhs => + match lhs with + context [spec ?li ?largs ?lcc] => + match rhs with + context [spec ?ri ?rargs ?rcc] => + replace (spec li largs lcc) with (spec ri rargs rcc) + end + end + end. + Local Ltac simplify_cc := + match goal with + |- context [CC.update ?to_write ?result ?cc_spec ?old_state] => + let e := eval cbv -[spec cc_spec CC.cc_l CC.cc_m CC.cc_z CC.cc_c] in + (CC.update to_write result cc_spec old_state) in + change (CC.update to_write result cc_spec old_state) with e + end. + + Local Ltac step := + match goal with + |- interp _ _ _ (Instr ?i ?rd1 ?args1 ?cont1) ?cc1 ?ctx1 = + interp _ _ _ (Instr ?i ?rd2 ?args2 ?cont2) ?cc2 ?ctx2 => + rewrite (interp_step _ _ i rd1 args1 cont1); + rewrite (interp_step _ _ i rd2 args2 cont2) + end; + cbn - [Fancy.interp Fancy.spec cc_spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + results_equiv; [ remember_single_result; repeat simplify_cc | try reflexivity ]. + + Lemma prod_barrett_red256_correct : + forall (cc_start_state : Fancy.CC.state) (* starting carry flags *) + (start_context : register -> Z) (* starting register values *) + (x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg : register), (* registers to use in computation *) + NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> (* registers are unique *) + 0 <= start_context x < 2^machine_wordsize -> + 0 <= start_context xHigh < M -> + start_context RegMuLow = muLow -> + start_context RegMod = M -> + start_context RegZero = 0 -> + cc_start_state.(Fancy.CC.cc_m) = cc_spec CC.M (start_context xHigh) -> + let X := start_context x + 2^machine_wordsize * start_context xHigh in + ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context = X mod M. + Proof. + intros. subst X. + assert (0 <= start_context xHigh < 2^machine_wordsize) by (cbv [M] in *; cbn; omega). + let r := (eval compute in (2 ^ machine_wordsize)) in + replace (2^machine_wordsize) with r in * by reflexivity. + + erewrite <-barrett_red256_fancy_correct with (error:=100000%positive) by eauto. + rewrite <-barrett_red256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) + by (auto; cbv [M muLow] in *; cbn; auto with omega). + + cbv [interp256 Translation.wordmax]. + match goal with + |- context [make_cc ?last_wrote ?ctx ?carry] => + let e := fresh in + let He := fresh in + remember (make_cc last_wrote ctx carry) as e eqn:He; + cbv [make_ctx app make_cc] in He; + cbn [Pos.eqb] in He; autorewrite with zsimplify in He; + subst e + end. + + repeat match goal with + H : context [start_context] |- _ => + rewrite <-H end. + + cbv [barrett_red256_alloc barrett_red256_fancy]. + repeat step. + reflexivity. + Qed. +End Barrett256. diff --git a/src/Fancy/Montgomery256.v b/src/Fancy/Montgomery256.v new file mode 100644 index 000000000..7e635e96f --- /dev/null +++ b/src/Fancy/Montgomery256.v @@ -0,0 +1,508 @@ +(* TODO: prune all these dependencies *) +Require Import Coq.ZArith.ZArith Coq.micromega.Lia. +Require Import Coq.derive.Derive. +Require Import Coq.Bool.Bool. +Require Import Coq.Strings.String. +Require Import Coq.Lists.List. +Require Crypto.Util.Strings.String. +Require Import Crypto.Util.Strings.Decimal. +Require Import Crypto.Util.Strings.HexString. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil Coq.Lists.List. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.GetGoal. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC. +Require Import Crypto.Util.ZUtil.Modulo. +Require Import Crypto.Util.ZUtil.Notations. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.ZUtil.Tactics.SplitMinMax. +Require Import Crypto.Util.ErrorT. +Require Import Crypto.Util.Strings.Show. +Require Import Crypto.Util.ZRange.Operations. +Require Import Crypto.Util.ZRange.BasicLemmas. +Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Arithmetic. +Require Import Crypto.Fancy.Prod. +Require Import Crypto.Fancy.Spec. +Require Import Crypto.Fancy.Translation. +Require Crypto.Language. +Require Crypto.UnderLets. +Require Crypto.AbstractInterpretation. +Require Crypto.AbstractInterpretationProofs. +Require Crypto.Rewriter. +Require Crypto.MiscCompilerPasses. +Require Crypto.CStringification. +Require Export Crypto.PushButtonSynthesis. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope Z_scope. + +Import Associational Positional. + +Import + Crypto.Language + Crypto.UnderLets + Crypto.AbstractInterpretation + Crypto.AbstractInterpretationProofs + Crypto.Rewriter + Crypto.MiscCompilerPasses + Crypto.CStringification. + +Import + Language.Compilers + UnderLets.Compilers + AbstractInterpretation.Compilers + AbstractInterpretationProofs.Compilers + Rewriter.Compilers + MiscCompilerPasses.Compilers + CStringification.Compilers. + +Import Compilers.defaults. +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion QArith_base.inject_Z : Z >-> Q. +(* Notation "x" := (expr.Var x) (only printing, at level 9) : expr_scope. *) + +Import UnsaturatedSolinas. + +(* TODO : once Barrett is updated & working, fix Montgomery to match *) +(* +Module Montgomery256. + + Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1). + Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849). + Definition R := Eval lazy in (2^256). + Definition R' := 115792089183396302114378112356516095823261736990586219612555396166510339686400. + Definition machine_wordsize := 256. + + Derive montred256 + SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256) + As montred256_correct. + Proof. Time solve_rmontred_nocache machine_wordsize. Time Qed. + + Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 (lo, hi) = ((lo + R * hi) * R') mod N. + Proof. + intros. + apply MontgomeryReduction.montred'_correct with (T:=lo + R * hi) (R':=R'); + try match goal with + | |- context[R'] => assumption + | |- context [lo] => + try assumption; progress autorewrite with zsimplify cancel_pair; reflexivity + end; lazy; try split; congruence. + Qed. + + (* + (* Note: If this is not factored out, then for some reason Qed takes forever in montred256_correct_full. *) + Lemma montred256_correct_proj2 : + forall xy : type.interp (type.prod type.Z type.Z), + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + xy = true -> + expr.Interp (@ident.interp) montred256 xy = app_curried (t:=type.arrow (type.prod type.Z type.Z) type.Z) (MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2) xy. + Proof. intros; destruct (montred256_correct xy); assumption. Qed. + Lemma montred256_correct_proj2' : + forall xy : type.interp (type.prod type.Z type.Z), + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + xy = true -> + expr.Interp (@ident.interp) montred256 xy = MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 xy. + Proof. intros; rewrite montred256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed. + *) + Local Arguments is_bounded_by_bool / . + Lemma montred256_correct_full R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + PreFancy.Interp 256 montred256 (lo, hi) = ((lo + R * hi) * R') mod N. + Proof. + intros. + rewrite <-montred'_correct_specialized by assumption. + destruct (proj1 montred256_correct ((lo, hi), tt) ((lo, hi), tt)) as [H2 H3]. + { repeat split. } + { cbn -[Z.pow]. + rewrite !andb_true_iff. + repeat apply conj; Z.ltb_to_lt; trivial; cbv [R N machine_wordsize] in *; lia. } + { etransitivity; [ eapply H3 | ]. (* need Strategy -100 [type.app_curried]. for this to be fast *) + generalize MontgomeryReduction.montred'; vm_compute; reflexivity. } + Qed. + + (* + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step' := + match goal with + | _ => assumption + | |- _ <= _ <= _ \/ @eq zrange _ _ => + right; lazy; try split; congruence + | |- _ <= _ <= _ \/ @eq zrange _ _ => + left; lazy; try split; congruence + | |- lower r[0~>_]%zrange = 0 => reflexivity + | |- context [PreFancy.ok_ident] => constructor + | |- context [PreFancy.ok_scalar] => constructor; try omega + | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] + | |- context [PreFancy.is_halved] => constructor + | |- context [PreFancy.in_word_range] => lazy; reflexivity + | |- context [PreFancy.in_flag_range] => lazy; reflexivity + | |- context [PreFancy.get_range] => + cbn [PreFancy.get_range lower upper fst snd ZRange.map] + | x : type.interp (type.prod _ _) |- _ => destruct x + | |- (_ <=? _)%zrange = true => + match goal with + | |- context [PreFancy.get_range_var] => + cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower R N] in *; cbn; + apply andb_true_iff; split; apply Z.leb_le + | _ => lazy + end; omega || reflexivity + | |- @eq zrange _ _ => lazy; reflexivity + | |- _ <= _ => cbv [machine_wordsize]; omega + | |- _ <= _ <= _ => cbv [machine_wordsize]; omega + end; intros. + + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step := + match goal with + | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' + end; intros; cbn [Nat.max].*) + + (* + Lemma montred256_prefancy_correct : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + @PreFancy.interp machine_wordsize base.type.Z (montred256 _ @ (##lo,##hi)) = ((lo + R * hi) * R') mod N. + Proof. + intros. + + rewrite montred256_prefancy_eq; cbv [montred256_prefancy']. + erewrite PreFancy.of_Expr_correct. + { apply montred256_correct_full; try assumption; reflexivity. } + { reflexivity. } + { lazy; reflexivity. } + { lazy; reflexivity. } + { repeat constructor. } + { cbv [In N N']; intros; intuition; subst; cbv; congruence. } + { assert (340282366920938463463374607431768211455 * 2 ^ 128 <= 2 ^ machine_wordsize - 1) as shiftl_128_ok by (lazy; congruence). + repeat (ok_expr_step; [ ]). + ok_expr_step. + lazy; congruence. + constructor. + constructor. } + { lazy. omega. } + Qed. +*) + + Definition montred256_fancy' (lo hi RegMod RegPInv RegZero error : positive) := + Fancy.of_Expr 3%positive + (fun z => if z =? N then Some RegMod else if z =? N' then Some RegPInv else if z =? 0 then Some RegZero else None) + [N;N'] + montred256 + ((lo, hi)%positive, tt) + error. + Derive montred256_fancy + SuchThat (forall RegMod RegPInv RegZero, + montred256_fancy RegMod RegPInv RegZero = montred256_fancy' RegMod RegPInv RegZero) + As montred256_fancy_eq. + Proof. + intros. + lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB + Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU + Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. + reflexivity. + Qed. + + Import Fancy.Registers. + + Definition montred256_alloc' lo hi RegPInv := + fun errorP errorR => + Fancy.allocate register + positive Pos.eqb + errorR + (montred256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) + [r2;r3;r4;r5;r6;r7;r8;r9;r10;r11;r12;r13;r14;r15;r16;r17;r18;r19;r20] + (fun n => if n =? 1000 then lo + else if n =? 1001 then hi + else if n =? 1002 then RegMod + else if n =? 1003 then RegPInv + else if n =? 1004 then RegZero + else errorR). + Derive montred256_alloc + SuchThat (montred256_alloc = montred256_alloc') + As montred256_alloc_eq. + Proof. + intros. + cbv [montred256_alloc' montred256_fancy]. + cbn. subst montred256_alloc. + reflexivity. + Qed. + + Import ProdEquiv. + + Local Ltac solve_bounds := + match goal with + | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega + | _ => assumption + end. + + Lemma montred256_alloc_equivalent errorP errorR cc_start_state start_context : + forall lo hi y t1 t2 scratch RegPInv extra_reg, + NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> + 0 <= start_context lo < R -> + 0 <= start_context hi < R -> + 0 <= start_context RegPInv < R -> + ProdEquiv.interp256 (montred256_alloc r0 r1 r30 errorP errorR) cc_start_state + (fun r => if reg_eqb r r0 + then start_context lo + else if reg_eqb r r1 + then start_context hi + else if reg_eqb r r30 + then start_context RegPInv + else start_context r) + = ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context. + Proof. + intros. cbv [R] in *. + cbv [Prod.MontRed256 montred256_alloc]. + + (* Extract proofs that no registers are equal to each other *) + repeat match goal with + | H : NoDup _ |- _ => inversion H; subst; clear H + | H : ~ In _ _ |- _ => cbv [In] in H + | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H + | H : ~ False |- _ => clear H + end. + + rewrite ProdEquiv.interp_Mul256 with (tmp2:=extra_reg) by (congruence || push_value_unused). + + rewrite mullh_mulhl. step_both_sides. + rewrite mullh_mulhl. step_both_sides. + (* + step_both_sides. + step_both_sides. + + rewrite ProdEquiv.interp_Mul256x256 with (tmp2:=extra_reg) by (congruence || push_value_unused). + + rewrite mulll_comm. step_both_sides. + step_both_sides. + step_both_sides. + rewrite mulhh_comm. step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + + + rewrite add_comm by (cbn; solve_bounds). step_both_sides. + rewrite addc_comm by (cbn; solve_bounds). step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + + cbn; repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence. + reflexivity.*) + Admitted. + + Import Fancy_PreFancy_Equiv. + + Definition interp_equivZZ_256 {s} := + @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). + Definition interp_equivZ_256 {s} := + @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(lia) ltac:(reflexivity). + + Local Ltac simplify_op_equiv start_ctx := + cbn - [Fancy.spec ident.gen_interp Fancy.cc_spec]; + repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end; + cbv - [ + Z.add_with_get_carry_full + Z.add_get_carry_full Z.sub_get_borrow_full + Z.le Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl + Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones + fst snd]; cbn [fst snd]; + try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega); + autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega; + repeat match goal with + | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H + | |- context [?x rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds) + | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega) + | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] => + replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring + end. + + Local Ltac solve_nonneg ctx := + match goal with x := (Fancy.spec _ _ _) |- _ => subst x end; + simplify_op_equiv ctx; Z.zero_bounds. + + Local Ltac generalize_result := + let v := fresh "v" in intro v; generalize v; clear v; intro v. + + Local Ltac generalize_result_nonneg ctx := + let v := fresh "v" in + let v_nonneg := fresh "v_nonneg" in + intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg]. + + Local Ltac step_abs := + match goal with + | [ |- context G[expr.interp ?ident_interp (expr.Abs ?f) ?x] ] + => let G' := context G[expr.interp ident_interp (f x)] in + change G'; cbv beta + end. + Local Ltac step ctx := + repeat step_abs; + match goal with + | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ => + apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx] + | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZ _ _ _ _ _ _) ] + => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result] + | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZZ _ _ _ _ _ _) ] + => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result] + end. + + Local Ltac break_ifs := + repeat (break_innermost_match_step; Z.ltb_to_lt; try (exfalso; omega); []). + + Local Opaque PreFancy.interp_cast_mod. + + Lemma prod_montred256_correct : + forall (cc_start_state : Fancy.CC.state) (* starting carry flags can be anything *) + (start_context : register -> Z) (* starting register values *) + (lo hi y t1 t2 scratch RegPInv extra_reg : register), (* registers to use in computation *) + NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> (* registers must be distinct *) + start_context RegPInv = N' -> (* RegPInv needs to hold the inverse of the modulus *) + start_context RegMod = N -> (* RegMod needs to hold the modulus *) + start_context RegZero = 0 -> (* RegZero needs to hold zero *) + (0 <= start_context lo < R) -> (* low half of the input is in bounds (R=2^256) *) + (0 <= start_context hi < R) -> (* high half of the input is in bounds (R=2^256) *) + let x := (start_context lo) + R * (start_context hi) in (* x is the input (split into two registers) *) + (0 <= x < R * N) -> (* input precondition *) + (ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context = (x * R') mod N). + Proof. + intros. subst x. cbv [N R N'] in *. + rewrite <-montred256_correct_full by (auto; vm_compute; reflexivity). + rewrite <-montred256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) + by (cbv [R]; auto with omega). + cbv [ProdEquiv.interp256]. + cbv [montred256_alloc montred256 expr.Interp]. + + (*step start_context; [ break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | ].*) + (*step start_context; [ break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | break_ifs; reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. + step start_context; [ reflexivity | | ]. + { + let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity. + rewrite !Z.shiftl_0_r, !Z.mod_mod by omega. + apply Z.testbit_neg_eq_if; + let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity; + auto using Z.mod_pos_bound with omega. } + step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. + reflexivity. + *) + Admitted. + + Import PrintingNotations. + Set Printing Width 10000. + + Print montred256. +(* +montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, + expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in + expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in + expr_let x2 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in + expr_let x3 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x2) in + expr_let x4 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x3₁) in + expr_let x5 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in + expr_let x6 := 79228162514264337593543950335 *₂₅₆ (uint128)(x4₁ >> 128) in + expr_let x7 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in + expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in + expr_let x9 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x5) in + expr_let x10 := ADDC_256 (x9₂, (uint128)(x7 >> 128), x8) in + expr_let x11 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x9₁) in + expr_let x12 := ADDC_256 (x11₂, (uint128)(x6 >> 128), x10₁) in + expr_let x13 := ADD_256 (x11₁, x₁) in + expr_let x14 := ADDC_256 (x13₂, x12₁, x₂) in + expr_let x15 := SELC (x14₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let x16 := SUB_256 (x14₁, x15) in + ADDM (x16₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951))%expr + : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)) +*) + + Import PreFancy. + Import PreFancy.Notations. + Local Notation "'RegMod'" := (expr.Ident (ident.Literal 115792089210356248762697446949407573530086143415290314195533631308867097853951)). + Local Notation "'RegPInv'" := (expr.Ident (ident.Literal 115792089210356248768974548684794254293921932838497980611635986753331132366849)). + Local Open Scope expr_scope. + Local Notation mulhl := (#(fancy_mulhl 256)). + Local Notation mulhh := (#(fancy_mulhh 256)). + Local Notation mulll := (#(fancy_mulll 256)). + Local Notation mullh := (#(fancy_mullh 256)). + Local Notation selc := (#(fancy_selc)). + Local Notation addm := (#(fancy_addm)). + Notation add n := (#(fancy_add 256 n)). + Notation addc n := (#(fancy_addc 256 n)). + + Print montred256. + (* +montred256 = +fun var : type -> Type => +λ x : var (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype), +mulhl@(x0, x₁, RegPInv); +mullh@(x1, x₁, RegPInv); +mulll@(x2, x₁, RegPInv); +(add 128)@(x3, x2, Lower{x1}); +(add 128)@(x4, x3₁, Lower{x0}); +mulll@(x5, RegMod, x4₁); +mullh@(x6, RegMod, x4₁); +mulhl@(x7, RegMod, x4₁); +mulhh@(x8, RegMod, x4₁); +(add 128)@(x9, x5, Lower{x7}); +(addc (-128))@(x10, carry{$x9}, x8, x7); +(add 128)@(x11, x9₁, Lower{x6}); +(addc (-128))@(x12, carry{$x11}, x10₁, x6); +(add 0)@(x13, x11₁, x₁); +(addc 0)@(x14, carry{$x13}, x12₁, x₂); +selc@(x15, (carry{$x14}, RegZero), RegMod); +#(fancy_sub 256 0)@(x16, x14₁, x15); +addm@(x17, (x16₁, RegZero), RegMod); +x17 + : Expr + (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype -> + type.base (base.type.type_base base.type.Z))%ptype + *) + +End Montgomery256. diff --git a/src/Fancy/Prod.v b/src/Fancy/Prod.v new file mode 100644 index 000000000..bba41fa62 --- /dev/null +++ b/src/Fancy/Prod.v @@ -0,0 +1,395 @@ +(* TODO: prune all these dependencies *) +Require Import Coq.ZArith.ZArith Coq.micromega.Lia. +Require Import Coq.derive.Derive. +Require Import Coq.Bool.Bool. +Require Import Coq.Strings.String. +Require Import Coq.Lists.List. +Require Crypto.Util.Strings.String. +Require Import Crypto.Util.Strings.Decimal. +Require Import Crypto.Util.Strings.HexString. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil Coq.Lists.List. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.GetGoal. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC. +Require Import Crypto.Util.ZUtil.Modulo. +Require Import Crypto.Util.ZUtil.Notations. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.ZUtil.Tactics.SplitMinMax. +Require Import Crypto.Util.ErrorT. +Require Import Crypto.Util.Strings.Show. +Require Import Crypto.Util.ZRange.Operations. +Require Import Crypto.Util.ZRange.BasicLemmas. +Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Arithmetic. +Require Import Crypto.Fancy.Spec. +Require Crypto.Language. +Require Crypto.UnderLets. +Require Crypto.AbstractInterpretation. +Require Crypto.AbstractInterpretationProofs. +Require Crypto.Rewriter. +Require Crypto.MiscCompilerPasses. +Require Crypto.CStringification. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope Z_scope. + +Import Associational Positional. + +Import + Crypto.Language + Crypto.UnderLets + Crypto.AbstractInterpretation + Crypto.AbstractInterpretationProofs + Crypto.Rewriter + Crypto.MiscCompilerPasses + Crypto.CStringification. + +Import + Language.Compilers + UnderLets.Compilers + AbstractInterpretation.Compilers + AbstractInterpretationProofs.Compilers + Rewriter.Compilers + MiscCompilerPasses.Compilers + CStringification.Compilers. + +Import Compilers.defaults. +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion QArith_base.inject_Z : Z >-> Q. +(* Notation "x" := (expr.Var x) (only printing, at level 9) : expr_scope. *) + +Import UnsaturatedSolinas. +Import Spec.Fancy. Import Registers. + +(* TODO : change these modules to sections *) +Module Prod. + Definition Mul256 (out src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := + Instr MUL128LL out (src1, src2) + (Instr MUL128UL tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr MUL128LU tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) cont)))). + Definition Mul256x256 (out outHigh src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := + Instr MUL128LL out (src1, src2) + (Instr MUL128UU outHigh (src1, src2) + (Instr MUL128UL tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) + (Instr MUL128LU tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont))))))). + + Definition MontRed256 lo hi y t1 t2 scratch RegPInv : @Fancy.expr register := + Mul256 y lo RegPInv t1 + (Mul256x256 t1 t2 y RegMod scratch + (Instr (ADD 0) lo (lo, t1) + (Instr (ADDC 0) hi (hi, t2) + (Instr SELC y (RegMod, RegZero) + (Instr (SUB 0) lo (hi, y) + (Instr ADDM lo (lo, RegZero, RegMod) + (Ret lo))))))). + + (* Barrett reduction -- this is only the "reduce" part, excluding the initial multiplication. *) + Definition MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : @Fancy.expr register := + let q1Bottom256 := scratchp1 in + let muSelect := scratchp2 in + let q2 := scratchp3 in + let q2High := scratchp4 in + let q2High2 := scratchp5 in + let q3 := scratchp1 in + let r2 := scratchp2 in + let r2High := scratchp3 in + let maybeM := scratchp1 in + Instr SELM muSelect (RegMuLow, RegZero) + (Instr (RSHI 255) q1Bottom256 (xHigh, x) + (Mul256x256 q2 q2High q1Bottom256 RegMuLow scratchp5 + (Instr (RSHI 255) q2High2 (RegZero, xHigh) + (Instr (ADD 0) q2High (q2High, q1Bottom256) + (Instr (ADDC 0) q2High2 (q2High2, RegZero) + (Instr (ADD 0) q2High (q2High, muSelect) + (Instr (ADDC 0) q2High2 (q2High2, RegZero) + (Instr (RSHI 1) q3 (q2High2, q2High) + (Mul256x256 r2 r2High RegMod q3 scratchp4 + (Instr (SUB 0) muSelect (x, r2) + (Instr (SUBC 0) xHigh (xHigh, r2High) + (Instr SELL maybeM (RegMod, RegZero) + (Instr (SUB 0) q3 (muSelect, maybeM) + (Instr ADDM x (q3, RegZero, RegMod) + (Ret x))))))))))))))). +End Prod. + +(* TODO : move to Fancy *) +Section interp_proofs. + Context {name} (name_eqb : name -> name -> bool) (wordmax : Z). + Let interp := interp name_eqb wordmax cc_spec. + Lemma interp_step i rd args cont cc ctx : + interp (Instr i rd args cont) cc ctx = + let result := spec i (Tuple.map ctx args) cc in + let new_cc := CC.update (writes_conditions i) result cc_spec cc in + let new_ctx := fun n => if name_eqb n rd then result mod wordmax else ctx n in interp cont new_cc new_ctx. + Proof. reflexivity. Qed. + + Lemma interp_state_equiv e : + forall cc ctx cc' ctx', + cc = cc' -> (forall r, ctx r = ctx' r) -> + interp e cc ctx = interp e cc' ctx'. + Proof. + induction e; intros; subst; cbn; [solve[auto]|]. + apply IHe; rewrite Tuple.map_ext with (g:=ctx') by auto; + [reflexivity|]. + intros; break_match; auto. + Qed. +End interp_proofs. + +Module ProdEquiv. + + Definition wordmax := 2^256. + Definition interp256 := Fancy.interp reg_eqb (2^256) cc_spec. + Lemma cc_overwrite_full x1 x2 l1 cc : + CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec (CC.update l1 x1 cc_spec cc) = CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec cc. + Proof. + cbv [CC.update]. cbn [CC.cc_c CC.cc_m CC.cc_l CC.cc_z]. + break_match; try match goal with H : ~ In _ _ |- _ => cbv [In] in H; tauto end. + reflexivity. + Qed. + + Definition value_unused r e : Prop := + forall x cc ctx, interp256 e cc ctx = interp256 e cc (fun r' => if reg_eqb r' r then x else ctx r'). + + Lemma value_unused_skip r i rd args cont (Hcont: value_unused r cont) : + r <> rd -> + (~ In r (Tuple.to_list _ args)) -> + value_unused r (Instr i rd args cont). + Proof. + cbv [value_unused interp256] in *; intros. + rewrite !interp_step; cbv zeta. + rewrite Hcont with (x:=x). + match goal with |- ?lhs = ?rhs => + match lhs with context [Tuple.map ?f ?t] => + match rhs with context [Tuple.map ?g ?t] => + rewrite (Tuple.map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) + end end end. + apply interp_state_equiv; [ congruence | ]. + { intros; cbv [reg_eqb] in *; break_match; congruence. } + Qed. + + Lemma value_unused_overwrite r i args cont : + (~ In r (Tuple.to_list _ args)) -> + value_unused r (Instr i r args cont). + Proof. + cbv [value_unused interp256]; intros; rewrite !interp_step; cbv zeta. + match goal with |- ?lhs = ?rhs => + match lhs with context [Tuple.map ?f ?t] => + match rhs with context [Tuple.map ?g ?t] => + rewrite (Tuple.map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) + end end end. + apply interp_state_equiv; [ congruence | ]. + { intros; cbv [reg_eqb] in *; break_match; congruence. } + Qed. + + Lemma value_unused_ret r r' : + r <> r' -> + value_unused r (Ret r'). + Proof. + cbv - [reg_dec]; intros. + break_match; congruence. + Qed. + + Ltac remember_results := + repeat match goal with |- context [(spec ?i ?args ?flags) mod ?w] => + let x := fresh "x" in + let y := fresh "y" in + let Heqx := fresh "Heqx" in + remember (spec i args flags) as x eqn:Heqx; + remember (x mod w) as y + end. + + Ltac do_interp_step := + rewrite interp_step; cbn - [interp spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + remember_results. + + Lemma interp_Mul256 out src1 src2 tmp tmp2 cont cc ctx: + out <> src1 -> + out <> src2 -> + out <> tmp -> + out <> tmp2 -> + src1 <> src2 -> + src1 <> tmp -> + src1 <> tmp2 -> + src2 <> tmp -> + src2 <> tmp2 -> + tmp <> tmp2 -> + value_unused tmp cont -> + value_unused tmp2 cont -> + interp256 (Prod.Mul256 out src1 src2 tmp cont) cc ctx = + interp256 ( + Instr MUL128LU tmp (src1, src2) + (Instr MUL128UL tmp2 (src1, src2) + (Instr MUL128LL out (src1, src2) + (Instr (ADD 128) out (out, tmp2) + (Instr (ADD 128) out (out, tmp) cont))))) cc ctx. + Proof. + intros; cbv [Prod.Mul256 interp256]. + repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU ADD] in * ). + + match goal with H : value_unused tmp _ |- _ => erewrite H end. + match goal with H : value_unused tmp2 _ |- _ => erewrite H end. + apply interp_state_equiv. + { rewrite !cc_overwrite_full. + f_equal. subst. lia. } + { intros; cbv [reg_eqb]. + repeat (break_match_step ltac:(fun _ => idtac); try congruence); reflexivity. } + Qed. + + Lemma interp_Mul256x256 out outHigh src1 src2 tmp tmp2 cont cc ctx: + out <> src1 -> + out <> outHigh -> + out <> src2 -> + out <> tmp -> + out <> tmp2 -> + outHigh <> src1 -> + outHigh <> src2 -> + outHigh <> tmp -> + outHigh <> tmp2 -> + src1 <> src2 -> + src1 <> tmp -> + src1 <> tmp2 -> + src2 <> tmp -> + src2 <> tmp2 -> + tmp <> tmp2 -> + value_unused tmp cont -> + value_unused tmp2 cont -> + interp256 (Prod.Mul256x256 out outHigh src1 src2 tmp cont) cc ctx = + interp256 ( + Instr MUL128LL out (src1, src2) + (Instr MUL128LU tmp (src1, src2) + (Instr MUL128UL tmp2 (src1, src2) + (Instr MUL128UU outHigh (src1, src2) + (Instr (ADD 128) out (out, tmp2) + (Instr (ADDC (-128)) outHigh (outHigh, tmp2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont)))))))) cc ctx. + Proof. + intros; cbv [Prod.Mul256x256 interp256]. + repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU MUL128UU ADD ADDC] in * ). + + match goal with H : value_unused tmp _ |- _ => erewrite H end. + match goal with H : value_unused tmp2 _ |- _ => erewrite H end. + apply interp_state_equiv. + { rewrite !cc_overwrite_full. + f_equal. + subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. + lia. } + { intros; cbv [reg_eqb]. + repeat (break_match_step ltac:(fun _ => idtac); try congruence); try reflexivity; [ ]. + subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. + lia. } + Qed. + + Local Ltac prove_comm H := + cbv [interp256]; rewrite !interp_step; cbn - [Fancy.interp]; + intros; rewrite H; try reflexivity. + + Lemma mulll_comm rd x y cont cc ctx : + interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = + interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx. + Proof. prove_comm Z.mul_comm. Qed. + + Lemma mulhh_comm rd x y cont cc ctx : + interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = + interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx. + Proof. prove_comm Z.mul_comm. Qed. + + Lemma mullh_mulhl rd x y cont cc ctx : + interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = + interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx. + Proof. prove_comm Z.mul_comm. Qed. + + Lemma add_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = + interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx. + Proof. + prove_comm Z.add_comm. + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). + reflexivity. + Qed. + + Lemma addc_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = + interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx. + Proof. + intros; + prove_comm (Z.add_comm (ctx x)). + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). + reflexivity. + Qed. + + (* Tactics to help prove that something in Fancy is line-by-line equivalent to something in PreFancy *) + Ltac push_value_unused := + repeat match goal with + | |- ~ In _ _ => cbn; intuition; congruence + | _ => apply ProdEquiv.value_unused_overwrite + | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ] + | _ => apply ProdEquiv.value_unused_ret; congruence + end. + + Ltac remember_single_result := + match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] => + let x := fresh "x" in + let y := fresh "y" in + let Heqx := fresh "Heqx" in + remember (Fancy.spec i args cc) as x eqn:Heqx; + remember (x mod w) as y + end. + Ltac step_both_sides := + match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 => + rewrite (interp_step reg_eqb wordmax i rd1 args1); rewrite (interp_step reg_eqb wordmax i rd2 args2); + cbn - [Fancy.interp Fancy.spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + remember_single_result; + lazymatch goal with + | |- context [Fancy.spec i _ _] => + let Heqa1 := fresh in + let Heqa2 := fresh in + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1; + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2; + cbn in Heqa1; cbn in Heqa2; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence; + let a1 := match type of Heqa1 with _ = ?a1 => a1 end in + let a2 := match type of Heqa2 with _ = ?a2 => a2 end in + (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2) + | _ => idtac + end + end. +End ProdEquiv. + diff --git a/src/Fancy/Spec.v b/src/Fancy/Spec.v new file mode 100644 index 000000000..01147a2a4 --- /dev/null +++ b/src/Fancy/Spec.v @@ -0,0 +1,348 @@ +(* TODO: prune all these dependencies *) +Require Import Coq.ZArith.ZArith Coq.micromega.Lia. +Require Import Coq.derive.Derive. +Require Import Coq.Bool.Bool. +Require Import Coq.Strings.String. +Require Import Coq.Lists.List. +Require Crypto.Util.Strings.String. +Require Import Crypto.Util.Strings.Decimal. +Require Import Crypto.Util.Strings.HexString. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil Coq.Lists.List. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.GetGoal. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC. +Require Import Crypto.Util.ZUtil.Modulo. +Require Import Crypto.Util.ZUtil.Notations. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.ZUtil.Tactics.SplitMinMax. +Require Import Crypto.Util.ErrorT. +Require Import Crypto.Util.Strings.Show. +Require Import Crypto.Util.ZRange.Operations. +Require Import Crypto.Util.ZRange.BasicLemmas. +Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Arithmetic. +Require Crypto.Language. +Require Crypto.UnderLets. +Require Crypto.AbstractInterpretation. +Require Crypto.AbstractInterpretationProofs. +Require Crypto.Rewriter. +Require Crypto.MiscCompilerPasses. +Require Crypto.CStringification. +Require Export Crypto.PushButtonSynthesis. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope Z_scope. + +Import Associational Positional. + +Import + Crypto.Language + Crypto.UnderLets + Crypto.AbstractInterpretation + Crypto.AbstractInterpretationProofs + Crypto.Rewriter + Crypto.MiscCompilerPasses + Crypto.CStringification. + +Import + Language.Compilers + UnderLets.Compilers + AbstractInterpretation.Compilers + AbstractInterpretationProofs.Compilers + Rewriter.Compilers + MiscCompilerPasses.Compilers + CStringification.Compilers. + +Import Compilers.defaults. +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion QArith_base.inject_Z : Z >-> Q. +(* Notation "x" := (expr.Var x) (only printing, at level 9) : expr_scope. *) + +Import UnsaturatedSolinas. + +Module Fancy. + + Module CC. + Inductive code : Type := + | C : code + | M : code + | L : code + | Z : code + . + + Record state := + { cc_c : bool; cc_m : bool; cc_l : bool; cc_z : bool }. + + Definition code_dec (x y : code) : {x = y} + {x <> y}. + Proof. destruct x, y; try apply (left eq_refl); right; congruence. Defined. + + Definition update (to_write : list code) (result : BinInt.Z) (cc_spec : code -> BinInt.Z -> bool) (old_state : state) + : state := + {| + cc_c := if (In_dec code_dec C to_write) + then cc_spec C result + else old_state.(cc_c); + cc_m := if (In_dec code_dec M to_write) + then cc_spec M result + else old_state.(cc_m); + cc_l := if (In_dec code_dec L to_write) + then cc_spec L result + else old_state.(cc_l); + cc_z := if (In_dec code_dec Z to_write) + then cc_spec Z result + else old_state.(cc_z) + |}. + + End CC. + + Record instruction := + { + num_source_regs : nat; + writes_conditions : list CC.code; + spec : tuple Z num_source_regs -> CC.state -> Z + }. + + Section expr. + Context {name : Type} (name_eqb : name -> name -> bool) (wordmax : Z) (cc_spec : CC.code -> Z -> bool). + + Inductive expr := + | Ret : name -> expr + | Instr (i : instruction) + (rd : name) (* destination register *) + (args : tuple name i.(num_source_regs)) (* source registers *) + (cont : expr) (* next line *) + : expr + . + + Fixpoint interp (e : expr) (cc : CC.state) (ctx : name -> Z) : Z := + match e with + | Ret n => ctx n + | Instr i rd args cont => + let result := i.(spec) (Tuple.map ctx args) cc in + let new_cc := CC.update i.(writes_conditions) result cc_spec cc in + let new_ctx := (fun n => if name_eqb n rd then result mod wordmax else ctx n) in + interp cont new_cc new_ctx + end. + End expr. + + Section ISA. + Import CC. + + Definition cc_spec (x : CC.code) (result : BinInt.Z) : bool := + match x with + | CC.C => Z.testbit result 256 (* carry bit *) + | CC.M => Z.testbit result 255 (* most significant bit *) + | CC.L => Z.testbit result 0 (* least significant bit *) + | CC.Z => result =? 0 (* whether equal to zero *) + end. + + Definition lower128 x := (Z.land x (Z.ones 128)). + Definition upper128 x := (Z.shiftr x 128). + Local Notation "x '[C]'" := (if x.(cc_c) then 1 else 0) (at level 20). + Local Notation "x '[M]'" := (if x.(cc_m) then 1 else 0) (at level 20). + Local Notation "x '[L]'" := (if x.(cc_l) then 1 else 0) (at level 20). + Local Notation "x '[Z]'" := (if x.(cc_z) then 1 else 0) (at level 20). + Local Notation "'int'" := (BinInt.Z). + Local Notation "x << y" := ((x << y) mod (2^256)) : Z_scope. (* truncating left shift *) + + + (* Note: In the specification document, argument order gets a bit + confusing. Like here, r0 is always the first argument "source 0" + and r1 the second. But the specification of MUL128LU is: + (R[RS1][127:0] * R[RS0][255:128]) + + while the specification of SUB is: + (R[RS0] - shift(R[RS1], imm)) + + In the SUB case, r0 is really treated the first argument, but in + MUL128LU the order seems to be reversed; rather than low-high, we + take the high part of the first argument r0 and the low parts of + r1. This is also true for MUL128UL. *) + + Definition ADD (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 + (r1 << imm)) + |}. + + Definition ADDC (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 + (r1 << imm) + cc[C]) + |}. + + Definition SUB (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 - (r1 << imm)) + |}. + + Definition SUBC (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 - (r1 << imm) - cc[C]) + |}. + + + Definition MUL128LL : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (lower128 r0) * (lower128 r1)) + |}. + + Definition MUL128LU : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (lower128 r1) * (upper128 r0)) (* see note *) + |}. + + Definition MUL128UL : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (upper128 r1) * (lower128 r0)) (* see note *) + |}. + + Definition MUL128UU : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (upper128 r0) * (upper128 r1)) + |}. + + (* Note : Unlike the other operations, the output of RSHI is + truncated in the specification. This is not strictly necessary, + since the interpretation function truncates the output + anyway. However, it is useful to make the definition line up + exactly with Z.rshi. *) + Definition RSHI (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (((2^256 * r0) + r1) >> imm) mod (2^256)) + |}. + + Definition SELC : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[C] =? 1 then r0 else r1) + |}. + + Definition SELM : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[M] =? 1 then r0 else r1) + |}. + + Definition SELL : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[L] =? 1 then r0 else r1) + |}. + + (* TODO : treat the MOD register specially, like CC *) + Definition ADDM : instruction := + {| + num_source_regs := 3; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1, MOD) cc => + let ra := r0 + r1 in + if ra >=? MOD + then ra - MOD + else ra) + |}. + + End ISA. + + Module Registers. + Inductive register : Type := + | r0 : register + | r1 : register + | r2 : register + | r3 : register + | r4 : register + | r5 : register + | r6 : register + | r7 : register + | r8 : register + | r9 : register + | r10 : register + | r11 : register + | r12 : register + | r13 : register + | r14 : register + | r15 : register + | r16 : register + | r17 : register + | r18 : register + | r19 : register + | r20 : register + | r21 : register + | r22 : register + | r23 : register + | r24 : register + | r25 : register + | r26 : register + | r27 : register + | r28 : register + | r29 : register + | r30 : register + | RegZero : register (* r31 *) + | RegMod : register + . + + Definition reg_dec (x y : register) : {x = y} + {x <> y}. + Proof. destruct x, y; try (apply left; congruence); right; congruence. Defined. + Definition reg_eqb x y := if reg_dec x y then true else false. + + Lemma reg_eqb_neq x y : x <> y -> reg_eqb x y = false. + Proof. cbv [reg_eqb]; break_match; congruence. Qed. + Lemma reg_eqb_refl x : reg_eqb x x = true. + Proof. cbv [reg_eqb]; break_match; congruence. Qed. + End Registers. +End Fancy. diff --git a/src/Fancy/Translation.v b/src/Fancy/Translation.v new file mode 100644 index 000000000..96817a8be --- /dev/null +++ b/src/Fancy/Translation.v @@ -0,0 +1,1246 @@ +(* TODO: prune all these dependencies *) +Require Import Coq.ZArith.ZArith Coq.micromega.Lia. +Require Import Coq.derive.Derive. +Require Import Coq.Bool.Bool. +Require Import Coq.Strings.String. +Require Import Coq.Lists.List. +Require Crypto.Util.Strings.String. +Require Import Crypto.Util.Strings.Decimal. +Require Import Crypto.Util.Strings.HexString. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil Coq.Lists.List. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.GetGoal. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC. +Require Import Crypto.Util.ZUtil.Modulo. +Require Import Crypto.Util.ZUtil.Notations. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.ZUtil.Tactics.SplitMinMax. +Require Import Crypto.Util.ErrorT. +Require Import Crypto.Util.Strings.Show. +Require Import Crypto.Util.ZRange.Operations. +Require Import Crypto.Util.ZRange.BasicLemmas. +Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Arithmetic. +Require Import Crypto.Fancy.Spec. +Require Crypto.Language. +Require Crypto.UnderLets. +Require Crypto.AbstractInterpretation. +Require Crypto.AbstractInterpretationProofs. +Require Crypto.Rewriter. +Require Crypto.MiscCompilerPasses. +Require Crypto.CStringification. +Require Export Crypto.PushButtonSynthesis. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope Z_scope. + +Import Associational Positional. + +Import + Crypto.Language + Crypto.UnderLets + Crypto.AbstractInterpretation + Crypto.AbstractInterpretationProofs + Crypto.Rewriter + Crypto.MiscCompilerPasses + Crypto.CStringification. + +Import + Language.Compilers + UnderLets.Compilers + AbstractInterpretation.Compilers + AbstractInterpretationProofs.Compilers + Rewriter.Compilers + MiscCompilerPasses.Compilers + CStringification.Compilers. + +Import Compilers.defaults. +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion QArith_base.inject_Z : Z >-> Q. +(* Notation "x" := (expr.Var x) (only printing, at level 9) : expr_scope. *) + +Import UnsaturatedSolinas. + +Import Spec.Fancy. + +(* TODO: organize this file *) +Section of_prefancy. + Local Notation cexpr := (@Compilers.expr.expr base.type ident.ident). + Local Notation LetInAppIdentZ S D r eidc x f + := (expr.LetIn + (A:=type.base (base.type.type_base base.type.Z)) + (B:=type.base D) + (expr.App + (s:=type.base (base.type.type_base base.type.Z)) + (d:=type.base (base.type.type_base base.type.Z)) + (expr.Ident (ident.Z_cast r)) + (expr.App + (s:=type.base S) + (d:=type.base (base.type.type_base base.type.Z)) + eidc + x)) + f). + Local Notation LetInAppIdentZZ S D r eidc x f + := (expr.LetIn + (A:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + (B:=type.base D) + (expr.App + (s:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + (d:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + (expr.Ident (ident.Z_cast2 r)) + (expr.App + (s:=type.base S) + (d:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + eidc + x)) + f). + Context (name : Type) (name_succ : name -> name) (error : name) (consts : Z -> option name). + + Fixpoint base_var (t : base.type) : Type := + match t with + | base.type.Z => name + | base.type.prod a b => base_var a * base_var b + | _ => unit + end. + Fixpoint var (t : type.type base.type) : Type := + match t with + | type.base t => base_var t + | type.arrow s d => var s -> var d + end. + Fixpoint base_error {t} : base_var t + := match t with + | base.type.Z => error + | base.type.prod A B => (@base_error A, @base_error B) + | _ => tt + end. + Fixpoint make_error {t} : var t + := match t with + | type.base _ => base_error + | type.arrow s d => fun _ => @make_error d + end. + + Fixpoint of_prefancy_scalar {t} (s : @cexpr var t) : var t + := match s in expr.expr t return var t with + | Compilers.expr.Var t v => v + | expr.App s d f x => @of_prefancy_scalar _ f (@of_prefancy_scalar _ x) + | expr.Ident t idc + => match idc in ident.ident t return var t with + | ident.Literal base.type.Z v => match consts v with + | Some n => n + | None => error + end + | ident.pair A B => fun a b => (a, b)%core + | ident.fst A B => fun v => fst v + | ident.snd A B => fun v => snd v + | ident.Z_cast r => fun v => v + | ident.Z_cast2 (r1, r2) => fun v => v + | ident.Z_land => fun x y => x + | _ => make_error + end + | expr.Abs s d f => make_error + | expr.LetIn A B x f => make_error + end%expr_pat%etype. + + (* Note : some argument orders are reversed for MUL128LU, MUL128UL, SELC, SELM, and SELL *) + Local Notation tZ := base.type.Z. + Definition of_prefancy_ident {s d : base.type} (idc : ident.ident (s -> d)) + : @cexpr var s -> option {i : instruction & tuple name i.(num_source_regs) } := + match idc in ident.ident t return match t return Type with + | type.arrow (type.base s) (type.base d) + => @cexpr var s + | _ => unit + end + -> option {i : instruction & tuple name i.(num_source_regs) } + with + | ident.fancy_add log2wordmax imm + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (ADD imm) (of_prefancy_scalar args)) + else None + | ident.fancy_addc log2wordmax imm + => fun args : @cexpr var (tZ * tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (ADDC imm) (of_prefancy_scalar ((#ident.snd @ (#ident.fst @ args)), (#ident.snd @ args)))) + else None + | ident.fancy_sub log2wordmax imm + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (SUB imm) (of_prefancy_scalar args)) + else None + | ident.fancy_subb log2wordmax imm + => fun args : @cexpr var (tZ * tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (SUBC imm) (of_prefancy_scalar ((#ident.snd @ (#ident.fst @ args)), (#ident.snd @ args)))) + else None + | ident.fancy_mulll log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128LL (of_prefancy_scalar args)) + else None + | ident.fancy_mullh log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128LU (of_prefancy_scalar ((#ident.snd @ args), (#ident.fst @ args)))) + else None + | ident.fancy_mulhl log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128UL (of_prefancy_scalar ((#ident.snd @ args), (#ident.fst @ args)))) + else None + | ident.fancy_mulhh log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128UU (of_prefancy_scalar args)) + else None + | ident.fancy_rshi log2wordmax imm + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (RSHI imm) (of_prefancy_scalar args)) + else None + | ident.fancy_selc + => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ SELC (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) + | ident.fancy_selm log2wordmax + => fun args : @cexpr var (tZ * tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ SELM (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) + else None + | ident.fancy_sell + => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ SELL (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) + | ident.fancy_addm + => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ ADDM (of_prefancy_scalar args)) + | _ => fun _ => None + end. + + Local Notation "x <- y ; f" := (match y with Some x => f | None => Ret error end). + Definition of_prefancy_step + (of_prefancy : forall (next_name : name) {t} (e : @cexpr var t), @expr name) + (next_name : name) {t} (e : @cexpr var t) : @expr name + := let default _ := (e' <- type.try_transport (@base.try_make_transport_cps) (@cexpr var) t tZ e; + Ret (of_prefancy_scalar e')) in + match e with + | LetInAppIdentZ s d r eidc x f + => idc <- invert_expr.invert_Ident eidc; + instr_args <- @of_prefancy_ident s tZ idc x; + let i : instruction := projT1 instr_args in + let args : tuple name i.(num_source_regs) := projT2 instr_args in + Instr i next_name args (@of_prefancy (name_succ next_name) _ (f next_name)) + | LetInAppIdentZZ s d r eidc x f + => idc <- invert_expr.invert_Ident eidc; + instr_args <- @of_prefancy_ident s (tZ * tZ) idc x; + let i : instruction := projT1 instr_args in + let args : tuple name i.(num_source_regs) := projT2 instr_args in + Instr i next_name args (@of_prefancy (name_succ next_name) _ (f (next_name, next_name))) (* the second argument is for the carry, and it will not be read from directly. *) + | _ => default tt + end. + Fixpoint of_prefancy (next_name : name) {t} (e : @cexpr var t) : @expr name + := @of_prefancy_step of_prefancy next_name t e. + + Section Proofs. + Context (name_eqb : name -> name -> bool). + Context (name_lt : name -> name -> Prop) + (name_lt_trans : forall n1 n2 n3, + name_lt n1 n2 -> name_lt n2 n3 -> name_lt n1 n3) + (name_lt_irr : forall n, ~ name_lt n n) + (name_lt_succ : forall n, name_lt n (name_succ n)) + (name_eqb_eq : forall n1 n2, name_eqb n1 n2 = true -> n1 = n2) + (name_eqb_neq : forall n1 n2, name_eqb n1 n2 = false -> n1 <> n2). + Local Notation wordmax := (2^256). + Local Notation interp := (interp name_eqb wordmax cc_spec). + Local Notation uint256 := r[0~>wordmax-1]%zrange. + Local Notation uint128 := r[0~>(2 ^ (Z.log2 wordmax / 2) - 1)]%zrange. + Definition cast_oor (r : zrange) (v : Z) := v mod (upper r + 1). + Local Notation "'existZ' x" := (existT _ (type.base (base.type.type_base tZ)) x) (at level 200). + Local Notation "'existZZ' x" := (existT _ (type.base (base.type.type_base tZ * base.type.type_base tZ)%etype) x) (at level 200). + Local Notation cinterp := (expr.interp (@ident.gen_interp cast_oor)). + Definition interp_if_Z {t} (e : cexpr t) : option Z := + option_map (expr.interp (@ident.gen_interp cast_oor) (t:=tZ)) + (type.try_transport + (@base.try_make_transport_cps) + _ _ tZ e). + + Lemma interp_if_Z_Some {t} e r : + @interp_if_Z t e = Some r -> + exists e', + (type.try_transport + (@base.try_make_transport_cps) _ _ tZ e) = Some e' /\ + expr.interp (@ident.gen_interp cast_oor) (t:=tZ) e' = r. + Proof. + clear. cbv [interp_if_Z option_map]. + break_match; inversion 1; intros. + subst; eexists. tauto. + Qed. + + Inductive valid_scalar + : @cexpr var (base.type.type_base tZ) -> Prop := + | valid_scalar_literal : + forall v n, + consts v = Some n -> + valid_scalar (expr.Ident (@ident.Literal base.type.Z v)) + | valid_scalar_Var : + forall v, + valid_scalar (expr.App (expr.Ident (ident.Z_cast uint256)) (expr.Var v)) + | valid_scalar_fst : + forall v r2, + valid_scalar + (expr.App (expr.Ident (ident.Z_cast uint256)) + (expr.App (expr.Ident (@ident.fst (base.type.type_base tZ) + (base.type.type_base tZ))) + (expr.App (expr.Ident (ident.Z_cast2 (uint256, r2))) (expr.Var v)))) + . + Inductive valid_carry + : @cexpr var (base.type.type_base tZ) -> Prop := + | valid_carry_0 : consts 0 <> None -> valid_carry (expr.Ident (@ident.Literal base.type.Z 0)) + | valid_carry_1 : consts 1 <> None -> valid_carry (expr.Ident (@ident.Literal base.type.Z 1)) + | valid_carry_snd : + forall v r2, + valid_carry + (expr.App (expr.Ident (ident.Z_cast r[0~>1])) + (expr.App (expr.Ident (@ident.snd (base.type.type_base tZ) + (base.type.type_base tZ))) + (expr.App (expr.Ident (ident.Z_cast2 (r2, r[0~>1]))) (expr.Var v)))) + . + + Fixpoint interp_base (ctx : name -> Z) (cctx : name -> bool) {t} + : base_var t -> base.interp t := + match t as t0 return base_var t0 -> base.interp t0 with + | base.type.type_base tZ => fun n => ctx n + | (base.type.type_base tZ * base.type.type_base tZ)%etype => + fun v => (ctx (fst v), Z.b2z (cctx (snd v))) + | (a * b)%etype => + fun _ => DefaultValue.type.base.default + | _ => fun _ : unit => + DefaultValue.type.base.default + end. + + Definition new_write {d} : var d -> name := + match d with + | type.base (base.type.type_base tZ) => fun r => r + | type.base (base.type.type_base tZ * base.type.type_base tZ)%etype => fst + | _ => fun _ => error + end. + Definition new_cc_to_name (old_cc_to_name : CC.code -> name) (i : instruction) + {d} (new_r : var d) (x : CC.code) : name := + if (in_dec CC.code_dec x (writes_conditions i)) + then new_write new_r + else old_cc_to_name x. + + Inductive valid_ident + : forall {s d}, + (CC.code -> name) -> (* last variables that wrote to each flag *) + (var d -> CC.code -> name) -> (* new last variables that wrote to each flag *) + ident.ident (s->d) -> @cexpr var s -> Prop := + | valid_fancy_add : + forall r imm x y, + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r (ADD imm)) (ident.fancy_add 256 imm) (x, y)%expr_pat + | valid_fancy_addc : + forall r imm c x y, + (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.C) -> + valid_carry c -> + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r (ADDC imm)) (ident.fancy_addc 256 imm) (c, x, y)%expr_pat + | valid_fancy_sub : + forall r imm x y, + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r (SUB imm)) (ident.fancy_sub 256 imm) (x, y)%expr_pat + | valid_fancy_subb : + forall r imm c x y, + (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.C) -> + valid_carry c -> + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r (SUBC imm)) (ident.fancy_subb 256 imm) (c, x, y)%expr_pat + | valid_fancy_mulll : + forall r x y, + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r MUL128LL) (ident.fancy_mulll 256) (x, y)%expr_pat + | valid_fancy_mullh : + forall r x y, + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r MUL128LU) (ident.fancy_mullh 256) (x, y)%expr_pat + | valid_fancy_mulhl : + forall r x y, + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r MUL128UL) (ident.fancy_mulhl 256) (x, y)%expr_pat + | valid_fancy_mulhh : + forall r x y, + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r MUL128UU) (ident.fancy_mulhh 256) (x, y)%expr_pat + | valid_fancy_rshi : + forall r imm x y, + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r (RSHI imm)) (ident.fancy_rshi 256 imm) (x, y)%expr_pat + | valid_fancy_selc : + forall r c x y, + (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.C) -> + valid_carry c -> + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r SELC) ident.fancy_selc (c, x, y)%expr_pat + | valid_fancy_selm : + forall r c x y, + (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.M) -> + valid_scalar c -> + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r SELM) (ident.fancy_selm 256) (c, x, y)%expr_pat + | valid_fancy_sell : + forall r c x y, + (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.L) -> + valid_scalar c -> + valid_scalar x -> + valid_scalar y -> + valid_ident r (new_cc_to_name r SELL) ident.fancy_sell (c, x, y)%expr_pat + | valid_fancy_addm : + forall r x y m, + valid_scalar x -> + valid_scalar y -> + valid_scalar m -> + valid_ident r (new_cc_to_name r ADDM) ident.fancy_addm (x, y, m)%expr_pat + . + + Inductive valid_expr + : forall t, + (CC.code -> name) -> (* the last variables that wrote to each flag *) + @cexpr var t -> Prop := + | valid_LetInZ_loosen : + forall s d idc r rf x f u ia, + valid_ident r rf idc x -> + 0 < u < wordmax -> + (forall x, valid_expr _ (rf x) (f x)) -> + of_prefancy_ident idc x = Some ia -> + (forall cc ctx, + (forall n v, consts v = Some n -> ctx n = v) -> + (forall n, ctx n mod wordmax = ctx n) -> + let args := Tuple.map ctx (projT2 ia) in + spec (projT1 ia) args cc mod wordmax = spec (projT1 ia) args cc mod (u+1)) -> + valid_expr _ r (LetInAppIdentZ s d r[0~>u] (expr.Ident idc) x f) + | valid_LetInZ : + forall s d idc r rf x f, + valid_ident r rf idc x -> + (forall x, valid_expr _ (rf x) (f x)) -> + valid_expr _ r (LetInAppIdentZ s d uint256 (expr.Ident idc) x f) + | valid_LetInZZ : + forall s d idc r rf x f, + valid_ident r rf idc x -> + (forall x : var (type.base (base.type.type_base tZ * base.type.type_base tZ)%etype), + fst x = snd x -> + valid_expr _ (rf x) (f x)) -> + valid_expr _ r (LetInAppIdentZZ s d (uint256, r[0~>1]) (expr.Ident idc) x f) + | valid_Ret : + forall r x, + valid_scalar x -> + valid_expr _ r x + . + + Lemma cast_oor_id v u : 0 <= v <= u -> cast_oor r[0 ~> u] v = v. + Proof. intros; cbv [cast_oor upper]. apply Z.mod_small; omega. Qed. + Lemma cast_oor_mod v u : 0 <= u -> cast_oor r[0 ~> u] v mod (u+1) = v mod (u+1). + Proof. intros; cbv [cast_oor upper]. apply Z.mod_mod; omega. Qed. + + Lemma wordmax_nonneg : 0 <= wordmax. + Proof. cbv; congruence. Qed. + + Lemma of_prefancy_scalar_correct' + (e1 : @cexpr var (type.base (base.type.type_base tZ))) + (e2 : cexpr (type.base (base.type.type_base tZ))) + G (ctx : name -> Z) (cctx : name -> bool) : + valid_scalar e1 -> + LanguageWf.Compilers.expr.wf G e1 e2 -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall v1 v2, In (existZ (v1, v2)) G -> ctx v1 = v2) -> (* implied by above *) + (forall n, ctx n mod wordmax = ctx n) -> + (forall v1 v2, In (existZZ (v1, v2)) G -> ctx (fst v1) = fst v2) -> + (forall v1 v2, In (existZZ (v1, v2)) G -> Z.b2z (cctx (snd v1)) = snd v2) -> + ctx (of_prefancy_scalar e1) = cinterp e2. + Proof. + inversion 1; inversion 1; + cbv [interp_if_Z option_map]; + cbn [of_prefancy_scalar interp_base]; intros. + all: repeat first [ + progress subst + | exfalso; assumption + | progress inversion_sigma + | progress inversion_option + | progress Prod.inversion_prod + | progress LanguageInversion.Compilers.expr.inversion_expr + | progress LanguageInversion.Compilers.expr.invert_subst + | progress LanguageWf.Compilers.expr.inversion_wf_one_constr + | progress LanguageInversion.Compilers.expr.invert_match + | progress destruct_head'_sig + | progress destruct_head'_and + | progress destruct_head'_or + | progress Z.ltb_to_lt + | progress cbv [id] + | progress cbn [fst snd upper lower fst snd eq_rect projT1 projT2 expr.interp ident.interp ident.gen_interp interp_base] in * + | progress HProp.eliminate_hprop_eq + | progress break_innermost_match_hyps + | progress break_innermost_match + | match goal with H : context [_ = cinterp _] |- context [cinterp _] => + rewrite <-H by eauto; try reflexivity end + | solve [eauto using (f_equal2 pair), cast_oor_id, wordmax_nonneg] + | rewrite LanguageWf.Compilers.ident.cast_out_of_bounds_simple_0_mod + | rewrite Z.mod_mod by lia + | rewrite cast_oor_mod by (cbv; congruence) + | lia + | match goal with + H : context[ ?x mod _ = ?x ] |- _ => rewrite H end + | match goal with + | H : context [In _ _ -> _ = _] |- _ => erewrite H by eauto end + | match goal with + | H : forall v1 v2, In _ _ -> ?ctx v1 = v2 |- ?x = ?x mod ?m => + replace m with wordmax by ring; erewrite <-(H _ x) by eauto; solve [eauto] + end + | match goal with + | H : forall v1 v2, In _ _ -> ?ctx (fst v1) = fst v2, + H' : In (existZZ (_,(?x,?y))) _ |- ?x = ?x mod ?m => + replace m with wordmax by ring; + specialize (H _ _ H'); cbn [fst] in H; rewrite <-H; solve [eauto] end + ]. + Qed. + + Lemma of_prefancy_scalar_correct + (e1 : @cexpr var (type.base (base.type.type_base tZ))) + (e2 : cexpr (type.base (base.type.type_base tZ))) + G (ctx : name -> Z) cc : + valid_scalar e1 -> + LanguageWf.Compilers.expr.wf G e1 e2 -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cc v1 = v2) -> + (forall n, ctx n mod wordmax = ctx n) -> + ctx (of_prefancy_scalar e1) = cinterp e2. + Proof. + intros; match goal with H : context [interp_base _ _ _ = _] |- _ => + pose proof (H (base.type.type_base base.type.Z)); + pose proof (H (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype); cbn [interp_base] in * + end. + eapply of_prefancy_scalar_correct'; eauto; + match goal with + | H : forall _ _, In _ _ -> (_, _) = _ |- _ => + let v1 := fresh "v" in + let v2 := fresh "v" in + intros v1 v2 ?; rewrite <-(H v1 v2) by auto + end; reflexivity. + Qed. + + Lemma of_prefancy_ident_Some {s d} idc r rf x: + @valid_ident (type.base s) (type.base d) r rf idc x -> + of_prefancy_ident idc x <> None. + Proof. + induction s; inversion 1; intros; + repeat first [ + progress subst + | progress inversion_sigma + | progress cbn [eq_rect projT1 projT2 of_prefancy_ident invert_expr.invert_Ident option_map] in * + | progress Z.ltb_to_lt + | progress break_innermost_match + | progress LanguageInversion.Compilers.type.inversion_type + | progress LanguageInversion.Compilers.expr.inversion_expr + | congruence + ]. + Qed. + + Ltac name_eqb_to_eq := + repeat match goal with + | H : name_eqb _ _ = true |- _ => apply name_eqb_eq in H + | H : name_eqb _ _ = false |- _ => apply name_eqb_neq in H + end. + Ltac inversion_of_prefancy_ident := + match goal with + | H : of_prefancy_ident _ _ = None |- _ => + eapply of_prefancy_ident_Some in H; + [ contradiction | eassumption] + end. + + Local Ltac hammer := + repeat first [ + progress subst + | progress inversion_sigma + | progress inversion_option + | progress inversion_of_prefancy_ident + | progress Prod.inversion_prod + | progress cbv [id] + | progress cbn [eq_rect projT1 projT2 expr.interp ident.interp ident.gen_interp interp_base interp invert_expr.invert_Ident interp_if_Z option_map] in * + | progress LanguageInversion.Compilers.type_beq_to_eq + | progress name_eqb_to_eq + | progress LanguageInversion.Compilers.rewrite_type_transport_correct + | progress HProp.eliminate_hprop_eq + | progress break_innermost_match_hyps + | progress break_innermost_match + | progress LanguageInversion.Compilers.type.inversion_type + | progress LanguageInversion.Compilers.expr.inversion_expr + | solve [auto] + | contradiction + ]. + Ltac prove_Ret := + repeat match goal with + | H : valid_scalar (expr.LetIn _ _) |- _ => + inversion H + | _ => progress cbn [id of_prefancy of_prefancy_step of_prefancy_scalar] + | _ => progress hammer + | H : valid_scalar (expr.Ident _) |- _ => + inversion H; clear H + | |- _ = cinterp ?f (cinterp ?x) => + transitivity + (cinterp (f @ x)%expr); + [ | reflexivity ]; + erewrite <-of_prefancy_scalar_correct by (try reflexivity; eassumption) + end. + + Lemma cast_mod u v : + 0 <= u -> + ident.cast cast_oor r[0~>u] v = v mod (u + 1). + Proof. + intros. + rewrite LanguageWf.Compilers.ident.cast_out_of_bounds_simple_0_mod by auto using cast_oor_id. + cbv [cast_oor upper]. apply Z.mod_mod. omega. + Qed. + + Lemma cc_spec_c v : + Z.b2z (cc_spec CC.C v) = (v / wordmax) mod 2. + Proof. cbv [cc_spec]; apply Z.testbit_spec'. omega. Qed. + + Lemma cc_m_zselect x z nz : + x mod wordmax = x -> + (if (if cc_spec CC.M x then 1 else 0) =? 1 then nz else z) = + Z.zselect (x >> 255) z nz. + Proof. + intro Hx_small. + transitivity (if (Z.b2z (cc_spec CC.M x) =? 1) then nz else z); [ reflexivity | ]. + cbv [cc_spec Z.zselect]. + rewrite Z.testbit_spec', Z.shiftr_div_pow2 by omega. rewrite <-Hx_small. + rewrite Div.Z.div_between_0_if by (try replace (2 * (2 ^ 255)) with wordmax by reflexivity; + auto with zarith). + break_innermost_match; Z.ltb_to_lt; try rewrite Z.mod_small in * by omega; congruence. + Qed. + + Lemma cc_l_zselect x z nz : + (if (if cc_spec CC.L x then 1 else 0) =? 1 then nz else z) = Z.zselect (x &' 1) z nz. + Proof. + transitivity (if (Z.b2z (cc_spec CC.L x) =? 1) then nz else z); [ reflexivity | ]. + transitivity (Z.zselect (x &' Z.ones 1) z nz); [ | reflexivity ]. + cbv [cc_spec Z.zselect]. rewrite Z.testbit_spec', Z.land_ones by omega. + autorewrite with zsimplify_fast. rewrite Zmod_even. + break_innermost_match; Z.ltb_to_lt; congruence. + Qed. + + Lemma b2z_range b : 0<= Z.b2z b < 2. + Proof. cbv [Z.b2z]. break_match; lia. Qed. + + + Lemma of_prefancy_scalar_carry + (c : @cexpr var (type.base (base.type.type_base tZ))) + (e : cexpr (type.base (base.type.type_base tZ))) + G (ctx : name -> Z) cctx : + valid_carry c -> + LanguageWf.Compilers.expr.wf G c e -> + (forall n0, consts 0 = Some n0 -> cctx n0 = false) -> + (forall n1, consts 1 = Some n1 -> cctx n1 = true) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + Z.b2z (cctx (of_prefancy_scalar c)) = cinterp e. + Proof. + inversion 1; inversion 1; intros; hammer; cbn; + repeat match goal with + | H : context [ _ = false] |- Z.b2z _ = 0 => rewrite H; reflexivity + | H : context [ _ = true] |- Z.b2z _ = 1 => rewrite H; reflexivity + | _ => progress LanguageWf.Compilers.expr.inversion_wf_one_constr + | _ => progress cbn [fst snd] + | _ => progress destruct_head'_sig + | _ => progress destruct_head'_and + | _ => progress hammer + | _ => progress LanguageInversion.Compilers.expr.invert_subst + | _ => rewrite cast_mod by (cbv; congruence) + | _ => rewrite Z.mod_mod by omega + | _ => rewrite Z.mod_small by apply b2z_range + | H : (forall _ _ _, In _ _ -> interp_base _ _ _ = _), + H' : In (existZZ (?v, _)) _ |- context [cctx (snd ?v)] => + specialize (H _ _ _ H'); cbn in H + end. + Qed. + + Ltac simplify_ident := + repeat match goal with + | _ => progress intros + | _ => progress cbn [fst snd of_prefancy_ident] in * + | _ => progress LanguageWf.Compilers.expr.inversion_wf_one_constr + | H : { _ | _ } |- _ => destruct H + | H : _ /\ _ |- _ => destruct H + | H : upper _ = _ |- _ => rewrite H + | _ => rewrite cc_spec_c by auto + | _ => rewrite cast_mod by (cbv; congruence) + | H : _ |- _ => + apply LanguageInversion.Compilers.expr.invert_Ident_Some in H + | H : _ |- _ => + apply LanguageInversion.Compilers.expr.invert_App_Some in H + | H : ?P, H' : ?P |- _ => clear H' + | _ => progress hammer + end. + + (* TODO: zero flag is a little tricky, since the value + depends both on the stored variable and the carry if there + is one. For now, since Barrett doesn't use it, we're just + pretending it doesn't exist. *) + Definition cc_good cc cctx ctx r := + CC.cc_c cc = cctx (r CC.C) /\ + CC.cc_m cc = cc_spec CC.M (ctx (r CC.M)) /\ + CC.cc_l cc = cc_spec CC.L (ctx (r CC.L)) /\ + (forall n0 : name, consts 0 = Some n0 -> cctx n0 = false) /\ + (forall n1 : name, consts 1 = Some n1 -> cctx n1 = true). + + Lemma of_prefancy_identZ_loosen_correct {s} idc: + forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f u, + @valid_ident (type.base s) (type_base tZ) r rf idc x -> + LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> + LanguageWf.Compilers.expr.wf G #(ident.Z_cast r[0~>u]) f -> + 0 < u < wordmax -> + cc_good cc cctx ctx r -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall n, ctx n mod wordmax = ctx n) -> + of_prefancy_ident idc x = Some i -> + (spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod (u+1)) -> + spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = (cinterp f (cinterp x2)). + Proof. + Time + inversion 1; inversion 1; cbn [of_prefancy_ident]; hammer; (simplify_ident; [ ]). (* TODO : suuuuuper slow *) + all: + rewrite cast_mod by omega; + match goal with + | H : context [spec _ _ _ mod _ = _] |- ?x mod wordmax = _ mod ?m => + replace (x mod wordmax) with (x mod m) by auto + end. + all: cbn - [Z.shiftl wordmax]; cbv [cc_good] in *; destruct_head'_and; + repeat match goal with + | H : CC.cc_c _ = _ |- _ => rewrite H + | H : CC.cc_m _ = _ |- _ => rewrite H + | H : CC.cc_l _ = _ |- _ => rewrite H + | H : CC.cc_z _ = _ |- _ => rewrite H + | H: of_prefancy_scalar _ = ?r ?c |- _ => rewrite <-H + | _ => progress rewrite ?cc_m_zselect, ?cc_l_zselect by auto + | _ => progress rewrite ?Z.add_modulo_correct, ?Z.geb_leb by auto + | |- context [cinterp ?x] => + erewrite of_prefancy_scalar_correct with (e2:=x) by eauto + | |- context [cinterp ?x] => + erewrite <-of_prefancy_scalar_carry with (e:=x) by eauto + | |- context [if _ (of_prefancy_scalar _) then _ else _ ] => + cbv [Z.zselect Z.b2z]; + break_innermost_match; Z.ltb_to_lt; try reflexivity; + congruence + end; try reflexivity. + + { (* RSHI case *) + cbv [Z.rshi]. + rewrite Z.land_ones, Z.shiftl_mul_pow2 by (cbv; congruence). + change (2 ^ Z.log2 wordmax) with wordmax. + break_innermost_match; try congruence; [ ]. autorewrite with zsimplify_fast. + repeat (f_equal; try ring). } + Qed. + Lemma of_prefancy_identZ_correct {s} idc: + forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, + @valid_ident (type.base s) (type_base tZ) r rf idc x -> + LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> + LanguageWf.Compilers.expr.wf G #(ident.Z_cast uint256) f -> + cc_good cc cctx ctx r -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall n, ctx n mod wordmax = ctx n) -> + of_prefancy_ident idc x = Some i -> + spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = (cinterp f (cinterp x2)). + Proof. + intros; eapply of_prefancy_identZ_loosen_correct; try eassumption; [ | ]. + { cbn; omega. } { intros; f_equal; ring. } + Qed. + Lemma of_prefancy_identZZ_correct' {s} idc: + forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, + @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> + LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> + LanguageWf.Compilers.expr.wf G #(ident.Z_cast2 (uint256, r[0~>1])) f -> + cc_good cc cctx ctx r -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall n, ctx n mod wordmax = ctx n) -> + of_prefancy_ident idc x = Some i -> + spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = fst (cinterp f (cinterp x2)) /\ + Z.b2z (cc_spec CC.C (spec (projT1 i) (Tuple.map ctx (projT2 i)) cc)) = snd (cinterp f (cinterp x2)). + Proof. + inversion 1; inversion 1; cbn [of_prefancy_ident]; intros; hammer; (simplify_ident; [ ]); + cbn - [Z.div Z.modulo]; cbv [Z.sub_with_borrow Z.add_with_carry]; + cbv [cc_good] in *; destruct_head'_and; autorewrite with zsimplify_fast. + all: repeat match goal with + | H : CC.cc_c _ = _ |- _ => rewrite H + | H: of_prefancy_scalar _ = ?r ?c |- _ => rewrite <-H + | H : LanguageWf.Compilers.expr.wf _ ?x ?e |- context [cinterp ?e] => + erewrite <-of_prefancy_scalar_correct with (e1:=x) (e2:=e) by eauto + | H : LanguageWf.Compilers.expr.wf _ ?x ?e2 |- context [cinterp ?e2] => + erewrite <-of_prefancy_scalar_carry with (c:=x) (e:=e2) by eauto + end. + all: match goal with |- context [(?x << ?n) mod ?m] => + pose proof (Z.mod_pos_bound (x << n) m ltac:(omega)) end. + all:repeat match goal with + | |- context [if _ (of_prefancy_scalar _) then _ else _ ] => + cbv [Z.zselect Z.b2z]; break_innermost_match; Z.ltb_to_lt; try congruence; [ | ] + | _ => rewrite Z.add_opp_r + | _ => rewrite Div.Z.div_sub_small by auto with zarith + | H : forall n, ?ctx n mod wordmax = ?ctx n |- context [?ctx ?m - _] => rewrite <-(H m) + | |- ((?x - ?y - ?c) / _) mod _ = - ((- ?c + ?x - ?y) / _) mod _ => + replace (-c + x - y) with (x - (y + c)) by ring; replace (x - y - c) with (x - (y + c)) by ring + | _ => split + | _ => try apply (f_equal2 Z.modulo); try apply (f_equal2 Z.div); ring + | _ => break_innermost_match; reflexivity + end. + Qed. + Lemma of_prefancy_identZZ_correct {s} idc: + forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, + @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> + LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> + LanguageWf.Compilers.expr.wf G #(ident.Z_cast2 (uint256, r[0~>1])) f -> + cc_good cc cctx ctx r -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall n, ctx n mod wordmax = ctx n) -> + of_prefancy_ident idc x = Some i -> + spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = fst (cinterp f (cinterp x2)). + Proof. apply of_prefancy_identZZ_correct'. Qed. + Lemma of_prefancy_identZZ_correct_carry {s} idc: + forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, + @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> + LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> + LanguageWf.Compilers.expr.wf G #(ident.Z_cast2 (uint256, r[0~>1])) f -> + cc_good cc cctx ctx r -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall n, ctx n mod wordmax = ctx n) -> + of_prefancy_ident idc x = Some i -> + Z.b2z (cc_spec CC.C (spec (projT1 i) (Tuple.map ctx (projT2 i)) cc)) = snd (cinterp f (cinterp x2)). + Proof. apply of_prefancy_identZZ_correct'. Qed. + + Lemma identZZ_writes {s} idc r rf x: + @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> + forall i, of_prefancy_ident idc x = Some i -> + In CC.C (writes_conditions (projT1 i)). + Proof. + inversion 1; + repeat match goal with + | _ => progress intros + | _ => progress cbn [of_prefancy_ident writes_conditions ADD ADDC SUB SUBC In] in * + | _ => progress hammer; Z.ltb_to_lt + | _ => congruence + end. + Qed. + + (* Common side conditions for cases in of_prefancy_correct *) + Local Ltac side_cond := + repeat match goal with + | _ => progress intros + | _ => progress cbn [In fst snd] in * + | H : _ \/ _ |- _ => destruct H + | [H : forall _ _, In _ ?l -> _, H' : In _ ?l |- _] => + let H'' := fresh in + pose proof H'; apply H in H''; clear H + | H : name_lt ?n ?n |- _ => + specialize (name_lt_irr n); contradiction + | _ => progress hammer + | _ => solve [eauto] + end. + + Lemma interp_base_helper G next_name ctx cctx : + (forall n v2, In (existZ (n, v2)) G -> name_lt n next_name) -> + (forall n v2, In (existZZ (n, v2)) G -> name_lt (fst n) next_name) -> + (forall n v2, In (existZZ (n, v2)) G -> fst n = snd n) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> + t = base.type.type_base tZ + \/ t = (base.type.type_base tZ * base.type.type_base tZ)%etype) -> + forall t v1 v2 x xc, + In (existT (fun t : type => (var t * type.interp base.interp t)%type) (type.base t) (v1, v2)%zrange) + ((existZ (next_name, x)%zrange) :: G) -> + interp_base (fun n : name => if name_eqb n next_name then x else ctx n) + (fun n : name => if name_eqb n next_name then xc else cctx n) v1 = v2. + Proof. + intros. + repeat match goal with + | H: In _ (_ :: _) |- _ => cbn [In] in H; destruct H; [ solve [side_cond] | ] + | H : (forall t _ _, In _ ?G -> (t = _ \/ t = _)), H' : In _ ?G |- _ => + destruct (H _ _ _ H'); subst t + | H : forall _ _ _, In _ ?G -> interp_base _ _ _ = _, H' : In _ G |- _ => specialize (H _ _ _ H') + end; side_cond. + Qed. + + Lemma name_eqb_refl n : name_eqb n n = true. + Proof. case_eq (name_eqb n n); intros; name_eqb_to_eq; auto. Qed. + + Lemma valid_ident_new_cc_to_name s d r rf idc x y n : + @valid_ident (type.base s) (type.base d) r rf idc x -> + of_prefancy_ident idc x = Some y -> + rf n = new_cc_to_name r (projT1 y) n. + Proof. inversion 1; intros; hammer; simplify_ident. Qed. + + Lemma new_cc_to_name_Z_cases r i n x : + new_cc_to_name (d:=base.type.type_base tZ) r i n x + = if in_dec CC.code_dec x (writes_conditions i) + then n else r x. + Proof. reflexivity. Qed. + Lemma new_cc_to_name_ZZ_cases r i n x : + new_cc_to_name (d:=base.type.type_base tZ * base.type.type_base tZ) r i n x + = if in_dec CC.code_dec x (writes_conditions i) + then fst n else r x. + Proof. reflexivity. Qed. + + Lemma cc_good_helper cc cctx ctx r i x next_name : + (forall c, name_lt (r c) next_name) -> + (forall n v, consts v = Some n -> name_lt n next_name) -> + cc_good cc cctx ctx r -> + cc_good (CC.update (writes_conditions i) x cc_spec cc) + (fun n : name => + if name_eqb n next_name + then CC.cc_c (CC.update (writes_conditions i) x cc_spec cc) + else cctx n) + (fun n : name => if name_eqb n next_name then x mod wordmax else ctx n) + (new_cc_to_name (d:=base.type.type_base tZ) r i next_name). + Proof. + cbv [cc_good]; intros; destruct_head'_and. + rewrite !new_cc_to_name_Z_cases. + cbv [CC.update CC.cc_c CC.cc_m CC.cc_l CC.cc_z]. + repeat match goal with + | _ => split; intros + | _ => progress hammer + | H : forall c, name_lt (r c) (r ?c2) |- _ => specialize (H c2) + | H : (forall n v, consts v = Some n -> name_lt _ _), + H' : consts _ = Some _ |- _ => specialize (H _ _ H') + | H : name_lt ?n ?n |- _ => apply name_lt_irr in H; contradiction + | _ => cbv [cc_spec]; rewrite Z.mod_pow2_bits_low by omega + | _ => congruence + end. + Qed. + + Lemma of_prefancy_correct + {t} (e1 : @cexpr var t) (e2 : @cexpr _ t) r : + valid_expr _ r e1 -> + forall G, + LanguageWf.Compilers.expr.wf G e1 e2 -> + forall ctx cc cctx, + cc_good cc cctx ctx r -> + (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> + (forall n v2, In (existZZ (n, v2)) G -> fst n = snd n) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> + (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> + t = base.type.type_base tZ + \/ t = (base.type.type_base tZ * base.type.type_base tZ)%etype) -> + (forall n, ctx n mod wordmax = ctx n) -> + forall next_name result, + (forall c : CC.code, name_lt (r c) next_name) -> + (forall n v2, In (existZ (n, v2)) G -> name_lt n next_name) -> + (forall n v2, In (existZZ (n, v2)) G -> name_lt (fst n) next_name) -> + (interp_if_Z e2 = Some result) -> + interp (@of_prefancy next_name t e1) cc ctx = result. + Proof. + induction 1; inversion 1; cbv [interp_if_Z]; + cbn [of_prefancy of_prefancy_step]; intros; + match goal with H : context [interp_base _ _ _ = _] |- _ => + pose proof (H (base.type.type_base base.type.Z)) end; + try solve [prove_Ret]; [ | | ]; hammer; + match goal with + | H : context [interp (of_prefancy _ _) _ _ = _] + |- interp _ ?cc' ?ctx' = _ => + match goal with + | _ : context [LetInAppIdentZ _ _ _ _ _ _] |- _=> + erewrite H with + (G := (existZ (next_name, ctx' next_name)) :: G) + (e2 := _ (ctx' next_name)) + (cctx := (fun n => if name_eqb n next_name then CC.cc_c cc' else cctx n)) + | _ : context [LetInAppIdentZZ _ _ _ _ _ _] |- _=> + erewrite H with + (G := (existZZ ((next_name, next_name), (ctx' next_name, Z.b2z (CC.cc_c cc')))) :: G) + (e2 := _ (ctx' next_name, Z.b2z (CC.cc_c cc'))) + (cctx := (fun n => if name_eqb n next_name then CC.cc_c cc' else cctx n)) + end + end; + repeat match goal with + | _ => progress intros + | _ => rewrite name_eqb_refl in * + | _ => rewrite Z.testbit_spec' in * + | _ => erewrite valid_ident_new_cc_to_name by eassumption + | _ => rewrite new_cc_to_name_Z_cases + | _ => rewrite new_cc_to_name_ZZ_cases + | _ => solve [intros; eapply interp_base_helper; side_cond] + | _ => solve [intros; apply cc_good_helper; eauto] + | _ => reflexivity + | _ => solve [eauto using Z.mod_small, b2z_range] + | _ => progress autorewrite with zsimplify_fast + | _ => progress side_cond + end; [ | | ]. + { cbn - [cc_spec]; cbv [id]; cbn - [cc_spec]. + inversion wf_x; hammer. + erewrite of_prefancy_identZ_loosen_correct by eauto. + reflexivity. } + { cbn - [cc_spec]; cbv [id]; cbn - [cc_spec]. + inversion wf_x; hammer. + erewrite of_prefancy_identZ_correct by eassumption. + reflexivity. } + { cbn - [cc_spec]; cbv [id]; cbn - [cc_spec]. + match goal with H : _ |- _ => pose proof H; eapply identZZ_writes in H; [ | eassumption] end. + inversion wf_x; hammer. + erewrite of_prefancy_identZZ_correct by eassumption. + erewrite of_prefancy_identZZ_correct_carry by eassumption. + rewrite <-surjective_pairing. reflexivity. } + Qed. + End Proofs. +End of_prefancy. + +Section allocate_registers. + Context (reg name : Type) (name_eqb : name -> name -> bool) (error : reg). + Fixpoint allocate (e : @expr name) (reg_list : list reg) (name_to_reg : name -> reg) : @expr reg := + match e with + | Ret n => Ret (name_to_reg n) + | Instr i rd args cont => + match reg_list with + | r :: reg_list' => Instr i r (Tuple.map name_to_reg args) (allocate cont reg_list' (fun n => if name_eqb n rd then r else name_to_reg n)) + | nil => Ret error + end + end. +End allocate_registers. + +Definition test_prog : @expr positive := + Instr (ADD (128)) 3%positive (1, 2)%positive + (Instr (ADDC 0) 4%positive (3,1)%positive + (Ret 4%positive)). + +Definition x1 := 2^256 - 1. +Definition x2 := 2^128 - 1. +Definition wordmax := 2^256. +Definition expected := + let r3' := (x1 + (x2 << 128)) in + let r3 := r3' mod wordmax in + let c := r3' / wordmax in + let r4' := (r3 + x1 + c) in + r4' mod wordmax. +Definition actual := + interp Pos.eqb + (2^256) cc_spec test_prog {|CC.cc_c:=false; CC.cc_m:=false; CC.cc_l:=false; CC.cc_z:=false|} + (fun n => if n =? 1%positive + then x1 + else if n =? 2%positive + then x2 + else 0). +Lemma test_prog_ok : expected = actual. +Proof. reflexivity. Qed. + +Definition of_Expr {t} next_name (consts : Z -> option positive) + (e : expr.Expr t) + (x : type.for_each_lhs_of_arrow (var positive) t) + : positive -> @expr positive := + fun error => + @of_prefancy positive Pos.succ error consts next_name _ (invert_expr.smart_App_curried (e _) x). + +Section Proofs. + Fixpoint var_pairs {t var1 var2} + : type.for_each_lhs_of_arrow var1 t + -> type.for_each_lhs_of_arrow var2 t + -> list {t : Compilers.type base.type.type & (var1 t * var2 t)%type } := + match t as t0 return + (type.for_each_lhs_of_arrow var1 t0 + -> type.for_each_lhs_of_arrow var2 t0 -> _) with + | type.base _ => fun _ _ => nil + | (s -> d)%ptype => + fun x1 x2 => + existT _ _ (fst x1, fst x2) :: var_pairs (snd x1) (snd x2) + end. + + Local Notation existZ := (existT _ (type.base (base.type.type_base base.type.Z))). + Local Notation existZZ := (existT _ (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype)). + + Fixpoint make_ctx (var_list : list (positive * Z)) : positive -> Z := + match var_list with + | [] => fun _ => 0 + | (n, v) :: l' => fun m => if (m =? n)%positive then v else make_ctx l' m + end. + + Definition make_pairs : + list (positive * Z) -> list {t : Compilers.type base.type.type & (var positive t * @type.interp base.type base.interp t)%type } := map (fun x => existZ x). + + Fixpoint make_consts (consts_list : list (positive * Z)) : Z -> option positive := + match consts_list with + | [] => fun _ => None + | (n, v) :: l' => fun x => if x =? v then Some n else make_consts l' x + end. + + Local Ltac ez := + repeat match goal with + | _ => progress intros + | _ => progress subst + | H : _ \/ _ |- _ => destruct H + | H : _ |- _ => rewrite Pos.eqb_eq in H + | H : _ |- _ => rewrite Pos.eqb_neq in H + | _ => progress break_innermost_match + | _ => progress break_match_hyps + | _ => progress inversion_sigma + | _ => progress inversion_option + | _ => progress Prod.inversion_prod + | _ => progress HProp.eliminate_hprop_eq + | _ => progress Z.ltb_to_lt + | _ => reflexivity + | _ => congruence + | _ => solve [eauto] + end. + + + Lemma make_consts_ok consts_list n v : + make_consts consts_list v = Some n -> + In (existZ (n, v)%zrange) (make_pairs consts_list). + Proof. + cbv [make_pairs]; induction consts_list as [|[ ? ? ] ?]; cbn; ez. + Qed. + + Lemma make_pairs_ok consts_list: + forall v1 v2, + In (existZ (v1, v2)%zrange) (make_pairs consts_list) -> + In (v1, v2) consts_list. + Proof. + cbv [make_pairs]. induction consts_list as [| [ n v ] ? ]; cbn; [ tauto | ]. ez. + Qed. + Lemma make_ctx_ok consts_list: + (forall n v1 v2, In (n, v1) consts_list -> + In (n, v2) consts_list -> v1 = v2) -> + forall n v, + In (n, v) consts_list -> + make_ctx consts_list n = v. + Proof. + induction consts_list as [| [ n v ] ? ]; cbn; [ tauto | ]. + repeat match goal with + | _ => progress cbn [eq_rect fst snd] in * + | _ => progress ez + end. + Qed. + + Lemma make_ctx_cases consts_list n : + make_ctx consts_list n = 0 \/ + In (n, make_ctx consts_list n) consts_list. + Proof. induction consts_list; cbn; ez. Qed. + + Lemma only_integers consts_list t v1 v2 : + In (existT (fun t : type => (var positive t * type.interp base.interp t)%type) (type.base t) + (v1, v2)%zrange) (make_pairs consts_list) -> + t = base.type.type_base base.type.Z. + Proof. + induction consts_list; cbn; [ tauto | ]. + destruct 1; congruence || tauto. + Qed. + + Lemma no_pairs consts_list v1 v2 : + In (existZZ (v1, v2)%zrange) (make_pairs consts_list) -> False. + Proof. intro H; apply only_integers in H. congruence. Qed. + + + Definition make_cc last_wrote ctx carry_flag : CC.state := + {| CC.cc_c := carry_flag; + CC.cc_m := cc_spec CC.M (ctx (last_wrote CC.M)); + CC.cc_l := cc_spec CC.L (ctx (last_wrote CC.L)); + CC.cc_z := cc_spec CC.Z (ctx (last_wrote CC.Z) + + (if (last_wrote CC.C =? last_wrote CC.Z)%positive + then wordmax * Z.b2z carry_flag else 0)); + |}. + + + Hint Resolve Pos.lt_trans Pos.lt_irrefl Pos.lt_succ_diag_r Pos.eqb_refl. + Hint Resolve in_or_app. + Hint Resolve make_consts_ok make_pairs_ok make_ctx_ok no_pairs. + (* TODO : probably not all of these preconditions are necessary -- prune them sometime *) + Lemma of_Expr_correct next_name consts_list arg_list error + (carry_flag : bool) + (last_wrote : CC.code -> positive) (* variables which last wrote to each flag; put RegZero if flag empty *) + t (e : Expr t) + (x1 : type.for_each_lhs_of_arrow (var positive) t) + (x2 : type.for_each_lhs_of_arrow _ t) result : + let e1 := (invert_expr.smart_App_curried (e _) x1) in + let e2 := (invert_expr.smart_App_curried (e _) x2) in + let ctx := make_ctx (consts_list ++ arg_list) in + let consts := make_consts consts_list in + let cc := make_cc last_wrote ctx carry_flag in + let G := make_pairs consts_list ++ make_pairs arg_list in + (forall c, last_wrote c < next_name)%positive -> + (forall n v, In (n, v) (consts_list ++ arg_list) -> (n < next_name)%positive) -> + (In (last_wrote CC.C, Z.b2z carry_flag) consts_list) -> + (forall n v1 v2, In (n, v1) (consts_list ++ arg_list) -> + In (n, v2) (consts_list ++ arg_list) -> v1 = v2) (* no duplicate names *) -> + (forall v1 v2, In (v1, v2) consts_list -> v2 mod 2 ^ 256 = v2) -> + (forall v1 v2, In (v1, v2) arg_list -> v2 mod 2 ^ 256 = v2) -> + (LanguageWf.Compilers.expr.wf G e1 e2) -> + valid_expr _ error consts _ last_wrote e1 -> + interp_if_Z e2 = Some result -> + interp Pos.eqb wordmax cc_spec (of_Expr next_name consts e x1 error) cc ctx = result. + Proof. + cbv [of_Expr]; intros. + eapply of_prefancy_correct with (name_lt := Pos.lt) + (cctx := fun n => if (n =? last_wrote CC.C)%positive + then carry_flag + else match make_consts consts_list 1 with + | Some n1 => (n =? n1)%positive + | _ => false + end); + cbv [id]; eauto; + try apply Pos.eqb_neq; intros; + try solve [apply make_ctx_ok; auto; apply make_pairs_ok; + cbv [make_pairs]; rewrite map_app; auto ]; + repeat match goal with + | H : _ |- _ => apply in_app_or in H; destruct H + | H : In _ (make_pairs _) |- context [ _ = base.type.type_base _] => apply only_integers in H + | H : In _ (make_pairs _) |- context [interp_base] => + pose proof (only_integers _ _ _ _ H); subst; cbn [interp_base] + | _ => solve [eauto] + | _ => solve [exfalso; eauto] + end. + (* TODO : clean this up *) + { cbv [cc_good make_cc]; repeat split; intros; + [ rewrite Pos.eqb_refl; reflexivity | | ]; + break_innermost_match; try rewrite Pos.eqb_eq in *; subst; try reflexivity; + repeat match goal with + | H : make_consts _ _ = Some _ |- _ => + apply make_consts_ok, make_pairs_ok in H + | _ => apply Pos.eqb_neq; intro; subst + | _ => inversion_option; congruence + end; + match goal with + | H : In (?n, ?x) consts_list, H': In (?n, ?y) consts_list, + H'' : forall n x y, In (n,x) _ -> In (n,y) _ -> x = y |- _ => + assert (x = y) by (eapply H''; eauto) + end; destruct carry_flag; cbn [Z.b2z] in *; congruence. } + { match goal with |- context [make_ctx ?l ?n] => + let H := fresh in + destruct (make_ctx_cases l n) as [H | H]; + [ rewrite H | apply in_app_or in H; destruct H ] + end; eauto. } + Qed. +End Proofs. diff --git a/src/Toplevel2.v b/src/Toplevel2.v index 592915a2b..4398b8aba 100644 --- a/src/Toplevel2.v +++ b/src/Toplevel2.v @@ -1,3 +1,4 @@ +(* TODO: prune all these dependencies *) Require Import Coq.ZArith.ZArith Coq.micromega.Lia. Require Import Coq.derive.Derive. Require Import Coq.Bool.Bool. @@ -83,3173 +84,8 @@ Local Coercion QArith_base.inject_Z : Z >-> Q. Import UnsaturatedSolinas. -(* TODO: Figure out what examples should go here *) -(* -Module X25519_64. - Definition n := 5%nat. - Definition s := 2^255. - Definition c := [(1, 19)]. - Definition machine_wordsize := 64. - Local Notation tight_bounds := (tight_bounds n s c). - Local Notation loose_bounds := (loose_bounds n s c). - Local Notation prime_bound := (prime_bound s c). - - Derive base_51_relax - SuchThat (rrelax_correctT n s c machine_wordsize base_51_relax) - As base_51_relax_correct. - Proof. Time solve_rrelax machine_wordsize. Time Qed. - Derive base_51_carry_mul - SuchThat (rcarry_mul_correctT n s c machine_wordsize base_51_carry_mul) - As base_51_carry_mul_correct. - Proof. Time solve_rcarry_mul machine_wordsize. Time Qed. - Derive base_51_carry - SuchThat (rcarry_correctT n s c machine_wordsize base_51_carry) - As base_51_carry_correct. - Proof. Time solve_rcarry machine_wordsize. Time Qed. - Derive base_51_add - SuchThat (radd_correctT n s c machine_wordsize base_51_add) - As base_51_add_correct. - Proof. Time solve_radd machine_wordsize. Time Qed. - Derive base_51_sub - SuchThat (rsub_correctT n s c machine_wordsize base_51_sub) - As base_51_sub_correct. - Proof. Time solve_rsub machine_wordsize. Time Qed. - Derive base_51_opp - SuchThat (ropp_correctT n s c machine_wordsize base_51_opp) - As base_51_opp_correct. - Proof. Time solve_ropp machine_wordsize. Time Qed. - Derive base_51_to_bytes - SuchThat (rto_bytes_correctT n s c machine_wordsize base_51_to_bytes) - As base_51_to_bytes_correct. - Proof. Time solve_rto_bytes machine_wordsize. Time Qed. - Derive base_51_from_bytes - SuchThat (rfrom_bytes_correctT n s c machine_wordsize base_51_from_bytes) - As base_51_from_bytes_correct. - Proof. Time solve_rfrom_bytes machine_wordsize. Time Qed. - Derive base_51_encode - SuchThat (rencode_correctT n s c machine_wordsize base_51_encode) - As base_51_encode_correct. - Proof. Time solve_rencode machine_wordsize. Time Qed. - Derive base_51_zero - SuchThat (rzero_correctT n s c machine_wordsize base_51_zero) - As base_51_zero_correct. - Proof. Time solve_rzero machine_wordsize. Time Qed. - Derive base_51_one - SuchThat (rone_correctT n s c machine_wordsize base_51_one) - As base_51_one_correct. - Proof. Time solve_rone machine_wordsize. Time Qed. - Lemma base_51_curve_good - : check_args n s c machine_wordsize (Success tt) = Success tt. - Proof. vm_compute; reflexivity. Qed. - - Definition base_51_good : GoodT n s c machine_wordsize - := Good n s c machine_wordsize - base_51_curve_good - base_51_carry_mul_correct - base_51_carry_correct - base_51_relax_correct - base_51_add_correct - base_51_sub_correct - base_51_opp_correct - base_51_zero_correct - base_51_one_correct - base_51_encode_correct - base_51_to_bytes_correct - base_51_from_bytes_correct. - - Print Assumptions base_51_good. - Import PrintingNotations. - Set Printing Width 80. - Open Scope string_scope. - Local Notation prime_bytes_bounds := (prime_bytes_bounds n s c). - Print base_51_to_bytes. - Print base_51_carry_mul. -(*base_51_carry_mul = -fun var : type -> Type => -(λ x x0 : var (type.base (base.type.list (base.type.type_base base.type.Z))), - expr_let x1 := (uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ - ((uint64)(x[[1]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19) +₁₂₈ - ((uint64)(x[[2]]) *₁₂₈ ((uint64)(x0[[3]]) *₆₄ 19) +₁₂₈ - ((uint64)(x[[3]]) *₁₂₈ ((uint64)(x0[[2]]) *₆₄ 19) +₁₂₈ - (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[1]]) *₆₄ 19)))) in - expr_let x2 := (uint64)(x1 >> 51) +₁₂₈ - ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ - ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ - ((uint64)(x[[2]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19) +₁₂₈ - ((uint64)(x[[3]]) *₁₂₈ ((uint64)(x0[[3]]) *₆₄ 19) +₁₂₈ - (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[2]]) *₆₄ 19))))) in - expr_let x3 := (uint64)(x2 >> 51) +₁₂₈ - ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[2]]) +₁₂₈ - ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ - ((uint64)(x[[2]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ - ((uint64)(x[[3]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19) +₁₂₈ - (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[3]]) *₆₄ 19))))) in - expr_let x4 := (uint64)(x3 >> 51) +₁₂₈ - ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[3]]) +₁₂₈ - ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[2]]) +₁₂₈ - ((uint64)(x[[2]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ - ((uint64)(x[[3]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ - (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19))))) in - expr_let x5 := (uint64)(x4 >> 51) +₁₂₈ - ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[4]]) +₁₂₈ - ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[3]]) +₁₂₈ - ((uint64)(x[[2]]) *₁₂₈ (uint64)(x0[[2]]) +₁₂₈ - ((uint64)(x[[3]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ - (uint64)(x[[4]]) *₁₂₈ (uint64)(x0[[0]]))))) in - expr_let x6 := ((uint64)(x1) & 2251799813685247) +₆₄ (uint64)(x5 >> 51) *₆₄ 19 in - expr_let x7 := (uint64)(x6 >> 51) +₆₄ ((uint64)(x2) & 2251799813685247) in - expr_let x8 := ((uint64)(x6) & 2251799813685247) in - expr_let x9 := ((uint64)(x7) & 2251799813685247) in - expr_let x10 := (uint64)(x7 >> 51) +₆₄ ((uint64)(x3) & 2251799813685247) in - expr_let x11 := ((uint64)(x4) & 2251799813685247) in - expr_let x12 := ((uint64)(x5) & 2251799813685247) in - [x8; x9; x10; x11; x12])%expr - : Expr - (type.base (base.type.list (base.type.type_base base.type.Z)) -> - type.base (base.type.list (base.type.type_base base.type.Z)) -> - type.base (base.type.list (base.type.type_base base.type.Z)))%ptype -*) - Print base_51_sub. - (* -base_51_sub = -fun var : type -> Type => -(λ x x0 : var (type.base (base.type.list (base.type.type_base base.type.Z))), - expr_let x1 := (4503599627370458 +₆₄ (uint64)(x[[0]])) -₆₄ (uint64)(x0[[0]]) in - expr_let x2 := (4503599627370494 +₆₄ (uint64)(x[[1]])) -₆₄ (uint64)(x0[[1]]) in - expr_let x3 := (4503599627370494 +₆₄ (uint64)(x[[2]])) -₆₄ (uint64)(x0[[2]]) in - expr_let x4 := (4503599627370494 +₆₄ (uint64)(x[[3]])) -₆₄ (uint64)(x0[[3]]) in - expr_let x5 := (4503599627370494 +₆₄ (uint64)(x[[4]])) -₆₄ (uint64)(x0[[4]]) in - [x1; x2; x3; x4; x5])%expr - : Expr - (type.base (base.type.list (base.type.type_base base.type.Z)) -> - type.base (base.type.list (base.type.type_base base.type.Z)) -> - type.base (base.type.list (base.type.type_base base.type.Z)))%ptype -*) - - Compute ToString.C.ToFunctionString - true true "" "fecarry_mul" [] base_51_carry_mul - None (Some loose_bounds, (Some loose_bounds, tt)). - (* -void fecarry_mul(uint64_t[5] x1, uint64_t[5] x2, uint64_t[5] x3) { - uint128_t x4 = (((uint128_t)(x1[0]) * (x2[0])) + (((uint128_t)(x1[1]) * ((x2[4]) * 0x13)) + (((uint128_t)(x1[2]) * ((x2[3]) * 0x13)) + (((uint128_t)(x1[3]) * ((x2[2]) * 0x13)) + ((uint128_t)(x1[4]) * ((x2[1]) * 0x13)))))); - uint128_t x5 = ((uint64_t)(x4 >> 51) + (((uint128_t)(x1[0]) * (x2[1])) + (((uint128_t)(x1[1]) * (x2[0])) + (((uint128_t)(x1[2]) * ((x2[4]) * 0x13)) + (((uint128_t)(x1[3]) * ((x2[3]) * 0x13)) + ((uint128_t)(x1[4]) * ((x2[2]) * 0x13))))))); - uint128_t x6 = ((uint64_t)(x5 >> 51) + (((uint128_t)(x1[0]) * (x2[2])) + (((uint128_t)(x1[1]) * (x2[1])) + (((uint128_t)(x1[2]) * (x2[0])) + (((uint128_t)(x1[3]) * ((x2[4]) * 0x13)) + ((uint128_t)(x1[4]) * ((x2[3]) * 0x13))))))); - uint128_t x7 = ((uint64_t)(x6 >> 51) + (((uint128_t)(x1[0]) * (x2[3])) + (((uint128_t)(x1[1]) * (x2[2])) + (((uint128_t)(x1[2]) * (x2[1])) + (((uint128_t)(x1[3]) * (x2[0])) + ((uint128_t)(x1[4]) * ((x2[4]) * 0x13))))))); - uint128_t x8 = ((uint64_t)(x7 >> 51) + (((uint128_t)(x1[0]) * (x2[4])) + (((uint128_t)(x1[1]) * (x2[3])) + (((uint128_t)(x1[2]) * (x2[2])) + (((uint128_t)(x1[3]) * (x2[1])) + ((uint128_t)(x1[4]) * (x2[0]))))))); - uint64_t x9 = ((uint64_t)(x4 & 0x7ffffffffffffUL) + ((uint64_t)(x8 >> 51) * 0x13)); - uint64_t x10 = ((x9 >> 51) + (uint64_t)(x5 & 0x7ffffffffffffUL)); - x3[0] = (x9 & 0x7ffffffffffffUL); - x3[1] = (x10 & 0x7ffffffffffffUL); - x3[2] = ((x10 >> 51) + (uint64_t)(x6 & 0x7ffffffffffffUL)); - x3[3] = (uint64_t)(x7 & 0x7ffffffffffffUL); - x3[4] = (uint64_t)(x8 & 0x7ffffffffffffUL); -} - *) - Compute ToString.C.ToFunctionString - true true "" "fesub" [] base_51_sub - None (Some tight_bounds, (Some tight_bounds, tt)). -(* -void fesub(uint64_t[5] x1, uint64_t[5] x2, uint64_t[5] x3) { - x3[0] = ((0xfffffffffffdaUL + (x1[0])) - (x2[0])); - x3[1] = ((0xffffffffffffeUL + (x1[1])) - (x2[1])); - x3[2] = ((0xffffffffffffeUL + (x1[2])) - (x2[2])); - x3[3] = ((0xffffffffffffeUL + (x1[3])) - (x2[3])); - x3[4] = ((0xffffffffffffeUL + (x1[4])) - (x2[4])); -} -*) -End X25519_64. - -Module P224_64. - Definition s := 2^224. - Definition c := [(2^96, 1); (1,-1)]. - Definition machine_wordsize := 128. - - Derive mulmod - SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod) - As mulmod_correct. - Proof. Time solve_rmulmod machine_wordsize. Time Qed. - - Import PrintingNotations. - Open Scope expr_scope. - Set Printing Width 100000. - Set Printing Depth 100000. - - Local Notation "'mul128' '(' x ',' y ')'" := - (#(Z_cast2 (uint128, _)%core) @ (#Z_mul_split @ #(ident.Literal (t:=base.type.Z) 340282366920938463463374607431768211456) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'add128' '(' x ',' y ')'" := - (#(Z_cast2 (uint128, bool)%core) @ (#Z_add_get_carry @ #(ident.Literal (t:=base.type.Z) 340282366920938463463374607431768211456) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'adc128' '(' c ',' x ',' y ')'" := - (#(Z_cast2 (uint128, bool)%core) @ (#Z_add_with_get_carry @ #(ident.Literal (t:=base.type.Z) 340282366920938463463374607431768211456) @ c @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'sub128' '(' x ',' y ')'" := - (#(Z_cast2 (uint128, bool)%core) @ (#Z_sub_get_borrow @ #(ident.Literal (t:=base.type.Z) 340282366920938463463374607431768211456) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'sbb128' '(' c ',' x ',' y ')'" := - (#(Z_cast2 (uint128, bool)%core) @ (#Z_sub_with_get_borrow @ #(ident.Literal (t:=base.type.Z) 340282366920938463463374607431768211456) @ c @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'mul64' '(' x ',' y ')'" := - (#(Z_cast2 (uint64, _)%core) @ (#Z_mul_split @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'add64' '(' x ',' y ')'" := - (#(Z_cast2 (uint64, bool)%core) @ (#Z_add_get_carry @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'adc64' '(' c ',' x ',' y ')'" := - (#(Z_cast2 (uint64, bool)%core) @ (#Z_add_with_get_carry @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ c @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'adx64' '(' c ',' x ',' y ')'" := - (#(Z_cast bool) @ (#Z_add_with_carry @ c @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'sub64' '(' x ',' y ')'" := - (#(Z_cast2 (uint64, bool)%core) @ (#Z_sub_get_borrow @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'sbb64' '(' c ',' x ',' y ')'" := - (#(Z_cast2 (uint64, bool)%core) @ (#Z_sub_with_get_borrow @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ c @ x @ y))%expr (at level 50) : expr_scope. - Set Printing Width 1000000. - Print mulmod. -End P224_64. - -Module P192_64. - Definition s := 2^192. - Definition c := [(2^64, 1); (1,1)]. - Definition machine_wordsize := 64. - - Derive mulmod - SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod) - As mulmod_correct. - Proof. Time solve_rmulmod machine_wordsize. Time Qed. - - Import PrintingNotations. - Open Scope expr_scope. - Set Printing Width 100000. - Set Printing Depth 100000. - - Local Notation "'mul64' '(' x ',' y ')'" := - (#(Z_cast2 (uint64, _)%core) @ (#Z_mul_split @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'add64' '(' x ',' y ')'" := - (#(Z_cast2 (uint64, bool)%core) @ (#Z_add_get_carry @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'adc64' '(' c ',' x ',' y ')'" := - (#(Z_cast2 (uint64, bool)%core) @ (#Z_add_with_get_carry @ #(ident.Literal (t:=base.type.Z) 18446744073709551616) @ c @ x @ y))%expr (at level 50) : expr_scope. - Local Notation "'adx64' '(' c ',' x ',' y ')'" := - (#(Z_cast bool) @ (#Z_add_with_carry @ c @ x @ y))%expr (at level 50) : expr_scope. - - Print mulmod. -(* -mulmod = fun var : type -> Type => λ x x0 : var (type.base (base.type.list (base.type.type_base base.type.Z))), - expr_let x1 := mul64 ((uint64)(x[[2]]), (uint64)(x0[[2]])) in - expr_let x2 := mul64 ((uint64)(x[[2]]), (uint64)(x0[[1]])) in - expr_let x3 := mul64 ((uint64)(x[[2]]), (uint64)(x0[[0]])) in - expr_let x4 := mul64 ((uint64)(x[[1]]), (uint64)(x0[[2]])) in - expr_let x5 := mul64 ((uint64)(x[[1]]), (uint64)(x0[[1]])) in - expr_let x6 := mul64 ((uint64)(x[[1]]), (uint64)(x0[[0]])) in - expr_let x7 := mul64 ((uint64)(x[[0]]), (uint64)(x0[[2]])) in - expr_let x8 := mul64 ((uint64)(x[[0]]), (uint64)(x0[[1]])) in - expr_let x9 := mul64 ((uint64)(x[[0]]), (uint64)(x0[[0]])) in - expr_let x10 := add64 (x1₂, x9₂) in - expr_let x11 := adc64 (x10₂, 0, x8₂) in - expr_let x12 := add64 (x1₁, x10₁) in - expr_let x13 := adc64 (x12₂, 0, x11₁) in - expr_let x14 := add64 (x2₂, x12₁) in - expr_let x15 := adc64 (x14₂, 0, x13₁) in - expr_let x16 := add64 (x4₂, x14₁) in - expr_let x17 := adc64 (x16₂, x1₂, x15₁) in - expr_let x18 := add64 (x2₁, x16₁) in - expr_let x19 := adc64 (x18₂, x1₁, x17₁) in - expr_let x20 := add64 (x1₂, x9₁) in - expr_let x21 := adc64 (x20₂, x3₂, x18₁) in - expr_let x22 := adc64 (x21₂, x2₂, x19₁) in - expr_let x23 := add64 (x2₁, x20₁) in - expr_let x24 := adc64 (x23₂, x4₁, x21₁) in - expr_let x25 := adc64 (x24₂, x4₂, x22₁) in - expr_let x26 := add64 (x3₂, x23₁) in - expr_let x27 := adc64 (x26₂, x5₂, x24₁) in - expr_let x28 := adc64 (x27₂, x3₁, x25₁) in - expr_let x29 := add64 (x4₁, x26₁) in - expr_let x30 := adc64 (x29₂, x7₂, x27₁) in - expr_let x31 := adc64 (x30₂, x5₁, x28₁) in - expr_let x32 := add64 (x5₂, x29₁) in - expr_let x33 := adc64 (x32₂, x6₁, x30₁) in - expr_let x34 := adc64 (x33₂, x6₂, x31₁) in - expr_let x35 := add64 (x7₂, x32₁) in - expr_let x36 := adc64 (x35₂, x8₁, x33₁) in - expr_let x37 := adc64 (x36₂, x7₁, x34₁) in - [x35₁; x36₁; x37₁] - : Expr (type.base (base.type.list (base.type.type_base base.type.Z)) -> type.base (base.type.list (base.type.type_base base.type.Z)) -> type.base (base.type.list (base.type.type_base base.type.Z)))%ptype -*) - -End P192_64. - *) - -(** TODO: Figure out if this belongs here *) -Module PrintingNotations. - Export ident. - (*Global Set Printing Width 100000.*) - Open Scope zrange_scope. - Notation "'uint256'" - := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : zrange_scope. - Notation "'uint128'" - := (r[0 ~> 340282366920938463463374607431768211455]%zrange) : zrange_scope. - Notation "'uint64'" - := (r[0 ~> 18446744073709551615]) : zrange_scope. - Notation "'uint32'" - := (r[0 ~> 4294967295]) : zrange_scope. - Notation "'bool'" - := (r[0 ~> 1]%zrange) : zrange_scope. - Notation "( range )( ls [[ n ]] )" - := ((#(ident.Z_cast range) @ (ls [[ n ]]))%expr) - (format "( range )( ls [[ n ]] )") : expr_scope. - (*Notation "( range )( v )" := (ident.Z_cast range @@ v)%expr : expr_scope.*) - Notation "x *₂₅₆ y" - := (#(ident.Z_cast uint256) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. - Notation "x *₁₂₈ y" - := (#(ident.Z_cast uint128) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. - Notation "x *₆₄ y" - := (#(ident.Z_cast uint64) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. - Notation "x *₃₂ y" - := (#(ident.Z_cast uint32) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. - Notation "x +₂₅₆ y" - := (#(ident.Z_cast uint256) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. - Notation "x +₁₂₈ y" - := (#(ident.Z_cast uint128) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. - Notation "x +₆₄ y" - := (#(ident.Z_cast uint64) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. - Notation "x +₃₂ y" - := (#(ident.Z_cast uint32) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. - Notation "x -₁₂₈ y" - := (#(ident.Z_cast uint128) @ (#ident.Z_sub @ x @ y))%expr (at level 50) : expr_scope. - Notation "x -₆₄ y" - := (#(ident.Z_cast uint64) @ (#ident.Z_sub @ x @ y))%expr (at level 50) : expr_scope. - Notation "x -₃₂ y" - := (#(ident.Z_cast uint32) @ (#ident.Z_sub @ x @ y))%expr (at level 50) : expr_scope. - Notation "( out_t )( v >> count )" - := ((#(ident.Z_cast out_t) @ (#ident.Z_shiftr @ v @ count))%expr) - (format "( out_t )( v >> count )") : expr_scope. - Notation "( out_t )( v << count )" - := ((#(ident.Z_cast out_t) @ (#ident.Z_shiftl @ v @ count))%expr) - (format "( out_t )( v << count )") : expr_scope. - Notation "( range )( v )" - := ((#(ident.Z_cast range) @ $v)%expr) - (format "( range )( v )") : expr_scope. - Notation "( mask & ( out_t )( v ) )" - := ((#(ident.Z_cast out_t) @ (#ident.Z_land @ #(ident.Literal (t:=base.type.Z) mask) @ v))%expr) - (format "( mask & ( out_t )( v ) )") - : expr_scope. - Notation "( ( out_t )( v ) & mask )" - := ((#(ident.Z_cast out_t) @ (#ident.Z_land @ v @ #(ident.Literal (t:=base.type.Z) mask)))%expr) - (format "( ( out_t )( v ) & mask )") - : expr_scope. - - Notation "x" := (#(ident.Z_cast _) @ $x)%expr (only printing, at level 9) : expr_scope. - Notation "x" := (#(ident.Z_cast2 _) @ $x)%expr (only printing, at level 9) : expr_scope. - Notation "v ₁" := (#ident.fst @ $v)%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (#ident.snd @ $v)%expr (at level 10, format "v ₂") : expr_scope. - Notation "v ₁" := (#(ident.Z_cast _) @ (#ident.fst @ $v))%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (#(ident.Z_cast _) @ (#ident.snd @ $v))%expr (at level 10, format "v ₂") : expr_scope. - Notation "v ₁" := (#(ident.Z_cast _) @ (#ident.fst @ (#(ident.Z_cast2 _) @ $v)))%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (#(ident.Z_cast _) @ (#ident.snd @ (#(ident.Z_cast2 _) @ $v)))%expr (at level 10, format "v ₂") : expr_scope. - Notation "x" := (#(ident.Literal x%Z))%expr (only printing) : expr_scope. - - (*Notation "ls [[ n ]]" := (List.nth_default_concrete _ n @@ ls)%expr : expr_scope. - Notation "( range )( v )" := (ident.Z_cast range @@ v)%expr : expr_scope. - Notation "x *₁₂₈ y" - := (ident.Z_cast uint128 @@ (ident.Z.mul (x, y)))%expr (at level 40) : expr_scope. - Notation "( out_t )( v >> count )" - := (ident.Z_cast out_t (ident.Z.shiftr count @@ v)%expr) - (format "( out_t )( v >> count )") : expr_scope. - Notation "( out_t )( v >> count )" - := (ident.Z_cast out_t (ident.Z.shiftr count @@ v)%expr) - (format "( out_t )( v >> count )") : expr_scope. - Notation "v ₁" := (ident.fst @@ v)%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (ident.snd @@ v)%expr (at level 10, format "v ₂") : expr_scope.*) - (* - Notation "'ℤ'" - := BoundsAnalysis.type.Z : zrange_scope. - Notation "ls [[ n ]]" := (List.nth n @@ ls)%nexpr : nexpr_scope. - Notation "x *₆₄₋₆₄₋₁₂₈ y" - := (mul uint64 uint64 uint128 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₆₄₋₆₄₋₆₄ y" - := (mul uint64 uint64 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₃₂₋₃₂ y" - := (mul uint32 uint32 uint32 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₁₂₈₋₁₂₈ y" - := (mul uint32 uint128 uint128 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₆₄₋₆₄ y" - := (mul uint32 uint64 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₃₂₋₆₄ y" - := (mul uint32 uint32 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x +₁₂₈ y" - := (add uint128 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₆₄₋₁₂₈₋₁₂₈ y" - := (add uint64 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₃₂₋₆₄₋₆₄ y" - := (add uint32 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₆₄ y" - := (add uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₃₂ y" - := (add uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₁₂₈ y" - := (sub uint128 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₆₄₋₁₂₈₋₁₂₈ y" - := (sub uint64 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₃₂₋₆₄₋₆₄ y" - := (sub uint32 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₆₄ y" - := (sub uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₃₂ y" - := (sub uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x" := ({| BoundsAnalysis.type.value := x |}) (only printing) : nexpr_scope. - Notation "( out_t )( v >> count )" - := ((shiftr _ out_t count @@ v)%nexpr) - (format "( out_t )( v >> count )") - : nexpr_scope. - Notation "( out_t )( v << count )" - := ((shiftl _ out_t count @@ v)%nexpr) - (format "( out_t )( v << count )") - : nexpr_scope. - Notation "( ( out_t ) v & mask )" - := ((land _ out_t mask @@ v)%nexpr) - (format "( ( out_t ) v & mask )") - : nexpr_scope. -*) - (* TODO: come up with a better notation for arithmetic with carries - that still distinguishes it from arithmetic without carries? *) - Local Notation "'TwoPow256'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 (only parsing). - Notation "'ADD_256' ( x , y )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#ident.Z_add_get_carry @ #(ident.Literal (t:=base.type.Z) TwoPow256) @ x @ y))%expr : expr_scope. - Notation "'ADD_128' ( x , y )" := (#(ident.Z_cast2 (uint128, bool)%core) @ (#ident.Z_add_get_carry @ #(ident.Literal (t:=base.type.Z) TwoPow256) @ x @ y))%expr : expr_scope. - Notation "'ADDC_256' ( x , y , z )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#ident.Z_add_with_get_carry @ #(ident.Literal (t:=base.type.Z) TwoPow256) @ x @ y @ z))%expr : expr_scope. - Notation "'ADDC_128' ( x , y , z )" := (#(ident.Z_cast2 (uint128, bool)%core) @ (#ident.Z_add_with_get_carry @ #(ident.Literal (t:=base.type.Z) TwoPow256) @ x @ y @ z))%expr : expr_scope. - Notation "'SUB_256' ( x , y )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#ident.Z_sub_get_borrow @ #(ident.Literal (t:=base.type.Z) TwoPow256) @ x @ y))%expr : expr_scope. - Notation "'SUBB_256' ( x , y , z )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#ident.Z_sub_with_get_borrow @ #(ident.Literal (t:=base.type.Z) TwoPow256) @ x @ y @ z))%expr : expr_scope. - Notation "'ADDM' ( x , y , z )" := (#(ident.Z_cast uint256) @ (#ident.Z_add_modulo @ x @ y @ z))%expr : expr_scope. - Notation "'RSHI' ( x , y , z )" := (#(ident.Z_cast _) @ (#ident.Z_rshi @ _ @ x @ y @ z))%expr : expr_scope. - Notation "'SELC' ( x , y , z )" := (#(ident.Z_cast uint256) @ (ident.Z_zselect @ x @ y @ z))%expr : expr_scope. - Notation "'SELM' ( x , y , z )" := (#(ident.Z_cast uint256) @ (ident.Z_zselect @ (#(Z_cast bool) @ (#Z_cc_m @ _) @ x) @ y @ z))%expr : expr_scope. - Notation "'SELL' ( x , y , z )" := (#(ident.Z_cast uint256) @ (#ident.Z_zselect @ (#(Z_cast bool) @ (#Z_land @ #(ident.Literal (t:=base.type.Z 1)) @ x)) @ y @ z))%expr : expr_scope. -End PrintingNotations. - -Module Fancy. - - Module CC. - Inductive code : Type := - | C : code - | M : code - | L : code - | Z : code - . - - Record state := - { cc_c : bool; cc_m : bool; cc_l : bool; cc_z : bool }. - - Definition code_dec (x y : code) : {x = y} + {x <> y}. - Proof. destruct x, y; try apply (left eq_refl); right; congruence. Defined. - - Definition update (to_write : list code) (result : BinInt.Z) (cc_spec : code -> BinInt.Z -> bool) (old_state : state) - : state := - {| - cc_c := if (In_dec code_dec C to_write) - then cc_spec C result - else old_state.(cc_c); - cc_m := if (In_dec code_dec M to_write) - then cc_spec M result - else old_state.(cc_m); - cc_l := if (In_dec code_dec L to_write) - then cc_spec L result - else old_state.(cc_l); - cc_z := if (In_dec code_dec Z to_write) - then cc_spec Z result - else old_state.(cc_z) - |}. - - End CC. - - Record instruction := - { - num_source_regs : nat; - writes_conditions : list CC.code; - spec : tuple Z num_source_regs -> CC.state -> Z - }. - - Section expr. - Context {name : Type} (name_eqb : name -> name -> bool) (wordmax : Z) (cc_spec : CC.code -> Z -> bool). - - Inductive expr := - | Ret : name -> expr - | Instr (i : instruction) - (rd : name) (* destination register *) - (args : tuple name i.(num_source_regs)) (* source registers *) - (cont : expr) (* next line *) - : expr - . - - Fixpoint interp (e : expr) (cc : CC.state) (ctx : name -> Z) : Z := - match e with - | Ret n => ctx n - | Instr i rd args cont => - let result := i.(spec) (Tuple.map ctx args) cc in - let new_cc := CC.update i.(writes_conditions) result cc_spec cc in - let new_ctx := (fun n => if name_eqb n rd then result mod wordmax else ctx n) in - interp cont new_cc new_ctx - end. - End expr. - - Section ISA. - Import CC. - - Definition cc_spec (x : CC.code) (result : BinInt.Z) : bool := - match x with - | CC.C => Z.testbit result 256 (* carry bit *) - | CC.M => Z.testbit result 255 (* most significant bit *) - | CC.L => Z.testbit result 0 (* least significant bit *) - | CC.Z => result =? 0 (* whether equal to zero *) - end. - - Local Definition lower128 x := (Z.land x (Z.ones 128)). - Local Definition upper128 x := (Z.shiftr x 128). - Local Notation "x '[C]'" := (if x.(cc_c) then 1 else 0) (at level 20). - Local Notation "x '[M]'" := (if x.(cc_m) then 1 else 0) (at level 20). - Local Notation "x '[L]'" := (if x.(cc_l) then 1 else 0) (at level 20). - Local Notation "x '[Z]'" := (if x.(cc_z) then 1 else 0) (at level 20). - Local Notation "'int'" := (BinInt.Z). - Local Notation "x << y" := ((x << y) mod (2^256)) : Z_scope. (* truncating left shift *) - - - (* Note: In the specification document, argument order gets a bit - confusing. Like here, r0 is always the first argument "source 0" - and r1 the second. But the specification of MUL128LU is: - (R[RS1][127:0] * R[RS0][255:128]) - - while the specification of SUB is: - (R[RS0] - shift(R[RS1], imm)) - - In the SUB case, r0 is really treated the first argument, but in - MUL128LU the order seems to be reversed; rather than low-high, we - take the high part of the first argument r0 and the low parts of - r1. This is also true for MUL128UL. *) - - Definition ADD (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 + (r1 << imm)) - |}. - - Definition ADDC (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 + (r1 << imm) + cc[C]) - |}. - - Definition SUB (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 - (r1 << imm)) - |}. - - Definition SUBC (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 - (r1 << imm) - cc[C]) - |}. - - - Definition MUL128LL : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (lower128 r0) * (lower128 r1)) - |}. - - Definition MUL128LU : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (lower128 r1) * (upper128 r0)) (* see note *) - |}. - - Definition MUL128UL : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (upper128 r1) * (lower128 r0)) (* see note *) - |}. - - Definition MUL128UU : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (upper128 r0) * (upper128 r1)) - |}. - - (* Note : Unlike the other operations, the output of RSHI is - truncated in the specification. This is not strictly necessary, - since the interpretation function truncates the output - anyway. However, it is useful to make the definition line up - exactly with Z.rshi. *) - Definition RSHI (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (((2^256 * r0) + r1) >> imm) mod (2^256)) - |}. - - Definition SELC : instruction := - {| - num_source_regs := 2; - writes_conditions := []; - spec := (fun '(r0, r1) cc => - if cc[C] =? 1 then r0 else r1) - |}. - - Definition SELM : instruction := - {| - num_source_regs := 2; - writes_conditions := []; - spec := (fun '(r0, r1) cc => - if cc[M] =? 1 then r0 else r1) - |}. - - Definition SELL : instruction := - {| - num_source_regs := 2; - writes_conditions := []; - spec := (fun '(r0, r1) cc => - if cc[L] =? 1 then r0 else r1) - |}. - - (* TODO : treat the MOD register specially, like CC *) - Definition ADDM : instruction := - {| - num_source_regs := 3; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1, MOD) cc => - let ra := r0 + r1 in - if ra >=? MOD - then ra - MOD - else ra) - |}. - - End ISA. - - Module Registers. - Inductive register : Type := - | r0 : register - | r1 : register - | r2 : register - | r3 : register - | r4 : register - | r5 : register - | r6 : register - | r7 : register - | r8 : register - | r9 : register - | r10 : register - | r11 : register - | r12 : register - | r13 : register - | r14 : register - | r15 : register - | r16 : register - | r17 : register - | r18 : register - | r19 : register - | r20 : register - | r21 : register - | r22 : register - | r23 : register - | r24 : register - | r25 : register - | r26 : register - | r27 : register - | r28 : register - | r29 : register - | r30 : register - | RegZero : register (* r31 *) - | RegMod : register - . - - Definition reg_dec (x y : register) : {x = y} + {x <> y}. - Proof. destruct x, y; try (apply left; congruence); right; congruence. Defined. - Definition reg_eqb x y := if reg_dec x y then true else false. - - Lemma reg_eqb_neq x y : x <> y -> reg_eqb x y = false. - Proof. cbv [reg_eqb]; break_match; congruence. Qed. - Lemma reg_eqb_refl x : reg_eqb x x = true. - Proof. cbv [reg_eqb]; break_match; congruence. Qed. - End Registers. - - Section of_prefancy. - Local Notation cexpr := (@Compilers.expr.expr base.type ident.ident). - Local Notation LetInAppIdentZ S D r eidc x f - := (expr.LetIn - (A:=type.base (base.type.type_base base.type.Z)) - (B:=type.base D) - (expr.App - (s:=type.base (base.type.type_base base.type.Z)) - (d:=type.base (base.type.type_base base.type.Z)) - (expr.Ident (ident.Z_cast r)) - (expr.App - (s:=type.base S) - (d:=type.base (base.type.type_base base.type.Z)) - eidc - x)) - f). - Local Notation LetInAppIdentZZ S D r eidc x f - := (expr.LetIn - (A:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) - (B:=type.base D) - (expr.App - (s:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) - (d:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) - (expr.Ident (ident.Z_cast2 r)) - (expr.App - (s:=type.base S) - (d:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) - eidc - x)) - f). - Context (name : Type) (name_succ : name -> name) (error : name) (consts : Z -> option name). - - Fixpoint base_var (t : base.type) : Type := - match t with - | base.type.Z => name - | base.type.prod a b => base_var a * base_var b - | _ => unit - end. - Fixpoint var (t : type.type base.type) : Type := - match t with - | type.base t => base_var t - | type.arrow s d => var s -> var d - end. - Fixpoint base_error {t} : base_var t - := match t with - | base.type.Z => error - | base.type.prod A B => (@base_error A, @base_error B) - | _ => tt - end. - Fixpoint make_error {t} : var t - := match t with - | type.base _ => base_error - | type.arrow s d => fun _ => @make_error d - end. - - Fixpoint of_prefancy_scalar {t} (s : @cexpr var t) : var t - := match s in expr.expr t return var t with - | Compilers.expr.Var t v => v - | expr.App s d f x => @of_prefancy_scalar _ f (@of_prefancy_scalar _ x) - | expr.Ident t idc - => match idc in ident.ident t return var t with - | ident.Literal base.type.Z v => match consts v with - | Some n => n - | None => error - end - | ident.pair A B => fun a b => (a, b)%core - | ident.fst A B => fun v => fst v - | ident.snd A B => fun v => snd v - | ident.Z_cast r => fun v => v - | ident.Z_cast2 (r1, r2) => fun v => v - | ident.Z_land => fun x y => x - | _ => make_error - end - | expr.Abs s d f => make_error - | expr.LetIn A B x f => make_error - end%expr_pat%etype. - - (* Note : some argument orders are reversed for MUL128LU, MUL128UL, SELC, SELM, and SELL *) - Local Notation tZ := base.type.Z. - Definition of_prefancy_ident {s d : base.type} (idc : ident.ident (s -> d)) - : @cexpr var s -> option {i : instruction & tuple name i.(num_source_regs) } := - match idc in ident.ident t return match t return Type with - | type.arrow (type.base s) (type.base d) - => @cexpr var s - | _ => unit - end - -> option {i : instruction & tuple name i.(num_source_regs) } - with - | ident.fancy_add log2wordmax imm - => fun args : @cexpr var (tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ (ADD imm) (of_prefancy_scalar args)) - else None - | ident.fancy_addc log2wordmax imm - => fun args : @cexpr var (tZ * tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ (ADDC imm) (of_prefancy_scalar ((#ident.snd @ (#ident.fst @ args)), (#ident.snd @ args)))) - else None - | ident.fancy_sub log2wordmax imm - => fun args : @cexpr var (tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ (SUB imm) (of_prefancy_scalar args)) - else None - | ident.fancy_subb log2wordmax imm - => fun args : @cexpr var (tZ * tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ (SUBC imm) (of_prefancy_scalar ((#ident.snd @ (#ident.fst @ args)), (#ident.snd @ args)))) - else None - | ident.fancy_mulll log2wordmax - => fun args : @cexpr var (tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ MUL128LL (of_prefancy_scalar args)) - else None - | ident.fancy_mullh log2wordmax - => fun args : @cexpr var (tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ MUL128LU (of_prefancy_scalar ((#ident.snd @ args), (#ident.fst @ args)))) - else None - | ident.fancy_mulhl log2wordmax - => fun args : @cexpr var (tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ MUL128UL (of_prefancy_scalar ((#ident.snd @ args), (#ident.fst @ args)))) - else None - | ident.fancy_mulhh log2wordmax - => fun args : @cexpr var (tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ MUL128UU (of_prefancy_scalar args)) - else None - | ident.fancy_rshi log2wordmax imm - => fun args : @cexpr var (tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ (RSHI imm) (of_prefancy_scalar args)) - else None - | ident.fancy_selc - => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ SELC (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) - | ident.fancy_selm log2wordmax - => fun args : @cexpr var (tZ * tZ * tZ) => - if Z.eqb log2wordmax 256 - then Some (existT _ SELM (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) - else None - | ident.fancy_sell - => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ SELL (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) - | ident.fancy_addm - => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ ADDM (of_prefancy_scalar args)) - | _ => fun _ => None - end. - - Local Notation "x <- y ; f" := (match y with Some x => f | None => Ret error end). - Definition of_prefancy_step - (of_prefancy : forall (next_name : name) {t} (e : @cexpr var t), @expr name) - (next_name : name) {t} (e : @cexpr var t) : @expr name - := let default _ := (e' <- type.try_transport (@base.try_make_transport_cps) (@cexpr var) t tZ e; - Ret (of_prefancy_scalar e')) in - match e with - | LetInAppIdentZ s d r eidc x f - => idc <- invert_expr.invert_Ident eidc; - instr_args <- @of_prefancy_ident s tZ idc x; - let i : instruction := projT1 instr_args in - let args : tuple name i.(num_source_regs) := projT2 instr_args in - Instr i next_name args (@of_prefancy (name_succ next_name) _ (f next_name)) - | LetInAppIdentZZ s d r eidc x f - => idc <- invert_expr.invert_Ident eidc; - instr_args <- @of_prefancy_ident s (tZ * tZ) idc x; - let i : instruction := projT1 instr_args in - let args : tuple name i.(num_source_regs) := projT2 instr_args in - Instr i next_name args (@of_prefancy (name_succ next_name) _ (f (next_name, next_name))) (* the second argument is for the carry, and it will not be read from directly. *) - | _ => default tt - end. - Fixpoint of_prefancy (next_name : name) {t} (e : @cexpr var t) : @expr name - := @of_prefancy_step of_prefancy next_name t e. - - Section Proofs. - Context (name_eqb : name -> name -> bool). - Context (name_lt : name -> name -> Prop) - (name_lt_trans : forall n1 n2 n3, - name_lt n1 n2 -> name_lt n2 n3 -> name_lt n1 n3) - (name_lt_irr : forall n, ~ name_lt n n) - (name_lt_succ : forall n, name_lt n (name_succ n)) - (name_eqb_eq : forall n1 n2, name_eqb n1 n2 = true -> n1 = n2) - (name_eqb_neq : forall n1 n2, name_eqb n1 n2 = false -> n1 <> n2). - Local Notation wordmax := (2^256). - Local Notation interp := (interp name_eqb wordmax cc_spec). - Local Notation uint256 := r[0~>wordmax-1]%zrange. - Local Notation uint128 := r[0~>(2 ^ (Z.log2 wordmax / 2) - 1)]%zrange. - Definition cast_oor (r : zrange) (v : Z) := v mod (upper r + 1). - Local Notation "'existZ' x" := (existT _ (type.base (base.type.type_base tZ)) x) (at level 200). - Local Notation "'existZZ' x" := (existT _ (type.base (base.type.type_base tZ * base.type.type_base tZ)%etype) x) (at level 200). - Local Notation cinterp := (expr.interp (@ident.gen_interp cast_oor)). - Definition interp_if_Z {t} (e : cexpr t) : option Z := - option_map (expr.interp (@ident.gen_interp cast_oor) (t:=tZ)) - (type.try_transport - (@base.try_make_transport_cps) - _ _ tZ e). - - Lemma interp_if_Z_Some {t} e r : - @interp_if_Z t e = Some r -> - exists e', - (type.try_transport - (@base.try_make_transport_cps) _ _ tZ e) = Some e' /\ - expr.interp (@ident.gen_interp cast_oor) (t:=tZ) e' = r. - Proof. - clear. cbv [interp_if_Z option_map]. - break_match; inversion 1; intros. - subst; eexists. tauto. - Qed. - - Inductive valid_scalar - : @cexpr var (base.type.type_base tZ) -> Prop := - | valid_scalar_literal : - forall v n, - consts v = Some n -> - valid_scalar (expr.Ident (@ident.Literal base.type.Z v)) - | valid_scalar_Var : - forall v, - valid_scalar (expr.App (expr.Ident (ident.Z_cast uint256)) (expr.Var v)) - | valid_scalar_fst : - forall v r2, - valid_scalar - (expr.App (expr.Ident (ident.Z_cast uint256)) - (expr.App (expr.Ident (@ident.fst (base.type.type_base tZ) - (base.type.type_base tZ))) - (expr.App (expr.Ident (ident.Z_cast2 (uint256, r2))) (expr.Var v)))) - . - Inductive valid_carry - : @cexpr var (base.type.type_base tZ) -> Prop := - | valid_carry_0 : consts 0 <> None -> valid_carry (expr.Ident (@ident.Literal base.type.Z 0)) - | valid_carry_1 : consts 1 <> None -> valid_carry (expr.Ident (@ident.Literal base.type.Z 1)) - | valid_carry_snd : - forall v r2, - valid_carry - (expr.App (expr.Ident (ident.Z_cast r[0~>1])) - (expr.App (expr.Ident (@ident.snd (base.type.type_base tZ) - (base.type.type_base tZ))) - (expr.App (expr.Ident (ident.Z_cast2 (r2, r[0~>1]))) (expr.Var v)))) - . - - Fixpoint interp_base (ctx : name -> Z) (cctx : name -> bool) {t} - : base_var t -> base.interp t := - match t as t0 return base_var t0 -> base.interp t0 with - | base.type.type_base tZ => fun n => ctx n - | (base.type.type_base tZ * base.type.type_base tZ)%etype => - fun v => (ctx (fst v), Z.b2z (cctx (snd v))) - | (a * b)%etype => - fun _ => DefaultValue.type.base.default - | _ => fun _ : unit => - DefaultValue.type.base.default - end. - - Definition new_write {d} : var d -> name := - match d with - | type.base (base.type.type_base tZ) => fun r => r - | type.base (base.type.type_base tZ * base.type.type_base tZ)%etype => fst - | _ => fun _ => error - end. - Definition new_cc_to_name (old_cc_to_name : CC.code -> name) (i : instruction) - {d} (new_r : var d) (x : CC.code) : name := - if (in_dec CC.code_dec x (writes_conditions i)) - then new_write new_r - else old_cc_to_name x. - - Inductive valid_ident - : forall {s d}, - (CC.code -> name) -> (* last variables that wrote to each flag *) - (var d -> CC.code -> name) -> (* new last variables that wrote to each flag *) - ident.ident (s->d) -> @cexpr var s -> Prop := - | valid_fancy_add : - forall r imm x y, - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r (ADD imm)) (ident.fancy_add 256 imm) (x, y)%expr_pat - | valid_fancy_addc : - forall r imm c x y, - (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.C) -> - valid_carry c -> - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r (ADDC imm)) (ident.fancy_addc 256 imm) (c, x, y)%expr_pat - | valid_fancy_sub : - forall r imm x y, - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r (SUB imm)) (ident.fancy_sub 256 imm) (x, y)%expr_pat - | valid_fancy_subb : - forall r imm c x y, - (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.C) -> - valid_carry c -> - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r (SUBC imm)) (ident.fancy_subb 256 imm) (c, x, y)%expr_pat - | valid_fancy_mulll : - forall r x y, - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r MUL128LL) (ident.fancy_mulll 256) (x, y)%expr_pat - | valid_fancy_mullh : - forall r x y, - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r MUL128LU) (ident.fancy_mullh 256) (x, y)%expr_pat - | valid_fancy_mulhl : - forall r x y, - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r MUL128UL) (ident.fancy_mulhl 256) (x, y)%expr_pat - | valid_fancy_mulhh : - forall r x y, - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r MUL128UU) (ident.fancy_mulhh 256) (x, y)%expr_pat - | valid_fancy_rshi : - forall r imm x y, - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r (RSHI imm)) (ident.fancy_rshi 256 imm) (x, y)%expr_pat - | valid_fancy_selc : - forall r c x y, - (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.C) -> - valid_carry c -> - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r SELC) ident.fancy_selc (c, x, y)%expr_pat - | valid_fancy_selm : - forall r c x y, - (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.M) -> - valid_scalar c -> - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r SELM) (ident.fancy_selm 256) (c, x, y)%expr_pat - | valid_fancy_sell : - forall r c x y, - (of_prefancy_scalar (t:= base.type.type_base tZ) c = r CC.L) -> - valid_scalar c -> - valid_scalar x -> - valid_scalar y -> - valid_ident r (new_cc_to_name r SELL) ident.fancy_sell (c, x, y)%expr_pat - | valid_fancy_addm : - forall r x y m, - valid_scalar x -> - valid_scalar y -> - valid_scalar m -> - valid_ident r (new_cc_to_name r ADDM) ident.fancy_addm (x, y, m)%expr_pat - . - - Inductive valid_expr - : forall t, - (CC.code -> name) -> (* the last variables that wrote to each flag *) - @cexpr var t -> Prop := - | valid_LetInZ_loosen : - forall s d idc r rf x f u ia, - valid_ident r rf idc x -> - 0 < u < wordmax -> - (forall x, valid_expr _ (rf x) (f x)) -> - of_prefancy_ident idc x = Some ia -> - (forall cc ctx, - (forall n v, consts v = Some n -> ctx n = v) -> - (forall n, ctx n mod wordmax = ctx n) -> - let args := Tuple.map ctx (projT2 ia) in - spec (projT1 ia) args cc mod wordmax = spec (projT1 ia) args cc mod (u+1)) -> - valid_expr _ r (LetInAppIdentZ s d r[0~>u] (expr.Ident idc) x f) - | valid_LetInZ : - forall s d idc r rf x f, - valid_ident r rf idc x -> - (forall x, valid_expr _ (rf x) (f x)) -> - valid_expr _ r (LetInAppIdentZ s d uint256 (expr.Ident idc) x f) - | valid_LetInZZ : - forall s d idc r rf x f, - valid_ident r rf idc x -> - (forall x : var (type.base (base.type.type_base tZ * base.type.type_base tZ)%etype), - fst x = snd x -> - valid_expr _ (rf x) (f x)) -> - valid_expr _ r (LetInAppIdentZZ s d (uint256, r[0~>1]) (expr.Ident idc) x f) - | valid_Ret : - forall r x, - valid_scalar x -> - valid_expr _ r x - . - - Lemma cast_oor_id v u : 0 <= v <= u -> cast_oor r[0 ~> u] v = v. - Proof. intros; cbv [cast_oor upper]. apply Z.mod_small; omega. Qed. - Lemma cast_oor_mod v u : 0 <= u -> cast_oor r[0 ~> u] v mod (u+1) = v mod (u+1). - Proof. intros; cbv [cast_oor upper]. apply Z.mod_mod; omega. Qed. - - Lemma wordmax_nonneg : 0 <= wordmax. - Proof. cbv; congruence. Qed. - - Lemma of_prefancy_scalar_correct' - (e1 : @cexpr var (type.base (base.type.type_base tZ))) - (e2 : cexpr (type.base (base.type.type_base tZ))) - G (ctx : name -> Z) (cctx : name -> bool) : - valid_scalar e1 -> - LanguageWf.Compilers.expr.wf G e1 e2 -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall v1 v2, In (existZ (v1, v2)) G -> ctx v1 = v2) -> (* implied by above *) - (forall n, ctx n mod wordmax = ctx n) -> - (forall v1 v2, In (existZZ (v1, v2)) G -> ctx (fst v1) = fst v2) -> - (forall v1 v2, In (existZZ (v1, v2)) G -> Z.b2z (cctx (snd v1)) = snd v2) -> - ctx (of_prefancy_scalar e1) = cinterp e2. - Proof. - inversion 1; inversion 1; - cbv [interp_if_Z option_map]; - cbn [of_prefancy_scalar interp_base]; intros. - all: repeat first [ - progress subst - | exfalso; assumption - | progress inversion_sigma - | progress inversion_option - | progress Prod.inversion_prod - | progress LanguageInversion.Compilers.expr.inversion_expr - | progress LanguageInversion.Compilers.expr.invert_subst - | progress LanguageWf.Compilers.expr.inversion_wf_one_constr - | progress LanguageInversion.Compilers.expr.invert_match - | progress destruct_head'_sig - | progress destruct_head'_and - | progress destruct_head'_or - | progress Z.ltb_to_lt - | progress cbv [id] - | progress cbn [fst snd upper lower fst snd eq_rect projT1 projT2 expr.interp ident.interp ident.gen_interp interp_base] in * - | progress HProp.eliminate_hprop_eq - | progress break_innermost_match_hyps - | progress break_innermost_match - | match goal with H : context [_ = cinterp _] |- context [cinterp _] => - rewrite <-H by eauto; try reflexivity end - | solve [eauto using (f_equal2 pair), cast_oor_id, wordmax_nonneg] - | rewrite LanguageWf.Compilers.ident.cast_out_of_bounds_simple_0_mod - | rewrite Z.mod_mod by lia - | rewrite cast_oor_mod by (cbv; congruence) - | lia - | match goal with - H : context[ ?x mod _ = ?x ] |- _ => rewrite H end - | match goal with - | H : context [In _ _ -> _ = _] |- _ => erewrite H by eauto end - | match goal with - | H : forall v1 v2, In _ _ -> ?ctx v1 = v2 |- ?x = ?x mod ?m => - replace m with wordmax by ring; erewrite <-(H _ x) by eauto; solve [eauto] - end - | match goal with - | H : forall v1 v2, In _ _ -> ?ctx (fst v1) = fst v2, - H' : In (existZZ (_,(?x,?y))) _ |- ?x = ?x mod ?m => - replace m with wordmax by ring; - specialize (H _ _ H'); cbn [fst] in H; rewrite <-H; solve [eauto] end - ]. - Qed. - - Lemma of_prefancy_scalar_correct - (e1 : @cexpr var (type.base (base.type.type_base tZ))) - (e2 : cexpr (type.base (base.type.type_base tZ))) - G (ctx : name -> Z) cc : - valid_scalar e1 -> - LanguageWf.Compilers.expr.wf G e1 e2 -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cc v1 = v2) -> - (forall n, ctx n mod wordmax = ctx n) -> - ctx (of_prefancy_scalar e1) = cinterp e2. - Proof. - intros; match goal with H : context [interp_base _ _ _ = _] |- _ => - pose proof (H (base.type.type_base base.type.Z)); - pose proof (H (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype); cbn [interp_base] in * - end. - eapply of_prefancy_scalar_correct'; eauto; - match goal with - | H : forall _ _, In _ _ -> (_, _) = _ |- _ => - let v1 := fresh "v" in - let v2 := fresh "v" in - intros v1 v2 ?; rewrite <-(H v1 v2) by auto - end; reflexivity. - Qed. - - Lemma of_prefancy_ident_Some {s d} idc r rf x: - @valid_ident (type.base s) (type.base d) r rf idc x -> - of_prefancy_ident idc x <> None. - Proof. - induction s; inversion 1; intros; - repeat first [ - progress subst - | progress inversion_sigma - | progress cbn [eq_rect projT1 projT2 of_prefancy_ident invert_expr.invert_Ident option_map] in * - | progress Z.ltb_to_lt - | progress break_innermost_match - | progress LanguageInversion.Compilers.type.inversion_type - | progress LanguageInversion.Compilers.expr.inversion_expr - | congruence - ]. - Qed. - - Ltac name_eqb_to_eq := - repeat match goal with - | H : name_eqb _ _ = true |- _ => apply name_eqb_eq in H - | H : name_eqb _ _ = false |- _ => apply name_eqb_neq in H - end. - Ltac inversion_of_prefancy_ident := - match goal with - | H : of_prefancy_ident _ _ = None |- _ => - eapply of_prefancy_ident_Some in H; - [ contradiction | eassumption] - end. - - Local Ltac hammer := - repeat first [ - progress subst - | progress inversion_sigma - | progress inversion_option - | progress inversion_of_prefancy_ident - | progress Prod.inversion_prod - | progress cbv [id] - | progress cbn [eq_rect projT1 projT2 expr.interp ident.interp ident.gen_interp interp_base interp invert_expr.invert_Ident interp_if_Z option_map] in * - | progress LanguageInversion.Compilers.type_beq_to_eq - | progress name_eqb_to_eq - | progress LanguageInversion.Compilers.rewrite_type_transport_correct - | progress HProp.eliminate_hprop_eq - | progress break_innermost_match_hyps - | progress break_innermost_match - | progress LanguageInversion.Compilers.type.inversion_type - | progress LanguageInversion.Compilers.expr.inversion_expr - | solve [auto] - | contradiction - ]. - Ltac prove_Ret := - repeat match goal with - | H : valid_scalar (expr.LetIn _ _) |- _ => - inversion H - | _ => progress cbn [id of_prefancy of_prefancy_step of_prefancy_scalar] - | _ => progress hammer - | H : valid_scalar (expr.Ident _) |- _ => - inversion H; clear H - | |- _ = cinterp ?f (cinterp ?x) => - transitivity - (cinterp (f @ x)%expr); - [ | reflexivity ]; - erewrite <-of_prefancy_scalar_correct by (try reflexivity; eassumption) - end. - - Lemma cast_mod u v : - 0 <= u -> - ident.cast cast_oor r[0~>u] v = v mod (u + 1). - Proof. - intros. - rewrite LanguageWf.Compilers.ident.cast_out_of_bounds_simple_0_mod by auto using cast_oor_id. - cbv [cast_oor upper]. apply Z.mod_mod. omega. - Qed. - - Lemma cc_spec_c v : - Z.b2z (cc_spec CC.C v) = (v / wordmax) mod 2. - Proof. cbv [cc_spec]; apply Z.testbit_spec'. omega. Qed. - - Lemma cc_m_zselect x z nz : - x mod wordmax = x -> - (if (if cc_spec CC.M x then 1 else 0) =? 1 then nz else z) = - Z.zselect (x >> 255) z nz. - Proof. - intro Hx_small. - transitivity (if (Z.b2z (cc_spec CC.M x) =? 1) then nz else z); [ reflexivity | ]. - cbv [cc_spec Z.zselect]. - rewrite Z.testbit_spec', Z.shiftr_div_pow2 by omega. rewrite <-Hx_small. - rewrite Div.Z.div_between_0_if by (try replace (2 * (2 ^ 255)) with wordmax by reflexivity; - auto with zarith). - break_innermost_match; Z.ltb_to_lt; try rewrite Z.mod_small in * by omega; congruence. - Qed. - - Lemma cc_l_zselect x z nz : - (if (if cc_spec CC.L x then 1 else 0) =? 1 then nz else z) = Z.zselect (x &' 1) z nz. - Proof. - transitivity (if (Z.b2z (cc_spec CC.L x) =? 1) then nz else z); [ reflexivity | ]. - transitivity (Z.zselect (x &' Z.ones 1) z nz); [ | reflexivity ]. - cbv [cc_spec Z.zselect]. rewrite Z.testbit_spec', Z.land_ones by omega. - autorewrite with zsimplify_fast. rewrite Zmod_even. - break_innermost_match; Z.ltb_to_lt; congruence. - Qed. - - Lemma b2z_range b : 0<= Z.b2z b < 2. - Proof. cbv [Z.b2z]. break_match; lia. Qed. - - - Lemma of_prefancy_scalar_carry - (c : @cexpr var (type.base (base.type.type_base tZ))) - (e : cexpr (type.base (base.type.type_base tZ))) - G (ctx : name -> Z) cctx : - valid_carry c -> - LanguageWf.Compilers.expr.wf G c e -> - (forall n0, consts 0 = Some n0 -> cctx n0 = false) -> - (forall n1, consts 1 = Some n1 -> cctx n1 = true) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - Z.b2z (cctx (of_prefancy_scalar c)) = cinterp e. - Proof. - inversion 1; inversion 1; intros; hammer; cbn; - repeat match goal with - | H : context [ _ = false] |- Z.b2z _ = 0 => rewrite H; reflexivity - | H : context [ _ = true] |- Z.b2z _ = 1 => rewrite H; reflexivity - | _ => progress LanguageWf.Compilers.expr.inversion_wf_one_constr - | _ => progress cbn [fst snd] - | _ => progress destruct_head'_sig - | _ => progress destruct_head'_and - | _ => progress hammer - | _ => progress LanguageInversion.Compilers.expr.invert_subst - | _ => rewrite cast_mod by (cbv; congruence) - | _ => rewrite Z.mod_mod by omega - | _ => rewrite Z.mod_small by apply b2z_range - | H : (forall _ _ _, In _ _ -> interp_base _ _ _ = _), - H' : In (existZZ (?v, _)) _ |- context [cctx (snd ?v)] => - specialize (H _ _ _ H'); cbn in H - end. - Qed. - - Ltac simplify_ident := - repeat match goal with - | _ => progress intros - | _ => progress cbn [fst snd of_prefancy_ident] in * - | _ => progress LanguageWf.Compilers.expr.inversion_wf_one_constr - | H : { _ | _ } |- _ => destruct H - | H : _ /\ _ |- _ => destruct H - | H : upper _ = _ |- _ => rewrite H - | _ => rewrite cc_spec_c by auto - | _ => rewrite cast_mod by (cbv; congruence) - | H : _ |- _ => - apply LanguageInversion.Compilers.expr.invert_Ident_Some in H - | H : _ |- _ => - apply LanguageInversion.Compilers.expr.invert_App_Some in H - | H : ?P, H' : ?P |- _ => clear H' - | _ => progress hammer - end. - - (* TODO: zero flag is a little tricky, since the value - depends both on the stored variable and the carry if there - is one. For now, since Barrett doesn't use it, we're just - pretending it doesn't exist. *) - Definition cc_good cc cctx ctx r := - CC.cc_c cc = cctx (r CC.C) /\ - CC.cc_m cc = cc_spec CC.M (ctx (r CC.M)) /\ - CC.cc_l cc = cc_spec CC.L (ctx (r CC.L)) /\ - (forall n0 : name, consts 0 = Some n0 -> cctx n0 = false) /\ - (forall n1 : name, consts 1 = Some n1 -> cctx n1 = true). - - Lemma of_prefancy_identZ_loosen_correct {s} idc: - forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f u, - @valid_ident (type.base s) (type_base tZ) r rf idc x -> - LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> - LanguageWf.Compilers.expr.wf G #(ident.Z_cast r[0~>u]) f -> - 0 < u < wordmax -> - cc_good cc cctx ctx r -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall n, ctx n mod wordmax = ctx n) -> - of_prefancy_ident idc x = Some i -> - (spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod (u+1)) -> - spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = (cinterp f (cinterp x2)). - Proof. - Time - inversion 1; inversion 1; cbn [of_prefancy_ident]; hammer; (simplify_ident; [ ]). (* TODO : suuuuuper slow *) - all: - rewrite cast_mod by omega; - match goal with - | H : context [spec _ _ _ mod _ = _] |- ?x mod wordmax = _ mod ?m => - replace (x mod wordmax) with (x mod m) by auto - end. - all: cbn - [Z.shiftl wordmax]; cbv [cc_good] in *; destruct_head'_and; - repeat match goal with - | H : CC.cc_c _ = _ |- _ => rewrite H - | H : CC.cc_m _ = _ |- _ => rewrite H - | H : CC.cc_l _ = _ |- _ => rewrite H - | H : CC.cc_z _ = _ |- _ => rewrite H - | H: of_prefancy_scalar _ = ?r ?c |- _ => rewrite <-H - | _ => progress rewrite ?cc_m_zselect, ?cc_l_zselect by auto - | _ => progress rewrite ?Z.add_modulo_correct, ?Z.geb_leb by auto - | |- context [cinterp ?x] => - erewrite of_prefancy_scalar_correct with (e2:=x) by eauto - | |- context [cinterp ?x] => - erewrite <-of_prefancy_scalar_carry with (e:=x) by eauto - | |- context [if _ (of_prefancy_scalar _) then _ else _ ] => - cbv [Z.zselect Z.b2z]; - break_innermost_match; Z.ltb_to_lt; try reflexivity; - congruence - end; try reflexivity. - - { (* RSHI case *) - cbv [Z.rshi]. - rewrite Z.land_ones, Z.shiftl_mul_pow2 by (cbv; congruence). - change (2 ^ Z.log2 wordmax) with wordmax. - break_innermost_match; try congruence; [ ]. autorewrite with zsimplify_fast. - repeat (f_equal; try ring). } - Qed. - Lemma of_prefancy_identZ_correct {s} idc: - forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, - @valid_ident (type.base s) (type_base tZ) r rf idc x -> - LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> - LanguageWf.Compilers.expr.wf G #(ident.Z_cast uint256) f -> - cc_good cc cctx ctx r -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall n, ctx n mod wordmax = ctx n) -> - of_prefancy_ident idc x = Some i -> - spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = (cinterp f (cinterp x2)). - Proof. - intros; eapply of_prefancy_identZ_loosen_correct; try eassumption; [ | ]. - { cbn; omega. } { intros; f_equal; ring. } - Qed. - Lemma of_prefancy_identZZ_correct' {s} idc: - forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, - @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> - LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> - LanguageWf.Compilers.expr.wf G #(ident.Z_cast2 (uint256, r[0~>1])) f -> - cc_good cc cctx ctx r -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall n, ctx n mod wordmax = ctx n) -> - of_prefancy_ident idc x = Some i -> - spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = fst (cinterp f (cinterp x2)) /\ - Z.b2z (cc_spec CC.C (spec (projT1 i) (Tuple.map ctx (projT2 i)) cc)) = snd (cinterp f (cinterp x2)). - Proof. - inversion 1; inversion 1; cbn [of_prefancy_ident]; intros; hammer; (simplify_ident; [ ]); - cbn - [Z.div Z.modulo]; cbv [Z.sub_with_borrow Z.add_with_carry]; - cbv [cc_good] in *; destruct_head'_and; autorewrite with zsimplify_fast. - all: repeat match goal with - | H : CC.cc_c _ = _ |- _ => rewrite H - | H: of_prefancy_scalar _ = ?r ?c |- _ => rewrite <-H - | H : LanguageWf.Compilers.expr.wf _ ?x ?e |- context [cinterp ?e] => - erewrite <-of_prefancy_scalar_correct with (e1:=x) (e2:=e) by eauto - | H : LanguageWf.Compilers.expr.wf _ ?x ?e2 |- context [cinterp ?e2] => - erewrite <-of_prefancy_scalar_carry with (c:=x) (e:=e2) by eauto - end. - all: match goal with |- context [(?x << ?n) mod ?m] => - pose proof (Z.mod_pos_bound (x << n) m ltac:(omega)) end. - all:repeat match goal with - | |- context [if _ (of_prefancy_scalar _) then _ else _ ] => - cbv [Z.zselect Z.b2z]; break_innermost_match; Z.ltb_to_lt; try congruence; [ | ] - | _ => rewrite Z.add_opp_r - | _ => rewrite Div.Z.div_sub_small by auto with zarith - | H : forall n, ?ctx n mod wordmax = ?ctx n |- context [?ctx ?m - _] => rewrite <-(H m) - | |- ((?x - ?y - ?c) / _) mod _ = - ((- ?c + ?x - ?y) / _) mod _ => - replace (-c + x - y) with (x - (y + c)) by ring; replace (x - y - c) with (x - (y + c)) by ring - | _ => split - | _ => try apply (f_equal2 Z.modulo); try apply (f_equal2 Z.div); ring - | _ => break_innermost_match; reflexivity - end. - Qed. - Lemma of_prefancy_identZZ_correct {s} idc: - forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, - @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> - LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> - LanguageWf.Compilers.expr.wf G #(ident.Z_cast2 (uint256, r[0~>1])) f -> - cc_good cc cctx ctx r -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall n, ctx n mod wordmax = ctx n) -> - of_prefancy_ident idc x = Some i -> - spec (projT1 i) (Tuple.map ctx (projT2 i)) cc mod wordmax = fst (cinterp f (cinterp x2)). - Proof. apply of_prefancy_identZZ_correct'. Qed. - Lemma of_prefancy_identZZ_correct_carry {s} idc: - forall (x : @cexpr var _) i ctx G cc cctx x2 r rf f, - @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> - LanguageWf.Compilers.expr.wf G (#idc @ x)%expr_pat x2 -> - LanguageWf.Compilers.expr.wf G #(ident.Z_cast2 (uint256, r[0~>1])) f -> - cc_good cc cctx ctx r -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall n, ctx n mod wordmax = ctx n) -> - of_prefancy_ident idc x = Some i -> - Z.b2z (cc_spec CC.C (spec (projT1 i) (Tuple.map ctx (projT2 i)) cc)) = snd (cinterp f (cinterp x2)). - Proof. apply of_prefancy_identZZ_correct'. Qed. - - Lemma identZZ_writes {s} idc r rf x: - @valid_ident (type.base s) (type_base (tZ * tZ)) r rf idc x -> - forall i, of_prefancy_ident idc x = Some i -> - In CC.C (writes_conditions (projT1 i)). - Proof. - inversion 1; - repeat match goal with - | _ => progress intros - | _ => progress cbn [of_prefancy_ident writes_conditions ADD ADDC SUB SUBC In] in * - | _ => progress hammer; Z.ltb_to_lt - | _ => congruence - end. - Qed. - - (* Common side conditions for cases in of_prefancy_correct *) - Local Ltac side_cond := - repeat match goal with - | _ => progress intros - | _ => progress cbn [In fst snd] in * - | H : _ \/ _ |- _ => destruct H - | [H : forall _ _, In _ ?l -> _, H' : In _ ?l |- _] => - let H'' := fresh in - pose proof H'; apply H in H''; clear H - | H : name_lt ?n ?n |- _ => - specialize (name_lt_irr n); contradiction - | _ => progress hammer - | _ => solve [eauto] - end. - - Lemma interp_base_helper G next_name ctx cctx : - (forall n v2, In (existZ (n, v2)) G -> name_lt n next_name) -> - (forall n v2, In (existZZ (n, v2)) G -> name_lt (fst n) next_name) -> - (forall n v2, In (existZZ (n, v2)) G -> fst n = snd n) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> - t = base.type.type_base tZ - \/ t = (base.type.type_base tZ * base.type.type_base tZ)%etype) -> - forall t v1 v2 x xc, - In (existT (fun t : type => (var t * type.interp base.interp t)%type) (type.base t) (v1, v2)%zrange) - ((existZ (next_name, x)%zrange) :: G) -> - interp_base (fun n : name => if name_eqb n next_name then x else ctx n) - (fun n : name => if name_eqb n next_name then xc else cctx n) v1 = v2. - Proof. - intros. - repeat match goal with - | H: In _ (_ :: _) |- _ => cbn [In] in H; destruct H; [ solve [side_cond] | ] - | H : (forall t _ _, In _ ?G -> (t = _ \/ t = _)), H' : In _ ?G |- _ => - destruct (H _ _ _ H'); subst t - | H : forall _ _ _, In _ ?G -> interp_base _ _ _ = _, H' : In _ G |- _ => specialize (H _ _ _ H') - end; side_cond. - Qed. - - Lemma name_eqb_refl n : name_eqb n n = true. - Proof. case_eq (name_eqb n n); intros; name_eqb_to_eq; auto. Qed. - - Lemma valid_ident_new_cc_to_name s d r rf idc x y n : - @valid_ident (type.base s) (type.base d) r rf idc x -> - of_prefancy_ident idc x = Some y -> - rf n = new_cc_to_name r (projT1 y) n. - Proof. inversion 1; intros; hammer; simplify_ident. Qed. - - Lemma new_cc_to_name_Z_cases r i n x : - new_cc_to_name (d:=base.type.type_base tZ) r i n x - = if in_dec CC.code_dec x (writes_conditions i) - then n else r x. - Proof. reflexivity. Qed. - Lemma new_cc_to_name_ZZ_cases r i n x : - new_cc_to_name (d:=base.type.type_base tZ * base.type.type_base tZ) r i n x - = if in_dec CC.code_dec x (writes_conditions i) - then fst n else r x. - Proof. reflexivity. Qed. - - Lemma cc_good_helper cc cctx ctx r i x next_name : - (forall c, name_lt (r c) next_name) -> - (forall n v, consts v = Some n -> name_lt n next_name) -> - cc_good cc cctx ctx r -> - cc_good (CC.update (writes_conditions i) x cc_spec cc) - (fun n : name => - if name_eqb n next_name - then CC.cc_c (CC.update (writes_conditions i) x cc_spec cc) - else cctx n) - (fun n : name => if name_eqb n next_name then x mod wordmax else ctx n) - (new_cc_to_name (d:=base.type.type_base tZ) r i next_name). - Proof. - cbv [cc_good]; intros; destruct_head'_and. - rewrite !new_cc_to_name_Z_cases. - cbv [CC.update CC.cc_c CC.cc_m CC.cc_l CC.cc_z]. - repeat match goal with - | _ => split; intros - | _ => progress hammer - | H : forall c, name_lt (r c) (r ?c2) |- _ => specialize (H c2) - | H : (forall n v, consts v = Some n -> name_lt _ _), - H' : consts _ = Some _ |- _ => specialize (H _ _ H') - | H : name_lt ?n ?n |- _ => apply name_lt_irr in H; contradiction - | _ => cbv [cc_spec]; rewrite Z.mod_pow2_bits_low by omega - | _ => congruence - end. - Qed. - - Lemma of_prefancy_correct - {t} (e1 : @cexpr var t) (e2 : @cexpr _ t) r : - valid_expr _ r e1 -> - forall G, - LanguageWf.Compilers.expr.wf G e1 e2 -> - forall ctx cc cctx, - cc_good cc cctx ctx r -> - (forall n v, consts v = Some n -> In (existZ (n, v)) G) -> - (forall n v2, In (existZZ (n, v2)) G -> fst n = snd n) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> interp_base ctx cctx v1 = v2) -> - (forall t v1 v2, In (existT _ (type.base t) (v1, v2)) G -> - t = base.type.type_base tZ - \/ t = (base.type.type_base tZ * base.type.type_base tZ)%etype) -> - (forall n, ctx n mod wordmax = ctx n) -> - forall next_name result, - (forall c : CC.code, name_lt (r c) next_name) -> - (forall n v2, In (existZ (n, v2)) G -> name_lt n next_name) -> - (forall n v2, In (existZZ (n, v2)) G -> name_lt (fst n) next_name) -> - (interp_if_Z e2 = Some result) -> - interp (@of_prefancy next_name t e1) cc ctx = result. - Proof. - induction 1; inversion 1; cbv [interp_if_Z]; - cbn [of_prefancy of_prefancy_step]; intros; - match goal with H : context [interp_base _ _ _ = _] |- _ => - pose proof (H (base.type.type_base base.type.Z)) end; - try solve [prove_Ret]; [ | | ]; hammer; - match goal with - | H : context [interp (of_prefancy _ _) _ _ = _] - |- interp _ ?cc' ?ctx' = _ => - match goal with - | _ : context [LetInAppIdentZ _ _ _ _ _ _] |- _=> - erewrite H with - (G := (existZ (next_name, ctx' next_name)) :: G) - (e2 := _ (ctx' next_name)) - (cctx := (fun n => if name_eqb n next_name then CC.cc_c cc' else cctx n)) - | _ : context [LetInAppIdentZZ _ _ _ _ _ _] |- _=> - erewrite H with - (G := (existZZ ((next_name, next_name), (ctx' next_name, Z.b2z (CC.cc_c cc')))) :: G) - (e2 := _ (ctx' next_name, Z.b2z (CC.cc_c cc'))) - (cctx := (fun n => if name_eqb n next_name then CC.cc_c cc' else cctx n)) - end - end; - repeat match goal with - | _ => progress intros - | _ => rewrite name_eqb_refl in * - | _ => rewrite Z.testbit_spec' in * - | _ => erewrite valid_ident_new_cc_to_name by eassumption - | _ => rewrite new_cc_to_name_Z_cases - | _ => rewrite new_cc_to_name_ZZ_cases - | _ => solve [intros; eapply interp_base_helper; side_cond] - | _ => solve [intros; apply cc_good_helper; eauto] - | _ => reflexivity - | _ => solve [eauto using Z.mod_small, b2z_range] - | _ => progress autorewrite with zsimplify_fast - | _ => progress side_cond - end; [ | | ]. - { cbn - [cc_spec]; cbv [id]; cbn - [cc_spec]. - inversion wf_x; hammer. - erewrite of_prefancy_identZ_loosen_correct by eauto. - reflexivity. } - { cbn - [cc_spec]; cbv [id]; cbn - [cc_spec]. - inversion wf_x; hammer. - erewrite of_prefancy_identZ_correct by eassumption. - reflexivity. } - { cbn - [cc_spec]; cbv [id]; cbn - [cc_spec]. - match goal with H : _ |- _ => pose proof H; eapply identZZ_writes in H; [ | eassumption] end. - inversion wf_x; hammer. - erewrite of_prefancy_identZZ_correct by eassumption. - erewrite of_prefancy_identZZ_correct_carry by eassumption. - rewrite <-surjective_pairing. reflexivity. } - Qed. - End Proofs. - End of_prefancy. - - Section allocate_registers. - Context (reg name : Type) (name_eqb : name -> name -> bool) (error : reg). - Fixpoint allocate (e : @expr name) (reg_list : list reg) (name_to_reg : name -> reg) : @expr reg := - match e with - | Ret n => Ret (name_to_reg n) - | Instr i rd args cont => - match reg_list with - | r :: reg_list' => Instr i r (Tuple.map name_to_reg args) (allocate cont reg_list' (fun n => if name_eqb n rd then r else name_to_reg n)) - | nil => Ret error - end - end. - End allocate_registers. - - Definition test_prog : @expr positive := - Instr (ADD (128)) 3%positive (1, 2)%positive - (Instr (ADDC 0) 4%positive (3,1)%positive - (Ret 4%positive)). - - Definition x1 := 2^256 - 1. - Definition x2 := 2^128 - 1. - Definition wordmax := 2^256. - Definition expected := - let r3' := (x1 + (x2 << 128)) in - let r3 := r3' mod wordmax in - let c := r3' / wordmax in - let r4' := (r3 + x1 + c) in - r4' mod wordmax. - Definition actual := - interp Pos.eqb - (2^256) cc_spec test_prog {|CC.cc_c:=false; CC.cc_m:=false; CC.cc_l:=false; CC.cc_z:=false|} - (fun n => if n =? 1%positive - then x1 - else if n =? 2%positive - then x2 - else 0). - Lemma test_prog_ok : expected = actual. - Proof. reflexivity. Qed. - - Definition of_Expr {t} next_name (consts : Z -> option positive) - (e : expr.Expr t) - (x : type.for_each_lhs_of_arrow (var positive) t) - : positive -> @expr positive := - fun error => - @of_prefancy positive Pos.succ error consts next_name _ (invert_expr.smart_App_curried (e _) x). - - Section Proofs. - - Section with_name. - Context (name : Type) (name_eqb : name -> name -> bool) - (name_succ : name -> name) (error : name) - (consts : Z -> option name) (wordmax : Z) - (cc_spec : CC.code -> Z -> bool). - - - Context (reg : Type) (error_reg : reg) (reg_eqb : reg -> reg -> bool). - Context (reg_eqb_refl : forall r, reg_eqb r r = true). - - Inductive error_free : @expr reg -> Prop := - | error_free_Ret : forall r, r <> error_reg -> error_free (Ret r) - | error_free_Instr : forall i rd args cont, - error_free cont -> - error_free (Instr i rd args cont) - . - - Lemma allocate_correct e : - forall cc ctx reg_list name_to_reg, - error_free (allocate reg name name_eqb error_reg e reg_list name_to_reg) -> - interp reg_eqb wordmax cc_spec (allocate reg name name_eqb error_reg e reg_list name_to_reg) cc ctx - = interp name_eqb wordmax cc_spec e cc (fun n : name => ctx (name_to_reg n)). - Proof. - induction e; destruct reg_list; inversion 1; intros; - try reflexivity; try congruence; [ ]. - cbn. rewrite IHe by auto. - rewrite Tuple.map_map. - (* - Need to prove that contexts are equivalent and swapping contexts is OK - *) - (* - TODO : either prove this lemma or devise a good way to - prove case-by-case that the output of allocate is - equivalent to the input. - *) - Admitted. - End with_name. - - Fixpoint var_pairs {t var1 var2} - : type.for_each_lhs_of_arrow var1 t - -> type.for_each_lhs_of_arrow var2 t - -> list {t : Compilers.type base.type.type & (var1 t * var2 t)%type } := - match t as t0 return - (type.for_each_lhs_of_arrow var1 t0 - -> type.for_each_lhs_of_arrow var2 t0 -> _) with - | type.base _ => fun _ _ => nil - | (s -> d)%ptype => - fun x1 x2 => - existT _ _ (fst x1, fst x2) :: var_pairs (snd x1) (snd x2) - end. - - Local Notation existZ := (existT _ (type.base (base.type.type_base base.type.Z))). - Local Notation existZZ := (existT _ (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype)). - - Fixpoint make_ctx (var_list : list (positive * Z)) : positive -> Z := - match var_list with - | [] => fun _ => 0 - | (n, v) :: l' => fun m => if (m =? n)%positive then v else make_ctx l' m - end. - - Definition make_pairs : - list (positive * Z) -> list {t : Compilers.type base.type.type & (var positive t * @type.interp base.type base.interp t)%type } := map (fun x => existZ x). - - Fixpoint make_consts (consts_list : list (positive * Z)) : Z -> option positive := - match consts_list with - | [] => fun _ => None - | (n, v) :: l' => fun x => if x =? v then Some n else make_consts l' x - end. - - Local Ltac ez := - repeat match goal with - | _ => progress intros - | _ => progress subst - | H : _ \/ _ |- _ => destruct H - | H : _ |- _ => rewrite Pos.eqb_eq in H - | H : _ |- _ => rewrite Pos.eqb_neq in H - | _ => progress break_innermost_match - | _ => progress break_match_hyps - | _ => progress inversion_sigma - | _ => progress inversion_option - | _ => progress Prod.inversion_prod - | _ => progress HProp.eliminate_hprop_eq - | _ => progress Z.ltb_to_lt - | _ => reflexivity - | _ => congruence - | _ => solve [eauto] - end. - - - Lemma make_consts_ok consts_list n v : - make_consts consts_list v = Some n -> - In (existZ (n, v)%zrange) (make_pairs consts_list). - Proof. - cbv [make_pairs]; induction consts_list as [|[ ? ? ] ?]; cbn; ez. - Qed. - - Lemma make_pairs_ok consts_list: - forall v1 v2, - In (existZ (v1, v2)%zrange) (make_pairs consts_list) -> - In (v1, v2) consts_list. - Proof. - cbv [make_pairs]. induction consts_list as [| [ n v ] ? ]; cbn; [ tauto | ]. ez. - Qed. - Lemma make_ctx_ok consts_list: - (forall n v1 v2, In (n, v1) consts_list -> - In (n, v2) consts_list -> v1 = v2) -> - forall n v, - In (n, v) consts_list -> - make_ctx consts_list n = v. - Proof. - induction consts_list as [| [ n v ] ? ]; cbn; [ tauto | ]. - repeat match goal with - | _ => progress cbn [eq_rect fst snd] in * - | _ => progress ez - end. - Qed. - - Lemma make_ctx_cases consts_list n : - make_ctx consts_list n = 0 \/ - In (n, make_ctx consts_list n) consts_list. - Proof. induction consts_list; cbn; ez. Qed. - - Lemma only_integers consts_list t v1 v2 : - In (existT (fun t : type => (var positive t * type.interp base.interp t)%type) (type.base t) - (v1, v2)%zrange) (make_pairs consts_list) -> - t = base.type.type_base base.type.Z. - Proof. - induction consts_list; cbn; [ tauto | ]. - destruct 1; congruence || tauto. - Qed. - - Lemma no_pairs consts_list v1 v2 : - In (existZZ (v1, v2)%zrange) (make_pairs consts_list) -> False. - Proof. intro H; apply only_integers in H. congruence. Qed. - - - Definition make_cc last_wrote ctx carry_flag : CC.state := - {| CC.cc_c := carry_flag; - CC.cc_m := cc_spec CC.M (ctx (last_wrote CC.M)); - CC.cc_l := cc_spec CC.L (ctx (last_wrote CC.L)); - CC.cc_z := cc_spec CC.Z (ctx (last_wrote CC.Z) - + (if (last_wrote CC.C =? last_wrote CC.Z)%positive - then wordmax * Z.b2z carry_flag else 0)); - |}. - - - Hint Resolve Pos.lt_trans Pos.lt_irrefl Pos.lt_succ_diag_r Pos.eqb_refl. - Hint Resolve in_or_app. - Hint Resolve make_consts_ok make_pairs_ok make_ctx_ok no_pairs. - (* TODO : probably not all of these preconditions are necessary -- prune them sometime *) - Lemma of_Expr_correct next_name consts_list arg_list error - (carry_flag : bool) - (last_wrote : CC.code -> positive) (* variables which last wrote to each flag; put RegZero if flag empty *) - t (e : Expr t) - (x1 : type.for_each_lhs_of_arrow (var positive) t) - (x2 : type.for_each_lhs_of_arrow _ t) result : - let e1 := (invert_expr.smart_App_curried (e _) x1) in - let e2 := (invert_expr.smart_App_curried (e _) x2) in - let ctx := make_ctx (consts_list ++ arg_list) in - let consts := make_consts consts_list in - let cc := make_cc last_wrote ctx carry_flag in - let G := make_pairs consts_list ++ make_pairs arg_list in - (forall c, last_wrote c < next_name)%positive -> - (forall n v, In (n, v) (consts_list ++ arg_list) -> (n < next_name)%positive) -> - (In (last_wrote CC.C, Z.b2z carry_flag) consts_list) -> - (forall n v1 v2, In (n, v1) (consts_list ++ arg_list) -> - In (n, v2) (consts_list ++ arg_list) -> v1 = v2) (* no duplicate names *) -> - (forall v1 v2, In (v1, v2) consts_list -> v2 mod 2 ^ 256 = v2) -> - (forall v1 v2, In (v1, v2) arg_list -> v2 mod 2 ^ 256 = v2) -> - (LanguageWf.Compilers.expr.wf G e1 e2) -> - valid_expr _ error consts _ last_wrote e1 -> - interp_if_Z e2 = Some result -> - interp Pos.eqb wordmax cc_spec (of_Expr next_name consts e x1 error) cc ctx = result. - Proof. - cbv [of_Expr]; intros. - eapply of_prefancy_correct with (name_lt := Pos.lt) - (cctx := fun n => if (n =? last_wrote CC.C)%positive - then carry_flag - else match make_consts consts_list 1 with - | Some n1 => (n =? n1)%positive - | _ => false - end); - cbv [id]; eauto; - try apply Pos.eqb_neq; intros; - try solve [apply make_ctx_ok; auto; apply make_pairs_ok; - cbv [make_pairs]; rewrite map_app; auto ]; - repeat match goal with - | H : _ |- _ => apply in_app_or in H; destruct H - | H : In _ (make_pairs _) |- context [ _ = base.type.type_base _] => apply only_integers in H - | H : In _ (make_pairs _) |- context [interp_base] => - pose proof (only_integers _ _ _ _ H); subst; cbn [interp_base] - | _ => solve [eauto] - | _ => solve [exfalso; eauto] - end. - (* TODO : clean this up *) - { cbv [cc_good make_cc]; repeat split; intros; - [ rewrite Pos.eqb_refl; reflexivity | | ]; - break_innermost_match; try rewrite Pos.eqb_eq in *; subst; try reflexivity; - repeat match goal with - | H : make_consts _ _ = Some _ |- _ => - apply make_consts_ok, make_pairs_ok in H - | _ => apply Pos.eqb_neq; intro; subst - | _ => inversion_option; congruence - end; - match goal with - | H : In (?n, ?x) consts_list, H': In (?n, ?y) consts_list, - H'' : forall n x y, In (n,x) _ -> In (n,y) _ -> x = y |- _ => - assert (x = y) by (eapply H''; eauto) - end; destruct carry_flag; cbn [Z.b2z] in *; congruence. } - { match goal with |- context [make_ctx ?l ?n] => - let H := fresh in - destruct (make_ctx_cases l n) as [H | H]; - [ rewrite H | apply in_app_or in H; destruct H ] - end; eauto. } - Qed. - - Section expression_equivalence. - Context {name1 name2} - (name1_eqb : name1 -> name1 -> bool) - (name2_eqb : name2 -> name2 -> bool) - (name1_eqb_eq : forall n m, name1_eqb n m = true -> n = m) - (name1_eqb_neq : forall n m, name1_eqb n m = false -> n <> m) - (name2_eqb_eq : forall n m, name2_eqb n m = true -> n = m) - (name2_eqb_neq : forall n m, name2_eqb n m = false -> n <> m). - - (* name1 should only map to a single name2; several name1s might map to the same name2 *) - Inductive in_step : (name1 -> name2) -> expr -> expr -> Prop := - | in_step_ret : - forall M n1 n2, M n1 = n2 -> in_step M (Ret n1) (Ret n2) - | in_step_instr : - forall i M rd1 rd2 args1 args2 e1 e2, - in_step M e1 e2 -> - Tuple.map M args1 = args2 -> (* args correspond with old assignments *) - M rd1 = rd2 -> (* destination register corresponds with new assignment *) - in_step M (Instr i rd1 args1 e1) (Instr i rd2 args2 e2) - . - - Lemma interp_eq M e1 e2 (HM : forall n n', M n = M n' -> n = n') : - in_step M e1 e2 -> - forall cc ctx1 ctx2, - (forall n1, ctx1 n1 = ctx2 (M n1)) -> - interp name1_eqb wordmax cc_spec e1 cc ctx1 = - interp name2_eqb wordmax cc_spec e2 cc ctx2. - Proof. - induction 1; intros; cbn [interp]; [ congruence | ]. - replace (Tuple.map ctx1 args1) with (Tuple.map ctx2 args2) - by (subst args2; rewrite Tuple.map_map; apply Tuple.map_ext_In; intros; - match goal with | H : context [ctx1 _ = ctx2 _] |- _ => rewrite H end; - f_equal; eauto using eq_sym). - apply IHin_step; intros; eauto. - break_innermost_match; - repeat match goal with - | _ => progress subst - | H : _ = true |- _ => apply name1_eqb_eq in H - | H : _ = false |- _ => apply name1_eqb_neq in H - | H : _ = true |- _ => apply name2_eqb_eq in H - | H : _ = false |- _ => apply name2_eqb_neq in H - | H : M _ = M _ |- _ => apply HM in H - end; congruence. - Qed. - End expression_equivalence. - End Proofs. -End Fancy. - -Module Prod. - Import Fancy. Import Registers. - - Definition Mul256 (out src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := - Instr MUL128LL out (src1, src2) - (Instr MUL128UL tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) - (Instr MUL128LU tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) cont)))). - Definition Mul256x256 (out outHigh src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := - Instr MUL128LL out (src1, src2) - (Instr MUL128UU outHigh (src1, src2) - (Instr MUL128UL tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) - (Instr (ADDC (-128)) outHigh (outHigh, tmp) - (Instr MUL128LU tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) - (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont))))))). - - Definition MontRed256 lo hi y t1 t2 scratch RegPInv : @Fancy.expr register := - Mul256 y lo RegPInv t1 - (Mul256x256 t1 t2 y RegMod scratch - (Instr (ADD 0) lo (lo, t1) - (Instr (ADDC 0) hi (hi, t2) - (Instr SELC y (RegMod, RegZero) - (Instr (SUB 0) lo (hi, y) - (Instr ADDM lo (lo, RegZero, RegMod) - (Ret lo))))))). - - (* Barrett reduction -- this is only the "reduce" part, excluding the initial multiplication. *) - Definition MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : @Fancy.expr register := - let q1Bottom256 := scratchp1 in - let muSelect := scratchp2 in - let q2 := scratchp3 in - let q2High := scratchp4 in - let q2High2 := scratchp5 in - let q3 := scratchp1 in - let r2 := scratchp2 in - let r2High := scratchp3 in - let maybeM := scratchp1 in - Instr SELM muSelect (RegMuLow, RegZero) - (Instr (RSHI 255) q1Bottom256 (xHigh, x) - (Mul256x256 q2 q2High q1Bottom256 RegMuLow scratchp5 - (Instr (RSHI 255) q2High2 (RegZero, xHigh) - (Instr (ADD 0) q2High (q2High, q1Bottom256) - (Instr (ADDC 0) q2High2 (q2High2, RegZero) - (Instr (ADD 0) q2High (q2High, muSelect) - (Instr (ADDC 0) q2High2 (q2High2, RegZero) - (Instr (RSHI 1) q3 (q2High2, q2High) - (Mul256x256 r2 r2High RegMod q3 scratchp4 - (Instr (SUB 0) muSelect (x, r2) - (Instr (SUBC 0) xHigh (xHigh, r2High) - (Instr SELL maybeM (RegMod, RegZero) - (Instr (SUB 0) q3 (muSelect, maybeM) - (Instr ADDM x (q3, RegZero, RegMod) - (Ret x))))))))))))))). -End Prod. - -Module ProdEquiv. - Import Fancy. Import Registers. - - Definition interp256 := Fancy.interp reg_eqb (2^256) cc_spec. - Lemma interp_step i rd args cont cc ctx : - interp256 (Instr i rd args cont) cc ctx = - let result := spec i (Tuple.map ctx args) cc in - let new_cc := CC.update (writes_conditions i) result cc_spec cc in - let new_ctx := fun n => if reg_eqb n rd then result mod wordmax else ctx n in interp256 cont new_cc new_ctx. - Proof. reflexivity. Qed. - - Lemma interp_state_equiv e : - forall cc ctx cc' ctx', - cc = cc' -> (forall r, ctx r = ctx' r) -> - interp256 e cc ctx = interp256 e cc' ctx'. - Proof. - induction e; intros; subst; cbn; [solve[auto]|]. - apply IHe; rewrite Tuple.map_ext with (g:=ctx') by auto; - [reflexivity|]. - intros; break_match; auto. - Qed. - Lemma cc_overwrite_full x1 x2 l1 cc : - CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec (CC.update l1 x1 cc_spec cc) = CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec cc. - Proof. - cbv [CC.update]. cbn [CC.cc_c CC.cc_m CC.cc_l CC.cc_z]. - break_match; try match goal with H : ~ In _ _ |- _ => cbv [In] in H; tauto end. - reflexivity. - Qed. - - Definition value_unused r e : Prop := - forall x cc ctx, interp256 e cc ctx = interp256 e cc (fun r' => if reg_eqb r' r then x else ctx r'). - - Lemma value_unused_skip r i rd args cont (Hcont: value_unused r cont) : - r <> rd -> - (~ In r (Tuple.to_list _ args)) -> - value_unused r (Instr i rd args cont). - Proof. - cbv [value_unused] in *; intros. - rewrite !interp_step; cbv zeta. - rewrite Hcont with (x:=x). - match goal with |- ?lhs = ?rhs => - match lhs with context [Tuple.map ?f ?t] => - match rhs with context [Tuple.map ?g ?t] => - rewrite (Tuple.map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) - end end end. - apply interp_state_equiv; [ congruence | ]. - { intros; cbv [reg_eqb] in *; break_match; congruence. } - Qed. - - Lemma value_unused_overwrite r i args cont : - (~ In r (Tuple.to_list _ args)) -> - value_unused r (Instr i r args cont). - Proof. - cbv [value_unused]; intros; rewrite !interp_step; cbv zeta. - match goal with |- ?lhs = ?rhs => - match lhs with context [Tuple.map ?f ?t] => - match rhs with context [Tuple.map ?g ?t] => - rewrite (Tuple.map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) - end end end. - apply interp_state_equiv; [ congruence | ]. - { intros; cbv [reg_eqb] in *; break_match; congruence. } - Qed. - - Lemma value_unused_ret r r' : - r <> r' -> - value_unused r (Ret r'). - Proof. - cbv - [reg_dec]; intros. - break_match; congruence. - Qed. - - Ltac remember_results := - repeat match goal with |- context [(spec ?i ?args ?flags) mod ?w] => - let x := fresh "x" in - let y := fresh "y" in - let Heqx := fresh "Heqx" in - remember (spec i args flags) as x eqn:Heqx; - remember (x mod w) as y - end. - - Ltac do_interp_step := - rewrite interp_step; cbn - [interp spec]; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; - remember_results. - - Lemma interp_Mul256 out src1 src2 tmp tmp2 cont cc ctx: - out <> src1 -> - out <> src2 -> - out <> tmp -> - out <> tmp2 -> - src1 <> src2 -> - src1 <> tmp -> - src1 <> tmp2 -> - src2 <> tmp -> - src2 <> tmp2 -> - tmp <> tmp2 -> - value_unused tmp cont -> - value_unused tmp2 cont -> - interp256 (Prod.Mul256 out src1 src2 tmp cont) cc ctx = - interp256 ( - Instr MUL128LU tmp (src1, src2) - (Instr MUL128UL tmp2 (src1, src2) - (Instr MUL128LL out (src1, src2) - (Instr (ADD 128) out (out, tmp2) - (Instr (ADD 128) out (out, tmp) cont))))) cc ctx. - Proof. - intros; cbv [Prod.Mul256]. - repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU ADD] in * ). - - match goal with H : value_unused tmp _ |- _ => erewrite H end. - match goal with H : value_unused tmp2 _ |- _ => erewrite H end. - apply interp_state_equiv. - { rewrite !cc_overwrite_full. - f_equal. subst. lia. } - { intros; cbv [reg_eqb]. - repeat (break_match_step ltac:(fun _ => idtac); try congruence); reflexivity. } - Qed. - - Lemma interp_Mul256x256 out outHigh src1 src2 tmp tmp2 cont cc ctx: - out <> src1 -> - out <> outHigh -> - out <> src2 -> - out <> tmp -> - out <> tmp2 -> - outHigh <> src1 -> - outHigh <> src2 -> - outHigh <> tmp -> - outHigh <> tmp2 -> - src1 <> src2 -> - src1 <> tmp -> - src1 <> tmp2 -> - src2 <> tmp -> - src2 <> tmp2 -> - tmp <> tmp2 -> - value_unused tmp cont -> - value_unused tmp2 cont -> - interp256 (Prod.Mul256x256 out outHigh src1 src2 tmp cont) cc ctx = - interp256 ( - Instr MUL128LL out (src1, src2) - (Instr MUL128LU tmp (src1, src2) - (Instr MUL128UL tmp2 (src1, src2) - (Instr MUL128UU outHigh (src1, src2) - (Instr (ADD 128) out (out, tmp2) - (Instr (ADDC (-128)) outHigh (outHigh, tmp2) - (Instr (ADD 128) out (out, tmp) - (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont)))))))) cc ctx. - Proof. - intros; cbv [Prod.Mul256x256]. - repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU MUL128UU ADD ADDC] in * ). - - match goal with H : value_unused tmp _ |- _ => erewrite H end. - match goal with H : value_unused tmp2 _ |- _ => erewrite H end. - apply interp_state_equiv. - { rewrite !cc_overwrite_full. - f_equal. - subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. - lia. } - { intros; cbv [reg_eqb]. - repeat (break_match_step ltac:(fun _ => idtac); try congruence); try reflexivity; [ ]. - subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. - lia. } - Qed. - - Lemma mulll_comm rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma mulhh_comm rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma mullh_mulhl rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma add_comm rd x y cont cc ctx : - 0 <= ctx x < 2^256 -> - 0 <= ctx y < 2^256 -> - ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx. - Proof. - intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm. - rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. - Qed. - - Lemma addc_comm rd x y cont cc ctx : - 0 <= ctx x < 2^256 -> - 0 <= ctx y < 2^256 -> - ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx. - Proof. - intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)). - rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. - Qed. - - (* Tactics to help prove that something in Fancy is line-by-line equivalent to something in PreFancy *) - Ltac push_value_unused := - repeat match goal with - | |- ~ In _ _ => cbn; intuition; congruence - | _ => apply ProdEquiv.value_unused_overwrite - | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ] - | _ => apply ProdEquiv.value_unused_ret; congruence - end. - - Ltac remember_single_result := - match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] => - let x := fresh "x" in - let y := fresh "y" in - let Heqx := fresh "Heqx" in - remember (Fancy.spec i args cc) as x eqn:Heqx; - remember (x mod w) as y - end. - Ltac step_both_sides := - match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 => - rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2); - cbn - [Fancy.interp Fancy.spec]; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; - remember_single_result; - lazymatch goal with - | |- context [Fancy.spec i _ _] => - let Heqa1 := fresh in - let Heqa2 := fresh in - remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1; - remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2; - cbn in Heqa1; cbn in Heqa2; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence; - let a1 := match type of Heqa1 with _ = ?a1 => a1 end in - let a2 := match type of Heqa2 with _ = ?a2 => a2 end in - (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2) - | _ => idtac - end - end. -End ProdEquiv. - -Module Barrett256. - Import LanguageWf.Compilers. - - Definition M := Eval lazy in (2^256-2^224+2^192+2^96-1). - Definition machine_wordsize := 256. - - Derive barrett_red256 - SuchThat (BarrettReduction.rbarrett_red_correctT M machine_wordsize barrett_red256) - As barrett_red256_correct. - Proof. Time solve_rbarrett_red_nocache machine_wordsize. Time Qed. - - Definition muLow := Eval lazy in (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize). - - Lemma barrett_reduce_correct_specialized : - forall (xLow xHigh : Z), - 0 <= xLow < 2 ^ machine_wordsize -> - 0 <= xHigh < M -> - BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. - Proof. - intros. - apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *; - try omega; - try match goal with - | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega - end; lazy; try split; congruence. - Qed. - - Eval simpl in (type.for_each_lhs_of_arrow (type.interp base.interp) - (type.base (base.type.type_base base.type.Z) -> - type.base (base.type.type_base base.type.Z) -> - type.base (base.type.type_base base.type.Z))%ptype). - - (* Note: If this is not factored out, then for some reason Qed takes forever in barrett_red256_correct_full. *) - Lemma barrett_red256_correct_proj2 : - forall x y, - ZRange.type.option.is_bounded_by - (t:=base.type.prod base.type.Z base.type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - (x, y) = true -> - type.app_curried - (expr.Interp (@ident.gen_interp ident.cast_outside_of_range) - barrett_red256) (x, (y, tt)) = - BarrettReduction.barrett_reduce machine_wordsize M - ((2 ^ (2 * machine_wordsize) / M) - mod 2 ^ machine_wordsize) 2 2 x y. - Proof. - intros. - destruct ((proj1 barrett_red256_correct) (x, (y, tt)) (x, (y, tt))). - { cbn; tauto. } - { cbn in *. rewrite andb_true_r. auto. } - { auto. } - Qed. - Lemma barrett_red256_correct_proj2' : - forall x y, - ZRange.type.option.is_bounded_by - (t:=base.type.prod base.type.Z base.type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - (x, y) = true -> - expr.Interp (@ident.interp) barrett_red256 x y = - BarrettReduction.barrett_reduce machine_wordsize M - ((2 ^ (2 * machine_wordsize) / M) - mod 2 ^ machine_wordsize) 2 2 x y. - Proof. - intros. - erewrite <-barrett_red256_correct_proj2 by assumption. - unfold type.app_curried. exact eq_refl. - Qed. - Strategy -100 [type.app_curried]. - Local Arguments is_bounded_by_bool / . - Lemma barrett_red256_correct_full : - forall (xLow xHigh : Z), - 0 <= xLow < 2 ^ machine_wordsize -> - 0 <= xHigh < M -> - expr.Interp (@ident.interp) barrett_red256 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. - Proof. - intros. - rewrite <-barrett_reduce_correct_specialized by assumption. - destruct (proj1 barrett_red256_correct (xLow, (xHigh, tt)) (xLow, (xHigh, tt))) as [H1 H2]. - { repeat split. } - { cbn -[Z.pow]. - rewrite !andb_true_iff. - assert (M < 2^machine_wordsize) by (vm_compute; reflexivity). - repeat apply conj; Z.ltb_to_lt; trivial; omega. } - { etransitivity; [ eapply H2 | ]. (* need Strategy -100 [type.app_curried]. for this to be fast *) - generalize BarrettReduction.barrett_reduce; vm_compute; reflexivity. } - Qed. - - Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) := - Fancy.of_Expr 6%positive - (Fancy.make_consts [(RegMuLow, muLow); (RegMod, M); (RegZero, 0)]) - barrett_red256 - (xLow, (xHigh, tt)) - error. - Derive barrett_red256_fancy - SuchThat (forall xLow xHigh RegMuLow RegMod RegZero, - barrett_red256_fancy xLow xHigh RegMuLow RegMod RegZero = barrett_red256_fancy' xLow xHigh RegMuLow RegMod RegZero) - As barrett_red256_fancy_eq. - Proof. - intros. - lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB Fancy.SUBC - Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU - Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. - reflexivity. - Qed. - Ltac step := repeat match goal with - | _ => progress cbn [fst snd] - | |- LanguageWf.Compilers.expr.wf _ _ _ => - econstructor; try solve [econstructor]; [ ] - | |- LanguageWf.Compilers.expr.wf _ _ _ => - solve [econstructor] - | |- In _ _ => auto 50 using in_eq, in_cons - end. - - (* TODO(jgross) - There's probably a more general statement to make here about the - correctness of smart_App_curried, but I'm not sure what it is. *) - Lemma interp_smart_App_curried_2 : - forall s1 s2 d (e : Compilers.expr (s1 -> s2 -> type.base d)) - (x1 : @type.interp base.type base.interp s1) - (x2 : @type.interp base.type base.interp s2), - interp (invert_expr.smart_App_curried e (x1, (x2, tt))) = interp e x1 x2. - Admitted. - - Lemma loosen_rshi_subgoal (ctx : positive -> Z) (n z: positive) cc : - ctx z = 0 -> - ctx n mod 2^256 = ctx n -> - Fancy.spec (Fancy.RSHI 255) (Tuple.map (n:=2) ctx (z, n)) cc mod 2 ^ 256 = - Fancy.spec (Fancy.RSHI 255) (Tuple.map (n:=2) ctx (z, n)) cc mod (1+1). - Proof. - intros Hz Hn. cbn [Tuple.map Tuple.map' fst snd]. rewrite Hz, <-Hn. - replace (1+1) with 2 by omega. assert (2 < 2^256) by (cbn; omega). - cbn [Fancy.spec Fancy.RSHI]. autorewrite with zsimplify_fast. - rewrite Z.shiftr_div_pow2 by omega. - match goal with |- context [(?x / ?d) mod _] => - assert (0 <= x / d < 2); - [ | rewrite !(Z.mod_small (x / d)) by omega; reflexivity ] - end. - split; [ solve [Z.zero_bounds] | ]. - apply Z.div_lt_upper_bound; [ cbn; omega | ]. - eapply Z.lt_le_trans; [ apply Z.mod_pos_bound; cbn; omega | ]. - cbn; omega. - Qed. - - (* This expression should have NO ands in it -- search for "&'" should return nothing *) - Print barrett_red256. - - (* TODO: don't rely on the C, M, and L flags *) - Lemma barrett_red256_fancy_correct : - forall xLow xHigh error, - 0 <= xLow < 2 ^ machine_wordsize -> - 0 <= xHigh < M -> - let RegZero := 1%positive in - let RegMod := 2%positive in - let RegMuLow := 3%positive in - let RegxHigh := 4%positive in - let RegxLow := 5%positive in - let consts_list := [(RegMuLow, muLow); (RegMod, M); (RegZero, 0)] in - let arg_list := [(RegxHigh, xHigh); (RegxLow, xLow)] in - let ctx := Fancy.make_ctx (consts_list ++ arg_list) in - let carry_flag := false in (* TODO: don't rely on this value, given it's unused *) - let last_wrote := (fun x : Fancy.CC.code => - match x with - | Fancy.CC.C => RegZero - | _ => RegxHigh (* xHigh needs to have written M; others unused *) - end) in - let cc := Fancy.make_cc last_wrote ctx carry_flag in - Fancy.interp Pos.eqb Fancy.wordmax Fancy.cc_spec (barrett_red256_fancy RegxLow RegxHigh RegMuLow RegMod RegZero error) cc ctx = (xLow + 2 ^ machine_wordsize * xHigh) mod M. - Proof. - intros. - rewrite barrett_red256_fancy_eq. - cbv [barrett_red256_fancy']. - rewrite <-barrett_red256_correct_full by auto. - eapply Fancy.of_Expr_correct with (x2 := (xLow, (xHigh, tt))). - { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. - intuition; Prod.inversion_prod; subst; cbv. break_innermost_match; congruence. } - { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. - intuition; Prod.inversion_prod; subst; cbv; congruence. } - { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. tauto. } - { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. - intuition; Prod.inversion_prod; subst; cbv; congruence. } - { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. - match goal with |- context [_ mod ?m] => change m with (2 ^ machine_wordsize) end. - assert (M < 2 ^ machine_wordsize) by (cbv; congruence). - assert (0 <= muLow < 2 ^ machine_wordsize) by (split; cbv; congruence). - intuition; Prod.inversion_prod; subst; apply Z.mod_small; omega. } - { cbn; intros; subst RegZero RegMod RegMuLow RegxHigh RegxLow. - match goal with |- context [_ mod ?m] => change m with (2 ^ machine_wordsize) end. - assert (M < 2 ^ machine_wordsize) by (cbv; congruence). - assert (0 <= muLow < 2 ^ machine_wordsize) by (split; cbv; congruence). - intuition; Prod.inversion_prod; subst; apply Z.mod_small; omega. } - { cbn. - repeat match goal with - | _ => apply expr.WfLetIn - | _ => progress step - | _ => econstructor - end. } - { cbn. cbv [muLow M]. - Ltac sub := - repeat match goal with - | _ => progress intros - | |- context [Fancy.valid_ident] => econstructor - | |- context[Fancy.valid_scalar] => econstructor - | |- context [Fancy.valid_carry] => econstructor - | _ => reflexivity - | |- _ <> None => cbn; congruence - | |- Fancy.of_prefancy_scalar _ _ _ _ = _ => cbn; solve [eauto] - end. - - admit. - (* TODO: this code is currently broken because there are unexpected redundant ands in the code *) - (* - repeat (econstructor; [ solve [sub] | intros ]). - econstructor. - (* For the too-tight RSHI cast, we have to loosen the bounds *) - eapply Fancy.valid_LetInZ_loosen; try solve [sub]; - [ cbn; omega | | intros; apply loosen_rshi_subgoal; solve [eauto] ]. - repeat (econstructor; [ solve [sub] | intros ]). - econstructor. - { sub. admit. - (* TODO: this is the too-tight RSHI cast *) } - repeat (econstructor; [ solve [sub] | intros ]). - econstructor. sub. *) - - } - { cbn - [barrett_red256]. - cbv [id]. - cbv [expr.Interp]. - replace (@ident.gen_interp Fancy.cast_oor) with (@ident.interp) by admit. (* TODO(jgross): need to be able to say that I can switch out cast_outside_of_range because bounds checking works *) - rewrite <-interp_smart_App_curried_2. - reflexivity. } - Admitted. - - Import Fancy.Registers. - - Definition barrett_red256_alloc' xLow xHigh RegMuLow := - fun errorP errorR => - Fancy.allocate register - positive Pos.eqb - errorR - (barrett_red256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) - [r2;r3;r4;r5;r6;r7;r8;r9;r10;r5;r11;r6;r12;r13;r14;r15;r16;r17;r18;r19;r20;r21;r22;r23;r24;r25;r26;r27;r28;r29] - (fun n => if n =? 1000 then xLow - else if n =? 1001 then xHigh - else if n =? 1002 then RegMuLow - else if n =? 1003 then RegMod - else if n =? 1004 then RegZero - else errorR). - Derive barrett_red256_alloc - SuchThat (barrett_red256_alloc = barrett_red256_alloc') - As barrett_red256_alloc_eq. - Proof. - intros. - cbv [barrett_red256_alloc' barrett_red256_fancy]. - cbn. subst barrett_red256_alloc. - reflexivity. - Qed. - - Set Printing Depth 1000. - Import ProdEquiv. - - Local Ltac solve_bounds := - match goal with - | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega - | _ => assumption - end. - - Lemma barrett_red256_alloc_equivalent errorP errorR cc_start_state start_context : - forall x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg, - NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> - 0 <= start_context x < 2^machine_wordsize -> - 0 <= start_context xHigh < 2^machine_wordsize -> - 0 <= start_context RegMuLow < 2^machine_wordsize -> - ProdEquiv.interp256 (barrett_red256_alloc r0 r1 r30 errorP errorR) cc_start_state - (fun r => if reg_eqb r r0 - then start_context x - else if reg_eqb r r1 - then start_context xHigh - else if reg_eqb r r30 - then start_context RegMuLow - else start_context r) - = ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context. - Proof. - intros. - let r := eval compute in (2^machine_wordsize) in - replace (2^machine_wordsize) with r in * by reflexivity. - cbv [Prod.MulMod barrett_red256_alloc]. - - (* Extract proofs that no registers are equal to each other *) - repeat match goal with - | H : NoDup _ |- _ => inversion H; subst; clear H - | H : ~ In _ _ |- _ => cbv [In] in H - | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H - | H : ~ False |- _ => clear H - end. - - step_both_sides. - - (* TODO: To prove equivalence between these two, we need to either relocate the RSHI instructions so they're in the same places or use instruction commutativity to push them down. *) - - Admitted. - - Lemma prod_barrett_red256_correct : - forall (cc_start_state : Fancy.CC.state) (* starting carry flags *) - (start_context : register -> Z) (* starting register values *) - (x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg : register), (* registers to use in computation *) - NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> (* registers are unique *) - 0 <= start_context x < 2^machine_wordsize -> - 0 <= start_context xHigh < M -> - start_context RegMuLow = muLow -> - start_context RegMod = M -> - start_context RegZero = 0 -> - cc_start_state.(Fancy.CC.cc_m) = (Z.cc_m (2^256) (start_context xHigh) =? 1) -> - let X := start_context x + 2^machine_wordsize * start_context xHigh in - ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context = X mod M. - Proof. - intros. subst X. - assert (0 <= start_context xHigh < 2^machine_wordsize) by (cbv [M] in *; cbn; omega). - let r := (eval compute in (2 ^ machine_wordsize)) in - replace (2^machine_wordsize) with r in * by reflexivity. - cbv [M muLow] in *. - - erewrite <-barrett_red256_fancy_correct with (error:=100000%positive) by eauto. - rewrite <-barrett_red256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) - by (auto; cbn; auto with omega). - cbv [ProdEquiv.interp256]. - let r := (eval compute in (2 ^ 256)) in - replace (2^256) with r in * by reflexivity. - - cbn - [Fancy.interp Pos.eqb]. - cbv [Fancy.make_cc]. - match goal with |- _ = Fancy.interp _ _ _ _ ?cc _ => - let x := fresh in - set cc as x; cbv [Pos.eqb] in x; subst x - end. - assert (Fancy.CC.cc_m cc_start_state = Fancy.cc_spec Fancy.CC.M (start_context xHigh)) as M_equal. - { match goal with H : Fancy.CC.cc_m _ = _ |- _ => rewrite H end. - cbv [Fancy.cc_spec]. rewrite Z.cc_m_eq, Z.testbit_eqb by omega. - rewrite Z.mod_small by (split; [ solve [Z.zero_bounds] | apply Z.div_lt_upper_bound; cbn; omega ]). - reflexivity. } - rewrite <-M_equal. - - (* strategy to fix flags : - 1) replace state on both sides with a state reflecting dead flags updated to 0; prove that each side ignores those flags and interps remain equal - 2) prove that the M flags are the same and rewrite; now same flags are on both sides - *) - - let dead_flags := constr:([Fancy.CC.C; Fancy.CC.L; Fancy.CC.Z]) in - match goal with - | H : Fancy.CC.cc_m _ = _ - |- _ = Fancy.interp _ _ _ _ ?cc _ => - let x := fresh in - let Hx := fresh in - remember (Fancy.CC.update dead_flags 0 Fancy.cc_spec cc) as x eqn:Hx; - cbv [Fancy.CC.update] in Hx; cbn in Hx; - match goal with - |- ?lhs = ?rhs => - match (eval pattern cc in rhs) with - ?f _ => transitivity (f x); subst x - end - end - end. - - Focus 2. { - (* here's where we need to prove the interps are equal even if I change the dead flags *) - - - cbv [barrett_red256_alloc barrett_red256_fancy]. - - (* - step start_context. - { match goal with H : Fancy.CC.cc_m _ = _ |- _ => rewrite H end. - match goal with |- context [Z.cc_m ?s ?x] => - pose proof (Z.cc_m_small s x ltac:(reflexivity) ltac:(omega)); - let H := fresh in - assert (Z.cc_m s x = 1 \/ Z.cc_m s x = 0) as H by omega; - destruct H as [H | H]; rewrite H in * - end; repeat (change (0 =? 1) with false || change (?x =? ?x) with true || cbv beta iota); - break_innermost_match; Z.ltb_to_lt; try congruence. - all: repeat match goal with - | [ H : context[ident.cast] |- _ ] - => rewrite ident.cast_in_bounds in H - by (cbv [is_bounded_by_bool]; rewrite Bool.andb_true_iff; split; Z.ltb_to_lt; cbn [upper lower]; lia) - end. - all: congruence. } - apply interp_equivZ_256; [ simplify_op_equiv start_context | ]. (* apply manually instead of using [step] to allow a custom bounds proof *) - all: rewrite ?ident.cast_in_bounds - by (cbv [is_bounded_by_bool]; rewrite Bool.andb_true_iff; split; Z.ltb_to_lt; cbn [upper lower]; lia). - { rewrite Z.rshi_correct by omega. - autorewrite with zsimplify_fast. - rewrite Z.shiftr_div_pow2 by omega. - break_innermost_match; Z.ltb_to_lt; try omega. - do 2 f_equal; omega. } - - (* Special case to remember the bound for the output of RSHI *) - let v := fresh "v" in - let v_bound := fresh "v_bound" in - intro v; assert (0 <= v <= 1) as v_bound; [ |generalize v v_bound; clear v v_bound; intros v v_bound]. - { solve_nonneg start_context. autorewrite with zsimplify_fast. - rewrite Z.shiftr_div_pow2 by omega. - rewrite Z.mod_small by admit. - split; [Z.zero_bounds|]. - apply Z.lt_succ_r. - apply Z.div_lt_upper_bound; try lia; admit. } - *) -(* - step start_context. - { rewrite Z.rshi_correct by omega. - rewrite Z.shiftr_div_pow2 by omega. - repeat (f_equal; try ring). } - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; - [ rewrite Z.mod_small with (b:=2) by (rewrite Z.mod_small by omega; omega); (* Here we make use of the bound of RSHI *) - reflexivity - | rewrite Z.mod_small with (b:=2) by (rewrite Z.mod_small by omega; omega); (* Here we make use of the bound of RSHI *) - reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context. - { rewrite Z.rshi_correct by omega. - rewrite Z.shiftr_div_pow2 by omega. - repeat (f_equal; try ring). } - - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - - step start_context. - { reflexivity. } - { autorewrite with zsimplify_fast. - match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite <-Z.testbit_neg_eq_if with (n:=256) by (cbn; omega). - reflexivity. } - step start_context. - { reflexivity. } - { autorewrite with zsimplify_fast. - rewrite Z.mod_small with (a:=(if (if _ replace (a - b - c) with (a - (b + c)) by ring end. - match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite <-Z.testbit_neg_eq_if with (n:=256) by (break_innermost_match; cbn; omega). - reflexivity. } - step start_context. - { rewrite Z.bit0_eqb. - match goal with |- context [(?x mod ?m) &' 1] => - replace (x mod m) with (x &' Z.ones 256) by (rewrite Z.land_ones by omega; reflexivity) end. - rewrite <-Z.land_assoc. - rewrite Z.land_ones with (n:=1) by omega. - cbn. - match goal with |- context [?x mod 2] => - let H := fresh in - assert (x mod 2 = 0 \/ x mod 2 = 1) as H - by (pose proof (Z.mod_pos_bound x 2 ltac:(omega)); omega); - destruct H as [H | H]; rewrite H - end; reflexivity. } - step start_context. - { reflexivity. } - { autorewrite with zsimplify_fast. - repeat match goal with |- context [?x mod ?m] => unique pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite <-Z.testbit_neg_eq_if with (n:=256) by (cbn; omega). - reflexivity. } - step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. - reflexivity. -*) - Admitted. - - Import PrintingNotations. - Set Printing Width 1000. - Open Scope expr_scope. - Print barrett_red256. - (* -barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, - expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in - expr_let x1 := RSHI (0, x₂, 255) in - expr_let x2 := RSHI (x₂, x₁, 255) in - expr_let x3 := 79228162514264337589248983038 *₂₅₆ (uint128)(x2 >> 128) in - expr_let x4 := 79228162514264337589248983038 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in - expr_let x5 := 340282366841710300930663525764514709507 *₂₅₆ (uint128)(x2 >> 128) in - expr_let x6 := 340282366841710300930663525764514709507 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in - expr_let x7 := ADD_256 ((uint256)(((uint128)(x5) & 340282366920938463463374607431768211455) << 128), x6) in - expr_let x8 := ADDC_256 (x7₂, (uint128)(x5 >> 128), x3) in - expr_let x9 := ADD_256 ((uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128), x7₁) in - expr_let x10 := ADDC_256 (x9₂, (uint128)(x4 >> 128), x8₁) in - expr_let x11 := ADD_256 (x2, x10₁) in - expr_let x12 := ADDC_128 (x11₂, 0, x1) in - expr_let x13 := ADD_256 (x0, x11₁) in - expr_let x14 := ADDC_128 (x13₂, 0, x12₁) in - expr_let x15 := RSHI (x14₁, x13₁, 1) in - expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x15 >> 128) in - expr_let x17 := 79228162514264337593543950335 *₂₅₆ (uint128)(x15 >> 128) in - expr_let x18 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in - expr_let x19 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in - expr_let x20 := ADD_256 ((uint256)(((uint128)(x18) & 340282366920938463463374607431768211455) << 128), x19) in - expr_let x21 := ADDC_256 (x20₂, (uint128)(x18 >> 128), x16) in - expr_let x22 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), x20₁) in - expr_let x23 := ADDC_256 (x22₂, (uint128)(x17 >> 128), x21₁) in - expr_let x24 := SUB_256 (x₁, x22₁) in - expr_let x25 := SUBB_256 (x24₂, x₂, x23₁) in - expr_let x26 := SELL (x25₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in - expr_let x27 := SUB_256 (x24₁, x26) in - ADDM (x27₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) - : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) - *) - -End Barrett256. - -(* TODO : once Barrett is updated & working, fix Montgomery to match *) -(* -Module Montgomery256. - - Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1). - Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849). - Definition R := Eval lazy in (2^256). - Definition R' := 115792089183396302114378112356516095823261736990586219612555396166510339686400. - Definition machine_wordsize := 256. - - Derive montred256 - SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256) - As montred256_correct. - Proof. Time solve_rmontred_nocache machine_wordsize. Time Qed. - - Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) : - forall (lo hi : Z), - 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> - MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 (lo, hi) = ((lo + R * hi) * R') mod N. - Proof. - intros. - apply MontgomeryReduction.montred'_correct with (T:=lo + R * hi) (R':=R'); - try match goal with - | |- context[R'] => assumption - | |- context [lo] => - try assumption; progress autorewrite with zsimplify cancel_pair; reflexivity - end; lazy; try split; congruence. - Qed. - - (* - (* Note: If this is not factored out, then for some reason Qed takes forever in montred256_correct_full. *) - Lemma montred256_correct_proj2 : - forall xy : type.interp (type.prod type.Z type.Z), - ZRange.type.option.is_bounded_by - (t:=type.prod type.Z type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - xy = true -> - expr.Interp (@ident.interp) montred256 xy = app_curried (t:=type.arrow (type.prod type.Z type.Z) type.Z) (MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2) xy. - Proof. intros; destruct (montred256_correct xy); assumption. Qed. - Lemma montred256_correct_proj2' : - forall xy : type.interp (type.prod type.Z type.Z), - ZRange.type.option.is_bounded_by - (t:=type.prod type.Z type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - xy = true -> - expr.Interp (@ident.interp) montred256 xy = MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 xy. - Proof. intros; rewrite montred256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed. - *) - Local Arguments is_bounded_by_bool / . - Lemma montred256_correct_full R' (R'_correct : Z.equiv_modulo N (R * R') 1) : - forall (lo hi : Z), - 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> - PreFancy.Interp 256 montred256 (lo, hi) = ((lo + R * hi) * R') mod N. - Proof. - intros. - rewrite <-montred'_correct_specialized by assumption. - destruct (proj1 montred256_correct ((lo, hi), tt) ((lo, hi), tt)) as [H2 H3]. - { repeat split. } - { cbn -[Z.pow]. - rewrite !andb_true_iff. - repeat apply conj; Z.ltb_to_lt; trivial; cbv [R N machine_wordsize] in *; lia. } - { etransitivity; [ eapply H3 | ]. (* need Strategy -100 [type.app_curried]. for this to be fast *) - generalize MontgomeryReduction.montred'; vm_compute; reflexivity. } - Qed. - - (* - (* TODO : maybe move these ok_expr tactics somewhere else *) - Ltac ok_expr_step' := - match goal with - | _ => assumption - | |- _ <= _ <= _ \/ @eq zrange _ _ => - right; lazy; try split; congruence - | |- _ <= _ <= _ \/ @eq zrange _ _ => - left; lazy; try split; congruence - | |- lower r[0~>_]%zrange = 0 => reflexivity - | |- context [PreFancy.ok_ident] => constructor - | |- context [PreFancy.ok_scalar] => constructor; try omega - | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] - | |- context [PreFancy.is_halved] => constructor - | |- context [PreFancy.in_word_range] => lazy; reflexivity - | |- context [PreFancy.in_flag_range] => lazy; reflexivity - | |- context [PreFancy.get_range] => - cbn [PreFancy.get_range lower upper fst snd ZRange.map] - | x : type.interp (type.prod _ _) |- _ => destruct x - | |- (_ <=? _)%zrange = true => - match goal with - | |- context [PreFancy.get_range_var] => - cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower R N] in *; cbn; - apply andb_true_iff; split; apply Z.leb_le - | _ => lazy - end; omega || reflexivity - | |- @eq zrange _ _ => lazy; reflexivity - | |- _ <= _ => cbv [machine_wordsize]; omega - | |- _ <= _ <= _ => cbv [machine_wordsize]; omega - end; intros. - - (* TODO : maybe move these ok_expr tactics somewhere else *) - Ltac ok_expr_step := - match goal with - | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' - end; intros; cbn [Nat.max].*) - - (* - Lemma montred256_prefancy_correct : - forall (lo hi : Z), - 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> - @PreFancy.interp machine_wordsize base.type.Z (montred256 _ @ (##lo,##hi)) = ((lo + R * hi) * R') mod N. - Proof. - intros. - - rewrite montred256_prefancy_eq; cbv [montred256_prefancy']. - erewrite PreFancy.of_Expr_correct. - { apply montred256_correct_full; try assumption; reflexivity. } - { reflexivity. } - { lazy; reflexivity. } - { lazy; reflexivity. } - { repeat constructor. } - { cbv [In N N']; intros; intuition; subst; cbv; congruence. } - { assert (340282366920938463463374607431768211455 * 2 ^ 128 <= 2 ^ machine_wordsize - 1) as shiftl_128_ok by (lazy; congruence). - repeat (ok_expr_step; [ ]). - ok_expr_step. - lazy; congruence. - constructor. - constructor. } - { lazy. omega. } - Qed. -*) - - Definition montred256_fancy' (lo hi RegMod RegPInv RegZero error : positive) := - Fancy.of_Expr 3%positive - (fun z => if z =? N then Some RegMod else if z =? N' then Some RegPInv else if z =? 0 then Some RegZero else None) - [N;N'] - montred256 - ((lo, hi)%positive, tt) - error. - Derive montred256_fancy - SuchThat (forall RegMod RegPInv RegZero, - montred256_fancy RegMod RegPInv RegZero = montred256_fancy' RegMod RegPInv RegZero) - As montred256_fancy_eq. - Proof. - intros. - lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB - Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU - Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. - reflexivity. - Qed. - - Import Fancy.Registers. - - Definition montred256_alloc' lo hi RegPInv := - fun errorP errorR => - Fancy.allocate register - positive Pos.eqb - errorR - (montred256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) - [r2;r3;r4;r5;r6;r7;r8;r9;r10;r11;r12;r13;r14;r15;r16;r17;r18;r19;r20] - (fun n => if n =? 1000 then lo - else if n =? 1001 then hi - else if n =? 1002 then RegMod - else if n =? 1003 then RegPInv - else if n =? 1004 then RegZero - else errorR). - Derive montred256_alloc - SuchThat (montred256_alloc = montred256_alloc') - As montred256_alloc_eq. - Proof. - intros. - cbv [montred256_alloc' montred256_fancy]. - cbn. subst montred256_alloc. - reflexivity. - Qed. - - Import ProdEquiv. - - Local Ltac solve_bounds := - match goal with - | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega - | _ => assumption - end. - - Lemma montred256_alloc_equivalent errorP errorR cc_start_state start_context : - forall lo hi y t1 t2 scratch RegPInv extra_reg, - NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> - 0 <= start_context lo < R -> - 0 <= start_context hi < R -> - 0 <= start_context RegPInv < R -> - ProdEquiv.interp256 (montred256_alloc r0 r1 r30 errorP errorR) cc_start_state - (fun r => if reg_eqb r r0 - then start_context lo - else if reg_eqb r r1 - then start_context hi - else if reg_eqb r r30 - then start_context RegPInv - else start_context r) - = ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context. - Proof. - intros. cbv [R] in *. - cbv [Prod.MontRed256 montred256_alloc]. - - (* Extract proofs that no registers are equal to each other *) - repeat match goal with - | H : NoDup _ |- _ => inversion H; subst; clear H - | H : ~ In _ _ |- _ => cbv [In] in H - | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H - | H : ~ False |- _ => clear H - end. - - rewrite ProdEquiv.interp_Mul256 with (tmp2:=extra_reg) by (congruence || push_value_unused). - - rewrite mullh_mulhl. step_both_sides. - rewrite mullh_mulhl. step_both_sides. - (* - step_both_sides. - step_both_sides. - - rewrite ProdEquiv.interp_Mul256x256 with (tmp2:=extra_reg) by (congruence || push_value_unused). - - rewrite mulll_comm. step_both_sides. - step_both_sides. - step_both_sides. - rewrite mulhh_comm. step_both_sides. - step_both_sides. - step_both_sides. - step_both_sides. - step_both_sides. - - - rewrite add_comm by (cbn; solve_bounds). step_both_sides. - rewrite addc_comm by (cbn; solve_bounds). step_both_sides. - step_both_sides. - step_both_sides. - step_both_sides. - - cbn; repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence. - reflexivity.*) - Admitted. - - Import Fancy_PreFancy_Equiv. - - Definition interp_equivZZ_256 {s} := - @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). - Definition interp_equivZ_256 {s} := - @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(lia) ltac:(reflexivity). - - Local Ltac simplify_op_equiv start_ctx := - cbn - [Fancy.spec ident.gen_interp Fancy.cc_spec]; - repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end; - cbv - [ - Z.add_with_get_carry_full - Z.add_get_carry_full Z.sub_get_borrow_full - Z.le Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl - Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones - fst snd]; cbn [fst snd]; - try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega); - autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega; - repeat match goal with - | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H - | |- context [?x rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds) - | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega) - | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] => - replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring - end. - - Local Ltac solve_nonneg ctx := - match goal with x := (Fancy.spec _ _ _) |- _ => subst x end; - simplify_op_equiv ctx; Z.zero_bounds. - - Local Ltac generalize_result := - let v := fresh "v" in intro v; generalize v; clear v; intro v. - - Local Ltac generalize_result_nonneg ctx := - let v := fresh "v" in - let v_nonneg := fresh "v_nonneg" in - intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg]. - - Local Ltac step_abs := - match goal with - | [ |- context G[expr.interp ?ident_interp (expr.Abs ?f) ?x] ] - => let G' := context G[expr.interp ident_interp (f x)] in - change G'; cbv beta - end. - Local Ltac step ctx := - repeat step_abs; - match goal with - | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ => - apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx] - | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZ _ _ _ _ _ _) ] - => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result] - | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZZ _ _ _ _ _ _) ] - => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result] - end. - - Local Ltac break_ifs := - repeat (break_innermost_match_step; Z.ltb_to_lt; try (exfalso; omega); []). - - Local Opaque PreFancy.interp_cast_mod. - - Lemma prod_montred256_correct : - forall (cc_start_state : Fancy.CC.state) (* starting carry flags can be anything *) - (start_context : register -> Z) (* starting register values *) - (lo hi y t1 t2 scratch RegPInv extra_reg : register), (* registers to use in computation *) - NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> (* registers must be distinct *) - start_context RegPInv = N' -> (* RegPInv needs to hold the inverse of the modulus *) - start_context RegMod = N -> (* RegMod needs to hold the modulus *) - start_context RegZero = 0 -> (* RegZero needs to hold zero *) - (0 <= start_context lo < R) -> (* low half of the input is in bounds (R=2^256) *) - (0 <= start_context hi < R) -> (* high half of the input is in bounds (R=2^256) *) - let x := (start_context lo) + R * (start_context hi) in (* x is the input (split into two registers) *) - (0 <= x < R * N) -> (* input precondition *) - (ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context = (x * R') mod N). - Proof. - intros. subst x. cbv [N R N'] in *. - rewrite <-montred256_correct_full by (auto; vm_compute; reflexivity). - rewrite <-montred256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) - by (cbv [R]; auto with omega). - cbv [ProdEquiv.interp256]. - cbv [montred256_alloc montred256 expr.Interp]. - - (*step start_context; [ break_ifs; reflexivity | ]. - step start_context; [ break_ifs; reflexivity | ]. - step start_context; [ break_ifs; reflexivity | ].*) - (*step start_context; [ break_ifs; reflexivity | ]. - step start_context; [ break_ifs; reflexivity | break_ifs; reflexivity | ]. - step start_context; [ break_ifs; reflexivity | break_ifs; reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. - step start_context; [ reflexivity | | ]. - { - let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity. - rewrite !Z.shiftl_0_r, !Z.mod_mod by omega. - apply Z.testbit_neg_eq_if; - let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity; - auto using Z.mod_pos_bound with omega. } - step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. - reflexivity. - *) - Admitted. - - Import PrintingNotations. - Set Printing Width 10000. - - Print montred256. -(* -montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, - expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in - expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in - expr_let x2 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in - expr_let x3 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x2) in - expr_let x4 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x3₁) in - expr_let x5 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in - expr_let x6 := 79228162514264337593543950335 *₂₅₆ (uint128)(x4₁ >> 128) in - expr_let x7 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in - expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in - expr_let x9 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x5) in - expr_let x10 := ADDC_256 (x9₂, (uint128)(x7 >> 128), x8) in - expr_let x11 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x9₁) in - expr_let x12 := ADDC_256 (x11₂, (uint128)(x6 >> 128), x10₁) in - expr_let x13 := ADD_256 (x11₁, x₁) in - expr_let x14 := ADDC_256 (x13₂, x12₁, x₂) in - expr_let x15 := SELC (x14₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in - expr_let x16 := SUB_256 (x14₁, x15) in - ADDM (x16₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951))%expr - : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)) -*) - - Import PreFancy. - Import PreFancy.Notations. - Local Notation "'RegMod'" := (expr.Ident (ident.Literal 115792089210356248762697446949407573530086143415290314195533631308867097853951)). - Local Notation "'RegPInv'" := (expr.Ident (ident.Literal 115792089210356248768974548684794254293921932838497980611635986753331132366849)). - Local Open Scope expr_scope. - Local Notation mulhl := (#(fancy_mulhl 256)). - Local Notation mulhh := (#(fancy_mulhh 256)). - Local Notation mulll := (#(fancy_mulll 256)). - Local Notation mullh := (#(fancy_mullh 256)). - Local Notation selc := (#(fancy_selc)). - Local Notation addm := (#(fancy_addm)). - Notation add n := (#(fancy_add 256 n)). - Notation addc n := (#(fancy_addc 256 n)). - - Print montred256. - (* -montred256 = -fun var : type -> Type => -λ x : var (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype), -mulhl@(x0, x₁, RegPInv); -mullh@(x1, x₁, RegPInv); -mulll@(x2, x₁, RegPInv); -(add 128)@(x3, x2, Lower{x1}); -(add 128)@(x4, x3₁, Lower{x0}); -mulll@(x5, RegMod, x4₁); -mullh@(x6, RegMod, x4₁); -mulhl@(x7, RegMod, x4₁); -mulhh@(x8, RegMod, x4₁); -(add 128)@(x9, x5, Lower{x7}); -(addc (-128))@(x10, carry{$x9}, x8, x7); -(add 128)@(x11, x9₁, Lower{x6}); -(addc (-128))@(x12, carry{$x11}, x10₁, x6); -(add 0)@(x13, x11₁, x₁); -(addc 0)@(x14, carry{$x13}, x12₁, x₂); -selc@(x15, (carry{$x14}, RegZero), RegMod); -#(fancy_sub 256 0)@(x16, x14₁, x15); -addm@(x17, (x16₁, RegZero), RegMod); -x17 - : Expr - (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype -> - type.base (base.type.type_base base.type.Z))%ptype - *) -End Montgomery256. +(* TODO : update these summaries and move them to other files *) Local Notation "i rd x y ; cont" := (Fancy.Instr i rd (x, y) cont) (at level 40, cont at level 200, format "i rd x y ; '//' cont"). Local Notation "i rd x y z ; cont" := (Fancy.Instr i rd (x, y, z) cont) (at level 40, cont at level 200, format "i rd x y z ; '//' cont"). -- cgit v1.2.3