diff options
author | Robert Sloan <varomodt@google.com> | 2016-11-08 19:02:15 -0800 |
---|---|---|
committer | Robert Sloan <varomodt@google.com> | 2016-11-08 19:02:15 -0800 |
commit | 6dbb07114f9e463007d80112242117e165c6698f (patch) | |
tree | 1b68801efb430b3423a8cff1fa25719c305bbbcc | |
parent | ea549915c168d1d4440708b75a35ec450648cf8e (diff) | |
parent | c89a77f3b6de068aaf1b8cd2adad73ef64c7fb13 (diff) |
Not quite done with WordUtil lemmas.
87 files changed, 6387 insertions, 2277 deletions
diff --git a/Bedrock/Word.v b/Bedrock/Word.v index 036b3198a..2c518807d 100644 --- a/Bedrock/Word.v +++ b/Bedrock/Word.v @@ -48,8 +48,8 @@ Fixpoint natToWord (sz n : nat) : word sz := Fixpoint wordToN sz (w : word sz) : N := match w with | WO => 0 - | WS false _ w' => 2 * wordToN w' - | WS true _ w' => Nsucc (2 * wordToN w') + | WS false _ w' => N.double (wordToN w') + | WS true _ w' => N.succ_double (wordToN w') end%N. Definition Nmod2 (n : N) : bool := @@ -506,6 +506,8 @@ Theorem wordToN_nat : forall sz (w : word sz), wordToN w = N_of_nat (wordToNat w rewrite N_of_mult. rewrite <- IHw. rewrite Nmult_comm. + rewrite N.succ_double_spec. + rewrite N.add_1_r. reflexivity. rewrite N_of_mult. @@ -1038,12 +1040,12 @@ Proof. induction a; intro b0; rewrite (shatter_word b0); intuition. simpl in H. destruct b; destruct (whd b0); intros. - f_equal. eapply IHa. eapply Nsucc_inj in H. + f_equal. eapply IHa. eapply N.succ_double_inj in H. destruct (wordToN a); destruct (wordToN (wtl b0)); try congruence. destruct (wordToN (wtl b0)); destruct (wordToN a); inversion H. destruct (wordToN (wtl b0)); destruct (wordToN a); inversion H. f_equal. eapply IHa. - destruct (wordToN a); destruct (wordToN (wtl b0)); try congruence. + destruct (wordToN a); destruct (wordToN (wtl b0)); simpl in *; try congruence. Qed. Lemma unique_inverse : forall sz (a b1 b2 : word sz), a ^+ b1 = wzero _ -> @@ -12,7 +12,8 @@ HIDE := $(if $(VERBOSE),,@) .PHONY: coq clean update-_CoqProject cleanall install \ install-coqprime clean-coqprime coqprime \ - specific non-specific + specific non-specific \ + extraction ghc SORT_COQPROJECT = sed 's,[^/]*/,~&,g' | env LC_COLLATE=C sort | sed 's,~,,g' @@ -84,6 +85,31 @@ Makefile.coq: Makefile _CoqProject $(SHOW)'COQ_MAKEFILE -f _CoqProject > $@' $(HIDE)$(COQBIN)coq_makefile -f _CoqProject | sed s'|^\(-include.*\)$$|ifneq ($$(filter-out $(FAST_TARGETS),$$(MAKECMDGOALS)),)~\1~else~ifeq ($$(MAKECMDGOALS),)~\1~endif~endif|g' | tr '~' '\n' | sed s'/^clean:$$/clean::/g' | sed s'/^Makefile: /Makefile-old: /g' | sed s'/^printenv:$$/printenv::/g' > $@ +src/Experiments/Ed25519_noimports.hs: src/Experiments/Ed25519Extraction.vo src/Experiments/Ed25519Extraction.v + +src/Experiments/Ed25519.hs: src/Experiments/Ed25519_noimports.hs src/Experiments/Ed25519_imports.hs + ( cd src/Experiments && \ + < Ed25519_noimports.hs \ + sed "/import qualified Prelude/r Ed25519_imports.hs" | \ + sed 's/ Ed25519_noimports / Ed25519 /g' \ + > Ed25519.hs ) + +src/Experiments/X25519.hs: src/Experiments/X25519_noimports.hs src/Experiments/Ed25519_imports.hs + ( cd src/Experiments && \ + < X25519_noimports.hs \ + sed "/import qualified Prelude/r Ed25519_imports.hs" | \ + sed 's/ X25519_noimports / X25519 /g' \ + > X25519.hs ) + +src/Experiments/Ed25519.o src/Experiments/Ed25519.core: src/Experiments/Ed25519.hs + ( cd src/Experiments && ghc -XStrict -O3 Ed25519.hs -ddump-simpl > Ed25519.core ) + +src/Experiments/X25519.o src/Experiments/X25519.core: src/Experiments/X25519.hs + ( cd src/Experiments && ghc -XStrict -O3 X25519.hs -ddump-simpl > X25519.core ) + +extraction: src/Experiments/Ed25519.hs src/Experiments/X25519.hs +ghc: src/Experiments/Ed25519.core src/Experiments/Ed25519.o src/Experiments/X25519.o src/Experiments/X25519.core + clean:: rm -f Makefile.coq diff --git a/_CoqProject b/_CoqProject index 497b704cb..5b4c06f9c 100644 --- a/_CoqProject +++ b/_CoqProject @@ -6,13 +6,13 @@ src/Algebra.v src/BaseSystem.v src/BaseSystemProofs.v src/EdDSARepChange.v +src/MxDHRepChange.v src/Testbit.v src/Assembly/Bounds.v src/Assembly/Compile.v src/Assembly/Conversions.v src/Assembly/Evaluables.v src/Assembly/GF25519.v -src/Assembly/GF25519BoundedInstantiation.v src/Assembly/HL.v src/Assembly/LL.v src/Assembly/PhoasCommon.v @@ -62,6 +62,7 @@ src/Encoding/ModularWordEncodingTheorems.v src/Encoding/PointEncoding.v src/Encoding/PointEncodingPre.v src/Experiments/Ed25519.v +src/Experiments/Ed25519Extraction.v src/Experiments/ExtrHaskellNats.v src/Experiments/GenericFieldPow.v src/Experiments/MontgomeryCurve.v @@ -73,6 +74,7 @@ src/ModularArithmetic/ModularBaseSystem.v src/ModularArithmetic/ModularBaseSystemList.v src/ModularArithmetic/ModularBaseSystemListProofs.v src/ModularArithmetic/ModularBaseSystemListZOperations.v +src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v src/ModularArithmetic/ModularBaseSystemOpt.v src/ModularArithmetic/ModularBaseSystemProofs.v src/ModularArithmetic/Pow2Base.v @@ -91,6 +93,7 @@ src/ModularArithmetic/BarrettReduction/ZHandbook.v src/ModularArithmetic/Montgomery/Z.v src/ModularArithmetic/Montgomery/ZBounded.v src/ModularArithmetic/Montgomery/ZProofs.v +src/Reflection/Application.v src/Reflection/CommonSubexpressionElimination.v src/Reflection/Conversion.v src/Reflection/CountLets.v @@ -114,7 +117,6 @@ src/Reflection/WfProofs.v src/Reflection/WfReflective.v src/Reflection/WfReflectiveGen.v src/Reflection/WfRel.v -src/Reflection/WfRelReflective.v src/Reflection/Named/Compile.v src/Reflection/Named/ContextOn.v src/Reflection/Named/DeadCodeElimination.v @@ -125,6 +127,8 @@ src/Reflection/Named/Syntax.v src/Reflection/Z/Interpretations.v src/Reflection/Z/Reify.v src/Reflection/Z/Syntax.v +src/Reflection/Z/Interpretations/Relations.v +src/Reflection/Z/Interpretations/RelationsCombinations.v src/Spec/CompleteEdwardsCurve.v src/Spec/Ed25519.v src/Spec/EdDSA.v @@ -137,11 +141,24 @@ src/Specific/GF1305.v src/Specific/GF25519.v src/Specific/GF25519Bounded.v src/Specific/GF25519BoundedCommon.v -src/Specific/GF25519BoundedCommonWord.v +src/Specific/GF25519Reflective.v src/Specific/SC25519.v src/Specific/FancyMachine256/Barrett.v src/Specific/FancyMachine256/Core.v src/Specific/FancyMachine256/Montgomery.v +src/Specific/GF25519Reflective/Common.v +src/Specific/GF25519Reflective/Reified.v +src/Specific/GF25519Reflective/Reified/Add.v +src/Specific/GF25519Reflective/Reified/CarryAdd.v +src/Specific/GF25519Reflective/Reified/CarryOpp.v +src/Specific/GF25519Reflective/Reified/CarrySub.v +src/Specific/GF25519Reflective/Reified/Freeze.v +src/Specific/GF25519Reflective/Reified/GeModulus.v +src/Specific/GF25519Reflective/Reified/Mul.v +src/Specific/GF25519Reflective/Reified/Opp.v +src/Specific/GF25519Reflective/Reified/Pack.v +src/Specific/GF25519Reflective/Reified/Sub.v +src/Specific/GF25519Reflective/Reified/Unpack.v src/Tactics/VerdiTactics.v src/Tactics/Algebra_syntax/Nsatz.v src/Test/Curve25519SpecTestVectors.v @@ -153,7 +170,9 @@ src/Util/Decidable.v src/Util/Equality.v src/Util/FixCoqMistakes.v src/Util/GlobalSettings.v +src/Util/HList.v src/Util/HProp.v +src/Util/IffT.v src/Util/Isomorphism.v src/Util/IterAssocOp.v src/Util/LetIn.v @@ -163,6 +182,7 @@ src/Util/NatUtil.v src/Util/Notations.v src/Util/NumTheoryUtil.v src/Util/Option.v +src/Util/PartiallyReifiedProp.v src/Util/PointedProp.v src/Util/Prod.v src/Util/Relations.v diff --git a/sha512word.hs b/sha512word.hs new file mode 100644 index 000000000..5c30631a0 --- /dev/null +++ b/sha512word.hs @@ -0,0 +1,22 @@ +module SHA512Word where + +import qualified Data.ByteString.Lazy as B +import qualified Data.Digest.Pure.SHA as SHA +import Data.Bits ((.|.), shiftL, testBit) +import Data.Word (Word8) + +b2i :: Integral a => Bool -> a +b2i b = case b of { True -> 1 ; False -> 0 } + +leBitsToBytes :: [Bool] -> [Word8] +leBitsToBytes [] = [] +leBitsToBytes (a:b:c:d:e:f:g:h:bs) = (b2i a .|. (b2i b `shiftL` 1) .|. (b2i c `shiftL` 2) .|. (b2i d `shiftL` 3) .|. (b2i e `shiftL` 4) .|. (b2i f `shiftL` 5) .|. (b2i g `shiftL` 6) .|. (b2i h `shiftL` 7)) : leBitsToBytes bs +leBitsToBytes bs = error $ "byte must have exactly 8 bits, got " ++ show bs + + +bytesToLEBits :: [Word8] -> [Bool] +bytesToLEBits [] = [] +bytesToLEBits (x:xs) = (x `testBit` 0) : (x `testBit` 1) : (x `testBit` 2) : (x `testBit` 3) : (x `testBit` 4) : (x `testBit` 5) : (x `testBit` 6) : (x `testBit` 7) : bytesToLEBits xs + +h :: [Bool] -> [Bool] +h = bytesToLEBits . B.unpack . SHA.bytestringDigest . SHA.sha512 . B.pack . leBitsToBytes diff --git a/src/Algebra.v b/src/Algebra.v index c017a7456..de6b6f51a 100644 --- a/src/Algebra.v +++ b/src/Algebra.v @@ -378,6 +378,13 @@ Module Group. apply inv_id. Qed. + Lemma inv_zero_zero : forall x, inv x = id -> x = id. + Proof. + intros. + rewrite <-inv_id, <-H0. + symmetry; apply inv_inv. + Qed. + Lemma eq_r_opp_r_inv a b c : a = op c (inv b) <-> op a b = c. Proof. split; intro Hx; rewrite Hx || rewrite <-Hx; diff --git a/src/Assembly/GF25519.v b/src/Assembly/GF25519.v index e1ecfdce0..a904f14b1 100644 --- a/src/Assembly/GF25519.v +++ b/src/Assembly/GF25519.v @@ -7,6 +7,8 @@ Require Import Crypto.ModularArithmetic.ModularBaseSystem. Require Import Crypto.Specific.GF25519. Require Import Crypto.Util.Tuple. +Require InitialRing. + Module GF25519. Definition bits: nat := 64. Definition width: Width bits := W64. @@ -226,7 +228,7 @@ Module GF25519. Module Opp := Pipeline OppExpr. Section Instantiation. - Require Import InitialRing. + Import InitialRing. Definition Binary : Type := NAry 20 (word bits) (@interp_type (word bits) FE). Definition Unary : Type := NAry 10 (word bits) (@interp_type (word bits) FE). diff --git a/src/Assembly/GF25519BoundedInstantiation.v b/src/Assembly/GF25519BoundedInstantiation.v deleted file mode 100644 index 1c6897343..000000000 --- a/src/Assembly/GF25519BoundedInstantiation.v +++ /dev/null @@ -1,139 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Assembly.PhoasCommon. -Require Import Crypto.Assembly.QhasmCommon. -Require Import Crypto.Assembly.Compile. -Require Import Crypto.Assembly.LL. -Require Import Crypto.Assembly.GF25519. -Require Import Crypto.Specific.GF25519. -Require Import Crypto.Specific.GF25519BoundedCommonWord. -Require Import Crypto.Util.Tactics. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Tuple. - -(* Totally fine to edit these definitions; DO NOT change the type signatures at all *) -Section Operations. - Import Assembly.GF25519.GF25519. - Definition wfe: Type := @interp_type (word bits) FE. - - Definition ExprBinOp : Type := GF25519.Binary. - Definition ExprUnOp : Type := GF25519.Unary. - Axiom ExprUnOpFEToZ : Type. - Axiom ExprUnOpWireToFE : Type. - Axiom ExprUnOpFEToWire : Type. - - Local Existing Instance WordEvaluable. - - Definition interp_bexpr' (op: ExprBinOp) (x y: tuple (word bits) 10): tuple (word bits) 10 := - let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - let '(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9) := y in - op x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9. - - Definition interp_uexpr' (op: ExprUnOp) (x: tuple (word bits) 10): tuple (word bits) 10 := - let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - op x0 x1 x2 x3 x4 x5 x6 x7 x8 x9. - - Definition radd : ExprBinOp := GF25519.add. - Definition rsub : ExprBinOp := GF25519.sub. - Definition rmul : ExprBinOp := GF25519.mul. - Definition ropp : ExprUnOp := GF25519.opp. -End Operations. - -Definition interp_bexpr : ExprBinOp -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W - := interp_bexpr'. -Definition interp_uexpr : ExprUnOp -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W - := interp_uexpr'. -Axiom interp_uexpr_FEToZ : ExprUnOpFEToZ -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.word64. -Axiom interp_uexpr_FEToWire : ExprUnOpFEToWire -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.wire_digitsW. -Axiom interp_uexpr_WireToFE : ExprUnOpWireToFE -> Specific.GF25519BoundedCommonWord.wire_digitsW -> Specific.GF25519BoundedCommonWord.fe25519W. -Axiom rfreeze : ExprUnOp. -Axiom rge_modulus : ExprUnOpFEToZ. -Axiom rpack : ExprUnOpFEToWire. -Axiom runpack : ExprUnOpWireToFE. - -Declare Reduction asm_interp - := cbv [id - interp_bexpr interp_uexpr interp_bexpr' interp_uexpr' - radd rsub rmul ropp (*rfreeze rge_modulus rpack runpack*) - GF25519.GF25519.add GF25519.GF25519.sub GF25519.GF25519.mul GF25519.GF25519.opp (* GF25519.GF25519.freeze *) - GF25519.GF25519.bits GF25519.GF25519.FE - QhasmCommon.liftN QhasmCommon.NArgMap Compile.CompileHL.compile LL.LL.under_lets LL.LL.interp LL.LL.interp_arg LL.LL.match_arg_Prod Conversions.LLConversions.convertExpr Conversions.LLConversions.convertArg Conversions.LLConversions.convertVar PhoasCommon.type_rect PhoasCommon.type_rec PhoasCommon.type_ind PhoasCommon.interp_binop LL.LL.uninterp_arg - Evaluables.ezero Evaluables.toT Evaluables.fromT Evaluables.eadd Evaluables.esub Evaluables.emul Evaluables.eshiftr Evaluables.eand Evaluables.eltb Evaluables.eeqb - Evaluables.WordEvaluable Evaluables.ZEvaluable]. - -Definition interp_radd : Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W - := Eval asm_interp in interp_bexpr radd. -(*Print interp_radd.*) -Definition interp_radd_correct : interp_radd = interp_bexpr radd := eq_refl. -Definition interp_rsub : Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W - := Eval asm_interp in interp_bexpr rsub. -(*Print interp_rsub.*) -Definition interp_rsub_correct : interp_rsub = interp_bexpr rsub := eq_refl. -Definition interp_rmul : Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W - := Eval asm_interp in interp_bexpr rmul. -(*Print interp_rmul.*) -Definition interp_rmul_correct : interp_rmul = interp_bexpr rmul := eq_refl. -Definition interp_ropp : Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W - := Eval asm_interp in interp_uexpr ropp. -(*Print interp_ropp.*) -Definition interp_ropp_correct : interp_ropp = interp_uexpr ropp := eq_refl. -Definition interp_rfreeze : Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.fe25519W - := Eval asm_interp in interp_uexpr rfreeze. -(*Print interp_rfreeze.*) -Definition interp_rfreeze_correct : interp_rfreeze = interp_uexpr rfreeze := eq_refl. - -Definition interp_rge_modulus : Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.word64 - := Eval asm_interp in interp_uexpr_FEToZ rge_modulus. -Definition interp_rge_modulus_correct : interp_rge_modulus = interp_uexpr_FEToZ rge_modulus := eq_refl. - -Definition interp_rpack : Specific.GF25519BoundedCommonWord.fe25519W -> Specific.GF25519BoundedCommonWord.wire_digitsW - := Eval asm_interp in interp_uexpr_FEToWire rpack. -Definition interp_rpack_correct : interp_rpack = interp_uexpr_FEToWire rpack := eq_refl. - -Definition interp_runpack : Specific.GF25519BoundedCommonWord.wire_digitsW -> Specific.GF25519BoundedCommonWord.fe25519W - := Eval asm_interp in interp_uexpr_WireToFE runpack. -Definition interp_runpack_correct : interp_runpack = interp_uexpr_WireToFE runpack := eq_refl. - -Local Notation binop_correct_and_bounded rop op - := (ibinop_correct_and_bounded (interp_bexpr rop) op) (only parsing). -Local Notation unop_correct_and_bounded rop op - := (iunop_correct_and_bounded (interp_uexpr rop) op) (only parsing). -Local Notation unop_FEToZ_correct rop op - := (iunop_FEToZ_correct (interp_uexpr_FEToZ rop) op) (only parsing). -Local Notation unop_FEToWire_correct_and_bounded rop op - := (iunop_FEToWire_correct_and_bounded (interp_uexpr_FEToWire rop) op) (only parsing). -Local Notation unop_WireToFE_correct_and_bounded rop op - := (iunop_WireToFE_correct_and_bounded (interp_uexpr_WireToFE rop) op) (only parsing). - -Local Ltac start_correct_and_bounded_t op op_expr lem := - intros; hnf in *; destruct_head' prod; simpl in * |- ; - repeat match goal with H : is_bounded _ = true |- _ => unfold_is_bounded_in H end; - repeat match goal with H : wire_digits_is_bounded _ = true |- _ => unfold_is_bounded_in H end; - change op with op_expr; - rewrite <- lem. - -Lemma radd_correct_and_bounded : binop_correct_and_bounded radd carry_add. -Proof. - intros; hnf in *; destruct_head' prod; simpl in * |- . - repeat match goal with H : is_bounded _ = true |- _ => unfold_is_bounded_in H end. -Admitted. -Lemma rsub_correct_and_bounded : binop_correct_and_bounded rsub carry_sub. -Proof. -Admitted. -Lemma rmul_correct_and_bounded : binop_correct_and_bounded rmul mul. -Proof. -Admitted. -Lemma ropp_correct_and_bounded : unop_correct_and_bounded ropp carry_opp. -Proof. -Admitted. -Lemma rfreeze_correct_and_bounded : unop_correct_and_bounded rfreeze freeze. -Proof. -Admitted. -Lemma rge_modulus_correct_and_bounded : unop_FEToZ_correct rge_modulus ge_modulus. -Proof. -Admitted. -Lemma rpack_correct_and_bounded : unop_FEToWire_correct_and_bounded rpack pack. -Proof. -Admitted. -Lemma runpack_correct_and_bounded : unop_WireToFE_correct_and_bounded runpack unpack. -Proof. -Admitted. diff --git a/src/Assembly/WordizeUtil.v b/src/Assembly/WordizeUtil.v index 98e01bc23..b5f246fb1 100644 --- a/src/Assembly/WordizeUtil.v +++ b/src/Assembly/WordizeUtil.v @@ -162,7 +162,7 @@ Section Misc. intros x H. replace (& wones (S n)) with (2 * & (wones n) + N.b2n true)%N - by (simpl; nomega). + by (simpl; rewrite ?N.succ_double_spec; simpl; nomega). rewrite N.testbit_succ_r; reflexivity. Qed. @@ -181,7 +181,7 @@ Section Misc. + replace (& (wones (S (S n)))) with (2 * (& (wones (S n))) + N.b2n true)%N - by (simpl; nomega). + by (simpl; rewrite ?N.succ_double_spec; simpl; nomega). rewrite Nat2N.inj_succ. rewrite N.testbit_succ_r. assumption. @@ -189,7 +189,7 @@ Section Misc. - induction k. + replace (& (wones (S n))) with (2 * (& (wones n)) + N.b2n true)%N - by (simpl; nomega). + by (simpl; rewrite ?N.succ_double_spec; simpl; nomega). rewrite N.testbit_0_r. reflexivity. @@ -203,12 +203,12 @@ Section Misc. try rewrite Pos.succ_pred_double; intuition). replace (& (wones (S n))) with (2 * (& (wones n)) + N.b2n true)%N - by (simpl; nomega). + by (simpl; rewrite ?N.succ_double_spec; simpl; nomega). rewrite N.testbit_succ_r. assumption. Qed. - + Lemma plus_le: forall {n} (x y: word n), (& (x ^+ y) <= &x + &y)%N. Proof. @@ -335,7 +335,7 @@ Section Exp. rewrite <- IHn. simpl; intuition. Qed. - + Lemma Npow2_succ: forall n, (Npow2 (S n) = 2 * (Npow2 n))%N. Proof. intros; simpl; induction (Npow2 n); intuition. Qed. @@ -459,12 +459,7 @@ Section SpecialFunctions. with (N.double (& (wtl x))) by (induction (& (wtl x)); simpl; intuition). - - rewrite N.double_spec. - replace (N.succ (2 * & wtl x)) - with ((2 * (& wtl x)) + 1)%N - by nomega. - rewrite <- N.succ_double_spec. - rewrite N.div2_succ_double. + - rewrite N.div2_succ_double. reflexivity. - induction (& (wtl x)); simpl; intuition. @@ -509,11 +504,13 @@ Section SpecialFunctions. induction k'. + clear IHn; induction x; simpl; intuition. - destruct (& x), b; simpl; intuition. + destruct (& x), b; simpl; intuition. + clear IHk'. shatter x; simpl. + rewrite N.succ_double_spec; simpl. + rewrite kill_match. replace (N.pos (Pos.of_succ_nat k')) with (N.succ (N.of_nat k')) @@ -536,7 +533,7 @@ Section SpecialFunctions. rewrite Nat2N.id; reflexivity. Qed. - + Lemma wordToN_split1: forall {n m} x, & (@split1 n m x) = N.land (& x) (& (wones n)). Proof. @@ -625,7 +622,7 @@ Section SpecialFunctions. rewrite N.shiftr_spec; try apply N_ge_0. replace (k - N.of_nat n + N.of_nat n)%N with k by nomega. rewrite N.land_spec. - induction (N.testbit x k); + induction (N.testbit x k); replace (N.testbit (& wones n) k) with false; simpl; intuition; try apply testbit_wones_false; @@ -648,7 +645,7 @@ Section SpecialFunctions. - rewrite Nat2N.inj_succ. replace (& wones (S x)) with (2 * & (wones x) + N.b2n true)%N - by (simpl; nomega). + by (simpl; rewrite ?N.succ_double_spec; simpl; nomega). replace (N.ones (N.succ _)) with (2 * N.ones (N.of_nat x) + N.b2n true)%N. @@ -734,7 +731,7 @@ Section SpecialFunctions. - propagate_wordToN. rewrite N2Nat.id. reflexivity. - + - rewrite N.land_ones. rewrite N.mod_small; try reflexivity. rewrite <- (N2Nat.id m). @@ -997,4 +994,3 @@ Section TopLevel. Close Scope nword_scope. End TopLevel. - diff --git a/src/CompleteEdwardsCurve/ExtendedCoordinates.v b/src/CompleteEdwardsCurve/ExtendedCoordinates.v index a6d97fd4b..a804317d6 100644 --- a/src/CompleteEdwardsCurve/ExtendedCoordinates.v +++ b/src/CompleteEdwardsCurve/ExtendedCoordinates.v @@ -66,9 +66,26 @@ Module Extended. (let '(X,Y,Z,T) := coordinates P in let iZ := Finv Z in ((X*iZ), (Y*iZ))) _. Definition eq (P Q:point) := E.eq (to_twisted P) (to_twisted Q). - Global Instance DecidableRel_eq : Decidable.DecidableRel eq := _. - Local Hint Unfold from_twisted to_twisted eq : bash. + Definition eq_noinv (P1 P2:point) := + let '(X1, Y1, Z1, _) := coordinates P1 in + let '(X2, Y2, Z2, _) := coordinates P2 in + Z2*X1 = Z1*X2 /\ Z2*Y1 = Z1*Y2. + + Local Hint Unfold from_twisted to_twisted eq eq_noinv : bash. + + Lemma eq_noinv_eq P Q : eq P Q <-> eq_noinv P Q. + Proof. safe_bash; repeat split; safe_bash. Qed. + Global Instance DecidableRel_eq_noinv : Decidable.DecidableRel eq_noinv. + Proof. + intros P Q. + destruct P as [ [ [ [ ] ? ] ? ] ?], Q as [ [ [ [ ] ? ] ? ] ? ]; simpl; exact _. + Defined. + Global Instance DecidableRel_eq : Decidable.DecidableRel eq. + Proof. + intros ? ?. + eapply @Decidable_iff_to_flip_impl; [eapply eq_noinv_eq | exact _]. + Defined. Global Instance Equivalence_eq : Equivalence eq. Proof. split; split; safe_bash. Qed. Global Instance Proper_from_twisted : Proper (E.eq==>eq) from_twisted. Proof. unsafe_bash. Qed. diff --git a/src/EdDSARepChange.v b/src/EdDSARepChange.v index d63e6d14e..20e544437 100644 --- a/src/EdDSARepChange.v +++ b/src/EdDSARepChange.v @@ -106,7 +106,7 @@ Section EdDSA. {Proper_ERepEnc:Proper (ErepEq==>Logic.eq) ERepEnc}. Context {ERepDec : word b -> option Erep} - {ERepDec_correct : forall w, ERepDec w = option_map EToRep (Edec w) }. + {ERepDec_correct : forall w, option_eq ErepEq (ERepDec w) (option_map EToRep (Edec w)) }. Context {SRep SRepEq} `{@Equivalence SRep SRepEq} {S2Rep:F l->SRep}. @@ -118,7 +118,7 @@ Section EdDSA. Context {SRepDec: word b -> option SRep} {SRepDec_correct : forall w, option_eq SRepEq (option_map S2Rep (Sdec w)) (SRepDec w)}. - + Definition verify_using_representation {mlen} (message:word mlen) (pk:word b) (sig:word (b+b)) : { answer | answer = verify' message pk sig }. @@ -159,7 +159,22 @@ Section EdDSA. (H _ (split1 b b sig ++ pk ++ message))) (ErepOpp (s))))) (split1 b b sig)) false (Sdec (split2 b b sig))) - false); rewrite <-(ERepDec_correct pk). + false). + (* rewrite with a complicated proper instance for inline code .. *) + etransitivity; + [| eapply Proper_option_rect_nd_changevalue; + [ + | reflexivity + | eapply ERepDec_correct + ]; + [ repeat match goal with + | |- _ => intro + | |- _ => eapply Proper_option_rect_nd_changebody + | |- _ ?x ?x => reflexivity + | H : _ |- _ => rewrite H; reflexivity + end + ] + ]. etransitivity. Focus 2. { eapply Proper_option_rect_nd_changebody; [intro|reflexivity]. diff --git a/src/Encoding/PointEncoding.v b/src/Encoding/PointEncoding.v index b005ce3df..1160ed83a 100644 --- a/src/Encoding/PointEncoding.v +++ b/src/Encoding/PointEncoding.v @@ -3,7 +3,7 @@ Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.Spec.CompleteEdwardsCurve. Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Bedrock.Word. +Require Import Bedrock.Word Crypto.Util.WordUtil. Require Import Crypto.Tactics.VerdiTactics. Require Import Crypto.Util.Option. Require Import Crypto.Util.NatUtil. @@ -15,7 +15,7 @@ Require Crypto.Encoding.PointEncodingPre. Eenc := encode_point Proper_Eenc := Proper_encode_point Edec := Fdecode_point (notation) - eq_enc_E_iff := Fdecode_encode_iff + eq_enc_E_iff := encode_point_decode_point_iff EToRep := point_phi Ahomom := point_phi_homomorphism ERepEnc := Kencode_point @@ -27,16 +27,19 @@ Require Crypto.Encoding.PointEncodingPre. Section PointEncoding. Context {b : nat} {m : Z} {Fa Fd : F m} {prime_m : Znumtheory.prime m} + {two_lt_m : (2 < m)%Z} {bound_check : (Z.to_nat m < 2 ^ b)%nat}. - Definition sign (x : F m) : bool := Z.testbit (F.to_Z x) 0. + Local Infix "++" := Word.combine. + Local Notation bit b := (Word.WS b Word.WO : Word.word 1). + Definition sign (x : F m) : bool := Z.testbit (F.to_Z x) 0. Definition Fencode (x : F m) : word b := NToWord b (Z.to_N (F.to_Z x)). Let Fpoint := @E.point (F m) Logic.eq F.one F.add F.mul Fa Fd. Definition encode_point (P : Fpoint) := - let '(x,y) := E.coordinates P in WS (sign x) (Fencode y). + let '(x,y) := E.coordinates P in Fencode y ++ bit (sign x). Import Morphisms. Lemma Proper_encode_point : Proper (E.eq ==> Logic.eq) encode_point. @@ -68,7 +71,6 @@ Section PointEncoding. {Kcoord_to_point : @E.point K Keq Kone Kadd Kmul Ka Kd -> Kpoint} {Kpoint_to_coord : Kpoint -> (K * K)}. Context {Kp2c_c2p : forall pt : E.point, Tuple.fieldwise (n := 2) Keq (Kpoint_to_coord (Kcoord_to_point pt)) (E.coordinates pt)}. - Check Kp2c_c2p. Context {Kpoint_eq : Kpoint -> Kpoint -> Prop} {Kpoint_add : Kpoint -> Kpoint -> Kpoint}. Context {Kpoint_eq_correct : forall p q, Kpoint_eq p q <-> Tuple.fieldwise (n := 2) Keq (Kpoint_to_coord p) (Kpoint_to_coord q)} {Kpoint_eq_Equivalence : Equivalence Kpoint_eq}. @@ -122,7 +124,8 @@ Section PointEncoding. option_eq Keq (option_map phi (Fdecode w)) (Kdec w)}. Context {Fsqrt : F m -> F m} {phi_sqrt : forall x, Keq (phi (Fsqrt x)) (Ksqrt (phi x))} - {Fsqrt_square : forall x root, eq x (F.mul root root) -> eq (Fsqrt x) root}. + {Fsqrt_square : forall x root, eq x (F.mul root root) -> + eq (F.mul (Fsqrt x) (Fsqrt x)) x}. Lemma point_phi_homomorphism: @Algebra.Monoid.is_homomorphism Fpoint E.eq Fpoint_add @@ -133,7 +136,7 @@ Section PointEncoding. Qed. Definition Kencode_point (P : Kpoint) := - let '(x,y) := Kpoint_to_coord P in WS (Ksign x) (Kenc y). + let '(x,y) := Kpoint_to_coord P in (Kenc y) ++ bit (Ksign x). Lemma Kencode_point_correct : forall P : Fpoint, encode_point P = Kencode_point (point_phi P). @@ -146,7 +149,9 @@ Section PointEncoding. pose proof (Kp2c_c2p x) as A; rewrite Heqp in A; inversion A; cbv [fst snd Tuple.fieldwise'] in * end. cbv [E.coordinates E.ref_phi proj1_sig] in *. - f_equal; rewrite ?H0, ?H1; auto. + apply (f_equal2 (fun a b => a ++ b)); + try apply (f_equal2 (fun a b => WS a b)); + rewrite ?H0, ?H1; auto. Qed. Lemma Proper_Kencode_point : Proper (Kpoint_eq ==> Logic.eq) Kencode_point. @@ -157,7 +162,9 @@ Section PointEncoding. destruct (Kpoint_to_coord x). destruct (Kpoint_to_coord y). simpl in H; destruct H. - f_equal; auto. + apply (f_equal2 (fun a b => a ++ b)); + try apply (f_equal2 (fun a b => WS a b)); + rewrite ?H0, ?H1; auto. Qed. @@ -172,11 +179,11 @@ Section PointEncoding. else Some p else None. - Definition Kdecode_coordinates (w : word (S b)) : option (K * K) := + Definition Kdecode_coordinates (w : word (b + 1)) : option (K * K) := option_rect (fun _ => option (K * K)) - (Kcoordinates_from_y (whd w)) + (Kcoordinates_from_y (wlast w)) None - (Kdec (wtl w)). + (Kdec (winit w)). Lemma onCurve_eq : forall x y, Keq (Kadd (Kmul Ka (Kmul x x)) (Kmul y y)) @@ -194,7 +201,7 @@ Section PointEncoding. | right _ => None end. - Definition Kdecode_point (w : word (S b)) : option Kpoint := + Definition Kdecode_point (w : word (b+1)) : option Kpoint := option_rect (fun _ => option Kpoint) Kpoint_from_xy None (Kdecode_coordinates w). Definition Fencoding : Encoding.CanonicalEncoding (F m) (word b). @@ -434,6 +441,82 @@ Section PointEncoding. + intros. apply Kpoint_from_xy_correct. Qed. + Lemma sign_zero : forall x, x = F.zero -> sign x = false. + Proof. + intros; subst. + reflexivity. + Qed. + + Lemma sign_negb : forall x : F m, x <> F.zero -> + negb (sign x) = sign (F.opp x). + Proof. + intros. + cbv [sign]. + rewrite !Z.bit0_odd. + rewrite F.to_Z_opp. + rewrite F.eq_to_Z_iff in H. + replace (@F.to_Z m F.zero) with 0%Z in H by reflexivity. + rewrite Z.mod_opp_l_nz by (solve [ZUtil.Z.prime_bound] || + rewrite F.mod_to_Z; auto). + rewrite F.mod_to_Z. + rewrite Z.odd_sub. + destruct (ZUtil.Z.prime_odd_or_2 m prime_m) as [? | m_odd]; + [ omega | rewrite m_odd]. + rewrite <-Bool.xorb_true_l; auto. + Qed. + + Lemma Eeq_point_eq : forall x y : option E.point, + option_eq E.eq x y <-> + option_eq + (@PointEncodingPre.point_eq _ eq F.one F.add F.mul Fa Fd) x y. + Proof. + intros. + cbv [option_eq E.eq PointEncodingPre.point_eq + PointEncodingPre.prod_eq]; repeat break_match; + try reflexivity. + cbv [E.coordinates]. + subst. + rewrite Heqp1, Heqp0. + cbv [Tuple.fieldwise Tuple.fieldwise' fst snd]. + tauto. + Qed. + + Lemma enc_canonical_equiv : forall (x_enc : word b) (x : F m), + option_eq eq (Fdecode x_enc) (Some x) -> + Fencode x = x_enc. + Proof. + intros. + cbv [option_eq] in *. + break_match; try discriminate. + subst. + apply (@Encoding.encoding_canonical _ _ Fencoding). + auto. + Qed. + + Lemma encode_point_decode_point_iff : forall P_ P, + encode_point P = P_ <-> + Option.option_eq E.eq (Fdecode_point P_) (Some P). + Proof. + pose proof (@PointEncodingPre.point_encoding_canonical + _ eq F.zero F.one F.opp F.add F.sub F.mul F.div + _ Fa Fd _ Fsqrt Fencoding enc_canonical_equiv + sign sign_zero sign_negb + ) as Hcanonical. + let A := fresh "H" in + match type of Hcanonical with + ?P -> _ => assert P as A by congruence; + specialize (Hcanonical A); clear A end. + intros. + rewrite Eeq_point_eq. + split; intros; subst. + { apply PointEncodingPre.point_encoding_valid; + auto using sign_zero, sign_negb; + congruence. } + { apply Hcanonical. + cbv [option_eq PointEncodingPre.point_eq PointEncodingPre.prod_eq] in H |- *. + break_match; congruence. } + Qed. + End RepChange. End PointEncoding. diff --git a/src/Encoding/PointEncodingPre.v b/src/Encoding/PointEncodingPre.v index 3c3075d4f..8a0d4c849 100644 --- a/src/Encoding/PointEncodingPre.v +++ b/src/Encoding/PointEncodingPre.v @@ -3,10 +3,12 @@ Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import Coq.Program.Equality. Require Import Crypto.CompleteEdwardsCurve.Pre. Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Bedrock.Word. +Require Import Bedrock.Word Crypto.Util.WordUtil. Require Import Crypto.Encoding.ModularWordEncodingTheorems. Require Import Crypto.Util.ZUtil. Require Import Crypto.Algebra. +Require Import Crypto.Util.Option. +Import Morphisms. Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.ModularArithmetic. @@ -39,9 +41,13 @@ Section PointEncodingPre. Local Notation solve_for_x2 := (@E.solve_for_x2 F one sub mul div a d). Context {sz : nat} (sz_nonzero : (0 < sz)%nat). - Context {sqrt : F -> F} (sqrt_square : forall x root, x == (root ^2) -> sqrt x == root) + Context {sqrt : F -> F} {Proper_sqrt : Proper (eq ==>eq) sqrt} + (sqrt_square : forall x root, x == (root ^2) -> + (sqrt x *sqrt x == x)) (sqrt_subst : forall x y, x == y -> sqrt x == sqrt y). Context (FEncoding : canonical encoding of F as (word sz)). + Context {enc_canonical_equiv : forall x_enc x, + option_eq eq (dec x_enc) (Some x) -> enc x = x_enc}. Context {sign_bit : F -> bool} (sign_bit_zero : forall x, x == 0 -> Logic.eq (sign_bit x) false) (sign_bit_opp : forall x, x !== 0 -> Logic.eq (negb (sign_bit x)) (sign_bit (opp x))) (sign_bit_subst : forall x y, x == y -> sign_bit x = sign_bit y). @@ -55,7 +61,7 @@ Section PointEncodingPre. pose proof root2_y. apply sqrt_square in root2_y. rewrite root2_y. - symmetry; assumption. + reflexivity. Qed. Lemma solve_onCurve: forall x y : F, onCurve (x,y) -> @@ -85,10 +91,10 @@ Section PointEncodingPre. apply E.solve_correct; eassumption. Qed. - Definition point_enc_coordinates (p : (F * F)) : Word.word (S sz) := let '(x,y) := p in - Word.WS (sign_bit x) (enc y). + Definition point_enc_coordinates (p : (F * F)) : Word.word (sz+1) := let '(x,y) := p in + combine (enc y) (WS (sign_bit x) WO). - Let point_enc (p : point) : Word.word (S sz) := point_enc_coordinates (E.coordinates p). + Let point_enc (p : point) : Word.word (sz+1) := point_enc_coordinates (E.coordinates p). Definition coord_from_y sign (y : F) : option (F * F) := let x2 := solve_for_x2 y in @@ -101,8 +107,8 @@ Section PointEncodingPre. else Some p else None. - Definition point_dec_coordinates (w : word (S sz)) : option (F * F) := - option_rect (fun _ => _) (coord_from_y (whd w)) None (dec (wtl w)). + Definition point_dec_coordinates (w : word (sz+1)) : option (F * F) := + option_rect (fun _ => _) (coord_from_y (wlast w)) None (dec (winit w)). (* Definition of product equality parameterized over equality of underlying types *) Definition prod_eq {A B} eqA eqB (x y : (A * B)) := let (xA,xB) := x in let (yA,yB) := y in @@ -120,15 +126,6 @@ Section PointEncodingPre. unfold prod_eq; intuition. Qed. - Definition option_eq {A} eq (x y : option A) := - match x with - | None => y = None - | Some ax => match y with - | None => False - | Some ay => eq ax ay - end - end. - Lemma option_eq_dec : forall {A eq} (A_eq_dec : forall a a' : A, {eq a a'} + {not (eq a a')}) (x y : option A), {option_eq eq x y} + {not (option_eq eq x y)}. Proof. @@ -227,7 +224,7 @@ Section PointEncodingPre. repeat break_match; subst; try destruct p; congruence || eauto using prod_eq_sym; intuition. Qed. - Opaque option_coordinates_eq option_point_eq point_eq option_eq prod_eq. + Opaque option_coordinates_eq option_point_eq. Ltac inversion_Some_eq := match goal with [H: Some ?x = Some ?y |- _] => inversion H; subst end. @@ -249,11 +246,16 @@ Section PointEncodingPre. | right _ => None end. - Definition point_dec (w : word (S sz)) : option point := + Definition point_dec (w : word (sz+1)) : option point := option_rect (fun _ => option point) point_from_xy None (point_dec_coordinates w). + Lemma bool_neq_negb x y : x <> y <-> x = negb y. + destruct x, y; split; (discriminate||tauto). + Qed. + Lemma point_coordinates_encoding_canonical : forall w p, - point_dec_coordinates w = Some p -> point_enc_coordinates p = w. + option_eq (Tuple.fieldwise (n := 2) eq) (point_dec_coordinates w) (Some p) -> + point_enc_coordinates p = w. Proof. repeat match goal with | |- _ => progress cbv [point_dec_coordinates option_rect @@ -266,37 +268,61 @@ Section PointEncodingPre. (intro A; specialize (sign_bit_zero _ A); congruence)) | p : F * F |- _ => destruct p | |- _ => break_match; try discriminate - | H : Some _ = Some _ |- _ => inversion H; subst; clear H | w : word (S sz) |- WS _ _ = ?w => rewrite (shatter_word w); f_equal + | H : option_eq _ (Some _) (Some _) |- _ => + cbv [option_eq Tuple.fieldwise Tuple.fieldwise' fst snd] in H; + destruct H + | H : Bool.eqb _ _ = _ |- _ => apply Bool.eqb_prop in H + | H : ?b = sign_bit ?x |- sign_bit ?y = ?b => erewrite <-sign_bit_subst by eassumption; instantiate; congruence + | H : ?b <> sign_bit ?x |- sign_bit ?y <> ?b => erewrite <-sign_bit_subst by eassumption; instantiate; congruence | |- sign_bit _ = whd ?w => destruct (whd w) | |- negb _ = false => apply Bool.negb_false_iff | |- _ => solve [auto using Bool.eqb_prop, Bool.eq_true_not_negb, Bool.not_false_is_true, encoding_canonical] - end. + end; + rewrite combine_winit_wlast; split; + try apply (f_equal2 (fun a b => WS a b)); + try solve + [ trivial + | apply enc_canonical_equiv; rewrite Heqo; auto]; + erewrite <-sign_bit_subst by eassumption. + { intuition. } + { apply bool_neq_negb in Heqb0. rewrite <-sign_bit_opp. + { congruence. } + { rewrite Bool.andb_false_iff in *. + unfold not; intro Hx; destruct Heqb; + [apply F_eqb_iff in Hx; congruence + |rewrite (sign_bit_zero _ Hx) in *; simpl negb in *; congruence]. } } Qed. - Lemma inversion_point_dec : forall w x, point_dec w = Some x -> - point_dec_coordinates w = Some (E.coordinates x). + Lemma inversion_point_dec : forall w x, + option_eq point_eq (point_dec w) (Some x) -> + option_eq (Tuple.fieldwise (n := 2) eq) (point_dec_coordinates w) (Some (E.coordinates x)). Proof. unfold point_dec, E.coordinates, point_from_xy, option_rect; intros. break_match; [ | congruence]. destruct p. break_match; [ | congruence ]. - match goal with [ H : Some _ = Some _ |- _ ] => inversion H end. - reflexivity. + destruct x as [xy pf]; destruct xy. + cbv [option_eq point_eq] in *. + simpl in *. + intuition. Qed. - Lemma point_encoding_canonical : forall w x, point_dec w = Some x -> point_enc x = w. + Lemma point_encoding_canonical : forall w x, + option_eq point_eq (point_dec w) (Some x) -> point_enc x = w. Proof. unfold point_enc; intros. apply point_coordinates_encoding_canonical. auto using inversion_point_dec. Qed. - Lemma y_decode : forall p, dec (wtl (point_enc_coordinates p)) = Some (snd p). + + Lemma y_decode : forall p, dec (winit (point_enc_coordinates p)) = Some (snd p). Proof. intros; destruct p. cbv [point_enc_coordinates wtl snd]. + rewrite winit_combine. exact (encoding_valid _). Qed. @@ -347,19 +373,40 @@ Section PointEncodingPre. break_if; [ | congruence]. assert (solve_for_x2 y == (x ^2)) as solve_correct by (symmetry; apply E.solve_correct; assumption). destruct (eq_dec x 0) as [eq_x_0 | neq_x_0]. - + rewrite !sign_bit_zero by - (eauto || (rewrite eq_x_0 in *; rewrite sqrt_square; [ | eauto]; reflexivity)). + + rewrite eq_x_0 in *. + assert (0^2 == 0) as zero_square by apply Ring.mul_0_l. + specialize (sqrt_square _ _ solve_correct). + rewrite solve_correct, zero_square in sqrt_square. + rewrite Ring.zero_product_iff_zero_factor in sqrt_square. + rewrite zero_square in *. + assert (sqrt (solve_for_x2 y) == 0) by (rewrite solve_correct; tauto). + rewrite !sign_bit_zero by (tauto || eauto). + rewrite wlast_combine. rewrite Bool.andb_false_r, Bool.eqb_reflx. apply option_coordinates_eq_iff; split; try reflexivity. - transitivity (sqrt (x ^2)); auto. - apply (sqrt_square); reflexivity. - + rewrite (proj1 (F_eqb_false _ 0)), Bool.andb_false_l by (rewrite sqrt_square; [ | eauto]; assumption). + etransitivity; eauto. + symmetry; eauto. + + assert (0^2 == 0) as zero_square by apply Ring.mul_0_l. + specialize (sqrt_square _ _ solve_correct). + rewrite !solve_correct in *. + symmetry in sqrt_square. + rewrite (proj1 (F_eqb_false _ 0)), Bool.andb_false_l. + Focus 2. { + rewrite !solve_correct in *. + intro. + apply neq_x_0. + rewrite H0, zero_square in sqrt_square. + rewrite Ring.zero_product_iff_zero_factor in sqrt_square. + tauto. } Unfocus. + rewrite wlast_combine. break_if; [ | apply eqb_sign_opp_r in Heqb]; try (apply option_coordinates_eq_iff; split; try reflexivity); try eapply sign_match with (y := solve_for_x2 y); eauto; - try solve [symmetry; auto]; rewrite ?square_opp; auto; - (rewrite sqrt_square; [ | eauto]); try apply Ring.opp_nonzero_nonzero; - assumption. + try solve [symmetry; auto]; rewrite ?square_opp; auto; + intro; apply neq_x_0; rewrite solve_correct in *; + try apply Group.inv_zero_zero in H0; + rewrite H0, zero_square in sqrt_square; + rewrite Ring.zero_product_iff_zero_factor in sqrt_square; tauto. Qed. Lemma point_encoding_valid : forall p, diff --git a/src/Experiments/Ed25519.v b/src/Experiments/Ed25519.v index 3c2c0baf8..20208924d 100644 --- a/src/Experiments/Ed25519.v +++ b/src/Experiments/Ed25519.v @@ -2,6 +2,7 @@ Require Import Coq.omega.Omega. Require Import Coq.Lists.List. Import ListNotations. Require Import Crypto.EdDSARepChange. +Require Import Crypto.MxDHRepChange. Import MxDH. Require Import Crypto.Spec.Ed25519. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.ListUtil. @@ -20,29 +21,39 @@ Local Coercion GF25519BoundedCommon.word64ToZ : GF25519BoundedCommon.word64 >-> Local Coercion GF25519BoundedCommon.proj1_fe25519 : GF25519BoundedCommon.fe25519 >-> GF25519.fe25519. Local Set Printing Coercions. +Local Notation eta x := (fst x, snd x). +Local Notation eta3 x := (eta (fst x), snd x). +Local Notation eta4 x := (eta3 (fst x), snd x). + Context {H: forall n : nat, Word.word n -> Word.word (b + b)}. -Definition feSign (x : GF25519BoundedCommon.fe25519) : bool := +Definition feSign (f : GF25519BoundedCommon.fe25519) : bool := + let x := GF25519Bounded.freeze f in let '(x9, x8, x7, x6, x5, x4, x3, x2, x1, x0) := (x : GF25519.fe25519) in BinInt.Z.testbit x0 0. -(* TODO *) -Context {feSign_correct : forall x, - PointEncoding.sign x = feSign (GF25519BoundedCommon.encode x)}. -Context {Proper_feSign : Proper (GF25519BoundedCommon.eq ==> eq) feSign}. - -Definition a : GF25519BoundedCommon.fe25519 := - Eval vm_compute in GF25519BoundedCommon.encode a. -Definition d : GF25519BoundedCommon.fe25519 := - Eval vm_compute in GF25519BoundedCommon.encode d. -Definition twice_d : GF25519BoundedCommon.fe25519 := - Eval vm_compute in (GF25519Bounded.add d d). +Section Constants. + Import GF25519BoundedCommon. + Definition a' : GF25519BoundedCommon.fe25519 := + Eval vm_compute in GF25519BoundedCommon.encode a. + Definition a : GF25519BoundedCommon.fe25519 := + Eval cbv [a' fe25519_word64ize word64ize andb opt.word64ToZ opt.word64ize opt.Zleb Z.compare CompOpp Pos.compare Pos.compare_cont] in (fe25519_word64ize a'). + Definition d' : GF25519BoundedCommon.fe25519 := + Eval vm_compute in GF25519BoundedCommon.encode d. + Definition d : GF25519BoundedCommon.fe25519 := + Eval cbv [d' fe25519_word64ize word64ize andb opt.word64ToZ opt.word64ize opt.Zleb Z.compare CompOpp Pos.compare Pos.compare_cont] in (fe25519_word64ize d'). + Definition twice_d' : GF25519BoundedCommon.fe25519 := + Eval vm_compute in (GF25519Bounded.add d d). + Definition twice_d : GF25519BoundedCommon.fe25519 := + Eval cbv [twice_d' fe25519_word64ize word64ize andb opt.word64ToZ opt.word64ize opt.Zleb Z.compare CompOpp Pos.compare Pos.compare_cont] in (fe25519_word64ize twice_d'). +End Constants. + Lemma phi_a : GF25519BoundedCommon.eq (GF25519BoundedCommon.encode Spec.Ed25519.a) a. Proof. reflexivity. Qed. Lemma phi_d : GF25519BoundedCommon.eq (GF25519BoundedCommon.encode Spec.Ed25519.d) d. Proof. vm_decide_no_check. Qed. -Let Erep := (@ExtendedCoordinates.Extended.point +Definition Erep := (@ExtendedCoordinates.Extended.point GF25519BoundedCommon.fe25519 GF25519BoundedCommon.eq GF25519BoundedCommon.zero @@ -56,7 +67,7 @@ Let Erep := (@ExtendedCoordinates.Extended.point Local Existing Instance GF25519.homomorphism_F25519_encode. Local Existing Instance GF25519.homomorphism_F25519_decode. -Lemma twedprm_ERep : +Local Instance twedprm_ERep : @CompleteEdwardsCurve.E.twisted_edwards_params GF25519BoundedCommon.fe25519 GF25519BoundedCommon.eq GF25519BoundedCommon.zero GF25519BoundedCommon.one @@ -93,7 +104,7 @@ Proof. reflexivity. Qed. -Let EToRep := +Definition EToRep := PointEncoding.point_phi (Kfield := GF25519Bounded.field25519) (phi_homomorphism := GF25519Bounded.homomorphism_F25519_encode) @@ -102,8 +113,8 @@ Let EToRep := (phi_d := phi_d) (Kcoord_to_point := ExtendedCoordinates.Extended.from_twisted (prm := twedprm_ERep) (field := GF25519Bounded.field25519)). -Let ZNWord sz x := Word.NToWord sz (BinInt.Z.to_N x). -Let WordNZ {sz} (w : Word.word sz) := BinInt.Z.of_N (Word.wordToN w). +Definition ZNWord sz x := Word.NToWord sz (BinInt.Z.to_N x). +Definition WordNZ {sz} (w : Word.word sz) := BinInt.Z.of_N (Word.wordToN w). (* TODO : GF25519.pack does most of the work here, but the spec currently talks @@ -122,7 +133,12 @@ Definition feEnc (x : GF25519BoundedCommon.fe25519) : Word.word 255 := (Word.combine (ZNWord 32 x4) (Word.combine (ZNWord 32 x5) (Word.combine (ZNWord 32 x6) (ZNWord 31 x7))))))). +Check GF25519Bounded.unpack. +Print GF25519BoundedCommon.wire_digits. +Eval compute in GF25519.wire_widths. +Eval compute in (Tuple.from_list 8 GF25519.wire_widths _). +(** TODO(jadep or andreser, from jgross): Is the reversal on the words passed in correct? *) Definition feDec (w : Word.word 255) : option GF25519BoundedCommon.fe25519 := let w0 := Word.split1 32 _ w in let a0 := Word.split2 32 _ w in @@ -138,18 +154,18 @@ Definition feDec (w : Word.word 255) : option GF25519BoundedCommon.fe25519 := let a5 := Word.split2 32 _ a4 in let w6 := Word.split1 32 _ a5 in let w7 := Word.split2 32 _ a5 in - let result := (GF25519Bounded.unpack (GF25519BoundedCommon.word32_to_unbounded_word w0, - GF25519BoundedCommon.word32_to_unbounded_word w1, - GF25519BoundedCommon.word32_to_unbounded_word w2, - GF25519BoundedCommon.word32_to_unbounded_word w3, - GF25519BoundedCommon.word32_to_unbounded_word w4, - GF25519BoundedCommon.word32_to_unbounded_word w5, + let result := (GF25519Bounded.unpack (GF25519BoundedCommon.word31_to_unbounded_word w7, GF25519BoundedCommon.word32_to_unbounded_word w6, - GF25519BoundedCommon.word31_to_unbounded_word w7)) in + GF25519BoundedCommon.word32_to_unbounded_word w5, + GF25519BoundedCommon.word32_to_unbounded_word w4, + GF25519BoundedCommon.word32_to_unbounded_word w3, + GF25519BoundedCommon.word32_to_unbounded_word w2, + GF25519BoundedCommon.word32_to_unbounded_word w1, + GF25519BoundedCommon.word32_to_unbounded_word w0)) in if GF25519BoundedCommon.w64eqb (GF25519Bounded.ge_modulus result) (GF25519BoundedCommon.ZToWord64 1) then None else (Some result). -Let ERepEnc := +Definition ERepEnc := (PointEncoding.Kencode_point (Ksign := feSign) (Kenc := feEnc) @@ -158,8 +174,8 @@ Let ERepEnc := (ExtendedCoordinates.Extended.to_twisted P (field:=GF25519Bounded.field25519))) ). -Let SRep := SC25519.SRep. -Let S2Rep := SC25519.S2Rep. +Definition SRep := SC25519.SRep. +Definition S2Rep := SC25519.S2Rep. (*Let SRep := Tuple.tuple (Word.word 32) 8. Let S2Rep := fun (x : ModularArithmetic.F.F l) => @@ -172,13 +188,13 @@ Let S2Rep := fun (x : ModularArithmetic.F.F l) => Lemma eq_a_minus1 : GF25519BoundedCommon.eq a (GF25519Bounded.opp GF25519BoundedCommon.one). Proof. vm_decide. Qed. -Let ErepAdd := +Definition ErepAdd := (@ExtendedCoordinates.Extended.add _ _ _ _ _ _ _ _ _ _ a d GF25519Bounded.field25519 twedprm_ERep _ eq_a_minus1 twice_d (eq_refl _) ). Local Coercion Z.of_nat : nat >-> Z. -Let ERepSel : bool -> Erep -> Erep -> Erep := fun b x y => if b then y else x. +Definition ERepSel : bool -> Erep -> Erep -> Erep := fun b x y => if b then y else x. Local Existing Instance ExtendedCoordinates.Extended.extended_group. @@ -212,6 +228,60 @@ Proof. reflexivity. Qed. +Lemma ERep_eq_E P Q : + ExtendedCoordinates.Extended.eq (field:=GF25519Bounded.field25519) + (EToRep P) (EToRep Q) + -> CompleteEdwardsCurveTheorems.E.eq P Q. +Proof. + destruct P as [[] HP], Q as [[] HQ]. + cbv [ExtendedCoordinates.Extended.eq EToRep PointEncoding.point_phi CompleteEdwardsCurveTheorems.E.ref_phi CompleteEdwardsCurveTheorems.E.eq CompleteEdwardsCurve.E.coordinates + ExtendedCoordinates.Extended.coordinates + ExtendedCoordinates.Extended.to_twisted + ExtendedCoordinates.Extended.from_twisted + GF25519BoundedCommon.eq ModularBaseSystem.eq + Tuple.fieldwise Tuple.fieldwise' fst snd proj1_sig]. + intro H. + rewrite !GF25519Bounded.mul_correct, !GF25519Bounded.inv_correct, !GF25519BoundedCommon.proj1_fe25519_encode in *. + rewrite !Algebra.Ring.homomorphism_mul in H. + pose proof (Algebra.Field.homomorphism_multiplicative_inverse (H:=GF25519.field25519)) as Hinv; + rewrite Hinv in H by vm_decide; clear Hinv. + let e := constr:((ModularBaseSystem.decode (GF25519BoundedCommon.proj1_fe25519 GF25519BoundedCommon.one))) in + set e as xe; assert (Hone:xe = ModularArithmetic.F.one) by vm_decide; subst xe; rewrite Hone in *; clear Hone. + rewrite <-!(Algebra.field_div_definition(inv:=ModularArithmetic.F.inv)) in H. + rewrite !(Algebra.Field.div_one(one:=ModularArithmetic.F.one)) in H. + pose proof ModularBaseSystemProofs.encode_rep as Hencode; + unfold ModularBaseSystem.rep in Hencode; rewrite !Hencode in H. + assumption. +Qed. + +Module N. + Lemma size_le a b : (a <= b -> N.size a <= N.size b)%N. + Proof. + destruct (dec (a=0)%N), (dec (b=0)%N); subst; auto using N.le_0_l. + { destruct a; auto. } + { rewrite !N.size_log2 by assumption. + rewrite <-N.succ_le_mono. + apply N.log2_le_mono. } + Qed. + + Lemma le_to_nat a b : (a <= b)%N <-> (N.to_nat a <= N.to_nat b)%nat. + Proof. + rewrite <-N.lt_succ_r. + rewrite <-Nat.lt_succ_r. + rewrite <-Nnat.N2Nat.inj_succ. + rewrite <-NatUtil.Nat2N_inj_lt. + rewrite !Nnat.N2Nat.id. + reflexivity. + Qed. + + Lemma size_nat_le a b : (a <= b)%N -> (N.size_nat a <= N.size_nat b)%nat. + Proof. + rewrite !IterAssocOp.Nsize_nat_equiv. + rewrite <-le_to_nat. + apply size_le. + Qed. +End N. + Section SRepERepMul. Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence. Import Coq.NArith.NArith Coq.PArith.BinPosDef. @@ -229,6 +299,7 @@ Section SRepERepMul. ll . + Lemma SRepERepMul_correct n P : ExtendedCoordinates.Extended.eq (field:=GF25519Bounded.field25519) (EToRep (CompleteEdwardsCurve.E.mul (n mod (Z.to_nat l))%nat P)) @@ -268,6 +339,30 @@ Section SRepERepMul. vm_decide. vm_compute. reflexivity. } Qed. + + Definition NERepMul : N -> Erep -> Erep := fun x => + IterAssocOp.iter_op + (op:=ErepAdd) + (id:=ExtendedCoordinates.Extended.zero(field:=GF25519Bounded.field25519)(prm:=twedprm_ERep)) + (N.testbit_nat x) + (sel:=ERepSel) + ll + . + Lemma NERepMul_correct n P : + (N.size_nat (N.of_nat n) <= ll) -> + ExtendedCoordinates.Extended.eq (field:=GF25519Bounded.field25519) + (EToRep (CompleteEdwardsCurve.E.mul n P)) + (NERepMul (N.of_nat n) (EToRep P)). + Proof. + rewrite ScalarMult.scalarmult_ext. + unfold NERepMul. + etransitivity; [|symmetry; eapply iter_op_correct]. + 3: intros; reflexivity. + 2: intros; reflexivity. + { rewrite Nat2N.id. + apply (@Group.homomorphism_scalarmult _ _ _ _ _ _ _ _ _ _ _ _ EToRep Ahomom ScalarMult.scalarmult_ref _ ScalarMult.scalarmult_ref _ _ _). } + { assumption. } + Qed. End SRepERepMul. Lemma ZToN_NPow2_lt : forall z n, (0 <= z < 2 ^ Z.of_nat n)%Z -> @@ -281,10 +376,10 @@ Proof. replace (Z.of_nat 2) with 2%Z by reflexivity. omega. Qed. - + Lemma combine_ZNWord : forall sz1 sz2 z1 z2, - (0 <= Z.of_nat sz1)%Z -> - (0 <= Z.of_nat sz2)%Z -> + (0 <= Z.of_nat sz1)%Z -> + (0 <= Z.of_nat sz2)%Z -> (0 <= z1 < 2 ^ (Z.of_nat sz1))%Z -> (0 <= z2 < 2 ^ (Z.of_nat sz2))%Z -> Word.combine (ZNWord sz1 z1) (ZNWord sz2 z2) = @@ -307,24 +402,34 @@ Proof. f_equal. Qed. -Lemma nth_default_B_compat : forall i, +Lemma nth_default_freeze_input_bound_compat : forall i, + (nth_default 0 PseudoMersenneBaseParams.limb_widths i < + GF25519.freeze_input_bound)%Z. +Proof. + pose proof GF25519.freezePreconditions25519. + intros. + destruct (lt_dec i (length PseudoMersenneBaseParams.limb_widths)). + { apply ModularBaseSystemProofs.B_compat. + rewrite nth_default_eq. + auto using nth_In. } + { rewrite nth_default_out_of_bounds by omega. + cbv; congruence. } +Qed. +(* +Lemma nth_default_int_width_compat : forall i, (nth_default 0 PseudoMersenneBaseParams.limb_widths i < GF25519.int_width)%Z. Proof. - assert (@ModularBaseSystemProofs.FreezePreconditions - GF25519.modulus GF25519.params25519 - GF25519.int_width) by - (let A := fresh "H" in - pose proof GF25519.freezePreconditions25519 as A; - inversion A; econstructor; eauto). - intros. - destruct (lt_dec i (length PseudoMersenneBaseParams.limb_widths)). - { apply ModularBaseSystemProofs.B_compat. - rewrite nth_default_eq. - auto using nth_In. } - { rewrite nth_default_out_of_bounds by omega. - cbv; congruence. } + pose proof GF25519.freezePreconditions25519. + intros. + destruct (lt_dec i (length PseudoMersenneBaseParams.limb_widths)). + { apply ModularBaseSystemProofs.int_width_compat. + rewrite nth_default_eq. + auto using nth_In. } + { rewrite nth_default_out_of_bounds by omega. + cbv; congruence. } Qed. +*) Lemma minrep_freeze : forall x, Pow2Base.bounded @@ -344,12 +449,7 @@ Lemma minrep_freeze : forall x, (ModularBaseSystem.encode x))) = 0%Z. Proof. - assert (@ModularBaseSystemProofs.FreezePreconditions - GF25519.modulus GF25519.params25519 - GF25519.int_width) - by (let A := fresh "H" in - pose proof GF25519.freezePreconditions25519 as A; - inversion A; econstructor; eauto). + pose proof GF25519.freezePreconditions25519. intros. match goal with |- appcontext [ModularBaseSystem.freeze _ ?x] => @@ -363,15 +463,17 @@ Proof. eapply Z.lt_le_trans; [ solve [intuition eauto] | ]. match goal with |- appcontext [if ?a then _ else _] => destruct a end. { apply Z.pow_le_mono_r; try omega. - apply Z.lt_le_incl, nth_default_B_compat. } - { transitivity (2 ^ (Z.pred GF25519.int_width))%Z. + apply Z.lt_le_incl. + apply nth_default_freeze_input_bound_compat. } + { transitivity (2 ^ (Z.pred GF25519.freeze_input_bound))%Z. { apply Z.pow_le_mono; try omega. - apply Z.lt_le_pred. - apply nth_default_B_compat. } + apply Z.lt_le_pred. + apply nth_default_freeze_input_bound_compat. } { rewrite Z.shiftr_div_pow2 by (auto using Pow2BaseProofs.nth_default_limb_widths_nonneg, PseudoMersenneBaseParamProofs.limb_widths_nonneg). - rewrite <- Z.pow_sub_r by (try omega; split; auto using Pow2BaseProofs.nth_default_limb_widths_nonneg, PseudoMersenneBaseParamProofs.limb_widths_nonneg, Z.lt_le_incl, nth_default_B_compat). - replace (2 ^ GF25519.int_width)%Z - with (2 ^ (Z.pred GF25519.int_width + 1))%Z by (f_equal; omega). + rewrite <- Z.pow_sub_r by (try omega; split; auto using Pow2BaseProofs.nth_default_limb_widths_nonneg, PseudoMersenneBaseParamProofs.limb_widths_nonneg, Z.lt_le_incl, nth_default_freeze_input_bound_compat). + replace (2 ^ GF25519.freeze_input_bound)%Z + with (2 ^ (Z.pred GF25519.freeze_input_bound + 1))%Z + by (f_equal; omega). rewrite Z.pow_add_r by (omega || (cbv; congruence)). rewrite <-Zplus_diag_eq_mult_2. match goal with |- (?a <= ?a + ?b - ?c)%Z => @@ -390,7 +492,7 @@ Qed. Lemma convert_freezes: forall x, (ModularBaseSystemList.freeze GF25519.int_width (Tuple.to_list - (length PseudoMersenneBaseParams.limb_widths) x)) = + (length PseudoMersenneBaseParams.limb_widths) x)) = (Tuple.to_list (length PseudoMersenneBaseParams.limb_widths) @@ -402,6 +504,11 @@ Proof. rewrite Tuple.to_list_from_list. reflexivity. Qed. +Ltac to_MBSfreeze H := + rewrite GF25519.freeze_correct in H; + rewrite ModularBaseSystemOpt.freeze_opt_correct in H + by (rewrite ?Tuple.length_to_list; reflexivity); + erewrite convert_freezes, Tuple.from_list_default_eq, Tuple.from_list_to_list in H. Lemma bounded_freeze : forall x, Pow2Base.bounded @@ -452,14 +559,6 @@ Proof. apply Z.mul_lt_mono_pos_l; omega. } Qed. -Definition freezePre : - @ModularBaseSystemProofs.FreezePreconditions - GF25519.modulus GF25519.params25519 GF25519.int_width. -Proof. - pose proof GF25519.freezePreconditions25519 as A. - inversion A; econstructor; eauto. -Defined. - Lemma feEnc_correct : forall x, PointEncoding.Fencode x = feEnc (GF25519BoundedCommon.encode x). Proof. @@ -470,9 +569,9 @@ Proof. remember (GF25519.pack x) end. transitivity (ZNWord 255 (Pow2Base.decode_bitwise GF25519.wire_widths (Tuple.to_list 8 w))). { cbv [ZNWord]. - do 2 f_equal. + do 2 apply f_equal. subst w. - pose proof freezePre. + pose proof GF25519.freezePreconditions25519. match goal with |- appcontext [GF25519.freeze ?x ] => let A := fresh "H" in @@ -509,8 +608,8 @@ Proof. rewrite Tuple.to_list_from_list. apply Conversion.convert_bounded. } { destruct w; - repeat match goal with p : _ * Z |- _ => destruct p end. - simpl Tuple.to_list in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + cbv [Tuple.to_list Tuple.to_list'] in *. rewrite Pow2BaseProofs.bounded_iff in *. (* TODO : Is there a better way to do this? *) pose proof (H0 0). @@ -522,7 +621,7 @@ Proof. pose proof (H0 6). pose proof (H0 7). clear H0. - cbv [GF25519.wire_widths nth_default nth_error] in *. + cbv [GF25519.wire_widths nth_default nth_error value] in *. repeat rewrite combine_ZNWord by (rewrite ?Znat.Nat2Z.inj_add; simpl Z.of_nat; repeat apply lor_shiftl_bounds; omega). cbv - [ZNWord Z.lor Z.shiftl]. rewrite Z.shiftl_0_l. @@ -536,37 +635,109 @@ Lemma initial_bounds : forall x n, nth_default 0 (Tuple.to_list (length PseudoMersenneBaseParams.limb_widths) (GF25519BoundedCommon.proj1_fe25519 x)) n < - 2 ^ GF25519.int_width - - (if PeanoNat.Nat.eq_dec n 0 + 2 ^ GF25519.freeze_input_bound - + (if eq_nat_dec n 0%nat then 0 else - Z.shiftr (2 ^ GF25519.int_width) + Z.shiftr (2 ^ GF25519.freeze_input_bound) (nth_default 0 PseudoMersenneBaseParams.limb_widths - (Init.Nat.pred n))))%Z. + (pred n))))%Z. Proof. intros. cbv [GF25519BoundedCommon.fe25519] in *. - repeat match goal with p : _ * _ |- _ => destruct p end. + repeat match goal with p : (_ * _)%type |- _ => destruct p end. cbv [GF25519BoundedCommon.proj1_fe25519]. cbv [GF25519BoundedCommon.fe25519WToZ GF25519BoundedCommon.proj1_fe25519W PseudoMersenneBaseParams.limb_widths GF25519.params25519 length - Tuple.to_list Tuple.to_list'] in *. - (* TODO (jgross) : this should probably be Ltac'ed *) - assert (n = 0 \/ n = 1 \/ n = 2 \/ n = 3 \/ n = 4 \/ n = 5 \/ n = 6 \/ n = 7 \/ n = 8 \/ n = 9) by omega. - repeat match goal with H : _ \/ _ |- _ => destruct H end; - subst; cbv [nth_default nth_error pred]; - match goal with |- appcontext [if ?x then _ else _] => - destruct x end; try congruence; - cbv - [GF25519BoundedCommon.proj_word Z.le Z.lt] in *; - match goal with - |- appcontext [GF25519BoundedCommon.proj_word ?b] => - let A := fresh "H" in - pose proof (@GF25519BoundedCommon.word_bounded _ _ b) as A; - rewrite Bool.andb_true_iff in A; destruct A end; - rewrite !Z.leb_le in *; - omega. + Tuple.to_list Tuple.to_list' nth_default] in *. + repeat match goal with + | [ |- appcontext[nth_error _ ?n] ] + => is_var n; destruct n; simpl @nth_error; cbv beta iota + end; + simpl in *; unfold Z.pow_pos; simpl; try omega; + match goal with + |- appcontext [GF25519BoundedCommon.proj_word ?b] => + let A := fresh "H" in + pose proof (@GF25519BoundedCommon.word_bounded _ _ b) as A; + rewrite Bool.andb_true_iff in A; destruct A end; + rewrite !Z.leb_le in *; + omega. +Qed. + +Lemma feSign_correct : forall x, + PointEncoding.sign x = feSign (GF25519BoundedCommon.encode x). +Proof. + cbv [PointEncoding.sign feSign]. + intros. + rewrite GF25519Bounded.freeze_correct. + rewrite GF25519BoundedCommon.proj1_fe25519_encode. + match goal with |- appcontext [GF25519.freeze ?x] => + remember (GF25519.freeze x) end. + transitivity (Z.testbit (nth_default 0%Z (Tuple.to_list 10 f) 0) 0). + Focus 2. { + cbv [GF25519.fe25519] in *. + repeat match goal with p : (_ * _)%type |- _ => destruct p end. + simpl. reflexivity. } Unfocus. + + rewrite !Z.bit0_odd. + rewrite <-@Pow2BaseProofs.parity_decode with (limb_widths := PseudoMersenneBaseParams.limb_widths) by (auto using PseudoMersenneBaseParamProofs.limb_widths_nonneg, Tuple.length_to_list; cbv; congruence). + pose proof GF25519.freezePreconditions25519. + match goal with H : _ = GF25519.freeze ?u |- _ => + let A := fresh "H" in let B := fresh "H" in + pose proof (ModularBaseSystemProofs.freeze_rep u x) as A; + match type of A with ?P -> _ => assert P as B by apply ModularBaseSystemProofs.encode_rep end; + specialize (A B); clear B + end. + to_MBSfreeze Heqf. + rewrite <-Heqf in *. + cbv [ModularBaseSystem.rep ModularBaseSystem.decode ModularBaseSystemList.decode] in *. + rewrite <-H1. + rewrite ModularArithmeticTheorems.F.to_Z_of_Z. + rewrite Z.mod_small; [ reflexivity | ]. + pose proof (minrep_freeze x). + apply ModularBaseSystemListProofs.ge_modulus_spec; + try solve [inversion H0; auto using Tuple.length_to_list]; + subst f; intuition auto. + Grab Existential Variables. + apply Tuple.length_to_list. +Qed. + + +Local Instance Proper_feSign : Proper (GF25519BoundedCommon.eq ==> eq) feSign. +Proof. + repeat intro; cbv [feSign]. + rewrite !GF25519Bounded.freeze_correct. + repeat match goal with |- appcontext[GF25519.freeze ?x] => + remember (GF25519.freeze x) end. + assert (Tuple.fieldwise (n := 10) eq f f0). + { pose proof GF25519.freezePreconditions25519. + match goal with H1 : _ = GF25519.freeze ?u, + H2 : _ = GF25519.freeze ?v |- _ => + let A := fresh "H" in + let HP := fresh "H" in + let HQ := fresh "H" in + pose proof (ModularBaseSystemProofs.freeze_canonical + (freeze_pre := GF25519.freezePreconditions25519) u v _ _ eq_refl eq_refl); + match type of A with ?P -> ?Q -> _ => + assert P as HP by apply initial_bounds; + assert Q as HQ by apply initial_bounds end; + specialize (A HP HQ); clear HP HQ end. + cbv [ModularBaseSystem.eq] in *. + to_MBSfreeze Heqf0. + to_MBSfreeze Heqf. + subst. + apply H2. + cbv [GF25519BoundedCommon.eq ModularBaseSystem.eq] in *. + auto. } + { cbv [GF25519.fe25519 ] in *. + repeat match goal with p : (_ * _)%type |- _ => destruct p end. + cbv [Tuple.fieldwise Tuple.fieldwise' fst snd] in *. + intuition congruence. } + Grab Existential Variables. + rewrite Tuple.length_to_list; reflexivity. + rewrite Tuple.length_to_list; reflexivity. Qed. Lemma Proper_pack : @@ -584,10 +755,10 @@ Proof. cbv [GF25519.wire_widths length Tuple.fieldwise Tuple.fieldwise' fst snd] in *; intuition subst; reflexivity. Qed. - + Lemma Proper_feEnc : Proper (GF25519BoundedCommon.eq ==> eq) feEnc. Proof. - pose proof freezePre. + pose proof GF25519.freezePreconditions25519. repeat intro; cbv [feEnc]. rewrite !GF25519Bounded.pack_correct, !GF25519Bounded.freeze_correct. rewrite !GF25519.freeze_correct, !ModularBaseSystemOpt.freeze_opt_correct @@ -599,7 +770,8 @@ Proof. let HP := fresh "H" in let HQ := fresh "H" in pose proof (ModularBaseSystemProofs.freeze_canonical - (freeze_pre := freezePre) x y (ModularBaseSystem.decode x) + (freeze_pre := GF25519.freezePreconditions25519) + x y (ModularBaseSystem.decode x) (ModularBaseSystem.decode y) eq_refl eq_refl); match type of A with ?P -> ?Q -> _ => assert P as HP by apply initial_bounds; @@ -616,7 +788,7 @@ Proof. apply Proper_pack. assumption. } { cbv [length GF25519.wire_digits] in *. - repeat match goal with p : _ * _ |- _ => destruct p end. + repeat match goal with p : (_ * _)%type |- _ => destruct p end. cbv [GF25519.wire_widths length Tuple.fieldwise Tuple.fieldwise' fst snd] in *. repeat match goal with H : _ /\ _ |- _ => destruct H end; subst; reflexivity. } @@ -659,13 +831,13 @@ Proof. tauto. Qed. -Let SRepEnc : SRep -> Word.word b := (fun x => Word.NToWord _ (Z.to_N x)). +Definition SRepEnc : SRep -> Word.word b := (fun x => Word.NToWord _ (Z.to_N x)). Local Instance Proper_SRepERepMul : Proper (SC25519.SRepEq ==> ExtendedCoordinates.Extended.eq (field:=GF25519Bounded.field25519) ==> ExtendedCoordinates.Extended.eq (field:=GF25519Bounded.field25519)) SRepERepMul. unfold SRepERepMul, SC25519.SRepEq. repeat intro. eapply IterAssocOp.Proper_iter_op. - { eapply ExtendedCoordinates.Extended.Proper_add. } + { eapply @ExtendedCoordinates.Extended.Proper_add. } { reflexivity. } { repeat intro; subst; reflexivity. } { unfold ERepSel; repeat intro; break_match; solve [ discriminate | eauto ]. } @@ -677,17 +849,40 @@ Lemma SRepEnc_correct : forall x : ModularArithmetic.F.F l, Senc x = SRepEnc (S2 unfold SRepEnc, Senc, Fencode; intros; f_equal. Qed. -(** TODO: How do we speed up vm_compute here? I think it's spending most of it's time rechecking boundedness... *) -Let ERepB : Erep. - let rB := (eval vm_compute in (proj1_sig (EToRep B))) in - exists rB. cbv [GF25519BoundedCommon.eq ModularBaseSystem.eq Pre.onCurve]. vm_decide_no_check. -Defined. - -Let ERepB_correct : ExtendedCoordinates.Extended.eq (field:=GF25519Bounded.field25519) ERepB (EToRep B). - vm_decide. -Qed. +Section ConstantPoints. + Import GF25519BoundedCommon. + Let proj1_sig_ERepB' := Eval vm_compute in proj1_sig (EToRep B). + Let tmap4 := Eval compute in @Tuple.map 4. Arguments tmap4 {_ _} _ _. + Let proj1_sig_ERepB := Eval cbv [tmap4 proj1_sig_ERepB' fe25519_word64ize word64ize andb opt.word64ToZ opt.word64ize opt.Zleb Z.compare CompOpp Pos.compare Pos.compare_cont] in (tmap4 fe25519_word64ize proj1_sig_ERepB'). + Let proj1_sig_ERepB_correct : proj1_sig_ERepB = proj1_sig (EToRep B). + Proof. vm_cast_no_check (eq_refl proj1_sig_ERepB). Qed. + + Definition ERepB : Erep. + exists (eta4 proj1_sig_ERepB). + cbv [GF25519BoundedCommon.eq ModularBaseSystem.eq Pre.onCurve]. + vm_decide_no_check. + Defined. + + Lemma ERepB_correct : ExtendedCoordinates.Extended.eq (field:=GF25519Bounded.field25519) ERepB (EToRep B). + generalize proj1_sig_ERepB_correct as H; destruct (EToRep B) as [B ?] in |- *. + cbv [proj1_sig] in |- *. intro. subst B. + vm_decide. + Qed. +End ConstantPoints. -Let sign := @EdDSARepChange.sign E +Lemma B_order_l : CompleteEdwardsCurveTheorems.E.eq + (CompleteEdwardsCurve.E.mul (Z.to_nat l) B) + CompleteEdwardsCurve.E.zero. +Proof. + apply ERep_eq_E. + rewrite NERepMul_correct; rewrite (Z_nat_N l). + 2:vm_decide. + apply dec_bool. + vm_cast_no_check (eq_refl true). +(* Time Qed. (* Finished transaction in 1646.167 secs (1645.753u,0.339s) (successful) *) *) +Admitted. + +Definition sign := @EdDSARepChange.sign E (@CompleteEdwardsCurveTheorems.E.eq Fq (@eq Fq) (@ModularArithmetic.F.one q) (@ModularArithmetic.F.add q) (@ModularArithmetic.F.mul q) Spec.Ed25519.a Spec.Ed25519.d) (@CompleteEdwardsCurve.E.add Fq (@eq Fq) (ModularArithmetic.F.of_Z q 0) (@ModularArithmetic.F.one q) @@ -709,7 +904,7 @@ Let sign := @EdDSARepChange.sign E (@ModularArithmetic.F.opp q) (@ModularArithmetic.F.add q) (@ModularArithmetic.F.sub q) (@ModularArithmetic.F.mul q) (@ModularArithmetic.F.inv q) (@ModularArithmetic.F.div q) (@PrimeFieldTheorems.F.field_modulo q prime_q) (@ModularArithmeticTheorems.F.eq_dec q) Spec.Ed25519.a - Spec.Ed25519.d curve_params) b H c n l B Eenc Senc (@ed25519 H) Erep ERepEnc SRep SC25519.SRepDecModL + Spec.Ed25519.d curve_params) b H c n l B Eenc Senc (@ed25519 H B_order_l ) Erep ERepEnc SRep SC25519.SRepDecModL SRepERepMul SRepEnc SC25519.SRepAdd SC25519.SRepMul ERepB SC25519.SRepDecModLShort. Let sign_correct : forall pk sk {mlen} (msg:Word.word mlen), sign pk sk _ msg = EdDSA.sign pk sk msg := @@ -728,7 +923,7 @@ Let sign_correct : forall pk sk {mlen} (msg:Word.word mlen), sign pk sk _ msg = (* B := *) B (* Eenc := *) Eenc (* Senc := *) Senc - (* prm := *) ed25519 + (* prm := *) (ed25519 B_order_l) (* Erep := *) Erep (* ErepEq := *) ExtendedCoordinates.Extended.eq (* ErepAdd := *) ErepAdd @@ -783,7 +978,7 @@ Proof. apply bound_check_255_helper; vm_compute; intuition congruence. Qed. -Let Edec := (@PointEncodingPre.point_dec +Definition Edec := (@PointEncodingPre.point_dec _ eq ModularArithmetic.F.zero ModularArithmetic.F.one @@ -797,10 +992,12 @@ Let Edec := (@PointEncodingPre.point_dec Spec.Ed25519.d _ Fsqrt - (PointEncoding.Fencoding (bound_check := bound_check255)) + (PointEncoding.Fencoding + (two_lt_m := GF25519.modulus_gt_2) + (bound_check := bound_check255)) Spec.Ed25519.sign). -Let Sdec : Word.word b -> option (ModularArithmetic.F.F l) := +Definition Sdec : Word.word b -> option (ModularArithmetic.F.F l) := fun w => let z := (BinIntDef.Z.of_N (Word.wordToN w)) in if ZArith_dec.Z_lt_dec z l @@ -825,11 +1022,11 @@ Proof. | |- _ => rewrite ModularArithmeticTheorems.F.of_Z_to_Z in * | |- _ => rewrite @ModularArithmeticTheorems.F.to_Z_of_Z in * | |- _ => reflexivity - | |- _ => omega + | |- _ => omega end. Qed. -Let SRepDec : Word.word b -> option SRep := fun w => option_map ModularArithmetic.F.to_Z (Sdec w). +Definition SRepDec : Word.word b -> option SRep := fun w => option_map ModularArithmetic.F.to_Z (Sdec w). Lemma SRepDec_correct : forall w : Word.word b, @Option.option_eq SRep SC25519.SRepEq @@ -839,7 +1036,7 @@ Proof. unfold SRepDec, S2Rep, SC25519.S2Rep; intros; reflexivity. Qed. -Let ERepDec := +Definition ERepDec := (@PointEncoding.Kdecode_point _ GF25519BoundedCommon.fe25519 @@ -859,15 +1056,384 @@ Let ERepDec := feDec GF25519Bounded.sqrt ). -Axiom ERepDec_correct : forall w : Word.word b, ERepDec w = @option_map E Erep EToRep (Edec w). +Lemma extended_to_coord_from_twisted: forall pt, + Tuple.fieldwise (n := 2) GF25519BoundedCommon.eq + (extended_to_coord (ExtendedCoordinates.Extended.from_twisted pt)) + (CompleteEdwardsCurve.E.coordinates pt). +Proof. + intros; cbv [extended_to_coord]. + rewrite ExtendedCoordinates.Extended.to_twisted_from_twisted. + reflexivity. +Qed. + +Local Instance Proper_sqrt : + Proper (GF25519BoundedCommon.eq ==> GF25519BoundedCommon.eq) GF25519Bounded.sqrt. +Admitted. + +Lemma WordNZ_split1 : forall {n m} w, + Z.of_N (Word.wordToN (Word.split1 n m w)) = ZUtil.Z.pow2_mod (Z.of_N (Word.wordToN w)) n. +Admitted. + +Lemma WordNZ_split2 : forall {n m} w, + Z.of_N (Word.wordToN (Word.split2 n m w)) = Z.shiftr (Z.of_N (Word.wordToN w)) n. +Admitted. + +Lemma WordNZ_range : forall {n} B w, + (2 ^ Z.of_nat n <= B)%Z -> + (0 <= Z.of_N (@Word.wordToN n w) < B)%Z. +Admitted. + +Lemma WordNZ_range_mono : forall {n} m w, + (Z.of_nat n <= m)%Z -> + (0 <= Z.of_N (@Word.wordToN n w) < 2 ^ m)%Z. +Admitted. + +(* TODO : move to ZUtil *) +Lemma pow2_mod_range : forall a n m, + (n <= m)%Z -> + (0 <= ZUtil.Z.pow2_mod a n < 2 ^ m)%Z. +Admitted. + +(* TODO : move to ZUtil *) +Lemma shiftr_range : forall a n m, + (0 <= a < 2 ^ (n + m))%Z -> + (0 <= Z.shiftr a n < 2 ^ m)%Z. +Admitted. + +Lemma feDec_correct : forall w : Word.word (pred b), + option_eq GF25519BoundedCommon.eq + (option_map GF25519BoundedCommon.encode + (PointEncoding.Fdecode w)) (feDec w). +Proof. + intros; cbv [PointEncoding.Fdecode feDec]. + Print GF25519BoundedCommon.eq. + rewrite <-GF25519BoundedCommon.word64eqb_Zeqb. + rewrite GF25519Bounded.ge_modulus_correct. + rewrite GF25519BoundedCommon.word64ToZ_ZToWord64 by + (rewrite GF25519BoundedCommon.unfold_Pow2_64; + cbv [GF25519BoundedCommon.Pow2_64]; omega). + rewrite GF25519.ge_modulus_correct. + rewrite ModularBaseSystemOpt.ge_modulus_opt_correct. + match goal with + |- appcontext [GF25519Bounded.unpack ?x] => + assert ((Z.of_N (Word.wordToN w)) = BaseSystem.decode (Pow2Base.base_from_limb_widths PseudoMersenneBaseParams.limb_widths) (Tuple.to_list 10 (GF25519BoundedCommon.proj1_fe25519 (GF25519Bounded.unpack x)))) end. + { + rewrite GF25519Bounded.unpack_correct. + rewrite GF25519.unpack_correct, ModularBaseSystemOpt.unpack_correct. + + cbv [GF25519BoundedCommon.proj1_wire_digits + GF25519BoundedCommon.wire_digitsWToZ + GF25519BoundedCommon.proj1_wire_digitsW + GF25519BoundedCommon.app_wire_digits + HList.mapt HList.mapt' + length GF25519.wire_widths + fst snd + ]. + + cbv [GF25519BoundedCommon.proj_word + GF25519BoundedCommon.word31_to_unbounded_word + GF25519BoundedCommon.word32_to_unbounded_word + GF25519BoundedCommon.word_to_unbounded_word + GF25519BoundedCommon.Build_bounded_word + GF25519BoundedCommon.Build_bounded_word' + ]. + rewrite !GF25519BoundedCommon.word64ToZ_ZToWord64 by + (rewrite GF25519BoundedCommon.unfold_Pow2_64; + cbv [GF25519BoundedCommon.Pow2_64]; + apply WordNZ_range; cbv; congruence). + rewrite !WordNZ_split1. + rewrite !WordNZ_split2. + simpl Z.of_nat. + cbv [ModularBaseSystem.eq]. + match goal with + |- appcontext [@ModularBaseSystem.unpack _ _ ?ls _ _ ?t] => + assert (Pow2Base.bounded ls (Tuple.to_list (length ls) t)) end. + { cbv [Pow2Base.bounded length]. + intros. + destruct (lt_dec i 8). + { cbv [Tuple.to_list Tuple.to_list' fst snd]. + assert (i = 0 \/ i = 1 \/ i = 2 \/ i = 3 \/ i = 4 \/ i = 5 \/ i = 6 \/ i = 7) by omega. + repeat match goal with H : (_ \/ _)%type |- _ => destruct H; subst end; + cbv [nth_default nth_error value]; try (apply pow2_mod_range; omega). + repeat apply shiftr_range; apply WordNZ_range_mono; cbv; + congruence. } + { rewrite !nth_default_out_of_bounds + by (rewrite ?Tuple.length_to_list; cbv [length]; omega). + rewrite Z.pow_0_r. omega. } } + cbv [ModularBaseSystem.unpack ModularBaseSystemList.unpack]. + rewrite Tuple.to_list_from_list. + rewrite <-Conversion.convert_correct by (auto || rewrite Tuple.to_list; reflexivity). + rewrite <-Pow2BaseProofs.decode_bitwise_spec by (auto || cbv [In]; intuition omega). + cbv [Tuple.to_list Tuple.to_list' length fst snd Pow2Base.decode_bitwise Pow2Base.decode_bitwise' nth_default nth_error ]. + clear. + apply Z.bits_inj'. + intros. + rewrite Z.shiftl_0_l. + rewrite Z.lor_0_r. + repeat match goal with |- appcontext[@Word.wordToN (?x + ?y) w] => + change (@Word.wordToN (x + y) w) with (@Word.wordToN (pred b) w) end. + assert ( + 0 <= n < 32 \/ + 32 <= n < 64 \/ + 64 <= n < 96 \/ + 96 <= n < 128 \/ + 128 <= n < 160 \/ + 160 <= n < 192 \/ + 192 <= n < 224 \/ + 224 <= n < 256 \/ + 256 <= n)%Z by omega. + repeat match goal with H : (_ \/ _)%type |- _ => destruct H; subst end; + repeat match goal with + | |- _ => rewrite Z.lor_spec + | |- _ => rewrite Z.shiftl_spec by omega + | |- _ => rewrite Z.shiftr_spec by omega + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => rewrite ZUtil.Z.testbit_pow2_mod by omega; + VerdiTactics.break_if; try omega + end; + repeat match goal with + | |- _ = (false || _)%bool => rewrite Bool.orb_false_l + | |- ?x = (?x || ?y)%bool => replace y with false; + [ rewrite Bool.orb_false_r; reflexivity | ] + | |- false = (?x || ?y)%bool => replace y with false; + [ rewrite Bool.orb_false_r; + replace x with false; [ reflexivity | ] + | ] + | |- false = Z.testbit _ _ => + rewrite Z.testbit_neg_r by omega; reflexivity + | |- Z.testbit ?w ?n = Z.testbit ?w ?m => + replace m with n by omega; reflexivity + | |- Z.testbit ?w ?n = (Z.testbit ?w ?m || _)%bool => + replace m with n by omega + end; + admit. (* TODO(jadep): there are goal left here on 8.4 *) + } + match goal with + |- option_eq _ (option_map _ (if Z_lt_dec ?a ?b then Some _ else None)) (if (?X =? 1)%Z then None else Some _) => + assert ((a < b)%Z <-> X = 0%Z) end. + { + rewrite ModularBaseSystemListProofs.ge_modulus_spec; + [ | cbv; congruence | rewrite Tuple.length_to_list; reflexivity | ]. + Focus 2. { + rewrite GF25519Bounded.unpack_correct. + rewrite GF25519.unpack_correct, ModularBaseSystemOpt.unpack_correct. + cbv [ModularBaseSystem.unpack]. + rewrite Tuple.to_list_from_list. + cbv [ModularBaseSystemList.unpack]. + apply Conversion.convert_bounded. + } Unfocus. + rewrite <-H0. + intuition; try omega. + apply Znat.N2Z.is_nonneg. + } + + do 2 VerdiTactics.break_if; + [ + match goal with H: ?P, Hiff : ?P <-> ?x = 0%Z |- _ => + let A := fresh "H" in + pose proof ((proj1 Hiff) H) as A; + rewrite A in *; discriminate + end + | | reflexivity | + match goal with + H: ~ ?P, Hiff : ?P <-> ModularBaseSystemList.ge_modulus ?x = 0%Z + |- _ => + exfalso; apply H; apply Hiff; + destruct (ModularBaseSystemListProofs.ge_modulus_01 x) as [Hgm | Hgm]; + rewrite Hgm in *; try discriminate; reflexivity + end ]. + + cbv [option_map option_eq]. + cbv [GF25519BoundedCommon.eq]. + rewrite GF25519BoundedCommon.proj1_fe25519_encode. + cbv [ModularBaseSystem.eq]. + etransitivity. + Focus 2. { + cbv [ModularBaseSystem.decode ModularBaseSystemList.decode]. + cbv [length PseudoMersenneBaseParams.limb_widths GF25519.params25519] in H0 |- *. + rewrite <-H0. + reflexivity. } Unfocus. + apply ModularBaseSystemProofs.encode_rep. + +Qed. + +Lemma Fsqrt_minus1_correct : + ModularArithmetic.F.mul Fsqrt_minus1 Fsqrt_minus1 = + ModularArithmetic.F.opp + (ModularArithmetic.F.of_Z GF25519.modulus 1). +Proof. + replace (Fsqrt_minus1) with (ModularBaseSystem.decode (GF25519.sqrt_m1)) by reflexivity. + rewrite <-ModularBaseSystemProofs.carry_mul_rep by reflexivity. + rewrite <-ModularBaseSystemOpt.carry_mul_opt_correct + with (k_ := GF25519.k_) (c_ := GF25519.c_) by reflexivity. + rewrite <-GF25519.mul_correct. + apply GF25519.sqrt_m1_correct. +Qed. + +Section bounded_by_from_is_bounded. + Local Arguments Z.sub !_ !_. + Local Arguments Z.pow_pos !_ !_ / . + Lemma bounded_by_from_is_bounded + : forall x, GF25519BoundedCommon.is_bounded x = true + -> ModularBaseSystemProofs.bounded_by + x + (ModularBaseSystemProofs.freeze_input_bounds (B := GF25519.freeze_input_bound)). + Proof. + intros x H. + pose proof (GF25519BoundedCommon.is_bounded_to_nth_default _ H) as H'; clear H. + unfold ModularBaseSystemProofs.bounded_by. + intros n pf; specialize (H' n pf). + match goal with + | [ H : (0 <= ?y <= _)%Z |- (0 <= ?x < _)%Z ] + => change y with x in H; generalize dependent x + end. + intros ? H'. + split; [ omega | ]. + eapply Z.le_lt_trans; [ exact (proj2 H') | ]. + unfold ModularBaseSystemProofs.freeze_input_bounds, nth_default, GF25519.freeze_input_bound; simpl in *. + repeat match goal with + | [ |- context[nth_error _ ?n] ] + => is_var n; destruct n; simpl + end; + try (vm_compute; reflexivity); + try omega. + Qed. +End bounded_by_from_is_bounded. + +Lemma bounded_by_encode_freeze : forall x, + ModularBaseSystemProofs.bounded_by + (ModularBaseSystem.encode x) + (ModularBaseSystemProofs.freeze_input_bounds (B := GF25519.freeze_input_bound)). +Proof. + intros; apply bounded_by_from_is_bounded, GF25519BoundedCommon.encode_bounded. +Qed. + +Lemma bounded_by_freeze : forall x, + ModularBaseSystemProofs.bounded_by + (GF25519BoundedCommon.fe25519WToZ (GF25519BoundedCommon.proj1_fe25519W x)) + (ModularBaseSystemProofs.freeze_input_bounds (B := GF25519.freeze_input_bound)). +Proof. + intros; apply bounded_by_from_is_bounded, GF25519BoundedCommon.is_bounded_proj1_fe25519. +Qed. + +Local Ltac prove_bounded_by := + repeat match goal with + | [ |- ModularBaseSystemProofs.bounded_by _ _ ] + => apply bounded_by_from_is_bounded + | [ |- GF25519BoundedCommon.is_bounded + (GF25519BoundedCommon.fe25519WToZ + (GF25519Bounded.mulW _ _)) = true ] + => apply GF25519Bounded.mulW_correct_and_bounded + | [ |- GF25519BoundedCommon.is_bounded + (GF25519BoundedCommon.fe25519WToZ + (GF25519Bounded.powW _ _)) = true ] + => apply GF25519Bounded.powW_correct_and_bounded + | [ |- context[GF25519BoundedCommon.fe25519WToZ (GF25519BoundedCommon.fe25519ZToW _)] ] + => rewrite GF25519BoundedCommon.fe25519WToZ_ZToW + | [ |- GF25519BoundedCommon.is_bounded (ModularBaseSystem.encode _) = true ] + => apply GF25519BoundedCommon.encode_bounded + end. + +Lemma sqrt_correct : forall x : ModularArithmetic.F.F q, + GF25519BoundedCommon.eq + (GF25519BoundedCommon.encode + (PrimeFieldTheorems.F.sqrt_5mod8 Fsqrt_minus1 x)) + (GF25519Bounded.sqrt (GF25519BoundedCommon.encode x)). +Proof. + intros. + cbv [GF25519BoundedCommon.eq]. + rewrite GF25519Bounded.sqrt_correct. + cbv [GF25519Bounded.GF25519sqrt]. + cbv [LetIn.Let_In]. + repeat match goal with (* needed on Coq 8.4, should be the only default everywhere *) + |- context[GF25519BoundedCommon.proj1_fe25519 (GF25519BoundedCommon.encode ?x)] => + rewrite (GF25519BoundedCommon.proj1_fe25519_encode x) + end. + rewrite GF25519.sqrt_correct, ModularBaseSystemOpt.sqrt_5mod8_opt_correct by reflexivity. + cbv [ModularBaseSystem.eq]. + rewrite ModularBaseSystemProofs.encode_rep. + symmetry. + eapply @ModularBaseSystemProofs.sqrt_5mod8_correct; + eauto using GF25519.freezePreconditions25519, ModularBaseSystemProofs.encode_rep, bounded_by_freeze, bounded_by_encode_freeze; + prove_bounded_by; + match goal with + | |- appcontext[GF25519Bounded.powW ?a ?ch] => + let A := fresh "H" in + destruct (GF25519Bounded.powW_correct_and_bounded ch a) as [A ?]; + [ rewrite GF25519BoundedCommon.fe25519WToZ_ZToW; + rewrite <-GF25519BoundedCommon.proj1_fe25519_encode; + apply GF25519BoundedCommon.is_bounded_proj1_fe25519 + | rewrite A; + rewrite GF25519.pow_correct, ModularBaseSystemOpt.pow_opt_correct + by reflexivity] + end;[ solve [f_equiv; apply GF25519BoundedCommon.fe25519WToZ_ZToW; + rewrite <-GF25519BoundedCommon.proj1_fe25519_encode; + apply GF25519BoundedCommon.is_bounded_proj1_fe25519] | ]. + match goal with + | |- appcontext[GF25519Bounded.mulW ?a ?b] => + let A := fresh "H" in + destruct (GF25519Bounded.mulW_correct_and_bounded a b) as [A ?]; + [ auto | auto | rewrite A] + end. + rewrite GF25519.mul_correct, ModularBaseSystemOpt.carry_mul_opt_correct by reflexivity. + rewrite !H0. + rewrite GF25519.pow_correct. + cbv [ModularBaseSystem.eq]. + rewrite ModularBaseSystemProofs.carry_mul_rep by reflexivity. + rewrite ModularBaseSystemProofs.mul_rep by reflexivity. + apply f_equal2; + rewrite ModularBaseSystemOpt.pow_opt_correct; reflexivity. +Qed. + +Lemma ERepDec_correct : forall w : Word.word b, + option_eq ExtendedCoordinates.Extended.eq (ERepDec w) (@option_map E Erep EToRep (Edec w)). +Proof. + exact (@PointEncoding.Kdecode_point_correct + (pred b) _ Spec.Ed25519.a Spec.Ed25519.d _ + GF25519.modulus_gt_2 bound_check255 + _ _ _ _ _ _ _ _ _ _ GF25519Bounded.field25519 + _ _ _ _ _ phi_a phi_d feSign feSign_correct _ + (ExtendedCoordinates.Extended.from_twisted + (field := GF25519Bounded.field25519) + (prm := twedprm_ERep)) + extended_to_coord + extended_to_coord_from_twisted + _ ext_eq_correct _ _ encode_eq_iff + feDec GF25519Bounded.sqrt _ _ feDec_correct + (@PrimeFieldTheorems.F.sqrt_5mod8 _ Fsqrt_minus1) + sqrt_correct + ). +Qed. -Axiom eq_enc_E_iff : forall (P_ : Word.word b) (P : E), +Lemma eq_enc_E_iff : forall (P_ : Word.word b) (P : E), Eenc P = P_ <-> Option.option_eq CompleteEdwardsCurveTheorems.E.eq (Edec P_) (Some P). +Proof. + cbv [Eenc]. + eapply (@PointEncoding.encode_point_decode_point_iff (b-1)); try (exact iff_equivalence || exact curve_params); []. + intros. + apply (@PrimeFieldTheorems.F.sqrt_5mod8_correct GF25519.modulus _ eq_refl Fsqrt_minus1 Fsqrt_minus1_correct). + eexists. + symmetry; eassumption. +Qed. + +Definition verify := @verify E b H B Erep ErepAdd + (@ExtendedCoordinates.Extended.opp GF25519BoundedCommon.fe25519 + GF25519BoundedCommon.eq GF25519BoundedCommon.zero + GF25519BoundedCommon.one GF25519Bounded.opp GF25519Bounded.add + GF25519Bounded.sub GF25519Bounded.mul GF25519Bounded.inv + GF25519BoundedCommon.div a d GF25519Bounded.field25519 twedprm_ERep + (fun x y : GF25519BoundedCommon.fe25519 => + @ModularArithmeticTheorems.F.eq_dec GF25519.modulus + (@ModularBaseSystem.decode GF25519.modulus GF25519.params25519 + (GF25519BoundedCommon.proj1_fe25519 x)) + (@ModularBaseSystem.decode GF25519.modulus GF25519.params25519 + (GF25519BoundedCommon.proj1_fe25519 y)))) EToRep ERepEnc ERepDec + SRep SC25519.SRepDecModL SRepERepMul SRepDec. Let verify_correct : forall {mlen : nat} (msg : Word.word mlen) (pk : Word.word b) - (sig : Word.word (b + b)), verify msg pk sig = true <-> EdDSA.valid msg pk sig := + (sig : Word.word (b + b)), verify _ msg pk sig = true <-> EdDSA.valid msg pk sig := @verify_correct (* E := *) E (* Eeq := *) CompleteEdwardsCurveTheorems.E.eq @@ -883,8 +1449,8 @@ Let verify_correct : (* B := *) B (* Eenc := *) Eenc (* Senc := *) Senc - (* prm := *) ed25519 - (* Proper_Eenc := *) PointEncoding.Proper_encode_point + (* prm := *) (ed25519 B_order_l) + (* Proper_Eenc := *) (PointEncoding.Proper_encode_point (b:=b-1)) (* Edec := *) Edec (* eq_enc_E_iff := *) eq_enc_E_iff (* Sdec := *) Sdec @@ -914,327 +1480,60 @@ Let verify_correct : (* SRepDec := *) SRepDec (* SRepDec_correct := *) SRepDec_correct . -Let both_correct := (@sign_correct, @verify_correct). -Print Assumptions both_correct. - - - - -(*** Extraction *) - - - - -Extraction Language Haskell. -Unset Extraction KeepSingleton. -Set Extraction AutoInline. -Set Extraction Optimize. -Unset Extraction AccessOpaque. - -(** Eq *) - -Extraction Implicit eq_rect [ x y ]. -Extraction Implicit eq_rect_r [ x y ]. -Extraction Implicit eq_rec [ x y ]. -Extraction Implicit eq_rec_r [ x y ]. - -Extract Inlined Constant eq_rect => "". -Extract Inlined Constant eq_rect_r => "". -Extract Inlined Constant eq_rec => "". -Extract Inlined Constant eq_rec_r => "". - -(** Ord *) - -Extract Inductive comparison => - "Prelude.Ordering" ["Prelude.EQ" "Prelude.LT" "Prelude.GT"]. - -(** Bool, sumbool, Decidable *) - -Extract Inductive bool => "Prelude.Bool" ["Prelude.True" "Prelude.False"]. -Extract Inductive sumbool => "Prelude.Bool" ["Prelude.True" "Prelude.False"]. -Extract Inductive Bool.reflect => "Prelude.Bool" ["Prelude.True" "Prelude.False"]. -Extract Inlined Constant Bool.iff_reflect => "". -Extraction Inline Crypto.Util.Decidable.Decidable Crypto.Util.Decidable.dec. - -(* Extract Inlined Constant Equality.bool_beq => *) -(* "((Prelude.==) :: Prelude.Bool -> Prelude.Bool -> Prelude.Bool)". *) -Extract Inlined Constant Bool.bool_dec => - "((Prelude.==) :: Prelude.Bool -> Prelude.Bool -> Prelude.Bool)". -Extract Inlined Constant Sumbool.sumbool_of_bool => "". - -Extract Inlined Constant negb => "Prelude.not". -Extract Inlined Constant orb => "(Prelude.||)". -Extract Inlined Constant andb => "(Prelude.&&)". -Extract Inlined Constant xorb => "Data.Bits.xor". - -(** Comparisons *) - -Extract Inductive comparison => "Prelude.Ordering" [ "Prelude.EQ" "Prelude.LT" "Prelude.GT" ]. -Extract Inductive CompareSpecT => "Prelude.Ordering" [ "Prelude.EQ" "Prelude.LT" "Prelude.GT" ]. - -(** Maybe *) - -Extract Inductive option => "Prelude.Maybe" ["Prelude.Just" "Prelude.Nothing"]. -Extract Inductive sumor => "Prelude.Maybe" ["Prelude.Just" "Prelude.Nothing"]. - -(** Either *) - -Extract Inductive sum => "Prelude.Either" ["Prelude.Left" "Prelude.Right"]. - -(** List *) - -Extract Inductive list => "[]" ["[]" "(:)"]. - -Extract Inlined Constant app => "(Prelude.++)". -Extract Inlined Constant List.map => "Prelude.map". -Extract Constant List.fold_left => "\f l z -> Data.List.foldl f z l". -Extract Inlined Constant List.fold_right => "Data.List.foldr". -Extract Inlined Constant List.find => "Data.List.find". -Extract Inlined Constant List.length => "Data.List.genericLength". - -(** Tuple *) - -Extract Inductive prod => "(,)" ["(,)"]. -Extract Inductive sigT => "(,)" ["(,)"]. - -Extract Inlined Constant fst => "Prelude.fst". -Extract Inlined Constant snd => "Prelude.snd". -Extract Inlined Constant projT1 => "Prelude.fst". -Extract Inlined Constant projT2 => "Prelude.snd". - -Extract Inlined Constant proj1_sig => "". - -(** Unit *) - -Extract Inductive unit => "()" ["()"]. +Lemma Fhomom_inv_zero : + GF25519BoundedCommon.eq + (GF25519BoundedCommon.encode + (@ModularArithmetic.F.inv GF25519.modulus + (ModularArithmetic.F.of_Z GF25519.modulus 0))) + (GF25519Bounded.inv GF25519BoundedCommon.zero). +Proof. + vm_decide_no_check. +Qed. -(** nat *) - -Require Import Crypto.Experiments.ExtrHaskellNats. - -(** positive *) -Require Import BinPos. - -Extract Inductive positive => "Prelude.Integer" [ - "(\x -> 2 Prelude.* x Prelude.+ 1)" - "(\x -> 2 Prelude.* x)" - "1" ] - "(\fI fO fH n -> {- match_on_positive -} - if n Prelude.== 1 then fH () else - if Prelude.odd n - then fI (n `Prelude.div` 2) - else fO (n `Prelude.div` 2))". - -Extract Inlined Constant Pos.succ => "(1 Prelude.+)". -Extract Inlined Constant Pos.add => "(Prelude.+)". -Extract Inlined Constant Pos.mul => "(Prelude.*)". -Extract Inlined Constant Pos.pow => "(Prelude.^)". -Extract Inlined Constant Pos.max => "Prelude.max". -Extract Inlined Constant Pos.min => "Prelude.min". -Extract Inlined Constant Pos.gcd => "Prelude.gcd". -Extract Inlined Constant Pos.land => "(Data.Bits..&.)". -Extract Inlined Constant Pos.lor => "(Data.Bits..|.)". -Extract Inlined Constant Pos.compare => "Prelude.compare". -Extract Inlined Constant Pos.ltb => "(Prelude.<)". -Extract Inlined Constant Pos.leb => "(Prelude.<=)". -Extract Inlined Constant Pos.eq_dec => "(Prelude.==)". -Extract Inlined Constant Pos.eqb => "(Prelude.==)". - -(* XXX: unsound -- overflow in fromIntegral *) -Extract Constant Pos.shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". -Extract Constant Pos.shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". -Extract Constant Pos.testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". - -Extract Constant Pos.pred => "(\n -> Prelude.max 1 (Prelude.pred n))". -Extract Constant Pos.sub => "(\n m -> Prelude.max 1 (n Prelude.- m))". - -(** N *) - -Extract Inlined Constant N.succ => "(1 Prelude.+)". -Extract Inlined Constant N.add => "(Prelude.+)". -Extract Inlined Constant N.mul => "(Prelude.*)". -Extract Inlined Constant N.pow => "(Prelude.^)". -Extract Inlined Constant N.max => "Prelude.max". -Extract Inlined Constant N.min => "Prelude.min". -Extract Inlined Constant N.gcd => "Prelude.gcd". -Extract Inlined Constant N.lcm => "Prelude.lcm". -Extract Inlined Constant N.land => "(Data.Bits..&.)". -Extract Inlined Constant N.lor => "(Data.Bits..|.)". -Extract Inlined Constant N.lxor => "Data.Bits.xor". -Extract Inlined Constant N.compare => "Prelude.compare". -Extract Inlined Constant N.eq_dec => "(Prelude.==)". -Extract Inlined Constant N.ltb => "(Prelude.<)". -Extract Inlined Constant N.leb => "(Prelude.<=)". -Extract Inlined Constant N.eq_dec => "(Prelude.==)". -Extract Inlined Constant N.odd => "Prelude.odd". -Extract Inlined Constant N.even => "Prelude.even". - -(* XXX: unsound -- overflow in fromIntegral *) -Extract Constant N.shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". -Extract Constant N.shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". -Extract Constant N.testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". - -Extract Constant N.pred => "(\n -> Prelude.max 0 (Prelude.pred n))". -Extract Constant N.sub => "(\n m -> Prelude.max 0 (n Prelude.- m))". -Extract Constant N.div => "(\n m -> if m Prelude.== 0 then 0 else Prelude.div n m)". -Extract Constant N.modulo => "(\n m -> if m Prelude.== 0 then 0 else Prelude.mod n m)". - -Extract Inductive N => "Prelude.Integer" [ "0" "(\x -> x)" ] - "(\fO fS n -> {- match_on_N -} if n Prelude.== 0 then fO () else fS (n Prelude.- 1))". - -(** Z *) -Require Import ZArith.BinInt. - -Extract Inductive Z => "Prelude.Integer" [ "0" "(\x -> x)" "Prelude.negate" ] - "(\fO fP fN n -> {- match_on_Z -} - if n Prelude.== 0 then fO () else - if n Prelude.> 0 then fP n else - fN (Prelude.negate n))". - -Extract Inlined Constant Z.succ => "(1 Prelude.+)". -Extract Inlined Constant Z.add => "(Prelude.+)". -Extract Inlined Constant Z.sub => "(Prelude.-)". -Extract Inlined Constant Z.opp => "Prelude.negate". -Extract Inlined Constant Z.mul => "(Prelude.*)". -Extract Inlined Constant Z.pow => "(Prelude.^)". -Extract Inlined Constant Z.pow_pos => "(Prelude.^)". -Extract Inlined Constant Z.max => "Prelude.max". -Extract Inlined Constant Z.min => "Prelude.min". -Extract Inlined Constant Z.lcm => "Prelude.lcm". -Extract Inlined Constant Z.land => "(Data.Bits..&.)". -Extract Inlined Constant Z.pred => "Prelude.pred". -Extract Inlined Constant Z.land => "(Data.Bits..&.)". -Extract Inlined Constant Z.lor => "(Data.Bits..|.)". -Extract Inlined Constant Z.lxor => "Data.Bits.xor". -Extract Inlined Constant Z.compare => "Prelude.compare". -Extract Inlined Constant Z.eq_dec => "(Prelude.==)". -Extract Inlined Constant Z_ge_lt_dec => "(Prelude.>=)". -Extract Inlined Constant Z_gt_le_dec => "(Prelude.>)". -Extract Inlined Constant Z.ltb => "(Prelude.<)". -Extract Inlined Constant Z.leb => "(Prelude.<=)". -Extract Inlined Constant Z.gtb => "(Prelude.>)". -Extract Inlined Constant Z.geb => "(Prelude.>=)". -Extract Inlined Constant Z.odd => "Prelude.odd". -Extract Inlined Constant Z.even => "Prelude.even". - -(* XXX: unsound -- overflow in fromIntegral *) -Extract Constant Z.shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". -Extract Constant Z.shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". -Extract Constant Z.testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". - -Extract Constant Z.div => "(\n m -> if m Prelude.== 0 then 0 else Prelude.div n m)". -Extract Constant Z.modulo => "(\n m -> if m Prelude.== 0 then 0 else Prelude.mod n m)". - -(** Conversions *) - -Extract Inlined Constant Z.of_N => "". -Extract Inlined Constant Z.to_N => "". -Extract Inlined Constant N.to_nat => "". -Extract Inlined Constant N.of_nat => "". -Extract Inlined Constant Z.to_nat => "". -Extract Inlined Constant Z.of_nat => "". -Extract Inlined Constant Z.abs_N => "Prelude.abs". -Extract Inlined Constant Z.abs_nat => "Prelude.abs". -Extract Inlined Constant Pos.pred_N => "Prelude.pred". -Extract Inlined Constant Pos.lxor => "Data.Bits.xor". - -(** Word *) -(* do not annotate every bit of a word with the number of bits after it *) -Extraction Implicit Word.WS [ 2 ]. -Extraction Implicit Word.whd [ 1 ]. -Extraction Implicit Word.wtl [ 1 ]. -Extraction Implicit Word.bitwp [ 2 ]. -Extraction Implicit Word.wand [ 1 ]. -Extraction Implicit Word.wor [ 1 ]. -Extraction Implicit Word.wxor [ 1 ]. -Extraction Implicit Word.wordToN [ 1 ]. -Extraction Implicit Word.wordToNat [ 1 ]. -Extraction Implicit Word.combine [ 1 3 ]. -Extraction Implicit Word.split1 [ 2 ]. -Extraction Implicit Word.split2 [ 2 ]. -Extraction Implicit WordUtil.cast_word [1 2 3]. -Extraction Implicit WordUtil.wfirstn [ 2 4 ]. -Extract Inlined Constant WordUtil.cast_word => "". - -(** Let_In *) -Extraction Inline LetIn.Let_In. - -(* inlining, primarily to reduce polymorphism *) -Extraction Inline dec_eq_Z dec_eq_N dec_eq_sig_hprop. -Extraction Inline Erep SRep ZNWord WordNZ. -Extraction Inline GF25519BoundedCommon.fe25519. -Extraction Inline EdDSARepChange.sign EdDSARepChange.splitSecretPrngCurve. -Extraction Inline Crypto.Util.IterAssocOp.iter_op Crypto.Util.IterAssocOp.test_and_op. -Extraction Inline PointEncoding.Kencode_point. -Extraction Inline ExtendedCoordinates.Extended.point ExtendedCoordinates.Extended.coordinates ExtendedCoordinates.Extended.to_twisted ExtendedCoordinates.Extended.from_twisted ExtendedCoordinates.Extended.add_coordinates ExtendedCoordinates.Extended.add ExtendedCoordinates.Extended.opp ExtendedCoordinates.Extended.zero. (* ExtendedCoordinates.Extended.zero could be precomputed *) -Extraction Inline CompleteEdwardsCurve.E.coordinates CompleteEdwardsCurve.E.zero. - -(* Recursive Extraction sign. *) - (* most of the code we want seems to be below [eq_dec1] and there is other stuff above that *) - (* TODO: remove branching from [sRep] functions *) - -(* fragment of output: - -sign :: Word -> Word -> Prelude.Integer -> Word -> Word -sign pk sk mlen msg = - let { - sp = let {hsk = h b sk} in - (,) - (sRepDecModLShort - (combine n (clearlow n c (wfirstn n ((Prelude.+) b b) hsk)) (Prelude.succ 0) - (wones (Prelude.succ 0)))) (split2 b b hsk)} - in - let {r = sRepDecModL (h ((Prelude.+) b mlen) (combine b (Prelude.snd sp) mlen msg))} in - let {r0 = sRepERepMul r eRepB} in - combine b (eRepEnc r0) b - (sRepEnc - (sRepAdd r - (sRepMul - (sRepDecModL - (h ((Prelude.+) b ((Prelude.+) b mlen)) - (combine b (eRepEnc r0) ((Prelude.+) b mlen) (combine b pk mlen msg)))) (Prelude.fst sp)))) - -sRepERepMul :: SRep0 -> Erep -> Erep -sRepERepMul sc a = - Prelude.snd - (funexp (\state -> - case state of { - (,) i acc -> - let {acc2 = erepAdd acc acc} in - let {acc2a = erepAdd a acc2} in - (\fO fS n -> {- match_on_nat -} if n Prelude.== 0 then fO () else fS (n Prelude.- 1)) - (\_ -> (,) 0 - acc) - (\i' -> (,) i' - (eRepSel ((\w n -> Data.Bits.testBit w (Prelude.fromIntegral n)) sc (of_nat i')) acc2 acc2a)) - i}) ((,) ll - (case ((,) zero_ one_) of { - (,) x y -> (,) ((,) ((,) x y) one_) (mul3 x y)})) ll) - -erepAdd :: (Point0 Fe25519) -> (Point0 Fe25519) -> Point0 Fe25519 -erepAdd p q = - case p of { - (,) y t1 -> - case y of { - (,) y0 z1 -> - case y0 of { - (,) x1 y1 -> - case q of { - (,) y2 t2 -> - case y2 of { - (,) y3 z2 -> - case y3 of { - (,) x2 y4 -> - let {a = mul3 (sub2 y1 x1) (sub2 y4 x2)} in - let {b0 = mul3 (add2 y1 x1) (add2 y4 x2)} in - let {c0 = mul3 (mul3 t1 twice_d) t2} in - let {d = mul3 z1 (add2 z2 z2)} in - let {e = sub2 b0 a} in - let {f = sub2 d c0} in - let {g = add2 d c0} in - let {h0 = add2 b0 a} in - let {x3 = mul3 e f} in - let {y5 = mul3 g h0} in - let {t3 = mul3 e h0} in let {z3 = mul3 f g} in (,) ((,) ((,) x3 y5) z3) t3}}}}}} -*) +Import ModularArithmetic. +Module Spec. + Module X25519. + Definition a : F q := F.of_Z _ 486662. + Definition a24 : F q := ((a - F.of_Z _ 2) / F.of_Z _ 4)%F. + End X25519. +End Spec. + +Section X25519Constants. + Import GF25519BoundedCommon. + Definition a24' : GF25519BoundedCommon.fe25519 := + Eval vm_compute in GF25519BoundedCommon.encode Spec.X25519.a24. + Definition a24 : GF25519BoundedCommon.fe25519 := + Eval cbv [a24' fe25519_word64ize word64ize andb opt.word64ToZ opt.word64ize opt.Zleb Z.compare CompOpp Pos.compare Pos.compare_cont] in (fe25519_word64ize a24'). + Lemma a24_correct : GF25519BoundedCommon.eq + (GF25519BoundedCommon.encode Spec.X25519.a24) + (a24). + Proof. vm_decide_no_check. Qed. +End X25519Constants. + +Definition x25519 (n:N) (x:GF25519BoundedCommon.fe25519) : GF25519BoundedCommon.fe25519 := + @MxDH.montladder GF25519BoundedCommon.fe25519 GF25519BoundedCommon.zero + GF25519BoundedCommon.one GF25519Bounded.add GF25519Bounded.sub + GF25519Bounded.mul GF25519Bounded.inv a24 + (fun (H : bool) + (H0 + H1 : GF25519BoundedCommon.fe25519 * GF25519BoundedCommon.fe25519) + => if H then (H1, H0) else (H0, H1)) 255 (N.testbit_nat n) x. + +Definition x25519_correct' n x : + GF25519BoundedCommon.eq + (GF25519BoundedCommon.encode (MxDH.montladder 255 (N.testbit_nat n) x)) + (MxDH.montladder 255 (N.testbit_nat n) (GF25519BoundedCommon.encode x)) := + MxDHRepChange + (field:=PrimeFieldTheorems.F.field_modulo GF25519.modulus) + (impl_field:=GF25519Bounded.field25519) + (homomorphism_inv_zero:=Fhomom_inv_zero) + (homomorphism_a24:=a24_correct) + (Fcswap_correct:= fun _ _ _ => (reflexivity _)) + (Kcswap_correct:= fun _ _ _ => (reflexivity _)) + (tb2_correct:=fun _ => (reflexivity _)) + 255 _. + +Print Assumptions x25519_correct'. +Let three_correct := (@sign_correct, @verify_correct, x25519_correct'). +Print Assumptions three_correct.
\ No newline at end of file diff --git a/src/Experiments/Ed25519Extraction.v b/src/Experiments/Ed25519Extraction.v new file mode 100644 index 000000000..20a76f17f --- /dev/null +++ b/src/Experiments/Ed25519Extraction.v @@ -0,0 +1,299 @@ +Require Import Crypto.Experiments.Ed25519. +Import Decidable BinNat BinInt ZArith_dec. + +Extraction Language Haskell. +Unset Extraction KeepSingleton. +Set Extraction AutoInline. +Set Extraction Optimize. +Unset Extraction AccessOpaque. + +(** Eq *) + +Extraction Implicit eq_rect [ x y ]. +Extraction Implicit eq_rect_r [ x y ]. +Extraction Implicit eq_rec [ x y ]. +Extraction Implicit eq_rec_r [ x y ]. + +Extract Inlined Constant eq_rect => "". +Extract Inlined Constant eq_rect_r => "". +Extract Inlined Constant eq_rec => "". +Extract Inlined Constant eq_rec_r => "". + +(** Ord *) + +Extract Inductive comparison => + "Prelude.Ordering" ["Prelude.EQ" "Prelude.LT" "Prelude.GT"]. + +(** Bool, sumbool, Decidable *) + +Extract Inductive bool => "Prelude.Bool" ["Prelude.True" "Prelude.False"]. +Extract Inductive sumbool => "Prelude.Bool" ["Prelude.True" "Prelude.False"]. +Extract Inductive Bool.reflect => "Prelude.Bool" ["Prelude.True" "Prelude.False"]. +Extract Inlined Constant Bool.iff_reflect => "". +Extraction Inline Crypto.Util.Decidable.Decidable Crypto.Util.Decidable.dec. + +(* Extract Inlined Constant Equality.bool_beq => *) +(* "((Prelude.==) :: Prelude.Bool -> Prelude.Bool -> Prelude.Bool)". *) +Extract Inlined Constant Bool.bool_dec => + "((Prelude.==) :: Prelude.Bool -> Prelude.Bool -> Prelude.Bool)". + +Extract Inlined Constant Sumbool.sumbool_of_bool => "". + +Extract Inlined Constant negb => "Prelude.not". +Extract Inlined Constant orb => "(Prelude.||)". +Extract Inlined Constant andb => "(Prelude.&&)". +Extract Inlined Constant xorb => "Data.Bits.xor". + +(** Comparisons *) + +Extract Inductive comparison => "Prelude.Ordering" [ "Prelude.EQ" "Prelude.LT" "Prelude.GT" ]. +Extract Inductive CompareSpecT => "Prelude.Ordering" [ "Prelude.EQ" "Prelude.LT" "Prelude.GT" ]. + +(** Maybe *) + +Extract Inductive option => "Prelude.Maybe" ["Prelude.Just" "Prelude.Nothing"]. +Extract Inductive sumor => "Prelude.Maybe" ["Prelude.Just" "Prelude.Nothing"]. + +(** Either *) + +Extract Inductive sum => "Prelude.Either" ["Prelude.Left" "Prelude.Right"]. + +(** List *) + +Extract Inductive list => "[]" ["[]" "(:)"]. + +Extract Inlined Constant app => "(Prelude.++)". +Extract Inlined Constant List.map => "Prelude.map". +Extract Constant List.fold_left => "\f l z -> Data.List.foldl f z l". +Extract Inlined Constant List.fold_right => "Data.List.foldr". +Extract Inlined Constant List.find => "Data.List.find". +Extract Inlined Constant List.length => "Data.List.genericLength". + +(** Tuple *) + +Extract Inductive prod => "(,)" ["(,)"]. +Extract Inductive sigT => "(,)" ["(,)"]. + +Extract Inlined Constant fst => "Prelude.fst". +Extract Inlined Constant snd => "Prelude.snd". +Extract Inlined Constant projT1 => "Prelude.fst". +Extract Inlined Constant projT2 => "Prelude.snd". + +Extract Inlined Constant proj1_sig => "". + +(** Unit *) + +Extract Inductive unit => "()" ["()"]. + +(** nat *) + +Require Import Crypto.Experiments.ExtrHaskellNats. + +(** positive *) +Require Import BinPos. + +Extract Inductive positive => "Prelude.Integer" [ + "(\x -> 2 Prelude.* x Prelude.+ 1)" + "(\x -> 2 Prelude.* x)" + "1" ] + "(\fI fO fH n -> {- match_on_positive -} + if n Prelude.== 1 then fH () else + if Prelude.odd n + then fI (n `Prelude.div` 2) + else fO (n `Prelude.div` 2))". + +Extract Inlined Constant Pos.succ => "(1 Prelude.+)". +Extract Inlined Constant Pos.add => "(Prelude.+)". +Extract Inlined Constant Pos.mul => "(Prelude.*)". +Extract Inlined Constant Pos.pow => "(Prelude.^)". +Extract Inlined Constant Pos.max => "Prelude.max". +Extract Inlined Constant Pos.min => "Prelude.min". +Extract Inlined Constant Pos.gcd => "Prelude.gcd". +Extract Inlined Constant Pos.land => "(Data.Bits..&.)". +Extract Inlined Constant Pos.lor => "(Data.Bits..|.)". +Extract Inlined Constant Pos.compare => "Prelude.compare". +Extract Inlined Constant Pos.ltb => "(Prelude.<)". +Extract Inlined Constant Pos.leb => "(Prelude.<=)". +Extract Inlined Constant Pos.eq_dec => "(Prelude.==)". +Extract Inlined Constant Pos.eqb => "(Prelude.==)". + +(* XXX: unsound -- overflow in fromIntegral *) +Extract Constant Pos.shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". +Extract Constant Pos.shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". +Extract Constant Pos.testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". + +Extract Constant Pos.pred => "(\n -> Prelude.max 1 (Prelude.pred n))". +Extract Constant Pos.sub => "(\n m -> Prelude.max 1 (n Prelude.- m))". + +(** N *) + +Extract Inlined Constant N.succ => "(1 Prelude.+)". +Extract Inlined Constant N.add => "(Prelude.+)". +Extract Inlined Constant N.mul => "(Prelude.*)". +Extract Inlined Constant N.pow => "(Prelude.^)". +Extract Inlined Constant N.max => "Prelude.max". +Extract Inlined Constant N.min => "Prelude.min". +Extract Inlined Constant N.gcd => "Prelude.gcd". +Extract Inlined Constant N.lcm => "Prelude.lcm". +Extract Inlined Constant N.land => "(Data.Bits..&.)". +Extract Inlined Constant N.lor => "(Data.Bits..|.)". +Extract Inlined Constant N.lxor => "Data.Bits.xor". +Extract Inlined Constant N.compare => "Prelude.compare". +Extract Inlined Constant N.eq_dec => "(Prelude.==)". +Extract Inlined Constant N.ltb => "(Prelude.<)". +Extract Inlined Constant N.leb => "(Prelude.<=)". +Extract Inlined Constant N.eq_dec => "(Prelude.==)". +Extract Inlined Constant N.odd => "Prelude.odd". +Extract Inlined Constant N.even => "Prelude.even". + +(* XXX: unsound -- overflow in fromIntegral *) +Extract Constant N.shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". +Extract Constant N.shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". +Extract Constant N.testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". +Extract Constant N.testbit_nat => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". + +Extract Constant N.pred => "(\n -> Prelude.max 0 (Prelude.pred n))". +Extract Constant N.sub => "(\n m -> Prelude.max 0 (n Prelude.- m))". +Extract Constant N.div => "(\n m -> if m Prelude.== 0 then 0 else Prelude.div n m)". +Extract Constant N.modulo => "(\n m -> if m Prelude.== 0 then 0 else Prelude.mod n m)". + +Extract Inductive N => "Prelude.Integer" [ "0" "(\x -> x)" ] + "(\fO fS n -> {- match_on_N -} if n Prelude.== 0 then fO () else fS n)". + +(** Z *) +Require Import ZArith.BinInt. + +Extract Inductive Z => "Prelude.Integer" [ "0" "(\x -> x)" "Prelude.negate" ] + "(\fO fP fN n -> {- match_on_Z -} + if n Prelude.== 0 then fO () else + if n Prelude.> 0 then fP n else + fN (Prelude.negate n))". + +Extract Inlined Constant Z.succ => "(1 Prelude.+)". +Extract Inlined Constant Z.add => "(Prelude.+)". +Extract Inlined Constant Z.sub => "(Prelude.-)". +Extract Inlined Constant Z.opp => "Prelude.negate". +Extract Inlined Constant Z.mul => "(Prelude.*)". +Extract Inlined Constant Z.pow => "(Prelude.^)". +Extract Inlined Constant Z.pow_pos => "(Prelude.^)". +Extract Inlined Constant Z.max => "Prelude.max". +Extract Inlined Constant Z.min => "Prelude.min". +Extract Inlined Constant Z.lcm => "Prelude.lcm". +Extract Inlined Constant Z.land => "(Data.Bits..&.)". +Extract Inlined Constant Z.pred => "Prelude.pred". +Extract Inlined Constant Z.land => "(Data.Bits..&.)". +Extract Inlined Constant Z.lor => "(Data.Bits..|.)". +Extract Inlined Constant Z.lxor => "Data.Bits.xor". +Extract Inlined Constant Z.compare => "Prelude.compare". +Extract Inlined Constant Z.eq_dec => "(Prelude.==)". +Extract Inlined Constant Z_ge_lt_dec => "(Prelude.>=)". +Extract Inlined Constant Z_gt_le_dec => "(Prelude.>)". +Extract Inlined Constant Z.ltb => "(Prelude.<)". +Extract Inlined Constant Z.leb => "(Prelude.<=)". +Extract Inlined Constant Z.gtb => "(Prelude.>)". +Extract Inlined Constant Z.geb => "(Prelude.>=)". +Extract Inlined Constant Z.odd => "Prelude.odd". +Extract Inlined Constant Z.even => "Prelude.even". + +(* XXX: unsound -- overflow in fromIntegral *) +Extract Constant Z.shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". +Extract Constant Z.shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". +Extract Constant Z.testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". + +Extract Constant Z.div => "(\n m -> if m Prelude.== 0 then 0 else Prelude.div n m)". +Extract Constant Z.modulo => "(\n m -> if m Prelude.== 0 then 0 else Prelude.mod n m)". + +(** Conversions *) + +Extract Inlined Constant Z.of_N => "". +Extract Inlined Constant Z.to_N => "". +Extract Inlined Constant N.to_nat => "". +Extract Inlined Constant N.of_nat => "". +Extract Inlined Constant Z.to_nat => "". +Extract Inlined Constant Z.of_nat => "". +Extract Inlined Constant Z.abs_N => "Prelude.abs". +Extract Inlined Constant Z.abs_nat => "Prelude.abs". +Extract Inlined Constant Pos.pred_N => "Prelude.pred". +Extract Inlined Constant Pos.lxor => "Data.Bits.xor". + +(** Word *) +(* do not annotate every bit of a word with the number of bits after it *) +Extraction Implicit Word.WS [ 2 ]. +Extraction Implicit Word.weqb [ 1 ]. +Extraction Implicit Word.whd [ 1 ]. +Extraction Implicit Word.wtl [ 1 ]. +Extraction Implicit Word.bitwp [ 2 ]. +Extraction Implicit Word.wand [ 1 ]. +Extraction Implicit Word.wor [ 1 ]. +Extraction Implicit Word.wxor [ 1 ]. +Extraction Implicit Word.wordToN [ 1 ]. +Extraction Implicit Word.wordToNat [ 1 ]. +Extraction Implicit Word.combine [ 1 3 ]. +Extraction Implicit Word.split1 [ 2 ]. +Extraction Implicit Word.split2 [ 2 ]. +Extraction Implicit WordUtil.cast_word [1 2 3]. +Extraction Implicit WordUtil.wfirstn [ 2 4 ]. +Extract Inlined Constant WordUtil.cast_word => "". +Extract Inductive Word.word => "[Prelude.Bool]" [ "[]" "(:)" ] + "(\fWO fWS w -> {- match_on_word -} case w of {[] -> fWO (); (b:w') -> fWS b w' } )". + +(** Let_In *) +Extraction Inline LetIn.Let_In. + +(* Word64 *) +Import Crypto.Reflection.Z.Interpretations. +Extract Inlined Constant Word64.word64 => "Data.Word.Word64". +Extract Inlined Constant GF25519BoundedCommon.word64 => "Data.Word.Word64". +Extract Inlined Constant GF25519BoundedCommon.w64eqb => "(Prelude.==)". +Extract Inlined Constant Word64.word64ToZ => "Prelude.fromIntegral". +Extract Inlined Constant GF25519BoundedCommon.word64ToZ => "Prelude.fromIntegral". +Extract Inlined Constant GF25519BoundedCommon.NToWord64 => "Prelude.fromIntegral". +Extract Inlined Constant GF25519BoundedCommon.ZToWord64 => "Prelude.fromIntegral". +Extract Inlined Constant Word64.add => "(Prelude.+)". +Extract Inlined Constant Word64.mul => "(Prelude.*)". +Extract Inlined Constant Word64.sub => "(Prelude.-)". +Extract Inlined Constant Word64.land => "(Data.Bits..&.)". +Extract Inlined Constant Word64.lor => "(Data.Bits..|.)". +Extract Constant Word64.neg => "(\_ w -> Prelude.negate w)". (* FIXME: reification: drop arg1 *) +Extract Constant Word64.shr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". +Extract Constant Word64.shl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". +Extract Constant Word64.cmovle => "(\x y r1 r2 -> if x Prelude.<= y then r1 else r2)". +Extract Constant Word64.cmovne => "(\x y r1 r2 -> if x Prelude.== y then r1 else r2)". + +(* inlining, primarily to reduce polymorphism *) +Extraction Inline dec_eq_Z dec_eq_N dec_eq_sig_hprop. +Extraction Inline Ed25519.Erep Ed25519.SRep Ed25519.ZNWord Ed25519.WordNZ. +Extraction Inline GF25519BoundedCommon.fe25519. +Extraction Inline EdDSARepChange.sign EdDSARepChange.splitSecretPrngCurve. +Extraction Inline Crypto.Util.IterAssocOp.iter_op Crypto.Util.IterAssocOp.test_and_op. +Extraction Inline PointEncoding.Kencode_point. +Extraction Inline ExtendedCoordinates.Extended.point ExtendedCoordinates.Extended.coordinates ExtendedCoordinates.Extended.to_twisted ExtendedCoordinates.Extended.from_twisted ExtendedCoordinates.Extended.add_coordinates ExtendedCoordinates.Extended.add ExtendedCoordinates.Extended.opp ExtendedCoordinates.Extended.zero. (* ExtendedCoordinates.Extended.zero could be precomputed *) +Extraction Inline CompleteEdwardsCurve.E.coordinates CompleteEdwardsCurve.E.zero. +Extraction Inline GF25519BoundedCommon.proj_word GF25519BoundedCommon.Build_bounded_word GF25519BoundedCommon.Build_bounded_word'. +Extraction Inline GF25519BoundedCommon.app_wire_digits GF25519BoundedCommon.wire_digit_bounds_exp. +Extraction Inline Crypto.Util.HList.mapt' Crypto.Util.HList.mapt Crypto.Util.Tuple.map. + +Extraction Implicit Ed25519.H [ 1 ]. +Extract Constant Ed25519.H => +"let { b2i b = case b of { Prelude.True -> 1 ; Prelude.False -> 0 } } in + let { leBitsToBytes [] = [] :: [Data.Word.Word8] ; + leBitsToBytes (a:b:c:d:e:f:g:h:bs) = (b2i a Data.Bits..|. (b2i b `Data.Bits.shiftL` 1) Data.Bits..|. (b2i c `Data.Bits.shiftL` 2) Data.Bits..|. (b2i d `Data.Bits.shiftL` 3) Data.Bits..|. (b2i e `Data.Bits.shiftL` 4) Data.Bits..|. (b2i f `Data.Bits.shiftL` 5) Data.Bits..|. (b2i g `Data.Bits.shiftL` 6) Data.Bits..|. (b2i h `Data.Bits.shiftL` 7)) : leBitsToBytes bs ; + leBitsToBytes bs = Prelude.error ('b':'s':'l':[]) } in + let { bytesToLEBits [] = [] :: [Prelude.Bool] ; + bytesToLEBits (x:xs) = (x `Data.Bits.testBit` 0) : (x `Data.Bits.testBit` 1) : (x `Data.Bits.testBit` 2) : (x `Data.Bits.testBit` 3) : (x `Data.Bits.testBit` 4) : (x `Data.Bits.testBit` 5) : (x `Data.Bits.testBit` 6) : (x `Data.Bits.testBit` 7) : bytesToLEBits xs } in + (bytesToLEBits Prelude.. B.unpack Prelude.. SHA.bytestringDigest Prelude.. SHA.sha512 Prelude.. B.pack Prelude.. leBitsToBytes)". + +(* invW makes ghc -XStrict very slow *) +(* Extract Constant GF25519Bounded.invW => "Prelude.error ('i':'n':'v':'W':[])". *) + +Extraction "src/Experiments/Ed25519_noimports.hs" Ed25519.sign (* Ed25519.verify *). +(* +*Ed25519 Prelude> and (eRepEnc ((sRepERepMul l eRepB))) == False +True +*Ed25519 Prelude> eRepEnc ((sRepERepMul l eRepB) `erepAdd` eRepB) == eRepEnc eRepB +True +*) + +Import Crypto.Spec.MxDH. +Extraction Inline MxDH.ladderstep MxDH.montladder. +Extraction "src/Experiments/X25519_noimports.hs" Crypto.Experiments.Ed25519.x25519.
\ No newline at end of file diff --git a/src/Experiments/Ed25519_imports.hs b/src/Experiments/Ed25519_imports.hs new file mode 100644 index 000000000..726b4b268 --- /dev/null +++ b/src/Experiments/Ed25519_imports.hs @@ -0,0 +1,5 @@ +import qualified Data.List +import qualified Data.Bits +import qualified Data.Word (Word8, Word64) +import qualified Data.ByteString.Lazy as B +import qualified Data.Digest.Pure.SHA as SHA diff --git a/src/Experiments/ExtrHaskellNats.v b/src/Experiments/ExtrHaskellNats.v index ef9fd06d9..3e2974ea1 100644 --- a/src/Experiments/ExtrHaskellNats.v +++ b/src/Experiments/ExtrHaskellNats.v @@ -51,12 +51,18 @@ Module Export Import_NPeano_Nat. Extract Inlined Constant ltb => "(Prelude.<)". Extract Inlined Constant leb => "(Prelude.<=)". Extract Inlined Constant eqb => "(Prelude.==)". + Extract Inlined Constant eq_dec => "(Prelude.==)". Extract Inlined Constant odd => "Prelude.odd". Extract Inlined Constant even => "Prelude.even". Extract Constant pred => "(\n -> Prelude.max 0 (Prelude.pred n))". Extract Constant sub => "(\n m -> Prelude.max 0 (n Prelude.- m))". Extract Constant div => "(\n m -> if m Prelude.== 0 then 0 else Prelude.div n m)". Extract Constant modulo => "(\n m -> if m Prelude.== 0 then 0 else Prelude.mod n m)". + + (* XXX: unsound due to potential overflow in the second argument *) + Extract Constant shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". + Extract Constant shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". + Extract Constant testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". End Import_NPeano_Nat. @@ -72,17 +78,28 @@ Module Export Import_Init_Nat. Extract Constant div => "(\n m -> if m Prelude.== 0 then 0 else Prelude.div n m)". Extract Constant modulo => "(\n m -> if m Prelude.== 0 then 0 else Prelude.mod n m)". + + (* XXX: unsound due to potential overflow in the second argument *) + Extract Constant shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". + Extract Constant shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". + Extract Constant testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". End Import_Init_Nat. Module Export Import_PeanoNat_Nat. Import Coq.Arith.PeanoNat.Nat. + Extract Inlined Constant eq_dec => "(Prelude.==)". Extract Inlined Constant add => "(Prelude.+)". Extract Inlined Constant mul => "(Prelude.*)". Extract Inlined Constant max => "Prelude.max". Extract Inlined Constant min => "Prelude.min". Extract Inlined Constant compare => "Prelude.compare". + + (* XXX: unsound due to potential overflow in the second argument *) + Extract Constant shiftr => "(\w n -> Data.Bits.shiftR w (Prelude.fromIntegral n))". + Extract Constant shiftl => "(\w n -> Data.Bits.shiftL w (Prelude.fromIntegral n))". + Extract Constant testbit => "(\w n -> Data.Bits.testBit w (Prelude.fromIntegral n))". End Import_PeanoNat_Nat. Extract Inlined Constant Compare_dec.nat_compare_alt => "Prelude.compare". diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index d8303d1a7..9d7ce7c1f 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -83,10 +83,10 @@ Section ModularBaseSystem. Definition eq (x y : digits) : Prop := decode x = decode y. - Definition freeze B (x : digits) : digits := - from_list (freeze B [[x]]) (length_freeze length_to_list). + Definition freeze int_width (x : digits) : digits := + from_list (freeze int_width [[x]]) (length_freeze length_to_list). - Definition eqb B (x y : digits) : bool := fieldwiseb Z.eqb (freeze B x) (freeze B y). + Definition eqb int_width (x y : digits) : bool := fieldwiseb Z.eqb (freeze int_width x) (freeze int_width y). (* Note : both of the following square root definitions will produce garbage output if the input is not square mod [modulus]. The caller should either provably only call them with square input, @@ -95,13 +95,10 @@ Section ModularBaseSystem. (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 4 + 1)) (x : digits) : digits := pow x chain. - (* sqrt_5mod8 is parameterized over implementation of [mul] and [pow] because it relies on bounds-checking - for these two functions, which is much easier for simplified implementations than the more generalized - ones defined here. *) - Definition sqrt_5mod8 B mul_ pow_ (chain : list (nat * nat)) + Definition sqrt_5mod8 int_width powx powx_squared (chain : list (nat * nat)) (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 8 + 1)) (sqrt_minus1 x : digits) : digits := - let b := pow_ x chain in if eqb B (mul_ b b) x then b else mul_ sqrt_minus1 b. + if eqb int_width powx_squared x then powx else mul sqrt_minus1 powx. Import Morphisms. Global Instance eq_Equivalence : Equivalence eq. @@ -109,9 +106,9 @@ Section ModularBaseSystem. split; cbv [eq]; repeat intro; congruence. Qed. - Definition select B (b : Z) (x y : digits) := - add (map (Z.land (neg B b)) x) - (map (Z.land (neg B (Z.lxor b 1))) x). + Definition select int_width (b : Z) (x y : digits) := + add (map (Z.land (neg int_width b)) x) + (map (Z.land (neg int_width (Z.lxor b 1))) x). Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = diff --git a/src/ModularArithmetic/ModularBaseSystemListZOperations.v b/src/ModularArithmetic/ModularBaseSystemListZOperations.v index 1d863abbd..09a252a06 100644 --- a/src/ModularArithmetic/ModularBaseSystemListZOperations.v +++ b/src/ModularArithmetic/ModularBaseSystemListZOperations.v @@ -2,6 +2,7 @@ (** We separate these out so that we can depend on them in other files without waiting for ModularBaseSystemList to build. *) Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.Tuple. Definition cmovl (x y r1 r2 : Z) := if Z.leb x y then r1 else r2. Definition cmovne (x y r1 r2 : Z) := if Z.eqb x y then r1 else r2. @@ -10,3 +11,7 @@ Definition cmovne (x y r1 r2 : Z) := if Z.eqb x y then r1 else r2. neg 1 = 2^64 - 1 (on 64-bit; 2^32-1 on 32-bit, etc.) neg 0 = 0 *) Definition neg (int_width : Z) (b : Z) := if Z.eqb b 1 then Z.ones int_width else 0%Z. + +(** TODO(jadep): Fill in this stub *) +Axiom conditional_subtract_modulus + : forall (limb_count : nat) (int_width : Z) (modulus value : Tuple.tuple Z limb_count), Tuple.tuple Z limb_count. diff --git a/src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v b/src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v new file mode 100644 index 000000000..bb833507f --- /dev/null +++ b/src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v @@ -0,0 +1,13 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil. + +Local Open Scope Z_scope. + +Lemma neg_nonneg : forall x y, 0 <= ModularBaseSystemListZOperations.neg x y. +Proof. Admitted. +Hint Resolve neg_nonneg : zarith. +Lemma neg_upperbound : forall x y, ModularBaseSystemListZOperations.neg x y <= Z.ones x. +Proof. Admitted. +Hint Resolve neg_upperbound : zarith. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index f7f6efad7..155698e56 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -905,30 +905,15 @@ Section Conversion. End Conversion. -Section with_base. - Context {modulus} (prm : PseudoMersenneBaseParams modulus). - Local Notation base := (Pow2Base.base_from_limb_widths limb_widths). - Local Notation log_cap i := (nth_default 0 limb_widths i). - - Record freezePreconditions int_width := - mkFreezePreconditions { - lt_1_length_base : (1 < length base)%nat; - int_width_pos : 0 < int_width; - int_width_compat : forall w, In w limb_widths -> w < int_width; - c_pos : 0 < c; - c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < 2 ^ log_cap 0; - c_reduce2 : c < 2 ^ log_cap 0 - c; - two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus - }. -End with_base. -Local Hint Resolve lt_1_length_base int_width_pos int_width_compat c_pos - c_reduce1 c_reduce2 two_pow_k_le_2modulus. +Local Hint Resolve lt_1_length_limb_widths int_width_pos B_pos B_compat + c_reduce1 c_reduce2. Section Canonicalization. Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} (* allows caller to precompute k and c *) (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) - {int_width} (preconditions : freezePreconditions prm int_width). + {int_width freeze_input_bound} + (preconditions : FreezePreconditions freeze_input_bound int_width). Local Notation digits := (tuple Z (length limb_widths)). Definition carry_full_3_opt_cps_sig @@ -1069,32 +1054,42 @@ Section SquareRoots. End SquareRoot3mod4. - Import Morphisms. - Global Instance eqb_Proper : Proper (Logic.eq ==> eq ==> eq ==> Logic.eq) ModularBaseSystem.eqb. Admitted. - Section SquareRoot5mod8. Context {ec : ExponentiationChain (modulus / 8 + 1)}. Context (sqrt_m1 : digits) (sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F)). - Context {int_width} (preconditions : freezePreconditions prm int_width). + Context {int_width freeze_input_bound} + (preconditions : FreezePreconditions freeze_input_bound int_width). - Definition sqrt_5mod8_opt_sig (us : digits) : + Definition sqrt_5mod8_opt_sig (powx powx_squared us : digits) : { vs : digits | - eq vs (sqrt_5mod8 int_width (carry_mul_opt k_ c_) (pow_opt k_ c_ one_) chain chain_correct sqrt_m1 us)}. + eq vs (sqrt_5mod8 int_width powx powx_squared chain chain_correct sqrt_m1 us)}. Proof. - eexists; cbv [sqrt_5mod8]. - let LHS := match goal with |- eq ?LHS ?RHS => LHS end in - let RHS := match goal with |- eq ?LHS ?RHS => RHS end in - let RHSf := match (eval pattern (pow_opt k_ c_ one_ us chain) in RHS) with ?RHSf _ => RHSf end in - change (eq LHS (Let_In (pow_opt k_ c_ one_ us chain) RHSf)). - reflexivity. + cbv [sqrt_5mod8]. + match goal with + |- appcontext[(if ?P then ?t else mul ?a ?b)] => + assert (eq (carry_mul_opt k_ c_ a b) (mul a b)) + by (rewrite carry_mul_opt_correct by auto; + cbv [eq]; rewrite carry_mul_rep, mul_rep; reflexivity) + end. + let RHS := match goal with |- {vs | eq ?vs ?RHS} => RHS end in + let RHSf := match (eval pattern powx in RHS) with ?RHSf _ => RHSf end in + change ({vs | eq vs (Let_In powx RHSf)}). + match goal with + | H : eq (?g powx) (?f powx) + |- {vs | eq vs (Let_In powx (fun x => if ?P then x else ?f x))} => + exists (Let_In powx (fun x => if P then x else g x)) + end. + break_if; try reflexivity. + cbv [Let_In]. + auto. Defined. - Definition sqrt_5mod8_opt us := Eval cbv [proj1_sig sqrt_5mod8_opt_sig] in - proj1_sig (sqrt_5mod8_opt_sig us). + Definition sqrt_5mod8_opt powx powx_squared us := Eval cbv [proj1_sig sqrt_5mod8_opt_sig] in + proj1_sig (sqrt_5mod8_opt_sig powx powx_squared us). - Definition sqrt_5mod8_opt_correct us - : eq (sqrt_5mod8_opt us) (ModularBaseSystem.sqrt_5mod8 int_width _ _ chain chain_correct sqrt_m1 us) - := Eval cbv [proj2_sig sqrt_5mod8_opt_sig] in proj2_sig (sqrt_5mod8_opt_sig us). + Definition sqrt_5mod8_opt_correct powx powx_squared us + : eq (sqrt_5mod8_opt powx powx_squared us) (ModularBaseSystem.sqrt_5mod8 int_width _ _ chain chain_correct sqrt_m1 us) + := Eval cbv [proj2_sig sqrt_5mod8_opt_sig] in proj2_sig (sqrt_5mod8_opt_sig powx powx_squared us). End SquareRoot5mod8. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index f6efdfd87..3b0231191 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -521,9 +521,11 @@ End CarryProofs. Hint Rewrite @length_carry_and_reduce @length_carry : distr_length. -Class FreezePreconditions `{prm : PseudoMersenneBaseParams} B := +Class FreezePreconditions `{prm : PseudoMersenneBaseParams} B int_width := { lt_1_length_limb_widths : (1 < length limb_widths)%nat; + int_width_pos : 0 < int_width; + B_le_int_width : B <= int_width; B_pos : 0 < B; B_compat : forall w, In w limb_widths -> w < B; (* on the first reduce step, we add at most one bit of width to the first digit *) @@ -940,8 +942,14 @@ Section CanonicalizationProofs. congruence. Qed. + Lemma int_width_compat : forall x, In x limb_widths -> x < int_width. + Proof. + intros. apply B_compat in H. + eapply Z.lt_le_trans; eauto using B_le_int_width. + Qed. + Lemma minimal_rep_freeze : forall u, initial_bounds u -> - minimal_rep (freeze B u). + minimal_rep (freeze int_width u). Proof. repeat match goal with | |- _ => progress (cbv [freeze ModularBaseSystemList.freeze]) @@ -952,12 +960,12 @@ Section CanonicalizationProofs. | |- _ => apply conditional_subtract_lt_modulus | |- _ => apply conditional_subtract_modulus_preserves_bounded | |- bounded _ (carry_full _) => apply bounded_iff - | |- _ => solve [auto using Z.lt_le_incl, B_pos, B_compat, lt_1_length_limb_widths, length_carry_full, length_to_list] + | |- _ => solve [auto using Z.lt_le_incl, int_width_pos, int_width_compat, lt_1_length_limb_widths, length_carry_full, length_to_list] end. Qed. Lemma freeze_decode : forall u, - BaseSystem.decode base (to_list _ (freeze B u)) mod modulus = + BaseSystem.decode base (to_list _ (freeze int_width u)) mod modulus = BaseSystem.decode base (to_list _ u) mod modulus. Proof. repeat match goal with @@ -967,7 +975,7 @@ Section CanonicalizationProofs. | |- _ => rewrite Z.mod_add by (pose proof prime_modulus; prime_bound) | |- _ => rewrite to_list_from_list | |- _ => rewrite conditional_subtract_modulus_spec by - auto using Z.lt_le_incl, B_pos, B_compat, lt_1_length_limb_widths, length_carry_full, length_to_list, ge_modulus_01 + (auto using Z.lt_le_incl, int_width_pos, int_width_compat, lt_1_length_limb_widths, length_carry_full, length_to_list, ge_modulus_01) end. rewrite !decode_mod_Fdecode by auto using length_carry_full, length_to_list. cbv [carry_full]. @@ -986,7 +994,7 @@ Section CanonicalizationProofs. rewrite from_list_to_list; reflexivity. Qed. - Lemma freeze_rep : forall u x, rep u x -> rep (freeze B u) x. + Lemma freeze_rep : forall u x, rep u x -> rep (freeze int_width u) x. Proof. cbv [rep]; intros. apply F.eq_to_Z_iff. @@ -997,7 +1005,7 @@ Section CanonicalizationProofs. Lemma freeze_canonical : forall u v x y, rep u x -> rep v y -> initial_bounds u -> initial_bounds v -> - (x = y <-> fieldwise Logic.eq (freeze B u) (freeze B v)). + (x = y <-> fieldwise Logic.eq (freeze int_width u) (freeze int_width v)). Proof. intros; apply bounded_canonical; auto using freeze_rep, minimal_rep_freeze. Qed. @@ -1017,12 +1025,12 @@ Section SquareRootProofs. then 0 else (2 ^ B) >> (nth_default 0 limb_widths (pred n)))). Definition bounded_by u bounds := - (forall n : nat, + (forall n : nat, (n < length limb_widths)%nat -> 0 <= nth_default 0 (to_list (length limb_widths) u) n < bounds n). Lemma eqb_true_iff : forall u v x y, bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> - u ~= x -> v ~= y -> (x = y <-> eqb B u v = true). + u ~= x -> v ~= y -> (x = y <-> eqb int_width u v = true). Proof. cbv [eqb freeze_input_bounds]. intros. rewrite fieldwiseb_fieldwise by (apply Z.eqb_eq). @@ -1031,10 +1039,10 @@ Section SquareRootProofs. Lemma eqb_false_iff : forall u v x y, bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> - u ~= x -> v ~= y -> (x <> y <-> eqb B u v = false). + u ~= x -> v ~= y -> (x <> y <-> eqb int_width u v = false). Proof. intros. - case_eq (eqb B u v). + case_eq (eqb int_width u v). + rewrite <-eqb_true_iff by eassumption; split; intros; congruence || contradiction. + split; intros; auto. @@ -1063,44 +1071,74 @@ Section SquareRootProofs. Context (modulus_5mod8 : modulus mod 8 = 5). Context {ec : ExponentiationChain (modulus / 8 + 1)}. Context (sqrt_m1 : digits) (sqrt_m1_correct : mul sqrt_m1 sqrt_m1 ~= F.opp 1%F). - Context (mul_ : digits -> digits -> digits) - (mul_equiv : forall x y, mul_ x y = mul x y) - {mul_input_bounds : nat -> Z} - (mul_bounded : forall x y, bounded_by x mul_input_bounds -> - bounded_by y mul_input_bounds -> - bounded_by (mul_ x y) freeze_input_bounds). - Context (pow_ : digits -> list (nat * nat) -> digits) - (pow_equiv : forall x is, pow_ x is = pow x is) - {pow_input_bounds : nat -> Z} - (pow_bounded : forall x is, bounded_by x pow_input_bounds -> - bounded_by (pow_ x is) mul_input_bounds). - - Lemma sqrt_5mod8_correct : forall u x, u ~= x -> - bounded_by u pow_input_bounds -> bounded_by u freeze_input_bounds -> - (sqrt_5mod8 B mul_ pow_ chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x. + + Lemma sqrt_5mod8_correct : forall u x powx powx_squared, u ~= x -> + bounded_by u freeze_input_bounds -> + bounded_by powx_squared freeze_input_bounds -> + ModularBaseSystem.eq powx (pow u chain) -> + ModularBaseSystem.eq powx_squared (mul powx powx) -> + (sqrt_5mod8 int_width powx powx_squared chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x. Proof. + cbv [sqrt_5mod8 F.sqrt_5mod8]. + intros. repeat match goal with | |- _ => progress (cbv [sqrt_5mod8 F.sqrt_5mod8]; intros) | |- _ => rewrite @F.pow_2_r in * | |- _ => rewrite eqb_correct in * by eassumption | |- (if eqb _ ?a ?b then _ else _) ~= (if dec (?c = _) then _ else _) => - assert (a ~= c); rewrite !mul_equiv, pow_equiv in *; - repeat break_if + assert (a ~= c) by + (cbv [rep]; rewrite <-chain_correct, <-pow_rep, <-mul_rep; + eassumption); repeat break_if | |- _ => apply mul_rep; try reflexivity; - rewrite <-chain_correct; apply pow_rep; eassumption - | |- _ => rewrite <-chain_correct; apply pow_rep; eassumption - | H : eqb _ ?a ?b = true |- _ => - rewrite <-(eqb_true_iff a b) in H - by (eassumption || rewrite <-mul_equiv, <-pow_equiv; - apply mul_bounded, pow_bounded; auto) - | H : eqb _ ?a ?b = false |- _ => - rewrite <-(eqb_false_iff a b) in H - by (eassumption || rewrite <-mul_equiv, <-pow_equiv; - apply mul_bounded, pow_bounded; auto) + rewrite <-chain_correct, <-pow_rep; eassumption + | |- _ => rewrite <-chain_correct, <-pow_rep; eassumption + | H : eqb _ ?a ?b = true, H1 : ?b ~= ?y, H2 : ?a ~= ?x |- _ => + rewrite <-(eqb_true_iff a b x y) in H by eassumption + | H : eqb _ ?a ?b = false, H1 : ?b ~= ?y, H2 : ?a ~= ?x |- _ => + rewrite <-(eqb_false_iff a b x y) in H by eassumption | |- _ => congruence end. Qed. End Sqrt5mod8. End SquareRootProofs. + +Section ConversionProofs. + Context `{prm :PseudoMersenneBaseParams}. + Context {target_widths} + (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) + (bits_eq : sum_firstn limb_widths (length limb_widths) = + sum_firstn target_widths (length target_widths)). + Local Notation target_base := (base_from_limb_widths target_widths). + + Lemma pack_rep : forall w, + bounded limb_widths (to_list _ w) -> + bounded target_widths (to_list _ w) -> + rep w (F.of_Z modulus + (BaseSystem.decode + target_base + (to_list _ (pack target_widths_nonneg bits_eq w)))). + Proof. + intros; cbv [pack ModularBaseSystemList.pack rep]. + rewrite Tuple.to_list_from_list. + apply F.eq_to_Z_iff. + rewrite F.to_Z_of_Z. + rewrite <-Conversion.convert_correct; auto using length_to_list. + Qed. + + Lemma unpack_rep : forall w, + bounded target_widths (to_list _ w) -> + rep (unpack target_widths_nonneg bits_eq w) + (F.of_Z modulus (BaseSystem.decode target_base (to_list _ w))). + Proof. + intros; cbv [unpack ModularBaseSystemList.unpack rep]. + apply F.eq_to_Z_iff. + rewrite <-from_list_default_eq with (d := 0). + rewrite <-decode_mod_Fdecode by apply Conversion.length_convert. + rewrite F.to_Z_of_Z. + rewrite <-Conversion.convert_correct; auto using length_to_list. + Qed. + + +End ConversionProofs. diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v index 5eba3582b..70947abdb 100644 --- a/src/ModularArithmetic/Pow2BaseProofs.v +++ b/src/ModularArithmetic/Pow2BaseProofs.v @@ -331,6 +331,22 @@ Section Pow2BaseProofs. intros; apply nth_default_preserves_properties; auto; omega. Qed. Hint Resolve nth_default_limb_widths_nonneg. + Lemma parity_decode : forall x, + (0 < nth_default 0 limb_widths 0) -> + length x = length limb_widths -> + Z.odd (BaseSystem.decode base x) = Z.odd (nth_default 0 x 0). + Proof. + intros. + destruct limb_widths, x; simpl in *; try discriminate; try reflexivity. + rewrite peel_decode, nth_default_cons. + fold (BaseSystem.mul_each (two_p z)). + rewrite <-mul_each_base, mul_each_rep. + rewrite Z.odd_add_mul_even; [ f_equal; ring | ]. + rewrite <-Z.even_spec, two_p_correct. + apply Z.even_pow. + rewrite @nth_default_cons in *; auto. + Qed. + Lemma decode_firstn_pow2_mod : forall us i, (i <= length us)%nat -> length us = length limb_widths -> diff --git a/src/MxDHRepChange.v b/src/MxDHRepChange.v new file mode 100644 index 000000000..c99635baf --- /dev/null +++ b/src/MxDHRepChange.v @@ -0,0 +1,157 @@ +Require Import Crypto.Spec.MxDH. +Require Import Crypto.Algebra. Import Monoid Group Ring Field. +Require Import Crypto.Util.Tuple. + +Section MxDHRepChange. + Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} {field:@Algebra.field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} {Feq_dec:Decidable.DecidableRel Feq} {Fcswap:bool->F*F->F*F->(F*F)*(F*F)} {Fa24:F} {tb1:nat->bool}. + Context {K Keq Kzero Kone Kopp Kadd Ksub Kmul Kinv Kdiv} {impl_field:@Algebra.field K Keq Kzero Kone Kopp Kadd Ksub Kmul Kinv Kdiv} {Keq_dec:Decidable.DecidableRel Keq} {Kcswap:bool->K*K->K*K->(K*K)*(K*K)} {Ka24:K} {tb2:nat->bool}. + + Context {FtoK} {homom:@Ring.is_homomorphism F Feq Fone Fadd Fmul + K Keq Kone Kadd Kmul FtoK}. + Context {homomorphism_inv_zero:Keq (FtoK (Finv Fzero)) (Kinv Kzero)}. + Context {homomorphism_a24:Keq (FtoK Fa24) Ka24}. + Context {Fcswap_correct:forall b x y, Fcswap b x y = if b then (y,x) else (x,y)}. + Context {Kcswap_correct:forall b x y, Kcswap b x y = if b then (y,x) else (x,y)}. + Context {tb2_correct:forall i, tb2 i = tb1 i}. + + (* TODO: move to algebra *) + Lemma homomorphism_multiplicative_inverse_complete' x : Keq (FtoK (Finv x)) (Kinv (FtoK x)). + Proof. + eapply (homomorphism_multiplicative_inverse_complete). + intro J; rewrite J. rewrite homomorphism_inv_zero, homomorphism_id. + reflexivity. + Qed. + + Ltac t := + repeat ( + rewrite homomorphism_id || + rewrite homomorphism_one || + rewrite homomorphism_a24 || + rewrite homomorphism_sub || + rewrite homomorphism_add || + rewrite homomorphism_mul || + rewrite homomorphism_multiplicative_inverse_complete' || + reflexivity + ). + + Import Crypto.Util.Tactics. + Import List. + Import Coq.Classes.Morphisms. + + Global Instance Proper_ladderstep : + Proper (Keq ==> (fieldwise (n:=2) Keq) ==> fieldwise (n:=2) Keq ==> fieldwise (n:=2) (fieldwise (n:=2) Keq)) (@MxDH.ladderstep K Kadd Ksub Kmul Ka24). + Proof. + cbv [MxDH.ladderstep tuple tuple' fieldwise fieldwise' fst snd]; + repeat intro; destruct_head' prod; destruct_head' and; repeat split; + repeat match goal with [H:Keq ?x ?y |- _ ] => rewrite !H; clear H x end; reflexivity. + Qed. + + Lemma MxLadderstepRepChange u P Q Ku (Ku_correct:Keq (FtoK u) Ku): + fieldwise (n:=2) (fieldwise (n:=2) Keq) + ((Tuple.map (n:=2) (Tuple.map (n:=2) FtoK)) (@MxDH.ladderstep F Fadd Fsub Fmul Fa24 u P Q)) + (@MxDH.ladderstep K Kadd Ksub Kmul Ka24 Ku (Tuple.map (n:=2) FtoK P) (Tuple.map (n:=2) FtoK Q)). + Proof. + destruct P as [? ?], Q as [? ?]; cbv; repeat split; rewrite <-?Ku_correct; t. + Qed. + + Let loopiter_sig F Fzero Fone Fadd Fsub Fmul Finv Fa24 Fcswap b tb u : + { loopiter | @MxDH.montladder F Fzero Fone Fadd Fsub Fmul Finv Fa24 Fcswap b tb u = + let '(_, _, _) := MxDH.downto _ _ loopiter in _ } := exist _ _ eq_refl. + Let loopiter F Fzero Fone Fadd Fsub Fmul Finv Fa24 Fcswap b tb u := + Eval cbv [proj1_sig loopiter_sig] in ( + proj1_sig (loopiter_sig F Fzero Fone Fadd Fsub Fmul Finv Fa24 Fcswap b tb u)). + + Let loopiter_phi s : ((K * K) * (K * K)) * bool := + (Tuple.map (n:=2) (Tuple.map (n:=2) FtoK) (fst s), snd s). + + Let loopiter_eq (a b: (((K * K) * (K * K)) * bool)) := + fieldwise (n:=2) (fieldwise (n:=2) Keq) (fst a) (fst b) /\ snd a = snd b. + + Local Instance Equivalence_loopiter_eq : Equivalence loopiter_eq. + Proof. + unfold loopiter_eq; split; repeat intro; + intuition (reflexivity||symmetry;eauto||etransitivity;symmetry;eauto). + Qed. + + Lemma MxLoopIterRepChange b Fu s i Ku (HKu:Keq (FtoK Fu) Ku) : loopiter_eq + (loopiter_phi (loopiter F Fzero Fone Fadd Fsub Fmul Finv Fa24 Fcswap b tb1 Fu s i)) + (loopiter K Kzero Kone Kadd Ksub Kmul Kinv Ka24 Kcswap b tb2 Ku (loopiter_phi s) i). + Proof. + destruct_head' prod; break_match. + simpl. + rewrite !Fcswap_correct, !Kcswap_correct, tb2_correct in *. + break_match; cbv [loopiter_eq loopiter_phi fst snd]; split; try reflexivity; + (etransitivity;[|etransitivity]; [ | eapply (MxLadderstepRepChange _ _ _ _ HKu) | ]; + match goal with Heqp:_ |- _ => rewrite <-Heqp; reflexivity end). + Qed. + + Global Instance Proper_fold_left {A RA B RB} : + Proper ((RA==>RB==>RA) ==> SetoidList.eqlistA RB ==> RA ==> RA) (@fold_left A B). + Proof. + intros ? ? ? ? ? Hl; induction Hl; repeat intro; [assumption|]. + simpl; cbv [Proper respectful] in *; eauto. + Qed. + + Lemma proj_fold_left {A B L} R {Equivalence_R:@Equivalence B R} (proj:A->B) step step' {Proper_step':(R ==> eq ==> R)%signature step' step'} (xs:list L) init init' (H0:R (proj init) init') (Hproj:forall x i, R (proj (step x i)) (step' (proj x) i)) : + R (proj (fold_left step xs init)) (fold_left step' xs init'). + Proof. + generalize dependent init; generalize dependent init'. + induction xs; [solve [eauto]|]. + repeat intro; simpl; rewrite IHxs by eauto. + f_equiv; eapply Proper_step'; eauto. + Qed. + + Global Instance Proper_downto {T R} {Equivalence_R:@Equivalence T R} : + Proper (R ==> Logic.eq ==> (R==>Logic.eq==>R) ==> R) MxDH.downto. + Proof. + intros s0 s0' Hs0 n' n Hn'; subst n'; generalize dependent s0; generalize dependent s0'. + induction n; repeat intro; [assumption|]. + unfold MxDH.downto; simpl. + eapply Proper_fold_left; try eassumption; try reflexivity. + cbv [Proper respectful] in *; eauto. + Qed. + + Global Instance Proper_loopiter a b c : + Proper (loopiter_eq ==> eq ==> loopiter_eq) (loopiter K Kzero Kone Kadd Ksub Kmul Kinv Ka24 Kcswap a b c). + Proof. + unfold loopiter; intros [? ?] [? ?] [[[] []] ?]; repeat intro ; cbv [fst snd] in * |-; subst. + repeat VerdiTactics.break_match; subst; repeat (VerdiTactics.find_injection; intros; subst). + split; [|reflexivity]. + etransitivity; [|etransitivity]; [ | eapply Proper_ladderstep | ]; + [eapply eq_subrelation; [ exact _ | symmetry; eassumption ] + | reflexivity | | + | eapply eq_subrelation; [exact _ | eassumption ] ]; + rewrite !Kcswap_correct in *; + match goal with [H: (if ?b then _ else _) = _ |- _] => destruct b end; + repeat (VerdiTactics.find_injection; intros; subst); + split; simpl; eauto. + Qed. + + Lemma MxDHRepChange b (x:F) : + Keq + (FtoK (@MxDH.montladder F Fzero Fone Fadd Fsub Fmul Finv Fa24 Fcswap b tb1 x)) + (@MxDH.montladder K Kzero Kone Kadd Ksub Kmul Kinv Ka24 Kcswap b tb2 (FtoK x)). + Proof. + cbv [MxDH.montladder]. + repeat break_match. + assert (Hrel:loopiter_eq (loopiter_phi (p, p0, b0)) (p1, p3, b1)). + { + rewrite <-Heqp0; rewrite <-Heqp. + unfold MxDH.downto. + eapply (proj_fold_left (L:=nat) loopiter_eq loopiter_phi). + { eapply @Proper_loopiter. } + { cbv; repeat split; t. } + { intros; eapply MxLoopIterRepChange; reflexivity. } + } + { destruct_head' prod; destruct Hrel as [[[] []] ?]; simpl in *; subst. + rewrite !Fcswap_correct, !Kcswap_correct in *. + match goal with [H: (if ?b then _ else _) = _ |- _] => destruct b end; + repeat (VerdiTactics.find_injection; intros; subst); + repeat match goal with [H: Keq (FtoK ?x) (?kx)|- _ ] => rewrite <- H end; + t. + } + Grab Existential Variables. + exact 0. + exact 0. + Qed. +End MxDHRepChange.
\ No newline at end of file diff --git a/src/Reflection/Application.v b/src/Reflection/Application.v new file mode 100644 index 000000000..4274b3b3f --- /dev/null +++ b/src/Reflection/Application.v @@ -0,0 +1,150 @@ +Require Import Crypto.Reflection.Syntax. + +Section language. + Context {base_type : Type} + {interp_base_type : base_type -> Type} + {op : flat_type base_type -> flat_type base_type -> Type} + {interp_op : forall src dst, op src dst -> interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst}. + Fixpoint count_binders (t : type base_type) : nat + := match t with + | Arrow A B => S (count_binders B) + | Tflat _ => 0 + end. + + Fixpoint remove_binders' (n : nat) (t : type base_type) {struct t} : type base_type + := match t, n with + | Tflat _, _ => t + | Arrow _ B, 0 => B + | Arrow A B, S n' + => remove_binders' n' B + end. + + Definition remove_binders (n : nat) (t : type base_type) : type base_type + := match n with + | 0 => t + | S n' => remove_binders' n' t + end. + + Fixpoint remove_all_binders (t : type base_type) : flat_type base_type + := match t with + | Tflat T => T + | Arrow A B => remove_all_binders B + end. + + Fixpoint binders_for' (n : nat) (t : type base_type) (var : base_type -> Type) {struct t} + := match n, t return Type with + | 0, Arrow A B => var A + | S n', Arrow A B => var A * binders_for' n' B var + | _, _ => unit + end%type. + Definition binders_for (n : nat) (t : type base_type) (var : base_type -> Type) + := match n return Type with + | 0 => unit + | S n' => binders_for' n' t var + end. + + Fixpoint all_binders_for (t : type base_type) + := match t return match t with + | Tflat _ => unit + | _ => flat_type base_type + end with + | Tflat T => tt + | Arrow A B + => match B return match B with Tflat _ => _ | _ => _ end -> _ with + | Tflat T => fun _ => Tbase A + | Arrow _ _ => fun T => Tbase A * T + end%ctype (all_binders_for B) + end. + + Definition interp_all_binders_for T var + := match T return Type with + | Tflat _ => unit + | Arrow A B => interp_flat_type var (all_binders_for (Arrow A B)) + end. + + Definition fst_binder {A B var} (args : interp_flat_type var (all_binders_for (Arrow A B))) : var A + := match B return interp_flat_type var (all_binders_for (Arrow A B)) -> var A with + | Tflat _ => fun x => x + | Arrow _ _ => fun x => fst x + end args. + Definition snd_binder {A B var} (args : interp_flat_type var (all_binders_for (Arrow A B))) + : interp_all_binders_for B var + := match B return interp_flat_type var (all_binders_for (Arrow A B)) + -> interp_all_binders_for B var + with + | Tflat _ => fun _ => tt + | Arrow _ _ => fun x => snd x + end args. + + Fixpoint Apply' n {var t} (x : @expr base_type interp_base_type op var t) + : forall (args : binders_for' n t var), + @expr base_type interp_base_type op var (remove_binders' n t) + := match x in (@expr _ _ _ _ t), n return (binders_for' n t var -> @expr _ _ _ _ (remove_binders' n t)) with + | Return _ _ as y, _ => fun _ => y + | Abs _ _ f, 0 => f + | Abs src dst f, S n' => fun args => @Apply' n' var dst (f (fst args)) (snd args) + end. + + Definition Apply n {var t} (x : @expr base_type interp_base_type op var t) + : forall (args : binders_for n t var), + @expr base_type interp_base_type op var (remove_binders n t) + := match n return binders_for n t var -> @expr _ _ _ _ (remove_binders n t) with + | 0 => fun _ => x + | S n' => @Apply' n' var t x + end. + + Fixpoint ApplyAll {var t} (x : @expr base_type interp_base_type op var t) + : forall (args : interp_all_binders_for t var), + @exprf base_type interp_base_type op var (remove_all_binders t) + := match x in @expr _ _ _ _ t + return (forall (args : interp_all_binders_for t var), + @exprf base_type interp_base_type op var (remove_all_binders t)) + with + | Return _ x => fun _ => x + | Abs src dst f => fun args => @ApplyAll var dst (f (fst_binder args)) (snd_binder args) + end. + + Fixpoint ApplyInterped' n {t} {struct t} + : forall (x : interp_type interp_base_type t) + (args : binders_for' n t interp_base_type), + interp_type interp_base_type (remove_binders' n t) + := match t, n return (forall (x : interp_type interp_base_type t) + (args : binders_for' n t interp_base_type), + interp_type interp_base_type (remove_binders' n t)) with + | Tflat _, _ => fun x _ => x + | Arrow s d, 0 => fun x => x + | Arrow s d, S n' => fun f args => @ApplyInterped' n' d (f (fst args)) (snd args) + end. + + Definition ApplyInterped (n : nat) {t} (x : interp_type interp_base_type t) + : forall (args : binders_for n t interp_base_type), + interp_type interp_base_type (remove_binders n t) + := match n return (binders_for n t interp_base_type -> interp_type interp_base_type (remove_binders n t)) with + | 0 => fun _ => x + | S n' => @ApplyInterped' n' t x + end. + + Fixpoint ApplyInterpedAll {t} + : forall (x : interp_type interp_base_type t) + (args : interp_all_binders_for t interp_base_type), + interp_flat_type interp_base_type (remove_all_binders t) + := match t return (forall (x : interp_type _ t) + (args : interp_all_binders_for t _), + interp_flat_type _ (remove_all_binders t)) + with + | Tflat _ => fun x _ => x + | Arrow A B => fun f x => @ApplyInterpedAll B (f (fst_binder x)) (snd_binder x) + end. +End language. + +Arguments all_binders_for {_} !_ / . +Arguments interp_all_binders_for {_} !_ _ / . +Arguments count_binders {_} !_ / . +Arguments binders_for {_} !_ !_ _ / . +Arguments remove_binders {_} !_ !_ / . +(* Work around bug #5175 *) +Arguments Apply {_ _ _ _ _ _} _ _ , {_ _ _} _ {_ _} _ _. +Arguments Apply _ _ _ !_ _ _ !_ !_ / . +Arguments ApplyInterped {_ _ !_ !_} _ _ / . +Arguments ApplyAll {_ _ _ _ !_} !_ _ / . +Arguments ApplyInterpedAll {_ _ !_} _ _ / . diff --git a/src/Reflection/Conversion.v b/src/Reflection/Conversion.v index 6f69b8f99..14dfc5633 100644 --- a/src/Reflection/Conversion.v +++ b/src/Reflection/Conversion.v @@ -93,6 +93,7 @@ Section language. Proof. induction e; repeat match goal with + | _ => progress unfold LetIn.Let_In | _ => reflexivity | _ => progress simpl in * | _ => rewrite_hyp !* diff --git a/src/Reflection/CountLets.v b/src/Reflection/CountLets.v index 927e7a168..8de6e7a2f 100644 --- a/src/Reflection/CountLets.v +++ b/src/Reflection/CountLets.v @@ -33,7 +33,7 @@ Section language. Fixpoint count_lets_genf {t} (e : exprf t) : nat := match e with | LetIn tx _ _ eC - => count_type_let tx + @count_lets_genf _ (eC (SmartVal var mkVar tx)) + => count_type_let tx + @count_lets_genf _ (eC (SmartValf var mkVar tx)) | _ => 0 end. Fixpoint count_lets_gen {t} (e : expr t) : nat diff --git a/src/Reflection/FilterLive.v b/src/Reflection/FilterLive.v index 3c1c3c8f7..446f9195c 100644 --- a/src/Reflection/FilterLive.v +++ b/src/Reflection/FilterLive.v @@ -50,7 +50,7 @@ Section language. | Some n => @filter_live_namesf (prefix ++ repeat dead_name (count_pairs tx))%list remaining' _ - (eC (SmartVal (fun _ => list Name) (fun _ => namesx ++ names_to_list n)%list _)) + (eC (SmartValf (fun _ => list Name) (fun _ => namesx ++ names_to_list n)%list _)) | None => nil end | Pair _ ex _ ey => merge_name_lists (@filter_live_namesf prefix remaining _ ex) diff --git a/src/Reflection/Inline.v b/src/Reflection/Inline.v index bfb3794c9..a42df2b68 100644 --- a/src/Reflection/Inline.v +++ b/src/Reflection/Inline.v @@ -33,12 +33,12 @@ Section language. => match postprocess _ (@inline_const_genf _ ex) in inline_directive t' return (interp_flat_type _ t' -> @exprf var tC) -> @exprf var tC with | default_inline _ ex => match ex in Syntax.exprf _ _ _ t' return (interp_flat_type _ t' -> @exprf var tC) -> @exprf var tC with - | Const _ x => fun eC => eC (SmartConst (op:=op) (var:=var) x) + | Const _ x => fun eC => eC (SmartConstf (op:=op) (var:=var) x) | Var _ x => fun eC => eC (Var x) - | ex => fun eC => LetIn ex (fun x => eC (SmartVarVar x)) + | ex => fun eC => LetIn ex (fun x => eC (SmartVarVarf x)) end | no_inline _ ex - => fun eC => LetIn ex (fun x => eC (SmartVarVar x)) + => fun eC => LetIn ex (fun x => eC (SmartVarVarf x)) | inline _ ex => fun eC => eC ex end (fun x => @inline_const_genf _ (eC x)) | Var _ x => x diff --git a/src/Reflection/InlineInterp.v b/src/Reflection/InlineInterp.v index 7de1cc3a5..27811914c 100644 --- a/src/Reflection/InlineInterp.v +++ b/src/Reflection/InlineInterp.v @@ -25,8 +25,8 @@ Section language. Local Notation wff := (@wff base_type_code interp_base_type op). Local Notation wf := (@wf base_type_code interp_base_type op). - Local Hint Extern 1 => eapply interpf_SmartConst. - Local Hint Extern 1 => eapply interpf_SmartVarVar. + Local Hint Extern 1 => eapply interpf_SmartConstf. + Local Hint Extern 1 => eapply interpf_SmartVarVarf. Local Ltac t_fin := repeat match goal with @@ -71,7 +71,7 @@ Section language. (existT (fun t : base_type_code => (exprf (Syntax.Tbase t) * interp_base_type t)%type) t (x, x')) G -> interpf interp_op x = x') - : interp_type_gen_rel_pointwise interp_flat_type (fun _ => @eq _) + : interp_type_gen_rel_pointwise (fun _ => @eq _) (interp interp_op (inline_const e1)) (interp interp_op e2). Proof. @@ -86,7 +86,7 @@ Section language. Lemma Interp_InlineConst {t} (e : Expr t) (wf : Wf e) - : interp_type_gen_rel_pointwise interp_flat_type (fun _ => @eq _) + : interp_type_gen_rel_pointwise (fun _ => @eq _) (Interp interp_op (InlineConst e)) (Interp interp_op e). Proof. diff --git a/src/Reflection/InlineWf.v b/src/Reflection/InlineWf.v index 781643a8a..dd0fb08a3 100644 --- a/src/Reflection/InlineWf.v +++ b/src/Reflection/InlineWf.v @@ -32,8 +32,8 @@ Section language. Local Hint Constructors Syntax.wff. Local Hint Extern 1 => progress unfold List.In in *. Local Hint Resolve wff_in_impl_Proper. - Local Hint Resolve wff_SmartVar. - Local Hint Resolve wff_SmartConst. + Local Hint Resolve wff_SmartVarf. + Local Hint Resolve wff_SmartConstf. Local Ltac t_fin := repeat first [ intro @@ -42,9 +42,9 @@ Section language. | tauto | progress subst | solve [ auto with nocore - | eapply (@wff_SmartVarVar _ _ _ _ _ _ _ (_ * _)); auto - | eapply wff_SmartConst; eauto with nocore - | eapply wff_SmartVarVar; eauto with nocore ] + | eapply (@wff_SmartVarVarf _ _ _ _ _ _ _ (_ * _)); auto + | eapply wff_SmartConstf; eauto with nocore + | eapply wff_SmartVarVarf; eauto with nocore ] | progress simpl in * | constructor | solve [ eauto ] ]. diff --git a/src/Reflection/InputSyntax.v b/src/Reflection/InputSyntax.v index 15f7515b3..dd9cab21d 100644 --- a/src/Reflection/InputSyntax.v +++ b/src/Reflection/InputSyntax.v @@ -70,7 +70,7 @@ Section language. Fixpoint compilef {t} (e : @exprf (interp_flat_type_gen var) t) : @Syntax.exprf base_type_code interp_base_type op var t := match e in exprf t return @Syntax.exprf _ _ _ _ t with | Const _ x => Syntax.Const x - | Var _ x => Syntax.SmartVar x + | Var _ x => Syntax.SmartVarf x | Op _ _ op args => Syntax.Op op (@compilef _ args) | LetIn _ ex _ eC => Syntax.LetIn (@compilef _ ex) (fun x => @compilef _ (eC x)) | Pair _ ex _ ey => Syntax.Pair (@compilef _ ex) (@compilef _ ey) @@ -96,8 +96,9 @@ Section language. induction e; repeat match goal with | _ => reflexivity + | _ => progress unfold LetIn.Let_In | _ => progress simpl in * - | _ => rewrite interpf_SmartVar + | _ => rewrite interpf_SmartVarf | _ => rewrite <- surjective_pairing | _ => progress rewrite_hyp * | [ |- context[let (x, y) := ?v in _] ] @@ -105,18 +106,23 @@ Section language. end. Qed. - Lemma Compile_correct {t : flat_type} (e : @Expr t) - : Syntax.Interp interp_op (Compile e) = Interp interp_op e. + Lemma Compile_correct {t} (e : @Expr t) + : interp_type_gen_rel_pointwise (fun _ => @eq _) + (Syntax.Interp interp_op (Compile e)) + (Interp interp_op e). Proof. unfold Interp, Compile, Syntax.Interp; simpl. pose (e interp_flat_type) as E. repeat match goal with |- context[e ?f] => change (e f) with E end. - refine match E with - | Abs _ _ _ => fun x : Prop => x (* small inversions *) - | Return _ _ => _ - end. - apply compilef_correct. + clearbody E; clear e. + induction E. + { apply compilef_correct. } + { simpl; auto. } Qed. + + Lemma Compile_flat_correct {t : flat_type} (e : @Expr t) + : Syntax.Interp interp_op (Compile e) = Interp interp_op e. + Proof. exact (@Compile_correct t e). Qed. End compile_correct. End expr_param. End language. @@ -129,3 +135,4 @@ Global Arguments MatchPair {_ _ _ _ _ _} _ {_} _. Global Arguments Pair {_ _ _ _ _} _ {_} _. Global Arguments Return {_ _ _ _ _} _. Global Arguments Abs {_ _ _ _ _ _} _. +Global Arguments Compile {_ _ _ t} _ _. diff --git a/src/Reflection/InterpProofs.v b/src/Reflection/InterpProofs.v index 86e8190ea..880256f1d 100644 --- a/src/Reflection/InterpProofs.v +++ b/src/Reflection/InterpProofs.v @@ -16,10 +16,10 @@ Section language. Let interp_flat_type := interp_flat_type interp_base_type. Context (interp_op : forall src dst, op src dst -> interp_flat_type src -> interp_flat_type dst). - Lemma interpf_SmartVar t v - : Syntax.interpf interp_op (SmartVar (t:=t) v) = v. + Lemma interpf_SmartVarf t v + : Syntax.interpf interp_op (SmartVarf (t:=t) v) = v. Proof. - unfold SmartVar; induction t; + unfold SmartVarf; induction t; repeat match goal with | _ => reflexivity | _ => progress simpl in * @@ -28,11 +28,11 @@ Section language. end. Qed. - Lemma interpf_SmartConst {t t'} v x x' + Lemma interpf_SmartConstf {t t'} v x x' (Hin : List.In (existT (fun t : base_type_code => (exprf base_type_code interp_base_type op (Syntax.Tbase t) * interp_base_type t)%type) t (x, x')) - (flatten_binding_list (t := t') base_type_code (SmartConst v) v)) + (flatten_binding_list (t := t') base_type_code (SmartConstf v) v)) : interpf interp_op x = x'. Proof. clear -Hin. @@ -42,11 +42,11 @@ Section language. intuition (inversion_sigma; inversion_prod; subst; eauto). } Qed. - Lemma interpf_SmartVarVar {t t'} v x x' + Lemma interpf_SmartVarVarf {t t'} v x x' (Hin : List.In (existT (fun t : base_type_code => (exprf base_type_code interp_base_type op (Syntax.Tbase t) * interp_base_type t)%type) t (x, x')) - (flatten_binding_list (t := t') base_type_code (SmartVarVar v) v)) + (flatten_binding_list (t := t') base_type_code (SmartVarVarf v) v)) : interpf interp_op x = x'. Proof. clear -Hin. @@ -56,14 +56,14 @@ Section language. intuition (inversion_sigma; inversion_prod; subst; eauto). } Qed. - Lemma interpf_SmartVarVar_eq {t t'} v v' x x' + Lemma interpf_SmartVarVarf_eq {t t'} v v' x x' (Heq : v = v') (Hin : List.In (existT (fun t : base_type_code => (exprf base_type_code interp_base_type op (Syntax.Tbase t) * interp_base_type t)%type) t (x, x')) - (flatten_binding_list (t := t') base_type_code (SmartVarVar v') v)) + (flatten_binding_list (t := t') base_type_code (SmartVarVarf v') v)) : interpf interp_op x = x'. Proof. - subst; eapply interpf_SmartVarVar; eassumption. + subst; eapply interpf_SmartVarVarf; eassumption. Qed. End language. diff --git a/src/Reflection/InterpWf.v b/src/Reflection/InterpWf.v index ef1168555..c53389b8c 100644 --- a/src/Reflection/InterpWf.v +++ b/src/Reflection/InterpWf.v @@ -72,7 +72,7 @@ Section language. List.In (existT (fun t : base_type_code => (interp_base_type t * interp_base_type t)%type) t (x, x')%core) G -> x = x') (Rwf : wf G e1 e2) - : interp_type_gen_rel_pointwise2 (fun _ => eq) (interp e1) (interp e2). + : interp_type_gen_rel_pointwise (fun _ => eq) (interp e1) (interp e2). Proof. induction Rwf; simpl; repeat intro; simpl in *; subst; eauto. match goal with diff --git a/src/Reflection/Linearize.v b/src/Reflection/Linearize.v index 8dafbdc11..810d9115b 100644 --- a/src/Reflection/Linearize.v +++ b/src/Reflection/Linearize.v @@ -53,11 +53,11 @@ Section language. | Const _ x => Const x | Var _ x => Var x | Op _ _ op args - => under_letsf (@linearizef _ args) (fun args => LetIn (Op op (SmartVar args)) SmartVar) + => under_letsf (@linearizef _ args) (fun args => LetIn (Op op (SmartVarf args)) SmartVarf) | Pair A ex B ey => under_letsf (@linearizef _ ex) (fun x => under_letsf (@linearizef _ ey) (fun y => - SmartVar (t:=Prod A B) (x, y))) + SmartVarf (t:=Prod A B) (x, y))) end. Fixpoint linearize {t} (e : expr t) : expr t diff --git a/src/Reflection/LinearizeInterp.v b/src/Reflection/LinearizeInterp.v index 1da4ac851..3ee3960d5 100644 --- a/src/Reflection/LinearizeInterp.v +++ b/src/Reflection/LinearizeInterp.v @@ -25,12 +25,13 @@ Section language. Local Notation wff := (@wff base_type_code interp_base_type op). Local Notation wf := (@wf base_type_code interp_base_type op). - Local Hint Extern 1 => eapply interpf_SmartConst. - Local Hint Extern 1 => eapply interpf_SmartVarVar. + Local Hint Extern 1 => eapply interpf_SmartConstf. + Local Hint Extern 1 => eapply interpf_SmartVarVarf. Local Ltac t_fin := repeat match goal with | _ => reflexivity + | _ => progress unfold LetIn.Let_In | _ => progress simpl in * | _ => progress intros | _ => progress inversion_sigma @@ -69,7 +70,7 @@ Section language. Proof. clear. induction e; - repeat first [ progress rewrite ?interpf_under_letsf, ?interpf_SmartVar + repeat first [ progress rewrite ?interpf_under_letsf, ?interpf_SmartVarf | progress simpl | t_fin ]. Qed. @@ -77,7 +78,7 @@ Section language. Local Hint Resolve interpf_linearizef. Lemma interp_linearize {t} e - : interp_type_gen_rel_pointwise interp_flat_type (fun _ => @eq _) + : interp_type_gen_rel_pointwise (fun _ => @eq _) (interp interp_op (linearize (t:=t) e)) (interp interp_op e). Proof. @@ -86,7 +87,7 @@ Section language. Qed. Lemma Interp_Linearize {t} (e : Expr t) - : interp_type_gen_rel_pointwise interp_flat_type (fun _ => @eq _) + : interp_type_gen_rel_pointwise (fun _ => @eq _) (Interp interp_op (Linearize e)) (Interp interp_op e). Proof. diff --git a/src/Reflection/LinearizeWf.v b/src/Reflection/LinearizeWf.v index 8a7ebb7af..36c9efecb 100644 --- a/src/Reflection/LinearizeWf.v +++ b/src/Reflection/LinearizeWf.v @@ -161,7 +161,7 @@ Section language. Local Hint Constructors or. Local Hint Extern 1 => progress unfold List.In in *. Local Hint Resolve wff_in_impl_Proper. - Local Hint Resolve wff_SmartVar. + Local Hint Resolve wff_SmartVarf. Lemma wff_linearizef G {t} e1 e2 : @wff var1 var2 G t e1 e2 diff --git a/src/Reflection/Named/EstablishLiveness.v b/src/Reflection/Named/EstablishLiveness.v index b9d283013..b2be2d19b 100644 --- a/src/Reflection/Named/EstablishLiveness.v +++ b/src/Reflection/Named/EstablishLiveness.v @@ -62,7 +62,7 @@ Section language. | LetIn tx n ex _ eC => let lx := @compute_livenessf ctx _ ex prefix in let lx := merge_liveness lx (prefix ++ repeat live (count_pairs tx)) in - let ctx := extend ctx n (SmartVal _ (fun _ => lx) tx) in + let ctx := extend ctx n (SmartValf _ (fun _ => lx) tx) in @compute_livenessf ctx _ eC (prefix ++ repeat dead (count_pairs tx)) | Pair _ ex _ ey => merge_liveness (@compute_livenessf ctx _ ex prefix) diff --git a/src/Reflection/Named/Syntax.v b/src/Reflection/Named/Syntax.v index 0ea950325..e77947693 100644 --- a/src/Reflection/Named/Syntax.v +++ b/src/Reflection/Named/Syntax.v @@ -183,7 +183,7 @@ Global Arguments Var {_ _ _ _ _} _. Global Arguments SmartVar {_ _ _ _ _} _. Global Arguments SmartConst {_ _ _ _ _} _. Global Arguments Op {_ _ _ _ _ _} _ _. -Global Arguments LetIn {_ _ _ _} _ {_} _ {_} _. +Global Arguments LetIn {_ _ _ _} _ _ _ {_} _. Global Arguments Pair {_ _ _ _ _} _ {_} _. Global Arguments Return {_ _ _ _ _} _. Global Arguments Abs {_ _ _ _ _ _} _ _. diff --git a/src/Reflection/Reify.v b/src/Reflection/Reify.v index b10ab66f4..1a2f22eca 100644 --- a/src/Reflection/Reify.v +++ b/src/Reflection/Reify.v @@ -6,6 +6,7 @@ Require Import Crypto.Reflection.Syntax. Require Import Crypto.Reflection.InputSyntax. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Notations. Class reify {varT} (var : varT) {eT} (e : eT) {T : Type} := Build_reify : T. @@ -67,11 +68,11 @@ Inductive reify_result_helper := | reification_unsuccessful. (** Override this to get a faster [reify_op] *) -Ltac base_reify_op op op_head := +Ltac base_reify_op op op_head expr := let r := constr:(_ : reify_op op op_head _ _) in type of r. -Ltac reify_op op op_head := - let t := base_reify_op op op_head in +Ltac reify_op op op_head expr := + let t := base_reify_op op op_head expr in constr:(op_info t). (** Change this with [Ltac reify_debug_level ::= constr:(1).] to get @@ -122,6 +123,10 @@ Ltac reifyf base_type_code interp_base_type op var e := let ex := reify_rec ex in let eC := reify_rec eC in mkLetIn ex eC + | dlet x := ?ex in @?eC x => + let ex := reify_rec ex in + let eC := reify_rec eC in + mkLetIn ex eC | pair ?a ?b => let a := reify_rec a in let b := reify_rec b in @@ -147,7 +152,7 @@ Ltac reifyf base_type_code interp_base_type op var e := let retv := match constr:(Set) with | _ => let retv := reifyf_var x mkVar in constr:(finished_value retv) | _ => let op_head := head x in - reify_op op op_head + reify_op op op_head x | _ => let c := mkConst t x in constr:(finished_value c) | _ => constr:(reification_unsuccessful) @@ -186,6 +191,20 @@ Ltac reifyf base_type_code interp_base_type op var e := let args := let a01 := mkPair a0 a1 in mkPair a01 a2 in mkOp (@Prod _ (@Prod _ a0T a1T) a2T) tR op_code args end + | 4%nat + => lazymatch x with + | ?f ?x0 ?x1 ?x2 ?x3 + => let a0T := (let t := type of x0 in reify_flat_type t) in + let a0 := reify_rec x0 in + let a1T := (let t := type of x1 in reify_flat_type t) in + let a1 := reify_rec x1 in + let a2T := (let t := type of x2 in reify_flat_type t) in + let a2 := reify_rec x2 in + let a3T := (let t := type of x3 in reify_flat_type t) in + let a3 := reify_rec x3 in + let args := let a01 := mkPair a0 a1 in let a012 := mkPair a01 a2 in mkPair a012 a3 in + mkOp (@Prod _ (@Prod _ (@Prod _ a0T a1T) a2T) a3T) tR op_code args + end | _ => cfail2 "Unsupported number of operation arguments in reifyf:"%string nargs end | reification_unsuccessful @@ -230,32 +249,56 @@ Ltac Reify' base_type_code interp_base_type op e := end. Ltac Reify base_type_code interp_base_type op e := let r := Reify' base_type_code interp_base_type op e in - constr:(InputSyntax.Compile base_type_code interp_base_type op r). + constr:(@InputSyntax.Compile base_type_code interp_base_type op _ r). Ltac lhs_of_goal := lazymatch goal with |- ?R ?LHS ?RHS => LHS end. Ltac rhs_of_goal := lazymatch goal with |- ?R ?LHS ?RHS => RHS end. -Ltac Reify_rhs base_type_code interp_base_type op interp_op := +Ltac Reify_rhs_gen Reify prove_interp_compile_correct interp_op try_tac := let rhs := rhs_of_goal in - let RHS := Reify base_type_code interp_base_type op rhs in - transitivity (Syntax.Interp interp_op RHS); + let RHS := Reify rhs in + let RHS' := (eval vm_compute in RHS) in + transitivity (Syntax.Interp interp_op RHS'); [ - | etransitivity; (* first we strip off the [InputSyntax.Compile] - bit; Coq is bad at inferring the type, so we - help it out by providing it *) + | transitivity (Syntax.Interp interp_op RHS); [ lazymatch goal with - | [ |- @Syntax.Interp ?base_type_code ?interp_base_type ?op ?interp_op (@Tflat _ ?t) (@Compile _ _ _ _ ?e) = _ ] - => exact (@InputSyntax.Compile_correct base_type_code interp_base_type op interp_op t e) - end - | ((* now we unfold the interpretation function, including the - parameterized bits; we assume that [hnf] is enough to unfold - the interpretation functions that we're parameterized - over. *) - lazymatch goal with - | [ |- @InputSyntax.Interp ?base_type_code ?interp_base_type ?op ?interp_op ?t ?e = _ ] - => let interp_base_type' := (eval hnf in interp_base_type) in - let interp_op' := (eval hnf in interp_op) in - change interp_base_type with interp_base_type'; - change interp_op with interp_op' + | [ |- ?R ?x ?y ] + => cut (x = y) end; - cbv iota beta delta [InputSyntax.Interp interp_type interp_type_gen interp_flat_type interp interpf]; simplify_projections; reflexivity) ] ]. + [ let H := fresh in + intro H; rewrite H; reflexivity + | apply f_equal; vm_compute; reflexivity ] + | etransitivity; (* first we strip off the [InputSyntax.Compile] + bit; Coq is bad at inferring the type, so we + help it out by providing it *) + [ prove_interp_compile_correct () + | try_tac + ltac:(fun _ + => (* now we unfold the interpretation function, + including the parameterized bits; we assume that + [hnf] is enough to unfold the interpretation + functions that we're parameterized over. *) + abstract ( + lazymatch goal with + | [ |- ?R (@InputSyntax.Interp ?base_type_code ?interp_base_type ?op ?interp_op ?t ?e) _ ] + => let interp_base_type' := (eval hnf in interp_base_type) in + let interp_op' := (eval hnf in interp_op) in + change interp_base_type with interp_base_type'; + change interp_op with interp_op' + end; + cbv iota beta delta [InputSyntax.Interp interp_type interp_type_gen interp_type_gen_hetero interp_flat_type interp interpf]; reflexivity)) ] ] ]. + +Ltac prove_compile_correct := + fun _ => lazymatch goal with + | [ |- @Syntax.Interp ?base_type_code ?interp_base_type ?op ?interp_op (@Tflat _ ?t) (@Compile _ _ _ _ ?e) = _ ] + => exact (@InputSyntax.Compile_flat_correct base_type_code interp_base_type op interp_op t e) + | [ |- interp_type_gen_rel_pointwise _ (@Syntax.Interp ?base_type_code ?interp_base_type ?op ?interp_op ?t (@Compile _ _ _ _ ?e)) _ ] + => exact (@InputSyntax.Compile_correct base_type_code interp_base_type op interp_op t e) + end. + +Ltac Reify_rhs base_type_code interp_base_type op interp_op := + Reify_rhs_gen + ltac:(Reify base_type_code interp_base_type op) + prove_compile_correct + interp_op + ltac:(fun tac => tac ()). diff --git a/src/Reflection/Syntax.v b/src/Reflection/Syntax.v index 37d585ab1..2f99be49c 100644 --- a/src/Reflection/Syntax.v +++ b/src/Reflection/Syntax.v @@ -1,6 +1,8 @@ (** * PHOAS Representation of Gallina *) Require Import Coq.Strings.String Coq.Classes.RelationClasses Coq.Classes.Morphisms. Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Tactics. Require Import Crypto.Util.Notations. @@ -26,32 +28,50 @@ Section language. Notation "A -> B" := (Arrow A B) : ctype_scope. Local Coercion Tbase : base_type_code >-> flat_type. + Fixpoint tuple' T n := + match n with + | O => T + | S n' => (tuple' T n' * T)%ctype + end. + Definition tuple T n := + match n with + | O => T (* default value; no empty tuple *) + | S n' => tuple' T n' + end. + Section interp. Section type. - Context (interp_flat_type : flat_type -> Type). - Fixpoint interp_type_gen (t : type) := - match t with - | Tflat t => interp_flat_type t - | Arrow x y => (interp_flat_type x -> interp_type_gen y)%type - end. - Section rel. - Context (R : forall t, interp_flat_type t -> interp_flat_type t -> Prop). - Fixpoint interp_type_gen_rel_pointwise (t : type) - : interp_type_gen t -> interp_type_gen t -> Prop := + Section hetero. + Context (interp_src_type : base_type_code -> Type). + Context (interp_flat_type : flat_type -> Type). + Fixpoint interp_type_gen_hetero (t : type) := match t with - | Tflat t => R t - | Arrow _ y => fun f g => forall x, interp_type_gen_rel_pointwise y (f x) (g x) + | Tflat t => interp_flat_type t + | Arrow x y => (interp_src_type x -> interp_type_gen_hetero y)%type end. - Global Instance interp_type_gen_rel_pointwise_Reflexive {H : forall t, Reflexive (R t)} - : forall t, Reflexive (interp_type_gen_rel_pointwise t). - Proof. induction t; repeat intro; reflexivity. Qed. - Global Instance interp_type_gen_rel_pointwise_Symmetric {H : forall t, Symmetric (R t)} - : forall t, Symmetric (interp_type_gen_rel_pointwise t). - Proof. induction t; simpl; repeat intro; symmetry; eauto. Qed. - Global Instance interp_type_gen_rel_pointwise_Transitive {H : forall t, Transitive (R t)} - : forall t, Transitive (interp_type_gen_rel_pointwise t). - Proof. induction t; simpl; repeat intro; etransitivity; eauto. Qed. - End rel. + End hetero. + Section homogenous. + Context (interp_flat_type : flat_type -> Type). + Definition interp_type_gen := interp_type_gen_hetero interp_flat_type interp_flat_type. + Section rel. + Context (R : forall t, interp_flat_type t -> interp_flat_type t -> Prop). + Fixpoint interp_type_gen_rel_pointwise (t : type) + : interp_type_gen t -> interp_type_gen t -> Prop := + match t with + | Tflat t => R t + | Arrow _ y => fun f g => forall x, interp_type_gen_rel_pointwise y (f x) (g x) + end. + Global Instance interp_type_gen_rel_pointwise_Reflexive {H : forall t, Reflexive (R t)} + : forall t, Reflexive (interp_type_gen_rel_pointwise t). + Proof. induction t; repeat intro; reflexivity. Qed. + Global Instance interp_type_gen_rel_pointwise_Symmetric {H : forall t, Symmetric (R t)} + : forall t, Symmetric (interp_type_gen_rel_pointwise t). + Proof. induction t; simpl; repeat intro; symmetry; eauto. Qed. + Global Instance interp_type_gen_rel_pointwise_Transitive {H : forall t, Transitive (R t)} + : forall t, Transitive (interp_type_gen_rel_pointwise t). + Proof. induction t; simpl; repeat intro; etransitivity; eauto. Qed. + End rel. + End homogenous. End type. Section flat_type. Context (interp_base_type : base_type_code -> Type). @@ -61,6 +81,33 @@ Section language. | Prod x y => prod (interp_flat_type x) (interp_flat_type y) end. Definition interp_type := interp_type_gen interp_flat_type. + Fixpoint flat_interp_tuple' {T n} : interp_flat_type (tuple' T n) -> Tuple.tuple' (interp_flat_type T) n + := match n return interp_flat_type (tuple' T n) -> Tuple.tuple' (interp_flat_type T) n with + | O => fun x => x + | S n' => fun xy => (@flat_interp_tuple' _ n' (fst xy), snd xy) + end. + Definition flat_interp_tuple {T n} : interp_flat_type (tuple T n) -> Tuple.tuple (interp_flat_type T) n + := match n return interp_flat_type (tuple T n) -> Tuple.tuple (interp_flat_type T) n with + | O => fun _ => tt + | S n' => @flat_interp_tuple' T n' + end. + Fixpoint flat_interp_untuple' {T n} : Tuple.tuple' (interp_flat_type T) n -> interp_flat_type (tuple' T n) + := match n return Tuple.tuple' (interp_flat_type T) n -> interp_flat_type (tuple' T n) with + | O => fun x => x + | S n' => fun xy => (@flat_interp_untuple' _ n' (fst xy), snd xy) + end. + Lemma flat_interp_untuple'_tuple' {T n v} + : @flat_interp_untuple' T n (flat_interp_tuple' v) = v. + Proof. induction n; [ reflexivity | simpl; rewrite IHn; destruct v; reflexivity ]. Qed. + Lemma flat_interp_untuple'_tuple {T n v} + : flat_interp_untuple' (@flat_interp_tuple T (S n) v) = v. + Proof. apply flat_interp_untuple'_tuple'. Qed. + Lemma flat_interp_tuple'_untuple' {T n v} + : @flat_interp_tuple' T n (flat_interp_untuple' v) = v. + Proof. induction n; [ reflexivity | simpl; rewrite IHn; destruct v; reflexivity ]. Qed. + Lemma flat_interp_tuple_untuple' {T n v} + : @flat_interp_tuple T (S n) (flat_interp_untuple' v) = v. + Proof. apply flat_interp_tuple'_untuple'. Qed. Section rel. Context (R : forall t, interp_base_type t -> interp_base_type t -> Prop). Fixpoint interp_flat_type_rel_pointwise (t : flat_type) @@ -76,28 +123,52 @@ Section language. End flat_type. Section rel_pointwise2. Section type. - Context (interp_flat_type1 interp_flat_type2 : flat_type -> Type) - (R : forall t, interp_flat_type1 t -> interp_flat_type2 t -> Prop). - - Fixpoint interp_type_gen_rel_pointwise2 (t : type) - : interp_type_gen interp_flat_type1 t -> interp_type_gen interp_flat_type2 t -> Prop - := match t with - | Tflat t => R t - | Arrow src dst => @respectful_hetero _ _ _ _ (R src) (fun _ _ => interp_type_gen_rel_pointwise2 dst) - end. + Section hetero. + Context (interp_src1 interp_src2 : base_type_code -> Type) + (interp_flat_type1 interp_flat_type2 : flat_type -> Type) + (Rsrc : forall t, interp_src1 t -> interp_src2 t -> Prop) + (R : forall t, interp_flat_type1 t -> interp_flat_type2 t -> Prop). + + Fixpoint interp_type_gen_rel_pointwise2_hetero (t : type) + : interp_type_gen_hetero interp_src1 interp_flat_type1 t + -> interp_type_gen_hetero interp_src2 interp_flat_type2 t + -> Prop + := match t with + | Tflat t => R t + | Arrow src dst => @respectful_hetero _ _ _ _ (Rsrc src) (fun _ _ => interp_type_gen_rel_pointwise2_hetero dst) + end. + End hetero. + Section homogenous. + Context (interp_flat_type1 interp_flat_type2 : flat_type -> Type) + (R : forall t, interp_flat_type1 t -> interp_flat_type2 t -> Prop). + + Definition interp_type_gen_rel_pointwise2 + := interp_type_gen_rel_pointwise2_hetero interp_flat_type1 interp_flat_type2 + interp_flat_type1 interp_flat_type2 + R R. + End homogenous. End type. Section flat_type. - Context (interp_base_type1 interp_base_type2 : base_type_code -> Type) - (R : forall t, interp_base_type1 t -> interp_base_type2 t -> Prop). - Fixpoint interp_flat_type_rel_pointwise2 (t : flat_type) - : interp_flat_type interp_base_type1 t -> interp_flat_type interp_base_type2 t -> Prop - := match t with - | Tbase t => R t - | Prod x y => fun a b => interp_flat_type_rel_pointwise2 x (fst a) (fst b) - /\ interp_flat_type_rel_pointwise2 y (snd a) (snd b) - end. - Definition interp_type_rel_pointwise2 - := interp_type_gen_rel_pointwise2 _ _ interp_flat_type_rel_pointwise2. + Context (interp_base_type1 interp_base_type2 : base_type_code -> Type). + Section gen_prop. + Context (P : Type) + (and : P -> P -> P) + (R : forall t, interp_base_type1 t -> interp_base_type2 t -> P). + + Fixpoint interp_flat_type_rel_pointwise2_gen_Prop (t : flat_type) + : interp_flat_type interp_base_type1 t -> interp_flat_type interp_base_type2 t -> P + := match t with + | Tbase t => R t + | Prod x y => fun a b => and (interp_flat_type_rel_pointwise2_gen_Prop x (fst a) (fst b)) + (interp_flat_type_rel_pointwise2_gen_Prop y (snd a) (snd b)) + end. + End gen_prop. + + Definition interp_flat_type_rel_pointwise2 + := @interp_flat_type_rel_pointwise2_gen_Prop Prop and. + + Definition interp_type_rel_pointwise2 R + := interp_type_gen_rel_pointwise2 _ _ (interp_flat_type_rel_pointwise2 R). End flat_type. End rel_pointwise2. End interp. @@ -145,24 +216,53 @@ Section language. (@smart_interp_flat_map f g h pair A (fst v)) (@smart_interp_flat_map f g h pair B (snd v)) end. - Fixpoint SmartVal {T} (val : forall t : base_type_code, T t) t : interp_flat_type_gen T t + Fixpoint smart_interp_map_hetero {f g g'} + (h : forall x, f x -> g (Tflat (Tbase x))) + (pair : forall A B, g (Tflat A) -> g (Tflat B) -> g (Prod A B)) + (abs : forall A B, (g' A -> g B) -> g (Arrow A B)) + {t} + : interp_type_gen_hetero g' (interp_flat_type_gen f) t -> g t + := match t return interp_type_gen_hetero g' (interp_flat_type_gen f) t -> g t with + | Tflat _ => @smart_interp_flat_map f (fun x => g (Tflat x)) h pair _ + | Arrow A B => fun v => abs _ _ + (fun x => @smart_interp_map_hetero f g g' h pair abs B (v x)) + end. + Fixpoint smart_interp_map {f g} + (h : forall x, f x -> g (Tflat (Tbase x))) + (h' : forall x, g (Tflat (Tbase x)) -> f x) + (pair : forall A B, g (Tflat A) -> g (Tflat B) -> g (Prod A B)) + (abs : forall A B, (g (Tflat (Tbase A)) -> g B) -> g (Arrow A B)) + {t} + : interp_type_gen (interp_flat_type_gen f) t -> g t + := match t return interp_type_gen (interp_flat_type_gen f) t -> g t with + | Tflat _ => @smart_interp_flat_map f (fun x => g (Tflat x)) h pair _ + | Arrow A B => fun v => abs _ _ + (fun x => @smart_interp_map f g h h' pair abs B (v (h' _ x))) + end. + Fixpoint SmartValf {T} (val : forall t : base_type_code, T t) t : interp_flat_type_gen T t := match t return interp_flat_type_gen T t with | Tbase _ => val _ - | Prod A B => (@SmartVal T val A, @SmartVal T val B) + | Prod A B => (@SmartValf T val A, @SmartValf T val B) end. (** [SmartVar] is like [Var], except that it inserts pair-projections and [Pair] as necessary to handle [flat_type], and not just [base_type_code] *) - Definition SmartVar {t} : interp_flat_type_gen var t -> exprf t + Definition SmartVarf {t} : interp_flat_type_gen var t -> exprf t := @smart_interp_flat_map var exprf (fun t => Var) (fun A B x y => Pair x y) t. - Definition SmartVarMap {var var'} (f : forall t, var t -> var' t) {t} + Definition SmartVarfMap {var var'} (f : forall t, var t -> var' t) {t} : interp_flat_type_gen var t -> interp_flat_type_gen var' t := @smart_interp_flat_map var (interp_flat_type_gen var') f (fun A B x y => pair x y) t. - Definition SmartVarVar {t} : interp_flat_type_gen var t -> interp_flat_type_gen exprf t - := SmartVarMap (fun t => Var). - Definition SmartConst {t} : interp_flat_type t -> interp_flat_type_gen exprf t - := SmartVarMap (fun t => Const (t:=t)). + Definition SmartVarMap {var var'} (f : forall t, var t -> var' t) (f' : forall t, var' t -> var t) {t} + : interp_type_gen (interp_flat_type_gen var) t -> interp_type_gen (interp_flat_type_gen var') t + := @smart_interp_map var (interp_type_gen (interp_flat_type_gen var')) f f' (fun A B x y => pair x y) (fun A B f x => f x) t. + Definition SmartVarMap_hetero {vars vars' var var'} (f : forall t, var t -> var' t) (f' : forall t, vars' t -> vars t) {t} + : interp_type_gen_hetero vars (interp_flat_type_gen var) t -> interp_type_gen_hetero vars' (interp_flat_type_gen var') t + := @smart_interp_map_hetero var (interp_type_gen_hetero vars' (interp_flat_type_gen var')) vars f (fun A B x y => pair x y) (fun A B f x => f (f' _ x)) t. + Definition SmartVarVarf {t} : interp_flat_type_gen var t -> interp_flat_type_gen exprf t + := SmartVarfMap (fun t => Var). + Definition SmartConstf {t} : interp_flat_type t -> interp_flat_type_gen exprf t + := SmartVarfMap (fun t => Const (t:=t)). End expr. Definition Expr (t : type) := forall var, @expr var t. @@ -175,7 +275,7 @@ Section language. | Const _ x => x | Var _ x => x | Op _ _ op args => @interp_op _ _ op (@interpf _ args) - | LetIn _ ex _ eC => let x := @interpf _ ex in @interpf _ (eC x) + | LetIn _ ex _ eC => dlet x := @interpf _ ex in @interpf _ (eC x) | Pair _ ex _ ey => (@interpf _ ex, @interpf _ ey) end. Fixpoint interp {t} (e : @expr interp_type t) : interp_type t @@ -256,32 +356,56 @@ Section language. Axiom Wf_admitted : forall {t} (E:Expr t), @Wf t E. End expr_param. End language. +Global Arguments tuple' {_}%type_scope _%ctype_scope _%nat_scope. +Global Arguments tuple {_}%type_scope _%ctype_scope _%nat_scope. Global Arguments Prod {_}%type_scope (_ _)%ctype_scope. Global Arguments Arrow {_}%type_scope (_ _)%ctype_scope. Global Arguments Tbase {_}%type_scope _%ctype_scope. Ltac admit_Wf := apply Wf_admitted. +Scheme Equality for flat_type. +Scheme Equality for type. + +Global Instance dec_eq_flat_type {base_type_code} `{DecidableRel (@eq base_type_code)} + : DecidableRel (@eq (flat_type base_type_code)). +Proof. + repeat intro; hnf; decide equality; apply dec; auto. +Defined. +Global Instance dec_eq_type {base_type_code} `{DecidableRel (@eq base_type_code)} + : DecidableRel (@eq (type base_type_code)). +Proof. + repeat intro; hnf; decide equality; apply dec; typeclasses eauto. +Defined. + Global Arguments Const {_ _ _ _ _} _. Global Arguments Var {_ _ _ _ _} _. -Global Arguments SmartVar {_ _ _ _ _} _. -Global Arguments SmartVal {_} T _ t. -Global Arguments SmartVarVar {_ _ _ _ _} _. -Global Arguments SmartVarMap {_ _ _} _ {_} _. -Global Arguments SmartConst {_ _ _ _ _} _. +Global Arguments SmartVarf {_ _ _ _ _} _. +Global Arguments SmartValf {_} T _ t. +Global Arguments SmartVarVarf {_ _ _ _ _} _. +Global Arguments SmartVarfMap {_ _ _} _ {_} _. +Global Arguments SmartVarMap_hetero {_ _ _ _ _} _ _ {_} _. +Global Arguments SmartVarMap {_ _ _} _ _ {_} _. +Global Arguments SmartConstf {_ _ _ _ _} _. Global Arguments Op {_ _ _ _ _ _} _ _. Global Arguments LetIn {_ _ _ _ _} _ {_} _. Global Arguments Pair {_ _ _ _ _} _ {_} _. Global Arguments Return {_ _ _ _ _} _. Global Arguments Abs {_ _ _ _ _ _} _. +Global Arguments flat_interp_tuple' {_ _ _ _} _. +Global Arguments flat_interp_tuple {_ _ _ _} _. +Global Arguments flat_interp_untuple' {_ _ _ _} _. Global Arguments interp_type_rel_pointwise2 {_ _ _} R {t} _ _. +Global Arguments interp_type_gen_rel_pointwise2_hetero {_ _ _ _ _} Rsrc R {t} _ _. Global Arguments interp_type_gen_rel_pointwise2 {_ _ _} R {t} _ _. +Global Arguments interp_flat_type_rel_pointwise2_gen_Prop {_ _ _ P} and R {t} _ _. Global Arguments interp_flat_type_rel_pointwise2 {_ _ _} R {t} _ _. Global Arguments mapf_interp_flat_type {_ _ _} _ {t} _. +Global Arguments interp_type_gen_hetero {_} _ _ _. Global Arguments interp_type_gen {_} _ _. Global Arguments interp_flat_type {_} _ _. Global Arguments interp_type_rel_pointwise {_} _ _ {_} _ _. -Global Arguments interp_type_gen_rel_pointwise {_} _ _ {_} _ _. +Global Arguments interp_type_gen_rel_pointwise {_ _} _ {_} _ _. Global Arguments interp_flat_type_rel_pointwise {_} _ _ {_} _ _. Global Arguments interp_type {_} _ _. Global Arguments wff {_ _ _ _ _} G {t} _ _. diff --git a/src/Reflection/TestCase.v b/src/Reflection/TestCase.v index 844fea592..a7e2146a6 100644 --- a/src/Reflection/TestCase.v +++ b/src/Reflection/TestCase.v @@ -96,23 +96,20 @@ Lemma base_type_eq_semidec_is_dec : forall t1 t2, Proof. intros t1 t2; destruct t1, t2; simpl; intros; congruence. Qed. -Definition op_beq t1 tR : op t1 tR -> op t1 tR -> option pointed_Prop - := fun x y => match x, y with - | Add, Add => Some trivial - | Add, _ => None - | Mul, Mul => Some trivial - | Mul, _ => None - | Sub, Sub => Some trivial - | Sub, _ => None - end. +Definition op_beq t1 tR : op t1 tR -> op t1 tR -> reified_Prop + := fun x y => match x, y return bool with + | Add, Add => true + | Add, _ => false + | Mul, Mul => true + | Mul, _ => false + | Sub, Sub => true + | Sub, _ => false + end. Lemma op_beq_bl t1 tR (x y : op t1 tR) - : match op_beq t1 tR x y with - | Some p => to_prop p - | None => False - end -> x = y. + : to_prop (op_beq t1 tR x y) -> x = y. Proof. destruct x; simpl; - refine match y with Add => _ | _ => _ end; tauto. + refine match y with Add => _ | _ => _ end; simpl; tauto. Qed. Ltac reflect_Wf := WfReflective.reflect_Wf base_type_eq_semidec_is_dec op_beq_bl. diff --git a/src/Reflection/WfProofs.v b/src/Reflection/WfProofs.v index 1e8eed632..acc72cc76 100644 --- a/src/Reflection/WfProofs.v +++ b/src/Reflection/WfProofs.v @@ -70,14 +70,14 @@ Section language. Local Hint Extern 1 => progress unfold List.In in *. Local Hint Resolve wff_in_impl_Proper. - Lemma wff_SmartVar {t} x1 x2 - : @wff var1 var2 (flatten_binding_list base_type_code x1 x2) t (SmartVar x1) (SmartVar x2). + Lemma wff_SmartVarf {t} x1 x2 + : @wff var1 var2 (flatten_binding_list base_type_code x1 x2) t (SmartVarf x1) (SmartVarf x2). Proof. - unfold SmartVar. + unfold SmartVarf. induction t; simpl; constructor; eauto. Qed. - Local Hint Resolve wff_SmartVar. + Local Hint Resolve wff_SmartVarf. Lemma wff_Const_eta G {A B} v1 v2 : @wff var1 var2 G (Prod A B) (Const v1) (Const v2) @@ -100,30 +100,30 @@ Section language. Local Hint Resolve wff_Const_eta_fst wff_Const_eta_snd. - Lemma wff_SmartConst G {t t'} v1 v2 x1 x2 + Lemma wff_SmartConstf G {t t'} v1 v2 x1 x2 (Hin : List.In (existT (fun t : base_type_code => (@exprf var1 t * @exprf var2 t)%type) t (x1, x2)) - (flatten_binding_list base_type_code (SmartConst v1) (SmartConst v2))) + (flatten_binding_list base_type_code (SmartConstf v1) (SmartConstf v2))) (Hwf : @wff var1 var2 G t' (Const v1) (Const v2)) : @wff var1 var2 G t x1 x2. Proof. induction t'. simpl in *. { intuition (inversion_sigma; inversion_prod; subst; eauto). } - { unfold SmartConst in *; simpl in *. + { unfold SmartConstf in *; simpl in *. apply List.in_app_iff in Hin. intuition (inversion_sigma; inversion_prod; subst; eauto). } Qed. - Local Hint Resolve wff_SmartConst. + Local Hint Resolve wff_SmartConstf. - Lemma wff_SmartVarVar G {t t'} v1 v2 x1 x2 + Lemma wff_SmartVarVarf G {t t'} v1 v2 x1 x2 (Hin : List.In (existT (fun t : base_type_code => (exprf t * exprf t)%type) t (x1, x2)) - (flatten_binding_list base_type_code (SmartVarVar v1) (SmartVarVar v2))) + (flatten_binding_list base_type_code (SmartVarVarf v1) (SmartVarVarf v2))) : @wff var1 var2 (flatten_binding_list base_type_code (t:=t') v1 v2 ++ G) t x1 x2. Proof. revert dependent G; induction t'; intros. simpl in *. { intuition (inversion_sigma; inversion_prod; subst; simpl; eauto). constructor; eauto. } - { unfold SmartVarVar in *; simpl in *. + { unfold SmartVarVarf in *; simpl in *. apply List.in_app_iff in Hin. intuition (inversion_sigma; inversion_prod; subst; eauto). { rewrite <- !List.app_assoc; eauto. } } diff --git a/src/Reflection/WfReflective.v b/src/Reflection/WfReflective.v index d68ed53ac..8a8eef38f 100644 --- a/src/Reflection/WfReflective.v +++ b/src/Reflection/WfReflective.v @@ -45,13 +45,13 @@ - [op_beq_bl : forall t1 tR x y, prop_of_option (op_beq t1 tR x y) -> x = y] for some [op_beq : forall t1 tR, op t1 tR -> op t1 tR - -> option pointed_Prop] *) + -> reified_Prop] *) Require Import Coq.Arith.Arith Coq.Logic.Eqdep_dec. Require Import Crypto.Reflection.Syntax. Require Import Crypto.Reflection.WfReflectiveGen. Require Import Crypto.Util.Notations Crypto.Util.Tactics Crypto.Util.Option Crypto.Util.Sigma Crypto.Util.Prod Crypto.Util.Decidable Crypto.Util.ListUtil. -Require Export Crypto.Util.PointedProp. (* export for the [bool >-> option pointed_Prop] coercion *) +Require Export Crypto.Util.PartiallyReifiedProp. (* export for the [bool >-> reified_Prop] coercion *) Require Export Crypto.Util.FixCoqMistakes. @@ -74,8 +74,8 @@ Section language. [pointed_Prop] internally because we need to talk about equality of things of type [var t], for [var : base_type_code -> Type]. It does not hurt to allow extra generality in [op_beq]. *) - Context (op_beq : forall t1 tR, op t1 tR -> op t1 tR -> option pointed_Prop). - Context (op_beq_bl : forall t1 tR x y, prop_of_option (op_beq t1 tR x y) -> x = y). + Context (op_beq : forall t1 tR, op t1 tR -> op t1 tR -> reified_Prop). + Context (op_beq_bl : forall t1 tR x y, to_prop (op_beq t1 tR x y) -> x = y). Context {var1 var2 : base_type_code -> Type}. Local Notation eP := (fun t => var1 (fst t) * var2 (snd t))%type (only parsing). @@ -90,8 +90,8 @@ Section language. Local Notation exprf := (@exprf base_type_code interp_base_type op). Local Notation expr := (@expr base_type_code interp_base_type op). Local Notation duplicate_type := (@duplicate_type base_type_code var1 var2). - Local Notation reflect_wffT := (@reflect_wffT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => eq) op_beq var1 var2). - Local Notation reflect_wfT := (@reflect_wfT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => eq) op_beq var1 var2). + Local Notation reflect_wffT := (@reflect_wffT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => rEq) op_beq var1 var2). + Local Notation reflect_wfT := (@reflect_wfT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => rEq) op_beq var1 var2). Local Notation flat_type_eq_semidec_transparent := (@flat_type_eq_semidec_transparent base_type_code base_type_eq_semidec_transparent). Local Notation preflatten_binding_list2 := (@preflatten_binding_list2 base_type_code base_type_eq_semidec_transparent var1 var2). Local Notation type_eq_semidec_transparent := (@type_eq_semidec_transparent base_type_code base_type_eq_semidec_transparent). @@ -105,25 +105,25 @@ Section language. Local Ltac handle_op_beq_correct := repeat match goal with - | [ H : op_beq ?t1 ?tR ?x ?y = _ |- _ ] - => let H' := fresh in - pose proof (op_beq_bl t1 tR x y) as H'; rewrite H in H'; clear H + | [ H : to_prop (op_beq ?t1 ?tR ?x ?y) |- _ ] + => apply op_beq_bl in H end. Local Ltac t_step := match goal with - | _ => progress unfold eq_type_and_var, op_beq', flatten_binding_list2, WfReflectiveGen.preflatten_binding_list2, option_map, and_option_pointed_Prop, eq_semidec_and_gen in * + | [ |- True ] => exact I + | _ => progress cbv beta delta [eq_type_and_var op_beq' flatten_binding_list2 WfReflectiveGen.preflatten_binding_list2 option_map eq_semidec_and_gen] in * | _ => progress simpl in * - | _ => progress break_match - | [ H : interp_flat_type_rel_pointwise2 (fun _ => eq) _ _ |- _ ] - => apply interp_flat_type_rel_pointwise2_eq in H | _ => progress subst + | _ => progress break_innermost_match_step | _ => progress inversion_option - | _ => progress inversion_pointed_Prop + | _ => progress inversion_prod + | _ => progress inversion_reified_Prop | _ => congruence | _ => tauto | _ => progress intros | _ => progress handle_op_beq_correct | _ => progress specialize_by tauto + | [ v : ex _ |- _ ] => destruct v | [ v : sigT _ |- _ ] => destruct v | [ v : prod _ _ |- _ ] => destruct v | [ H : forall x x', _ |- wff (flatten_binding_list _ ?x1 ?x2 ++ _)%list _ _ ] @@ -131,6 +131,7 @@ Section language. | [ H : forall x x', _ |- wf (existT _ _ (?x1, ?x2) :: _)%list _ _ ] => specialize (H x1 x2) | [ H : and _ _ |- _ ] => destruct H + | [ H : to_prop (_ /\ _) |- _ ] => apply to_prop_and_reified_Prop in H; destruct H | [ H : context[duplicate_type (_ ++ _)%list] |- _ ] => rewrite duplicate_type_app in H | [ H : context[List.length (duplicate_type _)] |- _ ] @@ -144,25 +145,28 @@ Section language. | [ H : base_type_eq_semidec_transparent _ _ = None |- False ] => eapply duplicate_type_not_in; eassumption | [ H : List.nth_error _ _ = Some _ |- _ ] => apply List.nth_error_In in H | [ H : List.In _ (duplicate_type _) |- _ ] => eapply duplicate_type_in in H; [ | eassumption.. ] - | [ H : context[match _ with _ => _ end] |- _ ] => revert H; progress break_match + | [ H : context[match _ with _ => _ end] |- _ ] => revert H; progress break_innermost_match | [ |- wff _ _ _ ] => constructor | [ |- wf _ _ _ ] => constructor - | _ => progress unfold and_pointed_Prop in * + | _ => progress unfold and_reified_Prop in * end. Local Ltac t := repeat t_step. Fixpoint reflect_wff (G : list (sigT (fun t => var1 t * var2 t)%type)) {t1 t2 : flat_type} (e1 : @exprf (fun t => nat * var1 t)%type t1) (e2 : @exprf (fun t => nat * var2 t)%type t2) {struct e1} - : match reflect_wffT (duplicate_type G) e1 e2, flat_type_eq_semidec_transparent t1 t2 with - | Some reflective_obligation, Some p + : let reflective_obligation := reflect_wffT (duplicate_type G) e1 e2 in + match flat_type_eq_semidec_transparent t1 t2 with + | Some p => to_prop reflective_obligation -> @wff base_type_code interp_base_type op var1 var2 G t2 (eq_rect _ exprf (unnatize_exprf (List.length G) e1) _ p) (unnatize_exprf (List.length G) e2) - | _, _ => True + | None => True end. Proof. + cbv zeta. destruct e1 as [ | | ? ? ? args | tx ex tC eC | ? ex ? ey ], - e2 as [ | | ? ? ? args' | tx' ex' tC' eC' | ? ex' ? ey' ]; simpl; try solve [ exact I ]; + e2 as [ | | ? ? ? args' | tx' ex' tC' eC' | ? ex' ? ey' ]; simpl; + try solve [ break_match; solve [ exact I | intros [] ] ]; [ clear reflect_wff | clear reflect_wff | specialize (reflect_wff G _ _ args args') @@ -187,11 +191,12 @@ Section language. Fixpoint reflect_wf (G : list (sigT (fun t => var1 t * var2 t)%type)) {t1 t2 : type} (e1 : @expr (fun t => nat * var1 t)%type t1) (e2 : @expr (fun t => nat * var2 t)%type t2) - : match reflect_wfT (duplicate_type G) e1 e2, type_eq_semidec_transparent t1 t2 with - | Some reflective_obligation, Some p + : let reflective_obligation := reflect_wfT (duplicate_type G) e1 e2 in + match type_eq_semidec_transparent t1 t2 with + | Some p => to_prop reflective_obligation -> @wf base_type_code interp_base_type op var1 var2 G t2 (eq_rect _ expr (unnatize_expr (List.length G) e1) _ p) (unnatize_expr (List.length G) e2) - | _, _ => True + | None => True end. Proof. destruct e1 as [ t1 e1 | tx tR f ], @@ -219,15 +224,15 @@ Section Wf. (base_type_eq_semidec_transparent : forall t1 t2 : base_type_code, option (t1 = t2)) (base_type_eq_semidec_is_dec : forall t1 t2, base_type_eq_semidec_transparent t1 t2 = None -> t1 <> t2) (op : flat_type base_type_code -> flat_type base_type_code -> Type) - (op_beq : forall t1 tR, op t1 tR -> op t1 tR -> option pointed_Prop) - (op_beq_bl : forall t1 tR x y, prop_of_option (op_beq t1 tR x y) -> x = y) + (op_beq : forall t1 tR, op t1 tR -> op t1 tR -> reified_Prop) + (op_beq_bl : forall t1 tR x y, to_prop (op_beq t1 tR x y) -> x = y) {t : type base_type_code} (e : @Expr base_type_code interp_base_type op t). (** Leads to smaller proofs, but is less generally applicable *) Theorem reflect_Wf_unnatize : (forall var1 var2, - prop_of_option (@reflect_wfT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => eq) op_beq var1 var2 nil t t (e _) (e _))) + to_prop (@reflect_wfT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => rEq) op_beq var1 var2 nil t t (e _) (e _))) -> Wf (fun var => unnatize_expr 0 (e (fun t => (nat * var t)%type))). Proof. intros H var1 var2; specialize (H var1 var2). @@ -241,7 +246,7 @@ Section Wf. Theorem reflect_Wf : (forall var1 var2, unnatize_expr 0 (e (fun t => (nat * var1 t)%type)) = e _ - /\ prop_of_option (@reflect_wfT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => eq) op_beq var1 var2 nil t t (e _) (e _))) + /\ to_prop (@reflect_wfT base_type_code interp_base_type interp_base_type base_type_eq_semidec_transparent op (fun _ => rEq) op_beq var1 var2 nil t t (e _) (e _))) -> Wf e. Proof. intros H var1 var2. @@ -256,8 +261,22 @@ Ltac generalize_reflect_Wf base_type_eq_semidec_is_dec op_beq_bl := | [ |- @Wf ?base_type_code ?interp_base_type ?op ?t ?e ] => generalize (@reflect_Wf_unnatize base_type_code interp_base_type _ base_type_eq_semidec_is_dec op _ op_beq_bl t e) end. -Ltac use_reflect_Wf := vm_compute; let H := fresh in intro H; apply H; clear H. -Ltac fin_reflect_Wf := repeat constructor. +Ltac use_reflect_Wf := + let H := fresh in + intro H; + lazymatch type of H with + | ?A -> ?B + => cut A + end; + [ abstract vm_cast_no_check H + | clear H ]. +Ltac fin_reflect_Wf := + intros; + lazymatch goal with + | [ |- to_prop ?P ] + => replace P with (trueify P) by abstract vm_cast_no_check (eq_refl P) + end; + apply trueify_true. (** The tactic [reflect_Wf] is the main tactic of this file, used to prove [Syntax.Wf] goals *) Ltac reflect_Wf base_type_eq_semidec_is_dec op_beq_bl := diff --git a/src/Reflection/WfReflectiveGen.v b/src/Reflection/WfReflectiveGen.v index 4935134d8..0d961ec97 100644 --- a/src/Reflection/WfReflectiveGen.v +++ b/src/Reflection/WfReflectiveGen.v @@ -50,7 +50,7 @@ Require Import Coq.Arith.Arith Coq.Logic.Eqdep_dec. Require Import Crypto.Reflection.Syntax. Require Import Crypto.Util.Notations Crypto.Util.Tactics Crypto.Util.Option Crypto.Util.Sigma Crypto.Util.Prod Crypto.Util.Decidable Crypto.Util.ListUtil. -Require Export Crypto.Util.PointedProp. (* export for the [bool >-> option pointed_Prop] coercion *) +Require Export Crypto.Util.PartiallyReifiedProp. (* export for the [bool >-> reified_Prop] coercion *) Require Export Crypto.Util.FixCoqMistakes. @@ -67,15 +67,15 @@ Section language. (base_type_eq_semidec_transparent : forall t1 t2 : base_type_code, option (t1 = t2)) (base_type_eq_semidec_is_dec : forall t1 t2, base_type_eq_semidec_transparent t1 t2 = None -> t1 <> t2) (op : flat_type base_type_code -> flat_type base_type_code -> Type) - (R : forall t, interp_base_type1 t -> interp_base_type2 t -> Prop). + (R : forall t, interp_flat_type interp_base_type1 t -> interp_flat_type interp_base_type2 t -> reified_Prop). (** In practice, semi-deciding equality of operators should either return [Some trivial] or [None], and not make use of the generality of [pointed_Prop]. However, we need to use [pointed_Prop] internally because we need to talk about equality of things of type [var t], for [var : base_type_code -> Type]. It does not hurt to allow extra generality in [op_beq]. *) - Context (op_beq : forall t1 tR, op t1 tR -> op t1 tR -> option pointed_Prop). - Context (op_beq_bl : forall t1 tR x y, prop_of_option (op_beq t1 tR x y) -> x = y). + Context (op_beq : forall t1 tR, op t1 tR -> op t1 tR -> reified_Prop). + Context (op_beq_bl : forall t1 tR x y, to_prop (op_beq t1 tR x y) -> x = y). Context {var1 var2 : base_type_code -> Type}. Local Notation eP := (fun t => var1 (fst t) * var2 (snd t))%type (only parsing). @@ -151,13 +151,13 @@ Section language. { rewrite base_type_eq_semidec_transparent_refl; rewrite_hyp !*; reflexivity. } Qed. - Definition op_beq' t1 tR t1' tR' (x : op t1 tR) (y : op t1' tR') : option pointed_Prop + Definition op_beq' t1 tR t1' tR' (x : op t1 tR) (y : op t1' tR') : reified_Prop := match flat_type_eq_semidec_transparent t1 t1', flat_type_eq_semidec_transparent tR tR' with | Some p, Some q => match p in (_ = t1'), q in (_ = tR') return op t1' tR' -> _ with | eq_refl, eq_refl => fun y => op_beq _ _ x y end y - | _, _ => None + | _, _ => rFalse end. (** While [Syntax.wff] is parameterized over a list of [sigT (fun t @@ -178,27 +178,27 @@ Section language. about [Syntax.wff] itself. *) Definition eq_semidec_and_gen {T} (semidec : forall x y : T, option (x = y)) - (t t' : T) (f g : T -> Type) (R : forall t, f t -> g t -> Prop) + (t t' : T) (f g : T -> Type) (R : forall t, f t -> g t -> reified_Prop) (x : f t) (x' : g t') - : option pointed_Prop + : reified_Prop := match semidec t t' with | Some p - => Some (inject (R _ (eq_rect _ f x _ p) x')) - | None => None + => R _ (eq_rect _ f x _ p) x' + | None => rFalse end. (* Here is where we use the generality of [pointed_Prop], to say that two things of type [var1] are equal, and two things of type [var2] are equal. *) - Definition eq_type_and_var : sigT eP -> sigT eP -> option pointed_Prop + Definition eq_type_and_var : sigT eP -> sigT eP -> reified_Prop := fun x y => (eq_semidec_and_gen - base_type_eq_semidec_transparent _ _ var1 var1 (fun _ => eq) (fst (projT2 x)) (fst (projT2 y)) + base_type_eq_semidec_transparent _ _ var1 var1 (fun _ => rEq) (fst (projT2 x)) (fst (projT2 y)) /\ eq_semidec_and_gen - base_type_eq_semidec_transparent _ _ var2 var2 (fun _ => eq) (snd (projT2 x)) (snd (projT2 y)))%option_pointed_prop. + base_type_eq_semidec_transparent _ _ var2 var2 (fun _ => rEq) (snd (projT2 x)) (snd (projT2 y)))%reified_prop. Definition rel_type_and_const (t t' : flat_type) (x : interp_flat_type1 t) (y : interp_flat_type2 t') - : option pointed_Prop + : reified_Prop := eq_semidec_and_gen - flat_type_eq_semidec_transparent _ _ interp_flat_type1 interp_flat_type2 (fun t => interp_flat_type_rel_pointwise2 R) x y. + flat_type_eq_semidec_transparent _ _ interp_flat_type1 interp_flat_type2 R x y. Definition duplicate_type (ls : list (sigT (fun t => var1 t * var2 t)%type)) : list (sigT eP) := List.map (fun v => existT eP (projT1 v, projT1 v) (projT2 v)) ls. @@ -292,71 +292,61 @@ Section language. (e1 : @exprf1 (fun t => nat * var1 t)%type t1) (e2 : @exprf2 (fun t => nat * var2 t)%type t2) {struct e1} - : option pointed_Prop - := match e1, e2 return option _ with + : reified_Prop + := match e1, e2 with | Const t0 x, Const t1 y => match flat_type_eq_semidec_transparent t0 t1 with - | Some p => Some (inject (interp_flat_type_rel_pointwise2 R (eq_rect _ interp_flat_type1 x _ p) y)) - | None => None + | Some p => R _ (eq_rect _ interp_flat_type1 x _ p) y + | None => rFalse end - | Const _ _, _ => None + | Const _ _, _ => rFalse | Var t0 x, Var t1 y => match beq_nat (fst x) (fst y), List.nth_error G (List.length G - S (fst x)) with | true, Some v => eq_type_and_var v (existT _ (t0, t1) (snd x, snd y)) - | _, _ => None + | _, _ => rFalse end - | Var _ _, _ => None + | Var _ _, _ => rFalse | Op t1 tR op args, Op t1' tR' op' args' - => match @reflect_wffT G t1 t1' args args', op_beq' t1 tR t1' tR' op op' with - | Some p, Some q => Some (p /\ q)%pointed_prop - | _, _ => None - end - | Op _ _ _ _, _ => None + => (@reflect_wffT G t1 t1' args args' /\ op_beq' t1 tR t1' tR' op op')%reified_prop + | Op _ _ _ _, _ => rFalse | LetIn tx ex tC eC, LetIn tx' ex' tC' eC' - => match @reflect_wffT G tx tx' ex ex', @flatten_binding_list2 tx tx', flat_type_eq_semidec_transparent tC tC' with - | Some p, Some G0, Some _ - => Some (p /\ inject (forall (x : interp_flat_type var1 tx) (x' : interp_flat_type var2 tx'), - match @reflect_wffT (G0 x x' ++ G)%list _ _ - (eC (snd (natize_interp_flat_type (List.length G) x))) - (eC' (snd (natize_interp_flat_type (List.length G) x'))) with - | Some p => to_prop p - | None => False - end)) - | _, _, _ => None + => let p := @reflect_wffT G tx tx' ex ex' in + match @flatten_binding_list2 tx tx', flat_type_eq_semidec_transparent tC tC' with + | Some G0, Some _ + => p + /\ (∀ (x : interp_flat_type var1 tx) (x' : interp_flat_type var2 tx'), + @reflect_wffT (G0 x x' ++ G)%list _ _ + (eC (snd (natize_interp_flat_type (List.length G) x))) + (eC' (snd (natize_interp_flat_type (List.length G) x')))) + | _, _ => rFalse end - | LetIn _ _ _ _, _ => None + | LetIn _ _ _ _, _ => rFalse | Pair tx ex ty ey, Pair tx' ex' ty' ey' - => match @reflect_wffT G tx tx' ex ex', @reflect_wffT G ty ty' ey ey' with - | Some p, Some q => Some (p /\ q) - | _, _ => None - end - | Pair _ _ _ _, _ => None - end%pointed_prop. + => @reflect_wffT G tx tx' ex ex' /\ @reflect_wffT G ty ty' ey ey' + | Pair _ _ _ _, _ => rFalse + end%reified_prop. Fixpoint reflect_wfT (G : list (sigT (fun t => var1 (fst t) * var2 (snd t))%type)) {t1 t2 : type} (e1 : @expr1 (fun t => nat * var1 t)%type t1) (e2 : @expr2 (fun t => nat * var2 t)%type t2) {struct e1} - : option pointed_Prop - := match e1, e2 return option _ with + : reified_Prop + := match e1, e2 with | Return _ x, Return _ y => reflect_wffT G x y - | Return _ _, _ => None + | Return _ _, _ => rFalse | Abs tx tR f, Abs tx' tR' f' => match @flatten_binding_list2 tx tx', type_eq_semidec_transparent tR tR' with | Some G0, Some _ - => Some (inject (forall (x : interp_flat_type var1 tx) (x' : interp_flat_type var2 tx'), - match @reflect_wfT (G0 x x' ++ G)%list _ _ - (f (snd (natize_interp_flat_type (List.length G) x))) - (f' (snd (natize_interp_flat_type (List.length G) x'))) with - | Some p => to_prop p - | None => False - end)) - | _, _ => None + => ∀ (x : interp_flat_type var1 tx) (x' : interp_flat_type var2 tx'), + @reflect_wfT (G0 x x' ++ G)%list _ _ + (f (snd (natize_interp_flat_type (List.length G) x))) + (f' (snd (natize_interp_flat_type (List.length G) x'))) + | _, _ => rFalse end - | Abs _ _ _, _ => None - end%pointed_prop. + | Abs _ _ _, _ => rFalse + end%reified_prop. End language. Global Arguments reflect_wffT {_ _ _} _ {op} R _ {var1 var2} G {t1 t2} _ _. diff --git a/src/Reflection/WfRelReflective.v b/src/Reflection/WfRelReflective.v deleted file mode 100644 index 0135f6afb..000000000 --- a/src/Reflection/WfRelReflective.v +++ /dev/null @@ -1,166 +0,0 @@ -(** * A reflective Version of [WfRel] proofs *) -(** See [WfReflective.v] and [WfReflectiveGen.v] for comments. *) -Require Import Coq.Arith.Arith Coq.Logic.Eqdep_dec. -Require Import Crypto.Reflection.Syntax. -Require Import Crypto.Reflection.WfRel. -Require Import Crypto.Reflection.WfReflectiveGen. -Require Import Crypto.Util.Notations Crypto.Util.Tactics Crypto.Util.Option Crypto.Util.Sigma Crypto.Util.Prod Crypto.Util.Decidable Crypto.Util.ListUtil. -Require Export Crypto.Util.PointedProp. (* export for the [bool >-> option pointed_Prop] coercion *) -Require Export Crypto.Util.FixCoqMistakes. - - -Section language. - (** To be able to optimize away so much of the [Syntax.wff] proof, - we must be able to decide a few things: equality of base types, - and equality of operator codes. Since we will be casting across - the equality proofs of base types, we require that this - semi-decider give transparent proofs. (This requirement is not - enforced, but it will block [vm_compute] when trying to use the - lemma in this file.) *) - Context (base_type_code : Type). - Context (interp_base_type1 interp_base_type2 : base_type_code -> Type). - Context (base_type_eq_semidec_transparent : forall t1 t2 : base_type_code, option (t1 = t2)). - Context (base_type_eq_semidec_is_dec : forall t1 t2, base_type_eq_semidec_transparent t1 t2 = None -> t1 <> t2). - Context (op : flat_type base_type_code -> flat_type base_type_code -> Type). - Context (R : forall t, interp_base_type1 t -> interp_base_type2 t -> Prop). - (** In practice, semi-deciding equality of operators should either - return [Some trivial] or [None], and not make use of the - generality of [pointed_Prop]. However, we need to use - [pointed_Prop] internally because we need to talk about equality - of things of type [var t], for [var : base_type_code -> Type]. - It does not hurt to allow extra generality in [op_beq]. *) - Context (op_beq : forall t1 tR, op t1 tR -> op t1 tR -> option pointed_Prop). - Context (op_beq_bl : forall t1 tR x y, prop_of_option (op_beq t1 tR x y) -> x = y). - Context {var1 var2 : base_type_code -> Type}. - - Local Notation eP := (fun t => var1 (fst t) * var2 (snd t))%type (only parsing). - - (* convenience notations that fill in some arguments used across the section *) - Local Notation flat_type := (flat_type base_type_code). - Local Notation type := (type base_type_code). - Let Tbase := @Tbase base_type_code. - Local Coercion Tbase : base_type_code >-> Syntax.flat_type. - Local Notation interp_type1 := (interp_type interp_base_type1). - Local Notation interp_type2 := (interp_type interp_base_type2). - Local Notation interp_flat_type1 := (interp_flat_type interp_base_type1). - Local Notation interp_flat_type2 := (interp_flat_type interp_base_type2). - Local Notation exprf1 := (@exprf base_type_code interp_base_type1 op). - Local Notation exprf2 := (@exprf base_type_code interp_base_type2 op). - Local Notation expr1 := (@expr base_type_code interp_base_type1 op). - Local Notation expr2 := (@expr base_type_code interp_base_type2 op). - Local Notation duplicate_type := (@duplicate_type base_type_code var1 var2). - Local Notation reflect_wffT := (@reflect_wffT base_type_code interp_base_type1 interp_base_type2 base_type_eq_semidec_transparent op R op_beq var1 var2). - Local Notation reflect_wfT := (@reflect_wfT base_type_code interp_base_type1 interp_base_type2 base_type_eq_semidec_transparent op R op_beq var1 var2). - Local Notation flat_type_eq_semidec_transparent := (@flat_type_eq_semidec_transparent base_type_code base_type_eq_semidec_transparent). - Local Notation preflatten_binding_list2 := (@preflatten_binding_list2 base_type_code base_type_eq_semidec_transparent var1 var2). - Local Notation type_eq_semidec_transparent := (@type_eq_semidec_transparent base_type_code base_type_eq_semidec_transparent). - - Local Ltac handle_op_beq_correct := - repeat match goal with - | [ H : op_beq ?t1 ?tR ?x ?y = _ |- _ ] - => let H' := fresh in - pose proof (op_beq_bl t1 tR x y) as H'; rewrite H in H'; clear H - end. - Local Ltac t_step := - match goal with - | _ => progress unfold eq_type_and_var, op_beq', flatten_binding_list2, WfReflectiveGen.preflatten_binding_list2, option_map, and_option_pointed_Prop, eq_semidec_and_gen in * - | _ => progress simpl in * - | _ => progress break_match - | _ => progress subst - | _ => progress inversion_option - | _ => progress inversion_pointed_Prop - | _ => congruence - | _ => tauto - | _ => progress intros - | _ => progress handle_op_beq_correct - | _ => progress specialize_by tauto - | [ v : sigT _ |- _ ] => destruct v - | [ v : prod _ _ |- _ ] => destruct v - | [ H : forall x x', _ |- rel_wff _ (flatten_binding_list _ ?x1 ?x2 ++ _)%list _ _ ] - => specialize (H x1 x2) - | [ H : forall x x', _ |- rel_wf _ (existT _ _ (?x1, ?x2) :: _)%list _ _ ] - => specialize (H x1 x2) - | [ H : and _ _ |- _ ] => destruct H - | [ H : context[duplicate_type (_ ++ _)%list] |- _ ] - => rewrite duplicate_type_app in H - | [ H : context[List.length (duplicate_type _)] |- _ ] - => rewrite duplicate_type_length in H - | [ H : context[List.length (_ ++ _)%list] |- _ ] - => rewrite List.app_length in H - | [ |- rel_wff _ _ (unnatize_exprf (fst _) _) (unnatize_exprf (fst _) _) ] - => erewrite length_natize_interp_flat_type1, length_natize_interp_flat_type2; eassumption - | [ |- rel_wf _ _ (unnatize_exprf (fst _) _) (unnatize_exprf (fst _) _) ] - => erewrite length_natize_interp_flat_type1, length_natize_interp_flat_type2; eassumption - | [ H : base_type_eq_semidec_transparent _ _ = None |- False ] => eapply duplicate_type_not_in; eassumption - | [ H : List.nth_error _ _ = Some _ |- _ ] => apply List.nth_error_In in H - | [ H : List.In _ (duplicate_type _) |- _ ] => eapply duplicate_type_in in H; [ | eassumption.. ] - | [ H : context[match _ with _ => _ end] |- _ ] => revert H; progress break_match - | [ |- rel_wff _ _ _ _ ] => constructor - | [ |- rel_wf _ _ _ _ ] => constructor - | _ => progress unfold and_pointed_Prop in * - end. - Local Ltac t := repeat t_step. - Fixpoint reflect_rel_wff (G : list (sigT (fun t => var1 t * var2 t)%type)) - {t1 t2 : flat_type} - (e1 : @exprf1 (fun t => nat * var1 t)%type t1) - (e2 : @exprf2 (fun t => nat * var2 t)%type t2) - {struct e1} - : match reflect_wffT (duplicate_type G) e1 e2, flat_type_eq_semidec_transparent t1 t2 with - | Some reflective_obligation, Some p - => to_prop reflective_obligation - -> @rel_wff base_type_code interp_base_type1 interp_base_type2 op R var1 var2 G t2 (eq_rect _ exprf1 (unnatize_exprf (List.length G) e1) _ p) (unnatize_exprf (List.length G) e2) - | _, _ => True - end. - Proof. - destruct e1 as [ | | ? ? ? args | tx ex tC eC | ? ex ? ey ], - e2 as [ | | ? ? ? args' | tx' ex' tC' eC' | ? ex' ? ey' ]; simpl; try solve [ exact I ]; - [ clear reflect_rel_wff - | clear reflect_rel_wff - | specialize (reflect_rel_wff G _ _ args args') - | pose proof (reflect_rel_wff G _ _ ex ex'); - pose proof (fun x x' - => match preflatten_binding_list2 tx tx' as v return match v with Some _ => _ | None => True end with - | Some G0 - => reflect_rel_wff - (G0 x x' ++ G)%list _ _ - (eC (snd (natize_interp_flat_type (length (duplicate_type G)) x))) - (eC' (snd (natize_interp_flat_type (length (duplicate_type G)) x'))) - | None => I - end); - clear reflect_rel_wff - | pose proof (reflect_rel_wff G _ _ ex ex'); pose proof (reflect_rel_wff G _ _ ey ey'); clear reflect_rel_wff ]. - { t. } - { t. } - { t. } - { t. } - { t. } - Qed. - Fixpoint reflect_rel_wf (G : list (sigT (fun t => var1 t * var2 t)%type)) - {t1 t2 : type} - (e1 : @expr1 (fun t => nat * var1 t)%type t1) - (e2 : @expr2 (fun t => nat * var2 t)%type t2) - : match reflect_wfT (duplicate_type G) e1 e2, type_eq_semidec_transparent t1 t2 with - | Some reflective_obligation, Some p - => to_prop reflective_obligation - -> @rel_wf base_type_code interp_base_type1 interp_base_type2 op R var1 var2 G t2 (eq_rect _ expr1 (unnatize_expr (List.length G) e1) _ p) (unnatize_expr (List.length G) e2) - | _, _ => True - end. - Proof. - destruct e1 as [ t1 e1 | tx tR f ], - e2 as [ t2 e2 | tx' tR' f' ]; simpl; try solve [ exact I ]; - [ clear reflect_rel_wf; - pose proof (@reflect_rel_wff G t1 t2 e1 e2) - | pose proof (fun x x' - => match preflatten_binding_list2 tx tx' as v return match v with Some _ => _ | None => True end with - | Some G0 - => reflect_rel_wf - (G0 x x' ++ G)%list _ _ - (f (snd (natize_interp_flat_type (length (duplicate_type G)) x))) - (f' (snd (natize_interp_flat_type (length (duplicate_type G)) x'))) - | None => I - end); - clear reflect_rel_wf ]. - { t. } - { t. } - Qed. -End language. diff --git a/src/Reflection/Z/Interpretations.v b/src/Reflection/Z/Interpretations.v index 514cca3cb..336376c9e 100644 --- a/src/Reflection/Z/Interpretations.v +++ b/src/Reflection/Z/Interpretations.v @@ -4,11 +4,15 @@ Require Import Coq.ZArith.ZArith. Require Import Coq.NArith.NArith. Require Import Crypto.Reflection.Z.Syntax. Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.Application. Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. Require Import Crypto.Util.Equality. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Option. +Require Crypto.Util.Tuple. +Require Crypto.Util.HList. Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Prod. Require Import Crypto.Util.Tactics. Require Import Crypto.Util.WordUtil. Require Import Bedrock.Word. @@ -27,6 +31,67 @@ Module Z. := interp_op src dst f. End Z. +Module LiftOption. + Section lift_option. + Context (T : Type). + + Definition interp_flat_type (t : flat_type base_type) + := option (interp_flat_type (fun _ => T) t). + + Definition interp_base_type' (t : base_type) + := match t with + | TZ => option T + end. + + Definition of' {t} : Syntax.interp_flat_type interp_base_type' t -> interp_flat_type t + := @smart_interp_flat_map + base_type + interp_base_type' interp_flat_type + (fun t => match t with TZ => fun x => x end) + (fun _ _ x y => match x, y with + | Some x', Some y' => Some (x', y') + | _, _ => None + end) + t. + + Fixpoint to' {t} : interp_flat_type t -> Syntax.interp_flat_type interp_base_type' t + := match t return interp_flat_type t -> Syntax.interp_flat_type interp_base_type' t with + | Tbase TZ => fun x => x + | Prod A B => fun x => (@to' A (option_map (@fst _ _) x), + @to' B (option_map (@snd _ _) x)) + end. + + Definition lift_relation {interp_base_type2} + (R : forall t, T -> interp_base_type2 t -> Prop) + : forall t, interp_base_type' t -> interp_base_type2 t -> Prop + := fun t x y => match of' (t:=Tbase t) x with + | Some x' => R t x' y + | None => True + end. + + Definition Some {t} (x : T) : interp_base_type' t + := match t with + | TZ => Some x + end. + End lift_option. + Global Arguments of' {T t} _. + Global Arguments to' {T t} _. + Global Arguments Some {T t} _. + Global Arguments lift_relation {T _} R _ _ _. + + Section lift_option2. + Context (T U : Type). + Definition lift_relation2 (R : T -> U -> Prop) + : forall t, interp_base_type' T t -> interp_base_type' U t -> Prop + := fun t x y => match of' (t:=Tbase t) x, of' (t:=Tbase t) y with + | Datatypes.Some x', Datatypes.Some y' => R x' y' + | None, None => True + | _, _ => False + end. + End lift_option2. + Global Arguments lift_relation2 {T U} R _ _ _. +End LiftOption. + Module Word64. Definition bit_width : nat := 64. Definition word64 := word bit_width. @@ -38,107 +103,169 @@ Module Word64. Definition ZToWord64 (x : Z) : word64 := NToWord _ (Z.to_N x). + Ltac fold_Word64_Z := + repeat match goal with + | [ |- context G[NToWord bit_width (Z.to_N ?x)] ] + => let G' := context G [ZToWord64 x] in change G' + | [ |- context G[Z.of_N (wordToN ?x)] ] + => let G' := context G [word64ToZ x] in change G' + | [ H : context G[NToWord bit_width (Z.to_N ?x)] |- _ ] + => let G' := context G [ZToWord64 x] in change G' in H + | [ H : context G[Z.of_N (wordToN ?x)] |- _ ] + => let G' := context G [word64ToZ x] in change G' in H + end. + Create HintDb push_word64ToZ discriminated. Hint Extern 1 => progress autorewrite with push_word64ToZ in * : push_word64ToZ. - Definition w64plus : word64 -> word64 -> word64 := @wplus _. - Definition w64minus : word64 -> word64 -> word64 := @wminus _. - Definition w64mul : word64 -> word64 -> word64 := @wmult _. - Definition w64shl : word64 -> word64 -> word64 - := fun x y => NToWord _ (Z.to_N (Z.shiftl (Z.of_N (wordToN x)) (Z.of_N (wordToN y)))). - Definition w64shr : word64 -> word64 -> word64 - := fun x y => NToWord _ (Z.to_N (Z.shiftr (Z.of_N (wordToN x)) (Z.of_N (wordToN y)))). - Definition w64land : word64 -> word64 -> word64 := @wand _. - Definition w64lor : word64 -> word64 -> word64 := @wor _. - Definition w64neg : word64 -> word64 -> word64 (* TODO: FIXME? *) - := fun x y => NToWord _ (Z.to_N (ModularBaseSystemListZOperations.neg (Z.of_N (wordToN x)) (Z.of_N (wordToN x)))). - Definition w64cmovne : word64 -> word64 -> word64 -> word64 -> word64 (* TODO: FIXME? *) - := fun x y z w => NToWord _ (Z.to_N (ModularBaseSystemListZOperations.cmovne (Z.of_N (wordToN x)) (Z.of_N (wordToN x)) (Z.of_N (wordToN z)) (Z.of_N (wordToN w)))). - Definition w64cmovle : word64 -> word64 -> word64 -> word64 -> word64 (* TODO: FIXME? *) - := fun x y z w => NToWord _ (Z.to_N (ModularBaseSystemListZOperations.cmovl (Z.of_N (wordToN x)) (Z.of_N (wordToN x)) (Z.of_N (wordToN z)) (Z.of_N (wordToN w)))). - Infix "+" := w64plus : word64_scope. - Infix "-" := w64minus : word64_scope. - Infix "*" := w64mul : word64_scope. - Infix "<<" := w64shl : word64_scope. - Infix ">>" := w64shr : word64_scope. - Infix "&'" := w64land : word64_scope. + Lemma bit_width_pos : (0 < Z.of_nat bit_width)%Z. + Proof. simpl; omega. Qed. + Local Hint Resolve bit_width_pos : zarith. + + Ltac arith := solve [ omega | auto using bit_width_pos with zarith ]. + + Lemma word64ToZ_bound w : (0 <= word64ToZ w < 2^Z.of_nat bit_width)%Z. + Proof. + pose proof (wordToNat_bound w) as H. + apply Nat2Z.inj_lt in H. + rewrite Zpow_pow2, Z2Nat.id in H by (apply Z.pow_nonneg; omega). + unfold word64ToZ. + rewrite wordToN_nat, nat_N_Z; omega. + Qed. + + Lemma word64ToZ_log_bound w : (0 <= word64ToZ w /\ Z.log2 (word64ToZ w) < Z.of_nat bit_width)%Z. + Proof. + pose proof (word64ToZ_bound w) as H. + destruct (Z_zerop (word64ToZ w)) as [H'|H']. + { rewrite H'; simpl; omega. } + { split; [ | apply Z.log2_lt_pow2 ]; try omega. } + Qed. + Lemma ZToWord64_word64ToZ (x : word64) : ZToWord64 (word64ToZ x) = x. + Proof. + unfold ZToWord64, word64ToZ. + rewrite N2Z.id, NToWord_wordToN. + reflexivity. + Qed. + Hint Rewrite ZToWord64_word64ToZ : push_word64ToZ. + + Lemma word64ToZ_ZToWord64 (x : Z) : (0 <= x < 2^Z.of_nat bit_width)%Z -> word64ToZ (ZToWord64 x) = x. + Proof. + unfold ZToWord64, word64ToZ; intros [H0 H1]. + pose proof H1 as H1'; apply Z2Nat.inj_lt in H1'; [ | omega.. ]. + rewrite <- Z.pow_Z2N_Zpow in H1' by omega. + replace (Z.to_nat 2) with 2%nat in H1' by reflexivity. + rewrite wordToN_NToWord_idempotent, Z2N.id by (omega || auto using bound_check_nat_N). + reflexivity. + Qed. + Hint Rewrite word64ToZ_ZToWord64 using arith : push_word64ToZ. + + Definition add : word64 -> word64 -> word64 := @wplus _. + Definition sub : word64 -> word64 -> word64 := @wminus _. + Definition mul : word64 -> word64 -> word64 := @wmult _. + Definition shl : word64 -> word64 -> word64 := @wordBin N.shiftl _. + Definition shr : word64 -> word64 -> word64 := @wordBin N.shiftr _. + Definition land : word64 -> word64 -> word64 := @wand _. + Definition lor : word64 -> word64 -> word64 := @wor _. + Definition neg : word64 -> word64 -> word64 (* TODO: FIXME? *) + := fun x y => ZToWord64 (ModularBaseSystemListZOperations.neg (word64ToZ x) (word64ToZ y)). + Definition cmovne : word64 -> word64 -> word64 -> word64 -> word64 (* TODO: FIXME? *) + := fun x y z w => ZToWord64 (ModularBaseSystemListZOperations.cmovne (word64ToZ x) (word64ToZ y) (word64ToZ z) (word64ToZ w)). + Definition cmovle : word64 -> word64 -> word64 -> word64 -> word64 (* TODO: FIXME? *) + := fun x y z w => ZToWord64 (ModularBaseSystemListZOperations.cmovl (word64ToZ x) (word64ToZ y) (word64ToZ z) (word64ToZ w)). + Definition conditional_subtract (pred_limb_count : nat) : word64 -> Tuple.tuple word64 (S pred_limb_count) -> Tuple.tuple word64 (S pred_limb_count) -> Tuple.tuple word64 (S pred_limb_count) + := fun x y z => Tuple.map ZToWord64 (@ModularBaseSystemListZOperations.conditional_subtract_modulus + (S pred_limb_count) (word64ToZ x) (Tuple.map word64ToZ y) (Tuple.map word64ToZ z)). + Infix "+" := add : word64_scope. + Infix "-" := sub : word64_scope. + Infix "*" := mul : word64_scope. + Infix "<<" := shl : word64_scope. + Infix ">>" := shr : word64_scope. + Infix "&'" := land : word64_scope. + + (*Local*) Hint Resolve <- Z.log2_lt_pow2_alt : zarith. + Local Hint Resolve eq_refl : zarith. Local Ltac w64ToZ_t := intros; try match goal with - | [ |- ?wordToZ (?op ?x ?y) = _ ] - => cbv [wordToZ op] in * + | [ |- ?wordToZ ?op = _ ] + => let op' := head op in + cbv [wordToZ op'] in * end; autorewrite with push_Zto_N push_Zof_N push_wordToN; try reflexivity. + Local Ltac w64ToZ_extra_t := + repeat first [ reflexivity + | progress cbv [ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovne ModularBaseSystemListZOperations.cmovl (*ModularBaseSystemListZOperations.conditional_subtract_modulus*)] in * + | progress break_match + | progress fold_Word64_Z + | progress intros + | progress autorewrite with push_Zto_N push_Zof_N push_wordToN push_word64ToZ ]. + Local Notation bounds_statement wop Zop + := ((0 <= Zop -> Z.log2 Zop < Z.of_nat bit_width -> word64ToZ wop = Zop)%Z). + Local Notation bounds_statement_tuple wop Zop + := ((HList.hlist (fun v => 0 <= v /\ Z.log2 v < Z.of_nat bit_width) Zop -> Tuple.map word64ToZ wop = Zop)%Z). Local Notation bounds_2statement wop Zop := (forall x y, - (0 <= Zop (word64ToZ x) (word64ToZ y) - -> Z.log2 (Zop (word64ToZ x) (word64ToZ y)) < Z.of_nat bit_width - -> word64ToZ (wop x y) = (Zop (word64ToZ x) (word64ToZ y)))%Z). + bounds_statement (wop x y) (Zop (word64ToZ x) (word64ToZ y))). + Local Notation bounds_1_tuple2_statement wop Zop + := (forall x y z, + bounds_statement_tuple (wop x y z) (Zop (word64ToZ x) (Tuple.map word64ToZ y) (Tuple.map word64ToZ z))). + Local Notation bounds_4statement wop Zop + := (forall x y z w, + bounds_statement (wop x y z w) (Zop (word64ToZ x) (word64ToZ y) (word64ToZ z) (word64ToZ w))). Require Import Crypto.Assembly.WordizeUtil. - Lemma word64ToZ_w64plus : bounds_2statement w64plus Z.add. Proof. w64ToZ_t. Qed. - Lemma word64ToZ_w64minus : bounds_2statement w64minus Z.sub. Proof. w64ToZ_t. Qed. - Lemma word64ToZ_w64mul : bounds_2statement w64mul Z.mul. Proof. w64ToZ_t. Qed. - Lemma word64ToZ_w64land : bounds_2statement w64land Z.land. Proof. w64ToZ_t. Qed. - Lemma word64ToZ_w64lor : bounds_2statement w64lor Z.lor. Proof. w64ToZ_t. Qed. + Lemma word64ToZ_add : bounds_2statement add Z.add. Proof. w64ToZ_t. Qed. + Lemma word64ToZ_sub : bounds_2statement sub Z.sub. Proof. w64ToZ_t. Qed. + Lemma word64ToZ_mul : bounds_2statement mul Z.mul. Proof. w64ToZ_t. Qed. - Lemma word64ToZ_w64shl : bounds_2statement w64shl Z.shiftl. + Lemma word64ToZ_shl : bounds_2statement shl Z.shiftl. Proof. - intros x y H H0. - w64ToZ_t. - - destruct (N.eq_dec (Z.to_N (Z.of_N (wordToN x) << Z.of_N (wordToN y))) 0%N) as [e|e]; [ - rewrite e; rewrite wordToN_NToWord; [|apply Npow2_gt0]; - rewrite <- e; rewrite Z2N.id; [reflexivity|assumption] - | apply N.neq_0_lt_0 in e]. - - apply Z.bits_inj_iff'; intros k Hpos. - rewrite Z.shiftl_spec; [|assumption]. - rewrite Z2N.inj_testbit; [|assumption]. - rewrite wordToN_NToWord. - - - rewrite <- N2Z.inj_testbit. - rewrite (Z2N.id k); [|assumption]. - rewrite Z2N.id; [|assumption]. - rewrite Z.shiftl_spec; [reflexivity|assumption]. - - - rewrite Npow2_N. - apply N.log2_lt_pow2; [assumption|]. - apply N2Z.inj_lt. - rewrite nat_N_Z. - refine (Z.le_lt_trans _ _ _ _ H0). - rewrite log2_conv; reflexivity. + w64ToZ_t; w64ToZ_extra_t; unfold word64ToZ, wordBin. + rewrite wordToN_NToWord; [rewrite <- Z.N2Z.inj_shiftl; reflexivity|]. + apply N2Z.inj_lt. + rewrite Z.N2Z.inj_shiftl. + destruct (Z.lt_ge_cases 0 ((word64ToZ x) << (word64ToZ y)))%Z; + [|eapply Z.le_lt_trans; [|apply N2Z.inj_lt, Npow2_gt0]; assumption]. + rewrite Npow2_N, N2Z.inj_pow. + apply Z.log2_lt_pow2; assumption. Qed. - Lemma word64ToZ_w64shr : bounds_2statement w64shr Z.shiftr. + Lemma word64ToZ_shr : bounds_2statement shr Z.shiftr. Proof. - intros x y H H0. - w64ToZ_t. - - destruct (N.eq_dec (Z.to_N (Z.of_N (wordToN x) >> Z.of_N (wordToN y))) 0%N) as [e|e]; [ - rewrite e; rewrite wordToN_NToWord; [|apply Npow2_gt0]; - rewrite <- e; rewrite Z2N.id; [reflexivity|assumption] - | apply N.neq_0_lt_0 in e]. - - apply Z.bits_inj_iff'; intros k Hpos. - rewrite Z.shiftr_spec; [|assumption]. - rewrite Z2N.inj_testbit; [|assumption]. - rewrite wordToN_NToWord. - - - rewrite <- N2Z.inj_testbit. - rewrite (Z2N.id k); [|assumption]. - rewrite Z2N.id; [|assumption]. - rewrite Z.shiftr_spec; [reflexivity|assumption]. - - - rewrite Npow2_N. - apply N.log2_lt_pow2; [assumption|]. - apply N2Z.inj_lt. - rewrite nat_N_Z. - refine (Z.le_lt_trans _ _ _ _ H0). - rewrite log2_conv; reflexivity. + w64ToZ_t; w64ToZ_extra_t; unfold word64ToZ, wordBin. + rewrite wordToN_NToWord; [rewrite <- Z.N2Z.inj_shiftr; reflexivity|]. + apply N2Z.inj_lt. + rewrite Z.N2Z.inj_shiftr. + destruct (Z.lt_ge_cases 0 ((word64ToZ x) >> (word64ToZ y)))%Z; + [|eapply Z.le_lt_trans; [|apply N2Z.inj_lt, Npow2_gt0]; assumption]. + rewrite Npow2_N, N2Z.inj_pow. + apply Z.log2_lt_pow2; assumption. + Qed. + + Lemma word64ToZ_land : bounds_2statement land Z.land. Proof. w64ToZ_t. Qed. + Lemma word64ToZ_lor : bounds_2statement lor Z.lor. Proof. w64ToZ_t. Qed. + Lemma word64ToZ_neg : bounds_2statement neg ModularBaseSystemListZOperations.neg. + Proof. w64ToZ_t; w64ToZ_extra_t. Qed. + Lemma word64ToZ_cmovne : bounds_4statement cmovne ModularBaseSystemListZOperations.cmovne. + Proof. w64ToZ_t; w64ToZ_extra_t. Qed. + Lemma word64ToZ_cmovle : bounds_4statement cmovle ModularBaseSystemListZOperations.cmovl. + Proof. w64ToZ_t; w64ToZ_extra_t. Qed. + Lemma word64ToZ_conditional_subtract pred_limb_count + : bounds_1_tuple2_statement (@conditional_subtract pred_limb_count) + (@ModularBaseSystemListZOperations.conditional_subtract_modulus (S pred_limb_count)). + Proof. + w64ToZ_t; unfold conditional_subtract; w64ToZ_extra_t. + repeat first [ progress w64ToZ_extra_t + | rewrite Tuple.map_map + | rewrite HList.Tuple.map_id_ext + | match goal with + | [ H : HList.hlist _ _ |- HList.hlist _ _ ] + => revert H; apply HList.hlist_impl + end + | apply HList.const ]. Qed. Definition interp_base_type (t : base_type) : Type @@ -152,11 +279,14 @@ Module Word64. | Mul => fun xy => fst xy * snd xy | Shl => fun xy => fst xy << snd xy | Shr => fun xy => fst xy >> snd xy - | Land => fun xy => w64land (fst xy) (snd xy) - | Lor => fun xy => w64lor (fst xy) (snd xy) - | Neg => fun xy => w64neg (fst xy) (snd xy) - | Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in w64cmovne x y z w - | Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in w64cmovle x y z w + | Land => fun xy => land (fst xy) (snd xy) + | Lor => fun xy => lor (fst xy) (snd xy) + | Neg => fun xy => neg (fst xy) (snd xy) + | Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w + | Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle x y z w + | ConditionalSubtract pred_n + => fun xyz => let '(x, y, z) := eta3 xyz in + flat_interp_untuple' (T:=Tbase TZ) (@conditional_subtract pred_n x (flat_interp_tuple y) (flat_interp_tuple z)) end%word64. Definition of_Z ty : Z.interp_base_type ty -> interp_base_type ty @@ -171,20 +301,28 @@ Module Word64. Module Export Rewrites. Ltac word64_util_arith := omega. - Hint Rewrite word64ToZ_w64plus using word64_util_arith : push_word64ToZ. - Hint Rewrite <- word64ToZ_w64plus using word64_util_arith : pull_word64ToZ. - Hint Rewrite word64ToZ_w64minus using word64_util_arith : push_word64ToZ. - Hint Rewrite <- word64ToZ_w64minus using word64_util_arith : pull_word64ToZ. - Hint Rewrite word64ToZ_w64mul using word64_util_arith : push_word64ToZ. - Hint Rewrite <- word64ToZ_w64mul using word64_util_arith : pull_word64ToZ. - Hint Rewrite word64ToZ_w64shl using word64_util_arith : push_word64ToZ. - Hint Rewrite <- word64ToZ_w64shl using word64_util_arith : pull_word64ToZ. - Hint Rewrite word64ToZ_w64shr using word64_util_arith : push_word64ToZ. - Hint Rewrite <- word64ToZ_w64shr using word64_util_arith : pull_word64ToZ. - Hint Rewrite word64ToZ_w64land using word64_util_arith : push_word64ToZ. - Hint Rewrite <- word64ToZ_w64land using word64_util_arith : pull_word64ToZ. - Hint Rewrite word64ToZ_w64lor using word64_util_arith : push_word64ToZ. - Hint Rewrite <- word64ToZ_w64lor using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_add using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_add using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_sub using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_sub using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_mul using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_mul using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_shl using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_shl using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_shr using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_shr using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_land using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_land using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_lor using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_lor using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_neg using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_neg using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_cmovne using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_cmovne using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_cmovle using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_cmovle using word64_util_arith : pull_word64ToZ. + Hint Rewrite word64ToZ_conditional_subtract using word64_util_arith : push_word64ToZ. + Hint Rewrite <- word64ToZ_conditional_subtract using word64_util_arith : pull_word64ToZ. End Rewrites. End Word64. @@ -200,65 +338,108 @@ Module ZBounds. := if ((0 <=? l) && (Z.log2 u <? Word64.bit_width))%Z%bool then Some {| lower := l ; upper := u |} else None. - Definition t_map2 (f : Z -> Z -> Z -> Z -> bounds) (x y : t) + Definition t_map2 (f : bounds -> bounds -> bounds) (x y : t) := match x, y with - | Some (Build_bounds lx ux), Some (Build_bounds ly uy) - => match f lx ly ux uy with + | Some x, Some y + => match f x y with | Build_bounds l u => SmartBuildBounds l u end | _, _ => None end%Z. + Definition t_map4 (f : bounds -> bounds -> bounds -> bounds -> bounds) (x y z w : t) + := match x, y, z, w with + | Some x, Some y, Some z, Some w + => match f x y z w with + | Build_bounds l u + => SmartBuildBounds l u + end + | _, _, _, _ => None + end%Z. - Definition add : t -> t -> t - := t_map2 (fun lx ly ux uy => {| lower := lx + ly ; upper := ux + uy |}). - Definition sub : t -> t -> t - := t_map2 (fun lx ly ux uy => {| lower := lx - uy ; upper := ux - ly |}). - Definition mul : t -> t -> t - := t_map2 (fun lx ly ux uy => {| lower := lx * ly ; upper := ux * uy |}). - Definition shl : t -> t -> t - := t_map2 (fun lx ly ux uy => {| lower := lx << ly ; upper := ux << uy |}). - Definition shr : t -> t -> t - := t_map2 (fun lx ly ux uy => {| lower := lx >> uy ; upper := ux >> ly |}). - Definition land : t -> t -> t - := t_map2 (fun lx ly ux uy => {| lower := 0 ; - upper := 2^(Z.succ (Z.min (Z.log2 ux) (Z.log2 uy))) |}). - Definition lor : t -> t -> t - := t_map2 (fun lx ly ux uy => {| lower := Z.max lx ly; - upper := 2^(Z.succ (Z.max (Z.log2 ux) (Z.log2 uy))) |}). + Definition add' : bounds -> bounds -> bounds + := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx + ly ; upper := ux + uy |}. + Definition add : t -> t -> t := t_map2 add'. + Definition sub' : bounds -> bounds -> bounds + := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx - uy ; upper := ux - ly |}. + Definition sub : t -> t -> t := t_map2 sub'. + Definition mul' : bounds -> bounds -> bounds + := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx * ly ; upper := ux * uy |}. + Definition mul : t -> t -> t := t_map2 mul'. + Definition shl' : bounds -> bounds -> bounds + := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx << ly ; upper := ux << uy |}. + Definition shl : t -> t -> t := t_map2 shl'. + Definition shr' : bounds -> bounds -> bounds + := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx >> uy ; upper := ux >> ly |}. + Definition shr : t -> t -> t := t_map2 shr'. + + Definition land' : bounds -> bounds -> bounds + := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := 0 ; upper := Z.min ux uy |}. + Definition land : t -> t -> t := t_map2 land'. + Definition lor' : bounds -> bounds -> bounds + := fun x y => let (lx, ux) := x in let (ly, uy) := y in + {| lower := Z.max lx ly; + upper := 2^(Z.max (Z.log2 (ux+1)) (Z.log2 (uy+1))) - 1 |}. + Definition lor : t -> t -> t := t_map2 lor'. + Definition neg' : bounds -> bounds -> bounds + := fun int_width v + => let (lint_width, uint_width) := int_width in + let (lb, ub) := v in + let might_be_one := ((lb <=? 1) && (1 <=? ub))%Z%bool in + let must_be_one := ((lb =? 1) && (ub =? 1))%Z%bool in + if must_be_one + then {| lower := Z.ones lint_width ; upper := Z.ones uint_width |} + else if might_be_one + then {| lower := 0 ; upper := Z.ones uint_width |} + else {| lower := 0 ; upper := 0 |}. Definition neg : t -> t -> t - := t_map2 (fun lint_width lb uint_width ub - => let might_be_one := ((lb <=? 1) && (1 <=? ub))%Z%bool in - let must_be_one := ((lb =? 1) && (ub =? 1))%Z%bool in - if must_be_one - then {| lower := Z.ones lint_width ; upper := Z.ones uint_width |} - else if might_be_one - then {| lower := 0 ; upper := Z.ones uint_width |} - else {| lower := 0 ; upper := 0 |}). - Definition cmovne (x y r1 r2 : t) : t - := match x, y with - | Some (Build_bounds lx ux), Some (Build_bounds ly uy) - => let must_be_equal := ((lx =? ux) && (ly =? uy) && (lx =? ly))%Z%bool in - let might_be_equal := ((lx <=? uy) && (ly <=? ux))%Z%bool in - if must_be_equal - then r1 - else if negb might_be_equal - then r2 - else t_map2 (fun lr1 lr2 ur1 ur2 => {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}) r1 r2 - | _, _ => None - end%Z. - Definition cmovle (x y r1 r2 : t) : t - := match x, y with - | Some (Build_bounds lx ux), Some (Build_bounds ly uy) - => let must_be_le := (ux <=? ly)%Z in - let might_be_le := (lx <=? uy)%Z in - if must_be_le - then r1 - else if negb might_be_le - then r2 - else t_map2 (fun lr1 lr2 ur1 ur2 => {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}) r1 r2 - | _, _ => None - end%Z. + := fun int_width v + => match int_width, v with + | Some (Build_bounds lint_width uint_width as int_width), Some (Build_bounds lb ub as v) + => if ((0 <=? lint_width) && (uint_width <=? Word64.bit_width))%Z%bool + then Some (neg' int_width v) + else None + | _, _ => None + end. + Definition cmovne' (r1 r2 : bounds) : bounds + := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. + Definition cmovne (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovne') x y r1 r2. + Definition cmovle' (r1 r2 : bounds) : bounds + := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. + Definition cmovle (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovle') x y r1 r2. + (** TODO(jadep): Check that this is correct; it computes the bounds, + conditional on the assumption that the entire calculation is + valid. Currently, it says that each limb is upper-bounded by + either the original value less the modulus, or by the smaller of + the original value and the modulus (in the case that the + subtraction is negative). Feel free to substitute any other + bounds you'd like here. *) + Definition conditional_subtract' (pred_n : nat) (int_width : bounds) + (modulus value : Tuple.tuple bounds (S pred_n)) + : Tuple.tuple bounds (S pred_n) + := Tuple.map2 + (fun modulus_bounds value_bounds : bounds + => let (ml, mu) := modulus_bounds in + let (vl, vu) := value_bounds in + {| lower := 0 ; upper := Z.max (Z.min vu mu) (vu - ml) |}) + modulus value. + (** TODO(jadep): Fill me in. This should check that the modulus and + value fit within int_width, that the modulus is of the right + form, and that the value is small enough. *) + Axiom check_conditional_subtract_bounds + : forall (pred_n : nat) (int_width : bounds) + (modulus value : Tuple.tuple bounds (S pred_n)), bool. + Definition conditional_subtract (pred_n : nat) (int_width : t) + (modulus value : Tuple.tuple t (S pred_n)) + : Tuple.tuple t (S pred_n) + := Tuple.push_option + match int_width, Tuple.lift_option modulus, Tuple.lift_option value with + | Some int_width, Some modulus, Some value + => if check_conditional_subtract_bounds pred_n int_width modulus value + then Some (conditional_subtract' pred_n int_width modulus value) + else None + | _, _, _ => None + end. Module Export Notations. Delimit Scope bounds_scope with bounds. @@ -272,9 +453,7 @@ Module ZBounds. End Notations. Definition interp_base_type (ty : base_type) : Type - := match ty with - | TZ => t - end. + := LiftOption.interp_base_type' bounds ty. Definition interp_op {src dst} (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst := match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with | Add => fun xy => fst xy + snd xy @@ -287,6 +466,9 @@ Module ZBounds. | Neg => fun xy => neg (fst xy) (snd xy) | Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w | Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle x y z w + | ConditionalSubtract pred_n + => fun xyz => let '(x, y, z) := eta3 xyz in + flat_interp_untuple' (T:=Tbase TZ) (@conditional_subtract pred_n x (flat_interp_tuple y) (flat_interp_tuple z)) end%bounds. Definition of_word64 ty : Word64.interp_base_type ty -> interp_base_type ty @@ -307,313 +489,417 @@ Module ZBounds. End ZBounds. Module BoundedWord64. + Local Notation is_bounded_by value lower upper + := ((0 <= lower /\ lower <= Word64.word64ToZ value <= upper /\ Z.log2 upper < Z.of_nat Word64.bit_width)%Z) + (only parsing). Record BoundedWord := { lower : Z ; value : Word64.word64 ; upper : Z ; - in_bounds : (0 <= lower /\ lower <= Word64.word64ToZ value <= upper /\ Z.log2 upper < Z.of_nat Word64.bit_width)%Z }. + in_bounds : is_bounded_by value lower upper }. Bind Scope bounded_word_scope with BoundedWord. Definition t := option BoundedWord. Bind Scope bounded_word_scope with t. Local Coercion Z.of_nat : nat >-> Z. - Definition interp_base_type (ty : base_type) : Type - := match ty with - | TZ => t - end. + Ltac inversion_BoundedWord := + repeat match goal with + | _ => progress subst + | [ H : _ = _ :> BoundedWord |- _ ] + => pose proof (f_equal lower H); + pose proof (f_equal upper H); + pose proof (f_equal value H); + clear H + end. + Definition interp_base_type (ty : base_type) + := LiftOption.interp_base_type' BoundedWord ty. Definition word64ToBoundedWord (x : Word64.word64) : t. Proof. refine (let v := Word64.word64ToZ x in - (if (0 <=? v)%Z as Hl return (0 <=? v)%Z = Hl -> t - then (if (Z.log2 v <? Z.of_nat Word64.bit_width)%Z as Hu return (Z.log2 v <? Z.of_nat Word64.bit_width)%Z = Hu -> _ -> t - then fun Hu Hl => Some {| lower := Word64.word64ToZ x ; value := x ; upper := Word64.word64ToZ x |} - else fun _ _ => None) eq_refl - else fun _ => None) eq_refl). + match Sumbool.sumbool_of_bool (0 <=? v)%Z, Sumbool.sumbool_of_bool (Z.log2 v <? Z.of_nat Word64.bit_width)%Z with + | left Hl, left Hu + => Some {| lower := Word64.word64ToZ x ; value := x ; upper := Word64.word64ToZ x |} + | _, _ => None + end). subst v. abstract (Z.ltb_to_lt; repeat split; (assumption || reflexivity)). Defined. + Definition boundedWordToWord64 (x : t) : Word64.word64 + := match x with + | Some x' => value x' + | None => Word64.ZToWord64 0 + end. + Definition of_word64 ty : Word64.interp_base_type ty -> interp_base_type ty := match ty return Word64.interp_base_type ty -> interp_base_type ty with | TZ => word64ToBoundedWord end. + Definition to_word64 ty : interp_base_type ty -> Word64.interp_base_type ty + := match ty return interp_base_type ty -> Word64.interp_base_type ty with + | TZ => boundedWordToWord64 + end. + + (** XXX FIXME(jgross) This is going to break horribly if we need to support any types other than [Z] *) + Definition to_word64' ty : BoundedWord -> Word64.interp_base_type ty + := match ty return BoundedWord -> Word64.interp_base_type ty with + | TZ => fun x => boundedWordToWord64 (Some x) + end. + + Definition to_Z' ty : BoundedWord -> Z.interp_base_type ty + := fun x => Word64.to_Z _ (to_word64' _ x). + + Definition of_Z ty : Z.interp_base_type ty -> interp_base_type ty + := fun x => of_word64 _ (Word64.of_Z _ x). + Definition to_Z ty : interp_base_type ty -> Z.interp_base_type ty + := fun x => Word64.to_Z _ (to_word64 _ x). Definition BoundedWordToBounds (x : BoundedWord) : ZBounds.bounds := {| ZBounds.lower := lower x ; ZBounds.upper := upper x |}. + Definition to_bounds' : t -> ZBounds.t + := option_map BoundedWordToBounds. + Definition to_bounds ty : interp_base_type ty -> ZBounds.interp_base_type ty := match ty return interp_base_type ty -> ZBounds.interp_base_type ty with - | TZ => option_map BoundedWordToBounds + | TZ => to_bounds' end. - Local Ltac build_binop word_op bounds_op := - refine (fun x y : t - => match x, y with - | Some x, Some y - => match bounds_op (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) - as bop return bounds_op (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) = bop -> t - with - | Some (ZBounds.Build_bounds l u) - => let pff := _ in - fun pf => Some {| lower := l ; value := word_op (value x) (value y) ; upper := u; - in_bounds := pff pf |} - | None => fun _ => None - end eq_refl - | _, _ => None - end); - try unfold word_op; try unfold bounds_op; - cbv [ZBounds.t_map2 BoundedWordToBounds ZBounds.SmartBuildBounds]. - - Local Ltac build_4op word_op bounds_op := - refine (fun x y z w : t - => match x, y, z, w with - | Some x, Some y, Some z, Some w - => match bounds_op (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) - (Some (BoundedWordToBounds z)) (Some (BoundedWordToBounds w)) - as bop return bounds_op (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) - (Some (BoundedWordToBounds z)) (Some (BoundedWordToBounds w)) - = bop -> t - with - | Some (ZBounds.Build_bounds l u) - => let pff := _ in - fun pf => Some {| lower := l ; value := word_op (value x) (value y) (value z) (value w) ; upper := u; - in_bounds := pff pf |} - | None => fun _ => None - end eq_refl - | _, _, _, _ => None - end); - try unfold word_op; try unfold bounds_op; - cbv [ZBounds.t_map2 BoundedWordToBounds ZBounds.SmartBuildBounds]. + Definition t_map2 + (opW : Word64.word64 -> Word64.word64 -> Word64.word64) + (opB : ZBounds.t -> ZBounds.t -> ZBounds.t) + (pf : forall x y l u, + opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) + = Some {| ZBounds.lower := l ; ZBounds.upper := u |} + -> let val := opW (value x) (value y) in + is_bounded_by val l u) + : t -> t -> t + := fun x y : t + => match x, y with + | Some x, Some y + => match opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) + as bop return opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) = bop -> t + with + | Some (ZBounds.Build_bounds l u) + => fun Heq => Some {| lower := l ; value := opW (value x) (value y) ; upper := u; + in_bounds := pf _ _ _ _ Heq |} + | None => fun _ => None + end eq_refl + | _, _ => None + end. + + Definition t_map4 + (opW : Word64.word64 -> Word64.word64 -> Word64.word64 -> Word64.word64 -> Word64.word64) + (opB : ZBounds.t -> ZBounds.t -> ZBounds.t -> ZBounds.t -> ZBounds.t) + (pf : forall x y z w l u, + opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) (Some (BoundedWordToBounds z)) (Some (BoundedWordToBounds w)) + = Some {| ZBounds.lower := l ; ZBounds.upper := u |} + -> let val := opW (value x) (value y) (value z) (value w) in + is_bounded_by val l u) + : t -> t -> t -> t -> t + := fun x y z w : t + => match x, y, z, w with + | Some x, Some y, Some z, Some w + => match opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) + (Some (BoundedWordToBounds z)) (Some (BoundedWordToBounds w)) + as bop return opB _ _ _ _ = bop -> t + with + | Some (ZBounds.Build_bounds l u) + => fun Heq => Some {| lower := l ; value := opW (value x) (value y) (value z) (value w) ; upper := u; + in_bounds := pf _ _ _ _ _ _ Heq |} + | None => fun _ => None + end eq_refl + | _, _, _, _ => None + end. + + Definition t_map1_tuple2 {n} + (opW : Word64.word64 -> Tuple.tuple Word64.word64 (S n) -> Tuple.tuple Word64.word64 (S n) -> Tuple.tuple Word64.word64 (S n)) + (opB : ZBounds.t -> Tuple.tuple ZBounds.t (S n) -> Tuple.tuple ZBounds.t (S n) -> Tuple.tuple ZBounds.t (S n)) + (pf : forall x y z bs, + Tuple.lift_option + (opB (Some (BoundedWordToBounds x)) (Tuple.push_option (Some (Tuple.map BoundedWordToBounds y))) + (Tuple.push_option (Some (Tuple.map BoundedWordToBounds z)))) + = Some bs + -> let val := opW (value x) (Tuple.map value y) (Tuple.map value z) in + HList.hlist + (fun vlu => let v := fst vlu in + let lu : ZBounds.bounds := snd vlu in + is_bounded_by v (ZBounds.lower lu) (ZBounds.upper lu)) + (Tuple.map2 (fun v (lu : ZBounds.bounds) => (v, lu)) + val bs)) + : t -> Tuple.tuple t (S n) -> Tuple.tuple t (S n) -> Tuple.tuple t (S n) + := fun (x : t) (y z : Tuple.tuple t (S n)) + => Tuple.push_option + match x, Tuple.lift_option y, Tuple.lift_option z with + | Some x, Some y, Some z + => match Tuple.lift_option (opB (Some (BoundedWordToBounds x)) + (Tuple.push_option (Some (Tuple.map BoundedWordToBounds y))) + (Tuple.push_option (Some (Tuple.map BoundedWordToBounds z)))) + as bop return Tuple.lift_option _ = bop -> option (Tuple.tuple _ (S n)) with + | Some bs + => fun Heq + => let v + := HList.mapt + (fun (vlu : Word64.word64 * ZBounds.bounds) pf + => {| lower := ZBounds.lower (snd vlu) ; value := fst vlu ; upper := ZBounds.upper (snd vlu) ; + in_bounds := pf |}) + (pf _ _ _ _ Heq) in + Some v + | None => fun _ => None + end eq_refl + | _, _, _ => None + end. + + Axiom proof_admitted : False. + Local Opaque Word64.bit_width. + Hint Resolve Z.ones_nonneg : zarith. Local Ltac t_start := - repeat first [ progress break_match + repeat first [ match goal with + | [ |- forall x y l u, ?opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) = Some _ -> let val := ?opW (value x) (value y) in _ ] + => try unfold opB; try unfold opW + | [ |- forall x y z w l u, ?opB _ _ _ _ = Some _ -> let val := ?opW (value x) (value y) (value z) (value w) in _ ] + => try unfold opB; try unfold opW + | [ |- appcontext[ZBounds.t_map2 ?op] ] => unfold op + | [ |- appcontext[?op (ZBounds.Build_bounds _ _) (ZBounds.Build_bounds _ _)] ] => unfold op + end + | progress cbv [BoundedWordToBounds ZBounds.SmartBuildBounds cmovne cmovl ModularBaseSystemListZOperations.neg] in * + | progress break_match + | progress break_match_hyps | progress intros | progress subst | progress ZBounds.inversion_bounds | progress inversion_option + | progress Word64.fold_Word64_Z | progress autorewrite with bool_congr_setoid in * | progress destruct_head' and | progress Z.ltb_to_lt | assumption | progress destruct_head' BoundedWord; simpl in * | progress autorewrite with push_word64ToZ - | progress repeat apply conj ]. + | progress repeat apply conj + | solve [ Word64.arith ] + | progress destruct_head' or ]. Ltac ktrans k := do k (etransitivity; [|eassumption]); assumption. Ltac trans' := first [ assumption | ktrans ltac:1 | ktrans ltac:2 ]. + + (** TODO(jadep): Use the bounds lemma here to prove that if each + component of [ret_val] is [Some (l, v, u)], then we can fill in + [pf] and return the tuple of [{| lower := l ; value := v ; upper + := u ; in_bounds := pf |}]. *) + Lemma conditional_subtract_bounded + (pred_n : nat) (x : BoundedWord) + (y z : Tuple.tuple BoundedWord (S pred_n)) + (H : ZBounds.check_conditional_subtract_bounds + pred_n (BoundedWordToBounds x) + (Tuple.map BoundedWordToBounds y) (Tuple.map BoundedWordToBounds z) = true) + : HList.hlist + (fun vlu : Word64.word64 * ZBounds.bounds => + (0 <= ZBounds.lower (snd vlu))%Z /\ + (ZBounds.lower (snd vlu) <= Word64.word64ToZ (fst vlu) <= ZBounds.upper (snd vlu))%Z /\ + (Z.log2 (ZBounds.upper (snd vlu)) < Word64.bit_width)%Z) + (Tuple.map2 (fun v lu => (v, lu)) + (Word64.conditional_subtract + pred_n (value x) (Tuple.map value y) (Tuple.map value z)) + (ZBounds.conditional_subtract' + pred_n (BoundedWordToBounds x) + (Tuple.map BoundedWordToBounds y) (Tuple.map BoundedWordToBounds z))). + Proof. Admitted. + + Local Ltac kill_assumptions := + repeat split; abstract (cbn; assumption). + + (* TODO (rsloan): not entirely sure what's the best way to match on these... *) + Local Ltac apply_update lem lower0 value0 upper0 lower1 value1 upper1 := first + [ apply (lem 64 lower1 value1 upper1 lower0 value0 upper0); kill_assumptions + | apply (lem 64 lower0 value0 upper0 lower1 value1 upper1); kill_assumptions]. + Definition add : t -> t -> t. Proof. - Ltac add_mono := - etransitivity; - [apply Z.add_le_mono_l | apply Z.add_le_mono_r]; - trans'. - - build_binop Word64.w64plus ZBounds.add; t_start; - unfold Word64.word64ToZ; rewrite wordToN_wplus; abstract first - [ add_mono - | transitivity (lower1 + lower0)%Z; [assumption|]; add_mono - | eapply Z.le_lt_trans; [|eassumption]; apply Z.log2_le_mono; add_mono ]. + refine (t_map2 Word64.add ZBounds.add _); t_start; + apply_update @add_valid_update lower0 value0 upper0 lower1 value1 upper1. Defined. Definition sub : t -> t -> t. Proof. - Ltac sub_mono := - etransitivity; - [| apply Z.sub_le_mono_r; eassumption]; first - [ apply Z.sub_le_mono_l; assumption - | apply Z.le_add_le_sub_l; etransitivity; - [|eassumption]; repeat rewrite Z.add_0_r; assumption]. - - build_binop Word64.w64minus ZBounds.sub; t_start; - unfold Word64.word64ToZ; rewrite wordToN_wminus; - apply Z.le_add_le_sub_l in H; abstract first - [ sub_mono - | transitivity (lower1 - lower0)%Z; [assumption|]; sub_mono - | eapply Z.le_lt_trans; [|eassumption]; apply Z.log2_le_mono; sub_mono ]. + refine (t_map2 Word64.sub ZBounds.sub _); t_start; + apply_update @sub_valid_update lower0 value0 upper0 lower1 value1 upper1. Defined. Definition mul : t -> t -> t. Proof. - Ltac mul_mono := - etransitivity; - [apply Z.mul_le_mono_nonneg_l | apply Z.mul_le_mono_nonneg_r]; - trans'. - - build_binop Word64.w64mul ZBounds.mul; t_start; - unfold Word64.word64ToZ; rewrite wordToN_wmult; abstract first - [ mul_mono - | transitivity (lower1 * lower0)%Z; [assumption|]; mul_mono - | eapply Z.le_lt_trans; [|eassumption]; apply Z.log2_le_mono; mul_mono ]. + refine (t_map2 Word64.mul ZBounds.mul _); t_start; + apply_update @mul_valid_update lower0 value0 upper0 lower1 value1 upper1. Defined. + Definition land : t -> t -> t. + Proof. + refine (t_map2 Word64.land ZBounds.land _); t_start; + apply_update @land_valid_update lower0 value0 upper0 lower1 value1 upper1. + Qed. + + Definition lor : t -> t -> t. + Proof. + refine (t_map2 Word64.lor ZBounds.lor _); t_start; + apply_update @lor_valid_update lower0 value0 upper0 lower1 value1 upper1. + Qed. + Definition shl : t -> t -> t. Proof. - Ltac shl_mono := etransitivity; - [apply Z.mul_le_mono_nonneg_l | apply Z.mul_le_mono_nonneg_r]. - - build_binop Word64.w64shl ZBounds.shl; t_start; abstract ( - unfold Word64.word64ToZ; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; - rewrite Z.shiftl_mul_pow2 in *; - repeat match goal with - | [|- (0 <= 2 ^ _)%Z ] => apply Z.pow_nonneg - | [|- (0 <= _ * _)%Z ] => apply Z.mul_nonneg_nonneg - | [|- (2 ^ _ <= 2 ^ _)%Z ] => apply Z.pow_le_mono_r - | [|- context[(?a << ?b)%Z]] => rewrite Z.shiftl_mul_pow2 - | [|- (_ < Npow2 _)%N] => - apply N2Z.inj_lt, Z.log2_lt_cancel; simpl; - eapply Z.le_lt_trans; [|eassumption]; apply Z.log2_le_mono; rewrite Z2N.id - - | _ => progress shl_mono - | _ => progress trans' - | _ => progress omega - end). + refine (t_map2 Word64.shl ZBounds.shl _); t_start; + apply_update @shl_valid_update lower0 value0 upper0 lower1 value1 upper1. Defined. Definition shr : t -> t -> t. Proof. - Ltac shr_mono := etransitivity; - [apply Z.div_le_compat_l | apply Z.div_le_mono]. - - assert (forall x, (0 <= x)%Z -> (0 < 2^x)%Z) as gt0. { - intros; rewrite <- (Z2Nat.id x); [|assumption]. - induction (Z.to_nat x) as [|n]; [cbv; auto|]. - eapply Z.lt_le_trans; [eassumption|rewrite Nat2Z.inj_succ]. - apply Z.pow_le_mono_r; [cbv; auto|omega]. - } - - build_binop Word64.w64shr ZBounds.shr; t_start; abstract ( - unfold Word64.word64ToZ; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; - rewrite Z.shiftr_div_pow2 in *; - repeat match goal with - | [|- _ /\ _ ] => split - | [|- (0 <= 2 ^ _)%Z ] => apply Z.pow_nonneg - | [|- (0 < 2 ^ ?X)%Z ] => apply gt0 - | [|- (0 <= _ / _)%Z ] => apply Z.div_le_lower_bound; [|rewrite Z.mul_0_r] - | [|- (2 ^ _ <= 2 ^ _)%Z ] => apply Z.pow_le_mono_r - | [|- context[(?a >> ?b)%Z]] => rewrite Z.shiftr_div_pow2 in * - | [|- (_ < Npow2 _)%N] => - apply N2Z.inj_lt, Z.log2_lt_cancel; simpl; - eapply Z.le_lt_trans; [|eassumption]; apply Z.log2_le_mono; rewrite Z2N.id - - | _ => progress shr_mono - | _ => progress trans' - | _ => progress omega - end). + refine (t_map2 Word64.shr ZBounds.shr _); t_start; + apply_update @shr_valid_update lower0 value0 upper0 lower1 value1 upper1. Defined. - Definition land : t -> t -> t. + Definition neg : t -> t -> t. + Proof. refine (t_map2 Word64.neg ZBounds.neg _); abstract t_start. Defined. + + Definition cmovne : t -> t -> t -> t -> t. + Proof. refine (t_map4 Word64.cmovne ZBounds.cmovne _); abstract t_start. Defined. + + Definition cmovle : t -> t -> t -> t -> t. + Proof. refine (t_map4 Word64.cmovle ZBounds.cmovle _); abstract t_start. Defined. + + Definition conditional_subtract (pred_n : nat) + : forall (int_width : t) (modulus val : Tuple.tuple t (S pred_n)), + Tuple.tuple t (S pred_n). Proof. - build_binop Word64.w64land ZBounds.land; t_start; [apply N2Z.is_nonneg|]; - unfold Word64.word64ToZ; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; - rewrite wordize_and. - - destruct (Z_ge_dec upper1 upper0) as [g|g]. - - - rewrite Z.min_r; [|abstract (apply Z.log2_le_mono; omega)]. - abstract ( - rewrite (land_intro_ones (wordToN value0)); - rewrite N.land_assoc; - etransitivity; [apply N2Z.inj_le; apply N.lt_le_incl; apply land_lt_Npow2|]; - rewrite N2Z.inj_pow; - apply Z.pow_le_mono; [abstract (split; cbn; [omega|reflexivity])|]; - unfold getBits; rewrite N2Z.inj_succ; - apply -> Z.succ_le_mono; - rewrite <- (N2Z.id (wordToN value0)), <- log2_conv; - apply Z.log2_le_mono; - etransitivity; [eassumption|reflexivity]). - - - rewrite Z.min_l; [|abstract (apply Z.log2_le_mono; omega)]. - abstract ( - rewrite (land_intro_ones (wordToN value1)); - rewrite <- N.land_comm, N.land_assoc; - etransitivity; [apply N2Z.inj_le; apply N.lt_le_incl; apply land_lt_Npow2|]; - rewrite N2Z.inj_pow; - apply Z.pow_le_mono; [abstract (split; cbn; [omega|reflexivity])|]; - unfold getBits; rewrite N2Z.inj_succ; - apply -> Z.succ_le_mono; - rewrite <- (N2Z.id (wordToN value1)), <- log2_conv; - apply Z.log2_le_mono; - etransitivity; [eassumption|reflexivity]). + refine (@t_map1_tuple2 pred_n (@Word64.conditional_subtract _) (@ZBounds.conditional_subtract _) _). + abstract ( + repeat first [ progress unfold ZBounds.conditional_subtract + | rewrite !Tuple.lift_push_option + | progress break_match + | congruence + | progress subst + | progress inversion_option + | intro + | solve [ auto using conditional_subtract_bounded ] ] + ). Defined. - Definition lor : t -> t -> t. + Local Notation binop_correct op opW opB := + (forall x y v, op (Some x) (Some y) = Some v + -> value v = opW (value x) (value y) + /\ Some (BoundedWordToBounds v) = opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y))) + (only parsing). + + Local Notation op4_correct op opW opB := + (forall x y z w v, op (Some x) (Some y) (Some z) (Some w) = Some v + -> value v = opW (value x) (value y) (value z) (value w) + /\ Some (BoundedWordToBounds v) = opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) + (Some (BoundedWordToBounds z)) (Some (BoundedWordToBounds w))) + (only parsing). + + Local Notation op1_tuple2_correct op opW opB := + (forall x y z v, + Tuple.lift_option (op (Some x) (Tuple.push_option (Some y)) (Tuple.push_option (Some z))) = Some v + -> Tuple.map value v = opW (value x) (Tuple.map value y) (Tuple.map value z) + /\ Some (Tuple.map BoundedWordToBounds v) + = Tuple.lift_option + (opB (Some (BoundedWordToBounds x)) + (Tuple.push_option (Some (Tuple.map BoundedWordToBounds y))) + (Tuple.push_option (Some (Tuple.map BoundedWordToBounds z))))) + (only parsing). + + Lemma t_map2_correct opW opB pf + : binop_correct (t_map2 opW opB pf) opW opB. Proof. - build_binop Word64.w64lor ZBounds.lor; t_start; - unfold Word64.word64ToZ in *; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; - rewrite wordize_or. - - - transitivity (Z.max (Z.of_N (wordToN value1)) (Z.of_N (wordToN value0))); - [ abstract (destruct - (Z_ge_dec lower1 lower0) as [l|l], - (Z_ge_dec (Z.of_N (& value1)%w) (Z.of_N (& value0)%w)) as [v|v]; - [ rewrite Z.max_l, Z.max_l | rewrite Z.max_l, Z.max_r - | rewrite Z.max_r, Z.max_l | rewrite Z.max_r, Z.max_r ]; - - try (omega || assumption)) - | ]. - - rewrite <- N2Z.inj_max. - apply Z2N.inj_le; [apply N2Z.is_nonneg|apply N2Z.is_nonneg|]. - repeat rewrite N2Z.id. - - abstract ( - destruct (N.max_dec (wordToN value1) (wordToN value0)) as [v|v]; - rewrite v; - apply N.ldiff_le, N.bits_inj_iff; intros k; - rewrite N.ldiff_spec, N.lor_spec; - induction (N.testbit (wordToN value1)), (N.testbit (wordToN value0)); simpl; - reflexivity). - - - apply Z.lt_le_incl, Z.log2_lt_cancel. - rewrite Z.log2_pow2; [| abstract ( - destruct (Z.max_dec (Z.log2 upper1) (Z.log2 upper0)) as [g|g]; - rewrite g; apply Z.le_le_succ_r, Z.log2_nonneg)]. - - eapply (Z.le_lt_trans _ (Z.log2 (Z.lor _ _)) _). - - + apply Z.log2_le_mono, Z.eq_le_incl. - apply Z.bits_inj_iff'; intros k Hpos. - rewrite Z2N.inj_testbit, Z.lor_spec, N.lor_spec; [|assumption]. - repeat (rewrite <- Z2N.inj_testbit; [|assumption]). - reflexivity. - - + abstract ( - rewrite Z.log2_lor; [|trans'|trans']; - destruct - (Z_ge_dec (Z.of_N (wordToN value1)) (Z.of_N (wordToN value0))) as [g0|g0], - (Z_ge_dec upper1 upper0) as [g1|g1]; - [ rewrite Z.max_l, Z.max_l - | rewrite Z.max_l, Z.max_r - | rewrite Z.max_r, Z.max_l - | rewrite Z.max_r, Z.max_r]; - try apply Z.log2_le_mono; try omega; - apply Z.le_succ_l; - apply -> Z.succ_le_mono; - apply Z.log2_le_mono; - assumption || (etransitivity; [eassumption|]; omega)). - Defined. + intros ??? H. + unfold t_map2 in H; convoy_destruct_in H; destruct_head' ZBounds.bounds; + unfold BoundedWordToBounds in *; + inversion_option; subst; simpl. + eauto. + Qed. - Axiom proof_admitted : False. - Tactic Notation "admit" := abstract case proof_admitted. + Lemma t_map4_correct opW opB pf + : op4_correct (t_map4 opW opB pf) opW opB. + Proof. + intros ????? H. + unfold t_map4 in H; convoy_destruct_in H; destruct_head' ZBounds.bounds; + unfold BoundedWordToBounds in *; + inversion_option; subst; simpl. + eauto. + Qed. - Definition neg : t -> t -> t. + (* TODO: Automate this proof more *) + Lemma t_map1_tuple2_correct {n} opW opB pf + : op1_tuple2_correct (t_map1_tuple2 (n:=n) opW opB pf) opW opB. Proof. - build_binop Word64.w64neg ZBounds.neg; t_start; - admit. - Defined. + intros ???? H. + unfold t_map1_tuple2 in H; unfold BoundedWordToBounds in *. + rewrite !Tuple.lift_push_option in H. + convoy_destruct_in H; [ | congruence ]. + rewrite_hyp *. + inversion_option. + symmetry in H. + pose proof (f_equal (Tuple.map value) H) as H0'. + pose proof (f_equal (Tuple.map BoundedWordToBounds) H) as H1'. + unfold BoundedWordToBounds in *. + rewrite_hyp !*. + rewrite !HList.map_mapt; simpl @lower; simpl @upper; simpl @value. + rewrite <- !HList.map_is_mapt. + rewrite !Tuple.map_map2; simpl @fst; simpl @snd. + rewrite !Tuple.map2_fst, !Tuple.map2_snd, Tuple.map_id, Tuple.map_id_ext + by (intros [? ?]; reflexivity). + eauto. + Qed. - Definition cmovne : t -> t -> t -> t -> t. + Local Notation binop_correct_None op opW opB := + (forall x y, op (Some x) (Some y) = None -> opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) = None) + (only parsing). + + Local Notation op4_correct_None op opW opB := + (forall x y z w, op (Some x) (Some y) (Some z) (Some w) = None + -> opB (Some (BoundedWordToBounds x)) (Some (BoundedWordToBounds y)) + (Some (BoundedWordToBounds z)) (Some (BoundedWordToBounds w)) + = None) + (only parsing). + + Local Notation op1_tuple2_correct_None op opW opB := + (forall x y z, + Tuple.lift_option (op (Some x) (Tuple.push_option (Some y)) (Tuple.push_option (Some z))) = None + -> Tuple.lift_option + (opB (Some (BoundedWordToBounds x)) + (Tuple.push_option (Some (Tuple.map BoundedWordToBounds y))) + (Tuple.push_option (Some (Tuple.map BoundedWordToBounds z)))) + = None) + (only parsing). + + Lemma t_map2_correct_None opW opB pf + : binop_correct_None (t_map2 opW opB pf) opW opB. Proof. - build_4op Word64.w64cmovne ZBounds.cmovne; t_start; - admit. - Defined. + intros ?? H. + unfold t_map2 in H; convoy_destruct_in H; destruct_head' ZBounds.bounds; + unfold BoundedWordToBounds in *; + inversion_option; subst; simpl. + eauto. + Qed. - Definition cmovle : t -> t -> t -> t -> t. + Lemma t_map4_correct_None opW opB pf + : op4_correct_None (t_map4 opW opB pf) opW opB. Proof. - build_4op Word64.w64cmovle ZBounds.cmovle; t_start; - admit. - Defined. + intros ???? H. + unfold t_map4 in H; convoy_destruct_in H; destruct_head' ZBounds.bounds; + unfold BoundedWordToBounds in *; + inversion_option; subst; simpl. + eauto. + Qed. + + Lemma t_map1_tuple2_correct_None {n} opW opB pf + : op1_tuple2_correct_None (t_map1_tuple2 (n:=n) opW opB pf) opW opB. + Proof. + intros ??? H. + unfold t_map1_tuple2 in H; unfold BoundedWordToBounds in *. + rewrite !Tuple.lift_push_option in H. + convoy_destruct_in H; congruence. + Qed. Module Export Notations. Delimit Scope bounded_word_scope with bounded_word. @@ -638,32 +924,28 @@ Module BoundedWord64. | Neg => fun xy => neg (fst xy) (snd xy) | Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w | Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle x y z w - end%bounded_word. + | ConditionalSubtract pred_n + => fun xyz => let '(x, y, z) := eta3 xyz in + flat_interp_untuple' (T:=Tbase TZ) (@conditional_subtract pred_n x (flat_interp_tuple y) (flat_interp_tuple z)) + end%bounded_word. End BoundedWord64. -Module Relations. - Definition lift_relation {T} (R : BoundedWord64.BoundedWord -> T -> Prop) : BoundedWord64.t -> T -> Prop - := fun x y => match x with - | Some _ => True - | None => False - end -> match x with - | Some x' => R x' y - | None => True - end. - - Definition related'_Z (x : BoundedWord64.BoundedWord) (y : Z) : Prop - := Word64.word64ToZ (BoundedWord64.value x) = y. - Definition related_Z : BoundedWord64.t -> Z -> Prop := lift_relation related'_Z. - Definition related'_word64 (x : BoundedWord64.BoundedWord) (y : Word64.word64) : Prop - := BoundedWord64.value x = y. - Definition related_word64 : BoundedWord64.t -> Word64.word64 -> Prop := lift_relation related'_word64. - Definition related_bounds (x : BoundedWord64.t) (y : ZBounds.t) : Prop - := match x, y with - | Some x, Some y - => BoundedWord64.lower x = ZBounds.lower y /\ BoundedWord64.upper x = ZBounds.upper y - | Some _, _ - => False - | None, None => True - | None, _ => False - end. -End Relations. +Module ZBoundsTuple. + Definition interp_flat_type (t : flat_type base_type) + := LiftOption.interp_flat_type ZBounds.bounds t. + + Definition of_ZBounds {ty} : Syntax.interp_flat_type ZBounds.interp_base_type ty -> interp_flat_type ty + := @LiftOption.of' ZBounds.bounds ty. + Definition to_ZBounds {ty} : interp_flat_type ty -> Syntax.interp_flat_type ZBounds.interp_base_type ty + := @LiftOption.to' ZBounds.bounds ty. +End ZBoundsTuple. + +Module BoundedWord64Tuple. + Definition interp_flat_type (t : flat_type base_type) + := LiftOption.interp_flat_type BoundedWord64.BoundedWord t. + + Definition of_BoundedWord64 {ty} : Syntax.interp_flat_type BoundedWord64.interp_base_type ty -> interp_flat_type ty + := @LiftOption.of' BoundedWord64.BoundedWord ty. + Definition to_BoundedWord64 {ty} : interp_flat_type ty -> Syntax.interp_flat_type BoundedWord64.interp_base_type ty + := @LiftOption.to' BoundedWord64.BoundedWord ty. +End BoundedWord64Tuple. diff --git a/src/Reflection/Z/Interpretations/Relations.v b/src/Reflection/Z/Interpretations/Relations.v new file mode 100644 index 000000000..457d0d5ad --- /dev/null +++ b/src/Reflection/Z/Interpretations/Relations.v @@ -0,0 +1,623 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.micromega.Psatz. +Require Import Crypto.Reflection.Z.Syntax. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.Application. +Require Import Crypto.Reflection.Z.Interpretations. +Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperationsProofs. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Tactics. + +Definition proj_eq_rel {A B} (proj : A -> B) (x : A) (y : B) : Prop + := proj x = y. +Definition related'_Z {t} (x : BoundedWord64.BoundedWord) (y : Z.interp_base_type t) : Prop + := proj_eq_rel (BoundedWord64.to_Z' _) x y. +Definition related_Z t : BoundedWord64.interp_base_type t -> Z.interp_base_type t -> Prop + := LiftOption.lift_relation (@related'_Z) t. +Definition related'_word64 {t} (x : BoundedWord64.BoundedWord) (y : Word64.interp_base_type t) : Prop + := proj_eq_rel (BoundedWord64.to_word64' _) x y. +Definition related_word64 t : BoundedWord64.interp_base_type t -> Word64.interp_base_type t -> Prop + := LiftOption.lift_relation (@related'_word64) t. +Definition related_bounds t : BoundedWord64.interp_base_type t -> ZBounds.interp_base_type t -> Prop + := LiftOption.lift_relation2 (proj_eq_rel BoundedWord64.BoundedWordToBounds) t. + +Definition related_word64_Z t : Word64.interp_base_type t -> Z.interp_base_type t -> Prop + := proj_eq_rel (Word64.to_Z _). + +Definition related'_word64_bounds : Word64.word64 -> ZBounds.bounds -> Prop + := fun value b => (0 <= ZBounds.lower b /\ ZBounds.lower b <= Word64.word64ToZ value <= ZBounds.upper b /\ Z.log2 (ZBounds.upper b) < Z.of_nat Word64.bit_width)%Z. +Definition related_word64_bounds : Word64.word64 -> ZBounds.t -> Prop + := fun value b => match b with + | Some b => related'_word64_bounds value b + | None => True + end. +Definition related_word64_boundsi (t : base_type) : Word64.interp_base_type t -> ZBounds.interp_base_type t -> Prop + := match t with + | TZ => related_word64_bounds + end. +Definition related_word64_boundsi' (t : base_type) : ZBounds.bounds -> Word64.interp_base_type t -> Prop + := match t return ZBounds.bounds -> Word64.interp_base_type t -> Prop with + | TZ => fun x y => related'_word64_bounds y x + end. + +Local Notation related_op R interp_op1 interp_op2 + := (forall (src dst : flat_type base_type) (op : op src dst) + (sv1 : interp_flat_type _ src) (sv2 : interp_flat_type _ src), + interp_flat_type_rel_pointwise2 R sv1 sv2 -> + interp_flat_type_rel_pointwise2 R (interp_op1 _ _ op sv1) (interp_op2 _ _ op sv2)) + (only parsing). +Local Notation related_const R interp f g + := (forall (t : base_type) (v : interp t), R t (f t v) (g t v)) + (only parsing). + +Local Ltac related_const_t := + let v := fresh in + let t := fresh in + intros t v; destruct t; intros; simpl in *; hnf; simpl; + cbv [BoundedWord64.word64ToBoundedWord related'_Z LiftOption.of' related_Z related_word64 related'_word64 proj_eq_rel] in *; + break_innermost_match; simpl; + first [ tauto + | Z.ltb_to_lt; + pose proof (Word64.word64ToZ_log_bound v); + try omega ]. + +Lemma related_Z_const : related_const related_Z Word64.interp_base_type BoundedWord64.of_word64 Word64.to_Z. +Proof. related_const_t. Qed. +Lemma related_bounds_const : related_const related_bounds Word64.interp_base_type BoundedWord64.of_word64 ZBounds.of_word64. +Proof. related_const_t. Qed. +Lemma related_word64_const : related_const related_word64 Word64.interp_base_type BoundedWord64.of_word64 (fun _ x => x). +Proof. related_const_t. Qed. + +Local Ltac related_word64_op_t_step := + first [ exact I + | reflexivity + | progress intros + | progress inversion_option + | progress ZBounds.inversion_bounds + | progress subst + | progress destruct_head' False + | progress destruct_head' prod + | progress destruct_head' and + | progress destruct_head' option + | progress destruct_head' BoundedWord64.BoundedWord + | progress cbv [related_word64 related_bounds related_Z LiftOption.lift_relation LiftOption.lift_relation2 LiftOption.of' smart_interp_flat_map BoundedWord64.BoundedWordToBounds BoundedWord64.to_bounds'] in * + | progress simpl @fst in * + | progress simpl @snd in * + | progress simpl @BoundedWord64.upper in * + | progress simpl @BoundedWord64.lower in * + | progress break_match + | progress break_match_hyps + | congruence + | match goal with + | [ H : ?op _ _ = Some _ |- _ ] + => let H' := fresh in + rename H into H'; + first [ pose proof (@BoundedWord64.t_map2_correct _ _ _ _ _ _ H') as H; clear H' + | pose proof (@BoundedWord64.t_map4_correct _ _ _ _ _ _ H') as H; clear H' + | pose proof (@BoundedWord64.t_map1_tuple2_correct _ _ _ _ _ _ H') as H; clear H' ]; + simpl in H + | [ H : ?op _ _ = None |- _ ] + => let H' := fresh in + rename H into H'; + first [ pose proof (@BoundedWord64.t_map2_correct_None _ _ _ _ _ H') as H; clear H' + | pose proof (@BoundedWord64.t_map4_correct_None _ _ _ _ _ H') as H; clear H' + | pose proof (@BoundedWord64.t_map1_tuple2_correct_None _ _ _ _ _ H') as H; clear H' ]; + simpl in H + end + | progress cbv [related'_word64 proj_eq_rel BoundedWord64.to_word64' BoundedWord64.boundedWordToWord64 BoundedWord64.value] in * + | match goal with + | [ H : ?op None _ = Some _ |- _ ] => progress simpl in H + | [ H : ?op _ None = Some _ |- _ ] => progress simpl in H + | [ H : ?op (Some _) (Some _) = Some _ |- _ ] => progress simpl in H + | [ H : ?op (Some _) (Some _) = None |- _ ] => progress simpl in H + end ]. +Local Ltac related_word64_op_t := repeat related_word64_op_t_step. + +Lemma related_word64_t_map2 opW opB pf + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=Prod (Tbase TZ) (Tbase TZ)) related_word64 sv1 sv2 + -> @related_word64 TZ (BoundedWord64.t_map2 opW opB pf (fst sv1) (snd sv1)) (opW (fst sv2) (snd sv2)). +Proof. + cbv [interp_flat_type BoundedWord64.interp_base_type ZBounds.interp_base_type LiftOption.interp_base_type' interp_flat_type_rel_pointwise2 interp_flat_type_rel_pointwise2_gen_Prop] in *. + related_word64_op_t. +Qed. + +Lemma related_word64_t_map4 opW opB pf + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=Prod (Prod (Prod (Tbase TZ) (Tbase TZ)) (Tbase TZ)) (Tbase TZ)) related_word64 sv1 sv2 + -> @related_word64 TZ (BoundedWord64.t_map4 opW opB pf (fst (fst (fst sv1))) (snd (fst (fst sv1))) (snd (fst sv1)) (snd sv1)) + (opW (fst (fst (fst sv2))) (snd (fst (fst sv2))) (snd (fst sv2)) (snd sv2)). +Proof. + cbv [interp_flat_type BoundedWord64.interp_base_type ZBounds.interp_base_type LiftOption.interp_base_type' interp_flat_type_rel_pointwise2 interp_flat_type_rel_pointwise2_gen_Prop] in *. + related_word64_op_t. +Qed. + +Lemma related_tuples_None_left + n T interp_base_type' + (R : forall t, LiftOption.interp_base_type' T t -> interp_base_type' t -> Prop) + (RNone : forall v, R TZ None v) + (v : interp_flat_type interp_base_type' (tuple (Tbase TZ) (S n))) + : interp_flat_type_rel_pointwise2 + R + (flat_interp_untuple' (T:=Tbase TZ) (Tuple.push_option (n:=S n) None)) + v. +Proof. + induction n; simpl; intuition. +Qed. + +Lemma related_tuples_Some_left + n T interp_base_type' + (R : forall t, T -> interp_base_type' t -> Prop) + u + (v : interp_flat_type interp_base_type' (tuple (Tbase TZ) (S n))) + : interp_flat_type_rel_pointwise2 + R + (flat_interp_untuple' (T:=Tbase TZ) u) + v + <-> interp_flat_type_rel_pointwise2 + (LiftOption.lift_relation R) + (flat_interp_untuple' (T:=Tbase TZ) (Tuple.push_option (n:=S n) (Some u))) + v. +Proof. + induction n; [ reflexivity | ]. + simpl in *; rewrite <- IHn; clear IHn. + reflexivity. +Qed. + +Lemma related_tuples_Some_left_ext + {n T interp_base_type'} + {R : forall t, T -> interp_base_type' t -> Prop} + {u v u'} + (H : Tuple.lift_option (flat_interp_tuple (T:=Tbase TZ) (n:=S n) u) = Some u') + : interp_flat_type_rel_pointwise2 + R + (flat_interp_untuple' (T:=Tbase TZ) u') v + <-> interp_flat_type_rel_pointwise2 + (LiftOption.lift_relation R) + u v. +Proof. + induction n. + { simpl in *; subst; reflexivity. } + { destruct_head_hnf' prod. + simpl in H; break_match_hyps; inversion_option; inversion_prod; subst. + simpl; rewrite <- IHn by eassumption; clear IHn. + reflexivity. } +Qed. + +Lemma related_tuples_proj_eq_rel_untuple + {n T interp_base_type'} + {proj : forall t, T -> interp_base_type' t} + {u : Tuple.tuple _ (S n)} {v : Tuple.tuple _ (S n)} + : interp_flat_type_rel_pointwise2 + (fun t => proj_eq_rel (proj t)) + (flat_interp_untuple' (T:=Tbase TZ) u) + (flat_interp_untuple' (T:=Tbase TZ) v) + <-> (Tuple.map (proj _) u = v). +Proof. + induction n; [ reflexivity | ]. + destruct_head_hnf' prod. + simpl @Tuple.tuple. + rewrite !Tuple.map_S, path_prod_uncurried_iff, <- prod_iff_and; unfold fst, snd. + rewrite <- IHn. + reflexivity. +Qed. + +Lemma related_tuples_proj_eq_rel_tuple + {n T interp_base_type'} + {proj : forall t, T -> interp_base_type' t} + {u v} + : interp_flat_type_rel_pointwise2 + (fun t => proj_eq_rel (proj t)) + u v + <-> (Tuple.map (proj _) (flat_interp_tuple (n:=S n) (T:=Tbase TZ) u) + = flat_interp_tuple (T:=Tbase TZ) v). +Proof. + rewrite <- related_tuples_proj_eq_rel_untuple, !flat_interp_untuple'_tuple; reflexivity. +Qed. + +Local Arguments LiftOption.lift_relation2 _ _ _ _ !_ !_ / . +Lemma related_tuples_lift_relation2_untuple' + n T U + (R : T -> U -> Prop) + (t : option (Tuple.tuple T (S n))) + (u : option (Tuple.tuple U (S n))) + : interp_flat_type_rel_pointwise2 + (LiftOption.lift_relation2 R) + (flat_interp_untuple' (T:=Tbase TZ) (Tuple.push_option t)) + (flat_interp_untuple' (T:=Tbase TZ) (Tuple.push_option u)) + <-> LiftOption.lift_relation2 + (interp_flat_type_rel_pointwise2 (fun _ => R)) + TZ + (option_map (flat_interp_untuple' (T:=Tbase TZ)) t) + (option_map (flat_interp_untuple' (T:=Tbase TZ)) u). +Proof. + induction n. + { destruct_head' option; reflexivity. } + { specialize (IHn (option_map (@fst _ _) t) (option_map (@fst _ _) u)). + destruct_head' option; + destruct_head_hnf' prod; + simpl @option_map in *; + simpl @LiftOption.lift_relation2 in *; + try (rewrite <- IHn; reflexivity); + try (simpl @interp_flat_type_rel_pointwise2; tauto). } +Qed. + +Lemma related_tuples_lift_relation2_untuple'_ext + {n T U} + {R : T -> U -> Prop} + {t u} + (H : (exists v, Tuple.lift_option (n:=S n) (flat_interp_tuple (T:=Tbase TZ) t) = Some v) + \/ (exists v, Tuple.lift_option (n:=S n) (flat_interp_tuple (T:=Tbase TZ) u) = Some v)) + : interp_flat_type_rel_pointwise2 + (LiftOption.lift_relation2 R) + t u + <-> LiftOption.lift_relation2 + (interp_flat_type_rel_pointwise2 (fun _ => R)) + TZ + (option_map (flat_interp_untuple' (T:=Tbase TZ)) (Tuple.lift_option (flat_interp_tuple (T:=Tbase TZ) t))) + (option_map (flat_interp_untuple' (T:=Tbase TZ)) (Tuple.lift_option (flat_interp_tuple (T:=Tbase TZ) u))). +Proof. + induction n. + { destruct_head_hnf' option; reflexivity. } + { specialize (IHn (fst t) (fst u)). + lazymatch type of IHn with + | ?T -> _ => let H := fresh in assert (H : T); [ | specialize (IHn H); clear H ] + end. + { destruct_head' or; [ left | right ]; destruct_head' ex; destruct_head_hnf' prod; eexists; + (etransitivity; + [ | first [ refine (f_equal (option_map (@fst _ _)) (_ : _ = Some (_, _))); eassumption + | refine (f_equal (option_map (@snd _ _)) (_ : _ = Some (_, _))); eassumption ] ]); + simpl in *; break_match; simpl in *; congruence. } + destruct_head_hnf' prod; + destruct_head_hnf' option; + simpl @fst in *; simpl @snd in *; + (etransitivity; [ simpl @interp_flat_type_rel_pointwise2 | reflexivity ]); + try solve [ repeat first [ progress simpl in * + | tauto + | congruence + | progress destruct_head ex + | progress destruct_head or + | progress break_match ] ]. } +Qed. + +Lemma lift_option_flat_interp_tuple' + {n T x y} + : (Tuple.lift_option (n:=S n) (A:=T) (flat_interp_tuple' (interp_base_type:=LiftOption.interp_base_type' _) (T:=Tbase TZ) x) = Some y) + <-> (x = flat_interp_untuple' (T:=Tbase TZ) (n:=n) (Tuple.push_option (n:=S n) (Some y))). +Proof. + rewrite Tuple.push_lift_option; generalize (Tuple.push_option (Some y)); intro. + split; intro; subst; + rewrite ?flat_interp_tuple'_untuple', ?flat_interp_untuple'_tuple'; + reflexivity. +Qed. + +Lemma lift_option_None_interp_flat_type_rel_pointwise2_1 + T U n R x y + (H : interp_flat_type_rel_pointwise2 (LiftOption.lift_relation2 R) x y) + (HNone : Tuple.lift_option (A:=T) (n:=S n) (flat_interp_tuple' (T:=Tbase TZ) (n:=n) x) = None) + : Tuple.lift_option (A:=U) (n:=S n) (flat_interp_tuple' (T:=Tbase TZ) (n:=n) y) = None. +Proof. + induction n; [ | specialize (IHn (fst x) (fst y) (proj1 H)) ]; + repeat first [ progress destruct_head_hnf' False + | reflexivity + | progress inversion_option + | progress simpl in * + | progress subst + | progress specialize_by congruence + | progress destruct_head_hnf' prod + | progress destruct_head_hnf' and + | progress destruct_head_hnf' option + | progress break_match + | progress break_match_hyps ]. +Qed. + +Local Arguments LiftOption.lift_relation _ _ _ _ !_ _ / . +Local Arguments LiftOption.of' _ _ !_ / . +Local Arguments BoundedWord64.BoundedWordToBounds !_ / . + +Local Ltac t_map1_tuple2_t_step := + first [ exact I + | reflexivity + | progress destruct_head_hnf' False + | progress subst + | progress destruct_head_hnf' prod + | progress destruct_head_hnf' and + | progress destruct_head_hnf' option + | progress inversion_option + | intro + | apply @related_tuples_None_left; constructor + | apply -> @related_tuples_Some_left + | apply <- @related_tuples_proj_eq_rel_untuple + | apply <- @related_tuples_lift_relation2_untuple' + | match goal with + | [ H : appcontext[LiftOption.lift_relation] |- _ ] + => eapply related_tuples_Some_left_ext in H; [ | eassumption ] + | [ H : appcontext[proj_eq_rel] |- _ ] + => apply -> @related_tuples_proj_eq_rel_tuple in H + | [ H : appcontext[LiftOption.lift_relation2] |- _ ] + => eapply (fun H => proj1 (related_tuples_lift_relation2_untuple'_ext H)) in H; + [ | first [ left; eexists; eassumption | right; eexists; eassumption ] ] + | [ H : Tuple.lift_option ?x = Some _, H' : context[?x] |- _ ] + => setoid_rewrite H in H' + | [ H : proj_eq_rel _ _ _ |- _ ] => hnf in H + | [ H : Tuple.lift_option (flat_interp_tuple' ?x) = Some _ |- _ ] + => is_var x; apply lift_option_flat_interp_tuple' in H + end + | progress rewrite ?HList.map'_mapt', <- ?HList.map_is_mapt' + | progress rewrite ?Tuple.map_map2, ?Tuple.map2_fst, ?Tuple.map2_snd, ?Tuple.map_id + | progress rewrite Tuple.map_id_ext by repeat (reflexivity || intros [] || intro) + | progress rewrite ?flat_interp_tuple_untuple', ?flat_interp_tuple'_untuple' in * + | progress unfold BoundedWord64.t_map1_tuple2, HList.mapt + | progress unfold related_word64, related'_word64, related_bounds in * + | progress simpl @BoundedWord64.to_word64' in * + | progress simpl @fst in * + | progress simpl @snd in * + | progress simpl @option_map in * + | progress simpl @BoundedWord64.BoundedWordToBounds in * + | progress break_match + | progress convoy_destruct + | progress simpl @interp_flat_type_rel_pointwise2 in * + | progress simpl @LiftOption.lift_relation in * + | progress simpl @LiftOption.lift_relation2 in * + | progress simpl @flat_interp_tuple in * + | progress simpl @LiftOption.of' in * + | progress simpl @smart_interp_flat_map in * + | rewrite_hyp <- !*; reflexivity + | solve [ eauto using lift_option_None_interp_flat_type_rel_pointwise2_1 ] + | match goal with + | [ H : LiftOption.lift_relation2 _ _ _ _ |- _ ] => unfold LiftOption.lift_relation2 in H + | [ H : LiftOption.of' _ = _ |- _ ] => unfold LiftOption.of' in H + | [ H : option_map _ _ = _ |- _ ] => unfold option_map in H + end ]. +Local Ltac t_map1_tuple2_t := repeat t_map1_tuple2_t_step. + +Lemma related_word64_t_map1_tuple2 {n} opW opB pf + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=Prod (Prod (Tbase TZ) (Syntax.tuple (Tbase TZ) (S n))) (Syntax.tuple (Tbase TZ) (S n))) related_word64 sv1 sv2 + -> interp_flat_type_rel_pointwise2 + (t:=Syntax.tuple (Tbase TZ) (S n)) related_word64 + (Syntax.flat_interp_untuple' (n:=n) (T:=Tbase TZ) (BoundedWord64.t_map1_tuple2 (n:=n) opW opB pf (fst (fst sv1)) (Syntax.flat_interp_tuple (snd (fst sv1))) (Syntax.flat_interp_tuple (snd sv1)))) + (Syntax.flat_interp_untuple' (n:=n) (T:=Tbase TZ) (opW (fst (fst sv2)) (Syntax.flat_interp_tuple (snd (fst sv2))) (Syntax.flat_interp_tuple (snd sv2)))). +Proof. t_map1_tuple2_t. Qed. + +Lemma related_word64_op : related_op related_word64 (@BoundedWord64.interp_op) (@Word64.interp_op). +Proof. + (let op := fresh in intros ?? op; destruct op; simpl); + try first [ apply related_word64_t_map2 + | apply related_word64_t_map4 + | apply related_word64_t_map1_tuple2 ]. +Qed. + +Lemma related_bounds_t_map2 opW opB pf + (HN0 : forall v, opB None v = None) + (HN1 : forall v, opB v None = None) + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=Prod (Tbase TZ) (Tbase TZ)) related_bounds sv1 sv2 + -> @related_bounds TZ (BoundedWord64.t_map2 opW opB pf (fst sv1) (snd sv1)) (opB (fst sv2) (snd sv2)). +Proof. + cbv [interp_flat_type BoundedWord64.interp_base_type ZBounds.interp_base_type LiftOption.interp_base_type' interp_flat_type_rel_pointwise2 interp_flat_type_rel_pointwise2_gen_Prop] in *. + related_word64_op_t. +Qed. + +Lemma related_bounds_t_map4 opW opB pf + (HN0 : forall x y z, opB None x y z = None) + (HN1 : forall x y z, opB x None y z = None) + (HN2 : forall x y z, opB x y None z = None) + (HN3 : forall x y z, opB x y z None = None) + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=Prod (Prod (Prod (Tbase TZ) (Tbase TZ)) (Tbase TZ)) (Tbase TZ)) related_bounds sv1 sv2 + -> @related_bounds TZ (BoundedWord64.t_map4 opW opB pf (fst (fst (fst sv1))) (snd (fst (fst sv1))) (snd (fst sv1)) (snd sv1)) + (opB (fst (fst (fst sv2))) (snd (fst (fst sv2))) (snd (fst sv2)) (snd sv2)). +Proof. + cbv [interp_flat_type BoundedWord64.interp_base_type ZBounds.interp_base_type LiftOption.interp_base_type' interp_flat_type_rel_pointwise2 interp_flat_type_rel_pointwise2_gen_Prop] in *. + destruct_head prod. + intros; destruct_head' prod. + progress cbv [related_word64 related_bounds related_Z LiftOption.lift_relation LiftOption.lift_relation2 LiftOption.of' smart_interp_flat_map BoundedWord64.BoundedWordToBounds BoundedWord64.to_bounds' proj_eq_rel] in *. + destruct_head' option; destruct_head_hnf' and; destruct_head_hnf' False; subst; + try solve [ simpl; rewrite ?HN0, ?HN1, ?HN2, ?HN3; tauto ]; + []. + related_word64_op_t. +Qed. + +Local Arguments Tuple.lift_option : simpl never. +Local Arguments Tuple.push_option : simpl never. +Local Arguments Tuple.map : simpl never. +Local Arguments Tuple.map2 : simpl never. + +Lemma related_bounds_t_map1_tuple2 {n} opW opB pf + (HN0 : forall x y, opB None x y = Tuple.push_option None) + (HN1 : forall x y z, Tuple.lift_option y = None -> opB x y z = Tuple.push_option None) + (HN2 : forall x y z, Tuple.lift_option z = None -> opB x y z = Tuple.push_option None) + (HN3 : forall x y z, Tuple.lift_option (opB x y z) = None -> opB x y z = Tuple.push_option None) + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=Prod (Prod (Tbase TZ) (Syntax.tuple (Tbase TZ) (S n))) (Syntax.tuple (Tbase TZ) (S n))) related_bounds sv1 sv2 + -> interp_flat_type_rel_pointwise2 + (t:=Syntax.tuple (Tbase TZ) (S n)) related_bounds + (Syntax.flat_interp_untuple' (n:=n) (T:=Tbase TZ) (BoundedWord64.t_map1_tuple2 (n:=n) opW opB pf (fst (fst sv1)) (Syntax.flat_interp_tuple (snd (fst sv1))) (Syntax.flat_interp_tuple (snd sv1)))) + (Syntax.flat_interp_untuple' (n:=n) (T:=Tbase TZ) (opB (fst (fst sv2)) (Syntax.flat_interp_tuple (snd (fst sv2))) (Syntax.flat_interp_tuple (snd sv2)))). +Proof. + t_map1_tuple2_t; + try first [ rewrite HN0 by (assumption || t_map1_tuple2_t) + | rewrite HN1 by (assumption || t_map1_tuple2_t) + | rewrite HN2 by (assumption || t_map1_tuple2_t) + | rewrite HN3 by (assumption || t_map1_tuple2_t) ]; + t_map1_tuple2_t. + { repeat match goal with + | [ |- context[HList.mapt' _ ?ls] ] + => not is_var ls; generalize ls; intro + | [ H : Tuple.lift_option _ = Some _ |- _ ] + => apply Tuple.push_lift_option in H; setoid_rewrite H + | _ => progress (break_match_hyps; t_map1_tuple2_t) + end. } + { repeat (break_match_hyps; t_map1_tuple2_t). + rewrite HN3 by (assumption || t_map1_tuple2_t). + t_map1_tuple2_t. } +Qed. + +Local Arguments ZBounds.SmartBuildBounds _ _ / . +Lemma related_bounds_op : related_op related_bounds (@BoundedWord64.interp_op) (@ZBounds.interp_op). +Proof. + let op := fresh in intros ?? op; destruct op; simpl. + { apply related_bounds_t_map2; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map2; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map2; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map2; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map2; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map2; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map2; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map2; intros; destruct_head' option; destruct_head' ZBounds.bounds; reflexivity. } + { apply related_bounds_t_map4; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map4; intros; destruct_head' option; reflexivity. } + { apply related_bounds_t_map1_tuple2; intros; destruct_head' option; try reflexivity; + unfold ZBounds.conditional_subtract in *; rewrite ?Tuple.lift_push_option in *; + repeat match goal with H : _ |- _ => rewrite !Tuple.lift_push_option in H end; + try reflexivity; + (rewrite_hyp ?* ); + break_match; try reflexivity. } +Qed. + +Local Ltac Word64.Rewrites.word64_util_arith ::= + solve [ autorewrite with Zshift_to_pow; omega + | autorewrite with Zshift_to_pow; nia + | autorewrite with Zshift_to_pow; auto with zarith + | eapply Z.le_lt_trans; [ eapply Z.log2_le_mono | eassumption ]; + autorewrite with Zshift_to_pow; auto using Z.mul_le_mono_nonneg with zarith; + solve [ omega + | nia + | etransitivity; [ eapply Z.div_le_mono | eapply Z.div_le_compat_l ]; + auto with zarith ] + | apply Z.land_nonneg; Word64.Rewrites.word64_util_arith + | eapply Z.le_lt_trans; [ eapply Z.log2_le_mono | eassumption ]; + apply Z.min_case_strong; intros; + first [ etransitivity; [ apply Z.land_upper_bound_l | ]; omega + | etransitivity; [ apply Z.land_upper_bound_r | ]; omega ] + | rewrite Z.log2_lor by omega; + apply Z.max_case_strong; intro; + (eapply Z.le_lt_trans; [ eapply Z.log2_le_mono; eassumption | assumption ]) + | eapply Z.le_lt_trans; [ eapply Z.log2_le_mono, neg_upperbound | ]; + Word64.Rewrites.word64_util_arith + | (progress unfold ModularBaseSystemListZOperations.cmovne, ModularBaseSystemListZOperations.cmovl); break_match; + Word64.Rewrites.word64_util_arith ]. +Local Ltac related_Z_op_t_step := + first [ progress related_word64_op_t_step + | progress cbv [related'_Z proj_eq_rel BoundedWord64.to_Z' BoundedWord64.to_word64' Word64.to_Z BoundedWord64.boundedWordToWord64 BoundedWord64.value] in * + | autorewrite with push_word64ToZ ]. +Local Ltac related_Z_op_t := repeat related_Z_op_t_step. + +Local Notation is_bounded_by value lower upper + := ((0 <= lower /\ lower <= Word64.word64ToZ value <= upper /\ Z.log2 upper < Z.of_nat Word64.bit_width)%Z) + (only parsing). +Local Notation is_in_bounds value bounds + := (is_bounded_by value (ZBounds.lower bounds) (ZBounds.upper bounds)) + (only parsing). + +Lemma related_Z_t_map2 opZ opW opB pf + (H : forall x y bxs bys brs, + Some brs = opB (Some bxs) (Some bys) + -> is_in_bounds x bxs + -> is_in_bounds y bys + -> is_in_bounds (opW x y) brs + -> Word64.word64ToZ (opW x y) = (opZ (Word64.word64ToZ x) (Word64.word64ToZ y))) + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=Prod (Tbase TZ) (Tbase TZ)) related_Z sv1 sv2 + -> @related_Z TZ (BoundedWord64.t_map2 opW opB pf (fst sv1) (snd sv1)) (opZ (fst sv2) (snd sv2)). +Proof. + cbv [interp_flat_type BoundedWord64.interp_base_type ZBounds.interp_base_type LiftOption.interp_base_type' interp_flat_type_rel_pointwise2 interp_flat_type_rel_pointwise2_gen_Prop] in *. + related_Z_op_t. + eapply H; eauto. +Qed. + +Lemma related_Z_t_map4 opZ opW opB pf + (H : forall x y z w bxs bys bzs bws brs, + Some brs = opB (Some bxs) (Some bys) (Some bzs) (Some bws) + -> is_in_bounds x bxs + -> is_in_bounds y bys + -> is_in_bounds z bzs + -> is_in_bounds w bws + -> is_in_bounds (opW x y z w) brs + -> Word64.word64ToZ (opW x y z w) = (opZ (Word64.word64ToZ x) (Word64.word64ToZ y) (Word64.word64ToZ z) (Word64.word64ToZ w))) + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=(Tbase TZ * Tbase TZ * Tbase TZ * Tbase TZ)%ctype) related_Z sv1 sv2 + -> @related_Z TZ (BoundedWord64.t_map4 opW opB pf (fst (fst (fst sv1))) (snd (fst (fst sv1))) (snd (fst sv1)) (snd sv1)) + (opZ (fst (fst (fst sv2))) (snd (fst (fst sv2))) (snd (fst sv2)) (snd sv2)). +Proof. + cbv [interp_flat_type BoundedWord64.interp_base_type ZBounds.interp_base_type LiftOption.interp_base_type' interp_flat_type_rel_pointwise2 interp_flat_type_rel_pointwise2_gen_Prop] in *. + related_Z_op_t. + eapply H; eauto. +Qed. + +Local Arguments related_Z _ !_ _ / . +Axiom proof_admitted : False. +Tactic Notation "admit" := abstract case proof_admitted. + +Local Arguments related'_Z _ _ _ / . +Lemma related_Z_t_map1_tuple2 n opZ opW opB pf + (H : forall x y z bxs bys bzs brs, + Tuple.push_option (Some brs) = opB (Some bxs) (Tuple.push_option (Some bys)) (Tuple.push_option (Some bzs)) + -> is_in_bounds x bxs + (*-> is_in_bounds y bys + -> is_in_bounds z bzs + -> is_in_bounds (opW x y z) brs*) + -> Tuple.map Word64.word64ToZ (opW x y z) = (opZ (Word64.word64ToZ x) (Tuple.map Word64.word64ToZ y) (Tuple.map Word64.word64ToZ z))) + sv1 sv2 + : interp_flat_type_rel_pointwise2 (t:=(Tbase TZ * Syntax.tuple (Tbase TZ) (S n) * Syntax.tuple (Tbase TZ) (S n))%ctype) related_Z sv1 sv2 + -> interp_flat_type_rel_pointwise2 + related_Z + (flat_interp_untuple' (T:=Tbase TZ) (BoundedWord64.t_map1_tuple2 opW opB pf (fst (fst sv1)) (flat_interp_tuple' (snd (fst sv1))) (flat_interp_tuple' (snd sv1)))) + (flat_interp_untuple' (T:=Tbase TZ) (opZ (fst (fst sv2)) (flat_interp_tuple' (snd (fst sv2))) (flat_interp_tuple' (snd sv2)))). +Proof. + repeat first [ progress simpl in * + | progress intros + | progress destruct_head_hnf' and + | progress destruct_head_hnf' prod + | progress destruct_head_hnf' option + | progress (unfold proj_eq_rel in *; subst) + | apply @related_tuples_None_left; constructor + | apply -> @related_tuples_Some_left + | apply <- @related_tuples_proj_eq_rel_untuple + | apply <- @related_tuples_lift_relation2_untuple' ]. + unfold related_Z. + admit. +Qed. + +Local Ltac related_Z_op_fin_t_step := + first [ progress subst + | progress destruct_head' ZBounds.bounds + | progress destruct_head' and + | progress ZBounds.inversion_bounds + | progress simpl in * |- + | progress break_match_hyps + | congruence + | progress inversion_option + | intro + | progress autorewrite with push_word64ToZ + | match goal with H : andb _ _ = true |- _ => rewrite Bool.andb_true_iff in H end + | progress Z.ltb_to_lt ]. +Local Ltac related_Z_op_fin_t := repeat related_Z_op_fin_t_step. + +Local Opaque Word64.bit_width. + +Lemma related_Z_op : related_op related_Z (@BoundedWord64.interp_op) (@Z.interp_op). +Proof. + let op := fresh in intros ?? op; destruct op; simpl. + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map2; related_Z_op_fin_t. } + { apply related_Z_t_map4; related_Z_op_fin_t. } + { apply related_Z_t_map4; related_Z_op_fin_t. } + { apply related_Z_t_map1_tuple2; related_Z_op_fin_t; + rewrite Word64.word64ToZ_conditional_subtract; try Word64.Rewrites.word64_util_arith. + pose proof BoundedWord64.conditional_subtract_bounded. + admit. (** TODO(jadep or jgross): Fill me in *) } +Qed. + +Create HintDb interp_related discriminated. +Hint Resolve related_Z_op related_bounds_op related_word64_op related_Z_const related_bounds_const related_word64_const : interp_related. diff --git a/src/Reflection/Z/Interpretations/RelationsCombinations.v b/src/Reflection/Z/Interpretations/RelationsCombinations.v new file mode 100644 index 000000000..e8ba22e00 --- /dev/null +++ b/src/Reflection/Z/Interpretations/RelationsCombinations.v @@ -0,0 +1,358 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Reflection.Z.Syntax. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.Application. +Require Import Crypto.Reflection.Z.Interpretations. +Require Import Crypto.Reflection.Z.Interpretations.Relations. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Tactics. + +Module Relations. + Section lift. + Context {interp_base_type1 interp_base_type2 : base_type -> Type} + (R : forall t, interp_base_type1 t -> interp_base_type2 t -> Prop). + + Definition interp_type_rel_pointwise2_uncurried + {t : type base_type} + := match t return interp_type interp_base_type1 t -> interp_type interp_base_type2 t -> _ with + | Tflat T => fun f g => interp_flat_type_rel_pointwise2 (t:=T) R f g + | Arrow A B + => fun f g + => forall x y, interp_flat_type_rel_pointwise2 R x y + -> interp_flat_type_rel_pointwise2 R (ApplyInterpedAll f x) (ApplyInterpedAll g y) + end. + + Lemma uncurry_interp_type_rel_pointwise2 + {t f g} + : interp_type_rel_pointwise2 (t:=t) R f g + <-> interp_type_rel_pointwise2_uncurried (t:=t) f g. + Proof. + unfold interp_type_rel_pointwise2_uncurried. + induction t as [|A B IHt]; [ reflexivity | ]. + { simpl; unfold Morphisms.respectful_hetero in *; destruct B. + { reflexivity. } + { setoid_rewrite IHt; clear IHt. + split; intro H; intros. + { simpl in *; intuition. } + { eapply (H (_, _) (_, _)); simpl in *; intuition. } } } + Qed. + End lift. + + Section proj. + Context {interp_base_type1 interp_base_type2} + (proj : forall t : base_type, interp_base_type1 t -> interp_base_type2 t). + + Let R {t : flat_type base_type} f g := + SmartVarfMap (t:=t) proj f = g. + + Definition interp_type_rel_pointwise2_uncurried_proj + {t : type base_type} + : interp_type interp_base_type1 t -> interp_type interp_base_type2 t -> Prop + := match t return interp_type interp_base_type1 t -> interp_type interp_base_type2 t -> Prop with + | Tflat T => @R _ + | Arrow A B + => fun f g + => forall x : interp_flat_type interp_base_type1 (all_binders_for (Arrow A B)), + let y := SmartVarfMap proj x in + let fx := ApplyInterpedAll f x in + let gy := ApplyInterpedAll g y in + @R _ fx gy + end. + + Lemma uncurry_interp_type_rel_pointwise2_proj + {t : type base_type} + {f : interp_type interp_base_type1 t} + {g} + : interp_type_rel_pointwise2 (t:=t) (fun t => @R _) f g + -> interp_type_rel_pointwise2_uncurried_proj (t:=t) f g. + Proof. + unfold interp_type_rel_pointwise2_uncurried_proj. + induction t as [t|A B IHt]; simpl; unfold Morphisms.respectful_hetero in *. + { induction t as [t|A IHA B IHB]; simpl; [ solve [ trivial | reflexivity ] | ]. + intros [HA HB]. + specialize (IHA _ _ HA); specialize (IHB _ _ HB). + unfold R in *. + repeat first [ progress destruct_head_hnf' prod + | progress simpl in * + | progress subst + | reflexivity ]. } + { destruct B; intros H ?; apply IHt, H; clear IHt; + repeat first [ reflexivity + | progress simpl in * + | progress unfold R, LiftOption.of' in * + | progress break_match ]. } + Qed. + End proj. + + Section proj_option. + Context {interp_base_type1 : Type} {interp_base_type2 : base_type -> Type} + (proj_option : forall t, interp_base_type1 -> interp_base_type2 t). + + Let R {t : flat_type base_type} f g := + let f' := LiftOption.of' (t:=t) f in + match f' with + | Some f' => SmartVarfMap proj_option f' = g + | None => True + end. + + Definition interp_type_rel_pointwise2_uncurried_proj_option + {t : type base_type} + : interp_type (LiftOption.interp_base_type' interp_base_type1) t -> interp_type interp_base_type2 t -> Prop + := match t return interp_type (LiftOption.interp_base_type' interp_base_type1) t -> interp_type interp_base_type2 t -> Prop with + | Tflat T => @R _ + | Arrow A B + => fun f g + => forall x : interp_flat_type (fun _ => interp_base_type1) (all_binders_for (Arrow A B)), + let y := SmartVarfMap proj_option x in + let fx := ApplyInterpedAll f (LiftOption.to' (Some x)) in + let gy := ApplyInterpedAll g y in + @R _ fx gy + end. + + Lemma uncurry_interp_type_rel_pointwise2_proj_option + {t : type base_type} + {f : interp_type (LiftOption.interp_base_type' interp_base_type1) t} + {g} + : interp_type_rel_pointwise2 (t:=t) (fun t => @R _) f g + -> interp_type_rel_pointwise2_uncurried_proj_option (t:=t) f g. + Proof. + unfold interp_type_rel_pointwise2_uncurried_proj_option. + induction t as [t|A B IHt]; simpl; unfold Morphisms.respectful_hetero in *. + { induction t as [t|A IHA B IHB]; simpl; [ solve [ trivial | reflexivity ] | ]. + intros [HA HB]. + specialize (IHA _ _ HA); specialize (IHB _ _ HB). + unfold R in *. + repeat first [ progress destruct_head_hnf' prod + | progress simpl in * + | progress unfold LiftOption.of' in * + | progress break_match + | progress break_match_hyps + | progress inversion_prod + | progress inversion_option + | reflexivity + | progress intuition subst ]. } + { destruct B; intros H ?; apply IHt, H; clear IHt. + { repeat first [ progress simpl in * + | progress unfold R, LiftOption.of' in * + | progress break_match + | reflexivity ]. } + { simpl in *; break_match; reflexivity. } } + Qed. + End proj_option. + + Section proj_option2. + Context {interp_base_type1 : Type} {interp_base_type2 : Type} + (proj : interp_base_type1 -> interp_base_type2). + + Let R {t : flat_type base_type} f g := + let f' := LiftOption.of' (t:=t) f in + let g' := LiftOption.of' (t:=t) g in + match f', g' with + | Some f', Some g' => SmartVarfMap (fun _ => proj) f' = g' + | None, None => True + | Some _, _ => False + | None, _ => False + end. + + Definition interp_type_rel_pointwise2_uncurried_proj_option2 + {t : type base_type} + : interp_type (LiftOption.interp_base_type' interp_base_type1) t -> interp_type (LiftOption.interp_base_type' interp_base_type2) t -> Prop + := match t return interp_type (LiftOption.interp_base_type' interp_base_type1) t -> interp_type (LiftOption.interp_base_type' interp_base_type2) t -> Prop with + | Tflat T => @R _ + | Arrow A B + => fun f g + => forall x : interp_flat_type (fun _ => interp_base_type1) (all_binders_for (Arrow A B)), + let y := SmartVarfMap (fun _ => proj) x in + let fx := ApplyInterpedAll f (LiftOption.to' (Some x)) in + let gy := ApplyInterpedAll g (LiftOption.to' (Some y)) in + @R _ fx gy + end. + + Lemma uncurry_interp_type_rel_pointwise2_proj_option2 + {t : type base_type} + {f : interp_type (LiftOption.interp_base_type' interp_base_type1) t} + {g : interp_type (LiftOption.interp_base_type' interp_base_type2) t} + : interp_type_rel_pointwise2 (t:=t) (fun t => @R _) f g + -> interp_type_rel_pointwise2_uncurried_proj_option2 (t:=t) f g. + Proof. + unfold interp_type_rel_pointwise2_uncurried_proj_option2. + induction t as [t|A B IHt]; simpl; unfold Morphisms.respectful_hetero in *. + { induction t as [t|A IHA B IHB]; simpl; [ solve [ trivial | reflexivity ] | ]. + intros [HA HB]. + specialize (IHA _ _ HA); specialize (IHB _ _ HB). + unfold R in *. + repeat first [ progress destruct_head_hnf' prod + | progress simpl in * + | progress unfold LiftOption.of' in * + | progress break_match + | progress break_match_hyps + | progress inversion_prod + | progress inversion_option + | reflexivity + | progress intuition subst ]. } + { destruct B; intros H ?; apply IHt, H; clear IHt. + { repeat first [ progress simpl in * + | progress unfold R, LiftOption.of' in * + | progress break_match + | reflexivity ]. } + { simpl in *; break_match; reflexivity. } } + Qed. + End proj_option2. + + Section proj_from_option2. + Context {interp_base_type0 : Type} {interp_base_type1 interp_base_type2 : base_type -> Type} + (proj01 : forall t, interp_base_type0 -> interp_base_type1 t) + (proj02 : forall t, interp_base_type0 -> interp_base_type2 t) + (proj : forall t, interp_base_type1 t -> interp_base_type2 t). + + Let R {t : flat_type base_type} f g := + proj_eq_rel (SmartVarfMap proj (t:=t)) f g. + + Definition interp_type_rel_pointwise2_uncurried_proj_from_option2 + {t : type base_type} + : interp_type (LiftOption.interp_base_type' interp_base_type0) t -> interp_type interp_base_type1 t -> interp_type interp_base_type2 t -> Prop + := match t return interp_type _ t -> interp_type _ t -> interp_type _ t -> Prop with + | Tflat T => fun f0 f g => match LiftOption.of' f0 with + | Some _ => True + | None => False + end -> @R _ f g + | Arrow A B + => fun f0 f g + => forall x : interp_flat_type (fun _ => interp_base_type0) (all_binders_for (Arrow A B)), + let x' := SmartVarfMap proj01 x in + let y' := SmartVarfMap proj x' in + let fx := ApplyInterpedAll f x' in + let gy := ApplyInterpedAll g y' in + let f0x := LiftOption.of' (ApplyInterpedAll f0 (LiftOption.to' (Some x))) in + match f0x with + | Some _ => True + | None => False + end + -> @R _ fx gy + end. + + Lemma uncurry_interp_type_rel_pointwise2_proj_from_option2 + {t : type base_type} + {f0} + {f : interp_type interp_base_type1 t} + {g : interp_type interp_base_type2 t} + (proj012 : forall t x, proj t (proj01 t x) = proj02 t x) + : interp_type_rel_pointwise2 (t:=t) (LiftOption.lift_relation (fun t => proj_eq_rel (proj01 t))) f0 f + -> interp_type_rel_pointwise2 (t:=t) (LiftOption.lift_relation (fun t => proj_eq_rel (proj02 t))) f0 g + -> interp_type_rel_pointwise2_uncurried_proj_from_option2 (t:=t) f0 f g. + Proof. + unfold interp_type_rel_pointwise2_uncurried_proj_from_option2. + induction t as [t|A B IHt]; simpl; unfold Morphisms.respectful_hetero in *. + { induction t as [t|A IHA B IHB]; simpl. + { cbv [LiftOption.lift_relation proj_eq_rel R]. + break_match; try tauto; intros; subst. + apply proj012. } + { intros [HA HB] [HA' HB'] Hrel. + specialize (IHA _ _ _ HA HA'); specialize (IHB _ _ _ HB HB'). + unfold R, proj_eq_rel in *. + repeat first [ progress destruct_head_hnf' prod + | progress simpl in * + | progress unfold LiftOption.of' in * + | progress break_match + | progress break_match_hyps + | progress inversion_prod + | progress inversion_option + | reflexivity + | progress intuition subst ]. } } + { destruct B; intros H0 H1 ?; apply IHt; clear IHt; first [ apply H0 | apply H1 ]; + repeat first [ progress simpl in * + | progress unfold R, LiftOption.of', proj_eq_rel, LiftOption.lift_relation in * + | break_match; rewrite <- ?proj012; reflexivity ]. } + Qed. + End proj_from_option2. + Global Arguments uncurry_interp_type_rel_pointwise2_proj_from_option2 {_ _ _ _ _} proj {t f0 f g} _ _ _. + + Section proj1_from_option2. + Context {interp_base_type0 interp_base_type1 : Type} {interp_base_type2 : base_type -> Type} + (proj01 : interp_base_type0 -> interp_base_type1) + (proj02 : forall t, interp_base_type0 -> interp_base_type2 t) + (R : forall t, interp_base_type1 -> interp_base_type2 t -> Prop). + + Definition interp_type_rel_pointwise2_uncurried_proj1_from_option2 + {t : type base_type} + : interp_type (LiftOption.interp_base_type' interp_base_type0) t -> interp_type (LiftOption.interp_base_type' interp_base_type1) t -> interp_type interp_base_type2 t -> Prop + := match t return interp_type _ t -> interp_type _ t -> interp_type _ t -> Prop with + | Tflat T => fun f0 f g => match LiftOption.of' f0 with + | Some _ => True + | None => False + end -> match LiftOption.of' f with + | Some f' => interp_flat_type_rel_pointwise2 (@R) f' g + | None => True + end + | Arrow A B + => fun f0 f g + => forall x : interp_flat_type (fun _ => interp_base_type0) (all_binders_for (Arrow A B)), + let x' := SmartVarfMap (fun _ => proj01) x in + let y' := SmartVarfMap proj02 x in + let fx := LiftOption.of' (ApplyInterpedAll f (LiftOption.to' (Some x'))) in + let gy := ApplyInterpedAll g y' in + let f0x := LiftOption.of' (ApplyInterpedAll f0 (LiftOption.to' (Some x))) in + match f0x with + | Some _ => True + | None => False + end + -> match fx with + | Some fx' => interp_flat_type_rel_pointwise2 (@R) fx' gy + | None => True + end + end. + + Lemma uncurry_interp_type_rel_pointwise2_proj1_from_option2 + {t : type base_type} + {f0} + {f : interp_type (LiftOption.interp_base_type' interp_base_type1) t} + {g : interp_type interp_base_type2 t} + (proj012R : forall t x, @R _ (proj01 x) (proj02 t x)) + : interp_type_rel_pointwise2 (t:=t) (LiftOption.lift_relation2 (proj_eq_rel proj01)) f0 f + -> interp_type_rel_pointwise2 (t:=t) (LiftOption.lift_relation (fun t => proj_eq_rel (proj02 t))) f0 g + -> interp_type_rel_pointwise2_uncurried_proj1_from_option2 (t:=t) f0 f g. + Proof. + unfold interp_type_rel_pointwise2_uncurried_proj1_from_option2. + induction t as [t|A B IHt]; simpl; unfold Morphisms.respectful_hetero in *. + { induction t as [t|A IHA B IHB]; simpl. + { cbv [LiftOption.lift_relation proj_eq_rel LiftOption.lift_relation2]. + break_match; try tauto; intros; subst. + apply proj012R. } + { intros [HA HB] [HA' HB'] Hrel. + specialize (IHA _ _ _ HA HA'); specialize (IHB _ _ _ HB HB'). + unfold proj_eq_rel in *. + repeat first [ progress destruct_head_hnf' prod + | progress simpl in * + | progress unfold LiftOption.of' in * + | progress break_match + | progress break_match_hyps + | progress inversion_prod + | progress inversion_option + | reflexivity + | progress intuition subst ]. } } + { destruct B; intros H0 H1 ?; apply IHt; clear IHt; first [ apply H0 | apply H1 ]; + repeat first [ progress simpl in * + | progress unfold R, LiftOption.of', proj_eq_rel, LiftOption.lift_relation in * + | break_match; reflexivity ]. } + Qed. + End proj1_from_option2. + Global Arguments uncurry_interp_type_rel_pointwise2_proj1_from_option2 {_ _ _ _ _} R {t f0 f g} _ _ _. + + Section combine_related. + Lemma related_flat_transitivity {interp_base_type1 interp_base_type2 interp_base_type3} + {R1 : forall t : base_type, interp_base_type1 t -> interp_base_type2 t -> Prop} + {R2 : forall t : base_type, interp_base_type1 t -> interp_base_type3 t -> Prop} + {R3 : forall t : base_type, interp_base_type2 t -> interp_base_type3 t -> Prop} + {t x y z} + : (forall t a b c, (R1 t a b : Prop) -> (R2 t a c : Prop) -> (R3 t b c : Prop)) + -> interp_flat_type_rel_pointwise2 (t:=t) R1 x y + -> interp_flat_type_rel_pointwise2 (t:=t) R2 x z + -> interp_flat_type_rel_pointwise2 (t:=t) R3 y z. + Proof. + intro HRel; induction t; simpl; intuition eauto. + Qed. + End combine_related. +End Relations. diff --git a/src/Reflection/Z/Reify.v b/src/Reflection/Z/Reify.v index 44689a2c3..6734d7d01 100644 --- a/src/Reflection/Z/Reify.v +++ b/src/Reflection/Z/Reify.v @@ -1,11 +1,15 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. +Require Import Crypto.Reflection.InputSyntax. Require Import Crypto.Reflection.Z.Syntax. +Require Import Crypto.Reflection.WfReflective. Require Import Crypto.Reflection.Reify. Require Import Crypto.Reflection.Inline. +Require Import Crypto.Reflection.InlineInterp. Require Import Crypto.Reflection.Linearize. +Require Import Crypto.Reflection.LinearizeInterp. -Ltac base_reify_op op op_head ::= +Ltac base_reify_op op op_head extra ::= lazymatch op_head with | @Z.add => constr:(reify_op op op_head 2 Add) | @Z.mul => constr:(reify_op op op_head 2 Mul) @@ -17,12 +21,35 @@ Ltac base_reify_op op op_head ::= | @ModularBaseSystemListZOperations.neg => constr:(reify_op op op_head 2 Neg) | @ModularBaseSystemListZOperations.cmovne => constr:(reify_op op op_head 4 Cmovne) | @ModularBaseSystemListZOperations.cmovl => constr:(reify_op op op_head 4 Cmovle) + | @ModularBaseSystemListZOperations.conditional_subtract_modulus + => lazymatch extra with + | @ModularBaseSystemListZOperations.conditional_subtract_modulus ?limb_count _ _ _ + => lazymatch (eval compute in limb_count) with + | 0 => fail 1 "Cannot handle empty limb-list in reifying conditional_subtract_modulus" + | S ?pred_limb_count => constr:(reify_op op op_head 3 (ConditionalSubtract pred_limb_count)) + | ?climb_count => fail 1 "Cannot handle non-ground length of limb-list in reifying conditional_subtract_modulus" "(" limb_count "which computes to" climb_count ")" + end + | _ => fail 100 "Anomaly: In Reflection.Z.base_reify_op: head is conditional_subtract_modulus but body is wrong:" extra + end end. Ltac base_reify_type T ::= lazymatch T with | Z => TZ end. -Ltac Reify' e := Reify.Reify' base_type interp_base_type op e. +Ltac Reify' e := Reflection.Reify.Reify' base_type interp_base_type op e. Ltac Reify e := - let v := Reify.Reify base_type interp_base_type op e in - constr:((*Inline _*) ((*CSE _*) ((*InlineConst*) (Linearize v)))). + let v := Reflection.Reify.Reify base_type interp_base_type op e in + constr:((*Inline _*) ((*CSE _*) (InlineConst (Linearize v)))). +Ltac prove_InlineConst_Linearize_Compile_correct := + fun _ + => lazymatch goal with + | [ |- Syntax.interp_type_gen_rel_pointwise _ (@Syntax.Interp ?base_type_code ?interp_base_type ?op ?interp_op ?t (InlineConst (Linearize _))) _ ] + => etransitivity; + [ apply (@Interp_InlineConst base_type_code interp_base_type op interp_op t); + reflect_Wf base_type_eq_semidec_is_dec op_beq_bl + | etransitivity; + [ apply (@Interp_Linearize base_type_code interp_base_type op interp_op t) + | prove_compile_correct () ] ] + end. +Ltac Reify_rhs := + Reflection.Reify.Reify_rhs_gen Reify prove_InlineConst_Linearize_Compile_correct interp_op ltac:(fun tac => tac ()). diff --git a/src/Reflection/Z/Syntax.v b/src/Reflection/Z/Syntax.v index 7b87934d6..5657fff32 100644 --- a/src/Reflection/Z/Syntax.v +++ b/src/Reflection/Z/Syntax.v @@ -4,6 +4,9 @@ Require Import Crypto.Reflection.Syntax. Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. Require Import Crypto.Util.Equality. Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.HProp. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Util.PartiallyReifiedProp. Export Syntax.Notations. Local Set Boolean Equality Schemes. @@ -15,11 +18,19 @@ Definition interp_base_type (v : base_type) : Type := | TZ => Z end. +Global Instance dec_eq_base_type : DecidableRel (@eq base_type) + := base_type_eq_dec. +Global Instance dec_eq_flat_type : DecidableRel (@eq (flat_type base_type)) := _. +Global Instance dec_eq_type : DecidableRel (@eq (type base_type)) := _. + Local Notation tZ := (Tbase TZ). Local Notation eta x := (fst x, snd x). Local Notation eta3 x := (eta (fst x), snd x). Local Notation eta4 x := (eta3 (fst x), snd x). +Axiom proof_admitted : False. +Local Notation admit := (match proof_admitted with end). + Inductive op : flat_type base_type -> flat_type base_type -> Type := | Add : op (tZ * tZ) tZ | Sub : op (tZ * tZ) tZ @@ -30,7 +41,12 @@ Inductive op : flat_type base_type -> flat_type base_type -> Type := | Lor : op (tZ * tZ) tZ | Neg : op (tZ * tZ) tZ | Cmovne : op (tZ * tZ * tZ * tZ) tZ -| Cmovle : op (tZ * tZ * tZ * tZ) tZ. +| Cmovle : op (tZ * tZ * tZ * tZ) tZ +| ConditionalSubtract (pred_limb_count : nat) + : op (tZ (* int_width *) + * Syntax.tuple tZ (S pred_limb_count) (* modulus *) + * Syntax.tuple tZ (S pred_limb_count) (* value *)) + (Syntax.tuple tZ (S pred_limb_count)). Definition interp_op src dst (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst := match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with @@ -44,4 +60,95 @@ Definition interp_op src dst (f : op src dst) : interp_flat_type interp_base_typ | Neg => fun xy => ModularBaseSystemListZOperations.neg (fst xy) (snd xy) | Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w | Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovl x y z w + | ConditionalSubtract pred_n + => fun xyz => let '(x, y, z) := eta3 xyz in + flat_interp_untuple' (T:=tZ) (@ModularBaseSystemListZOperations.conditional_subtract_modulus (S pred_n) x (flat_interp_tuple y) (flat_interp_tuple z)) end%Z. + +Definition base_type_eq_semidec_transparent (t1 t2 : base_type) + : option (t1 = t2) + := Some (match t1, t2 return t1 = t2 with + | TZ, TZ => eq_refl + end). +Lemma base_type_eq_semidec_is_dec t1 t2 : base_type_eq_semidec_transparent t1 t2 = None -> t1 <> t2. +Proof. + unfold base_type_eq_semidec_transparent; congruence. +Qed. + +Definition op_beq_hetero {t1 tR t1' tR'} (f : op t1 tR) (g : op t1' tR') : reified_Prop + := match f, g return bool with + | Add, Add => true + | Add, _ => false + | Sub, Sub => true + | Sub, _ => false + | Mul, Mul => true + | Mul, _ => false + | Shl, Shl => true + | Shl, _ => false + | Shr, Shr => true + | Shr, _ => false + | Land, Land => true + | Land, _ => false + | Lor, Lor => true + | Lor, _ => false + | Neg, Neg => true + | Neg, _ => false + | Cmovne, Cmovne => true + | Cmovne, _ => false + | Cmovle, Cmovle => true + | Cmovle, _ => false + | ConditionalSubtract n, ConditionalSubtract m => NatUtil.nat_beq n m + | ConditionalSubtract _, _ => false + end. + +Definition op_beq t1 tR (f g : op t1 tR) : reified_Prop + := Eval cbv [op_beq_hetero] in op_beq_hetero f g. + +Definition op_beq_hetero_type_eq {t1 tR t1' tR'} f g : to_prop (@op_beq_hetero t1 tR t1' tR' f g) -> t1 = t1' /\ tR = tR'. +Proof. + destruct f, g; simpl; try solve [ repeat constructor | intros [] ]. + unfold op_beq_hetero; simpl. + match goal with + | [ |- context[to_prop (reified_Prop_of_bool ?x)] ] + => destruct (Sumbool.sumbool_of_bool x) as [P|P] + end. + { apply NatUtil.internal_nat_dec_bl in P; subst; repeat constructor. } + { intro H'; exfalso; rewrite P in H'; exact H'. } +Defined. + +Definition op_beq_hetero_type_eqs {t1 tR t1' tR'} f g : to_prop (@op_beq_hetero t1 tR t1' tR' f g) -> t1 = t1' + := fun H => let (p, q) := @op_beq_hetero_type_eq t1 tR t1' tR' f g H in p. +Definition op_beq_hetero_type_eqd {t1 tR t1' tR'} f g : to_prop (@op_beq_hetero t1 tR t1' tR' f g) -> tR = tR' + := fun H => let (p, q) := @op_beq_hetero_type_eq t1 tR t1' tR' f g H in q. + +Definition op_beq_hetero_eq {t1 tR t1' tR'} f g + : forall pf : to_prop (@op_beq_hetero t1 tR t1' tR' f g), + eq_rect + _ (fun src => op src tR') + (eq_rect _ (fun dst => op t1 dst) f _ (op_beq_hetero_type_eqd f g pf)) + _ (op_beq_hetero_type_eqs f g pf) + = g. +Proof. + destruct f, g; simpl; try solve [ reflexivity | intros [] ]. + { unfold op_beq_hetero, op_beq_hetero_type_eqs, op_beq_hetero_type_eqd; simpl. + intro pf; edestruct Sumbool.sumbool_of_bool. + { simpl; + lazymatch goal with + | [ |- context[NatUtil.internal_nat_dec_bl ?x ?y ?pf] ] + => generalize dependent (NatUtil.internal_nat_dec_bl x y pf); intro; subst + end; + reflexivity. } + { match goal with + | [ |- context[False_ind _ ?pf] ] + => case pf + end. } } +Qed. + +Lemma op_beq_bl : forall t1 tR x y, to_prop (op_beq t1 tR x y) -> x = y. +Proof. + intros ?? f g H. + pose proof (op_beq_hetero_eq f g H) as H'. + generalize dependent (op_beq_hetero_type_eqd f g H). + generalize dependent (op_beq_hetero_type_eqs f g H). + intros; eliminate_hprop_eq; simpl in *; assumption. +Qed. diff --git a/src/Spec/Ed25519.v b/src/Spec/Ed25519.v index a8e95cf9d..aa904fc7e 100644 --- a/src/Spec/Ed25519.v +++ b/src/Spec/Ed25519.v @@ -61,22 +61,23 @@ Section Ed25519. (F.of_Z q 15112221349535400772501151409588531511454012693041857206046113283949847762202, F.of_Z q 4 / F.of_Z q 5). - Definition Fencode {b : nat} {m} : F m -> Word.word b := + Local Infix "++" := Word.combine. + Local Notation bit b := (Word.WS b Word.WO : Word.word 1). + + Definition Fencode {len} {m} : F m -> Word.word len := fun x : F m => (Word.NToWord _ (BinIntDef.Z.to_N (F.to_Z x))). Definition sign (x : F q) : bool := BinIntDef.Z.testbit (F.to_Z x) 0. Definition Eenc : E -> Word.word b := fun P => - let '(x,y) := E.coordinates P in Word.WS (sign x) (Fencode y). - Definition Senc : Fl -> Word.word b := Fencode. - - (* TODO(andreser): prove this after we have fast scalar multplication *) - Axiom B_order_l : CompleteEdwardsCurveTheorems.E.eq (BinInt.Z.to_nat l * B)%E E.zero. + let '(x,y) := E.coordinates P in Fencode (len:=b-1) y ++ bit (sign x). + Definition Senc : Fl -> Word.word b := Fencode (len:=b). Require Import Crypto.Util.Decidable. Definition ed25519 : + CompleteEdwardsCurveTheorems.E.eq (BinInt.Z.to_nat l * B)%E E.zero -> (* TODO: prove this earlier than Experiments/Ed25519? *) EdDSA (E:=E) (Eadd:=E.add) (Ezero:=E.zero) (EscalarMult:=E.mul) (B:=B) (Eopp:=Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.E.opp) (* TODO: move defn *) (Eeq:=Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.E.eq) (* TODO: move defn *) (l:=l) (b:=b) (n:=n) (c:=c) (Eenc:=Eenc) (Senc:=Senc) (H:=H). - Proof. split; try exact _; try exact B_order_l; vm_decide. Qed. + Proof. split; try (assumption || exact _); vm_decide. Qed. End Ed25519.
\ No newline at end of file diff --git a/src/Spec/MxDH.v b/src/Spec/MxDH.v index d637836e4..0829c46f7 100644 --- a/src/Spec/MxDH.v +++ b/src/Spec/MxDH.v @@ -44,7 +44,7 @@ Module MxDH. (* from RFC7748 *) ((X4, Z4), (X5, Z5)) end. - Context {S:Type} {testbit:S->nat->bool} {cswap:bool->F*F->F*F->(F*F)*(F*F)}. + Context {cswap:bool->F*F->F*F->(F*F)*(F*F)}. Fixpoint downto_iter (i:nat) : list nat := match i with @@ -59,22 +59,22 @@ Module MxDH. (* from RFC7748 *) (* Ideally, we would verify that this corresponds to x coordinate multiplication *) - Definition montladder bound (s:S) (u:F) := - let '(_, _, P1, P2, swap) := + Definition montladder bound (testbit:nat->bool) (u:F) := + let '(P1, P2, swap) := downto - (s, u, (1, 0), (u, 1), false) + ((1, 0), (u, 1), false) bound (fun state i => - let '(s, x, P1, P2, swap) := state in - let s_i := testbit s i in + let '(P1, P2, swap) := state in + let s_i := testbit i in let swap := xor swap s_i in let '(P1, P2) := cswap swap P1 P2 in let swap := s_i in - let '(P1, P2) := ladderstep x P1 P2 in - (s, x, P1, P2, swap) + let '(P1, P2) := ladderstep u P1 P2 in + (P1, P2, swap) ) in let '((x, z), _) := cswap swap P1 P2 in - x/z. + x * Finv z. End MontgomeryLadderKeyExchange. End MxDH. diff --git a/src/Specific/FancyMachine256/Barrett.v b/src/Specific/FancyMachine256/Barrett.v index fd880b440..f3258fe60 100644 --- a/src/Specific/FancyMachine256/Barrett.v +++ b/src/Specific/FancyMachine256/Barrett.v @@ -85,8 +85,8 @@ Section reflected. Context (m μ : Z) (props : fancy_machine.arithmetic ops). - Let result (v : tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple m μ (fst v) (snd v). - Let assembled_result (v : tuple fancy_machine.W 2) : fancy_machine.W := Core.Interp compiled_syntax m μ (fst v) (snd v). + Let result (v : Tuple.tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple m μ (fst v) (snd v). + Let assembled_result (v : Tuple.tuple fancy_machine.W 2) : fancy_machine.W := Core.Interp compiled_syntax m μ (fst v) (snd v). Theorem sanity : result = expression ops m μ. Proof. @@ -108,7 +108,7 @@ Section reflected. (H3 : b^(k - offset) <= m + 1) (H4 : 0 <= m < 2^(k + offset)) (H5 : 0 <= b^(2 * k) / m < b^(k + offset)) - (v : tuple fancy_machine.W 2) + (v : Tuple.tuple fancy_machine.W 2) (H6 : 0 <= decode v < b^(2 * k)). Theorem correctness : fancy_machine.decode (result v) = decode v mod m. Proof. diff --git a/src/Specific/FancyMachine256/Core.v b/src/Specific/FancyMachine256/Core.v index 207237db7..eb443a8e3 100644 --- a/src/Specific/FancyMachine256/Core.v +++ b/src/Specific/FancyMachine256/Core.v @@ -125,7 +125,7 @@ Section reflection. Definition Inline {t} e := @InlineConstGen base_type interp_base_type op postprocess t e. End reflection. -Ltac base_reify_op op op_head ::= +Ltac base_reify_op op op_head expr ::= lazymatch op_head with | @Interface.ldi => constr:(reify_op op op_head 1 OPldi) | @Interface.shrd => constr:(reify_op op op_head 3 OPshrd) diff --git a/src/Specific/FancyMachine256/Montgomery.v b/src/Specific/FancyMachine256/Montgomery.v index fcf14afe2..b6f2da64a 100644 --- a/src/Specific/FancyMachine256/Montgomery.v +++ b/src/Specific/FancyMachine256/Montgomery.v @@ -76,9 +76,9 @@ Section reflected. Context (modulus m' : Z) (props : fancy_machine.arithmetic ops). - Let result (v : tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple modulus m' (fst v) (snd v). + Let result (v : Tuple.tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple modulus m' (fst v) (snd v). - Let assembled_result (v : tuple fancy_machine.W 2) : fancy_machine.W := Core.Interp compiled_syntax modulus m' (fst v) (snd v). + Let assembled_result (v : Tuple.tuple fancy_machine.W 2) : fancy_machine.W := Core.Interp compiled_syntax modulus m' (fst v) (snd v). Theorem sanity : result = expression ops modulus m'. Proof. @@ -100,7 +100,7 @@ Section reflected. (H2 : 0 <= m' < 2^256) (H3 : 2^256 * R' ≡ 1) (H4 : modulus * m' ≡₂₅₆ -1) - (v : tuple fancy_machine.W 2) + (v : Tuple.tuple fancy_machine.W 2) (H5 : 0 <= decode v <= 2^256 * modulus). Theorem correctness : fancy_machine.decode (result v) = (decode v * R') mod modulus. diff --git a/src/Specific/GF1305.v b/src/Specific/GF1305.v index 72184cf07..d31a2319f 100644 --- a/src/Specific/GF1305.v +++ b/src/Specific/GF1305.v @@ -44,7 +44,7 @@ Instance carryChain : CarryChain limb_widths. contradiction H. Defined. -Definition freezePreconditions1305 : freezePreconditions params1305 int_width. +Definition freezePreconditions1305 : FreezePreconditions int_width int_width. Proof. constructor; compute_preconditions. Defined. diff --git a/src/Specific/GF25519.v b/src/Specific/GF25519.v index a39755eee..2c0365fd2 100644 --- a/src/Specific/GF25519.v +++ b/src/Specific/GF25519.v @@ -14,6 +14,7 @@ Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Notations. Require Import Crypto.Util.Decidable. Require Import Crypto.Algebra. +Require Crypto.Spec.Ed25519. Import ListNotations. Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. Local Open Scope Z. @@ -21,8 +22,9 @@ Local Open Scope Z. (* BEGIN precomputation. *) Definition modulus : Z := Eval compute in 2^255 - 19. -Lemma prime_modulus : prime modulus. Admitted. -Definition int_width := 32%Z. +Definition prime_modulus : prime modulus := Crypto.Spec.Ed25519.prime_q. +Definition int_width := 64%Z. +Definition freeze_input_bound := 32%Z. Instance params25519 : PseudoMersenneBaseParams modulus. construct_params prime_modulus 10%nat 255. @@ -46,7 +48,7 @@ Instance carryChain : CarryChain limb_widths. contradiction H. Defined. -Definition freezePreconditions25519 : freezePreconditions params25519 int_width. +Definition freezePreconditions25519 : FreezePreconditions freeze_input_bound int_width. Proof. constructor; compute_preconditions. Defined. @@ -584,23 +586,22 @@ Proof. exact (proj2_sig (eqb_sig f' g')). Qed. -Definition sqrt_sig (f : fe25519) : - { f' : fe25519 | f' = sqrt_5mod8_opt (int_width := int_width) k_ c_ one_ sqrt_m1 f}. +Definition sqrt_sig (powf powf_squared f : fe25519) : + { f' : fe25519 | f' = sqrt_5mod8_opt (int_width := int_width) k_ c_ sqrt_m1 powf powf_squared f}. Proof. eexists. cbv [sqrt_5mod8_opt int_width]. - rewrite <- pow_correct. apply Proper_Let_In_nd_changebody; [reflexivity|intro]. set_evars. rewrite <-!mul_correct, <-eqb_correct. subst_evars. reflexivity. Defined. -Definition sqrt (f : fe25519) : fe25519 - := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig f). +Definition sqrt (powf powf_squared f : fe25519) : fe25519 + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig powf powf_squared f). -Definition sqrt_correct (f : fe25519) - : sqrt f = sqrt_5mod8_opt k_ c_ one_ sqrt_m1 f - := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig f). +Definition sqrt_correct (powf powf_squared f : fe25519) + : sqrt powf powf_squared f = sqrt_5mod8_opt k_ c_ sqrt_m1 powf powf_squared f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig powf powf_squared f). Definition pack_simpl_sig (f : fe25519) : { f' | f' = pack_opt params25519 wire_widths_nonneg bits_eq f }. diff --git a/src/Specific/GF25519Bounded.v b/src/Specific/GF25519Bounded.v index d9194610b..2d83b8dbd 100644 --- a/src/Specific/GF25519Bounded.v +++ b/src/Specific/GF25519Bounded.v @@ -7,7 +7,7 @@ Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. Require Import Crypto.Specific.GF25519. Require Import Crypto.Specific.GF25519BoundedCommon. -(*Require Import Crypto.Assembly.GF25519BoundedInstantiation.*) +Require Import Crypto.Specific.GF25519Reflective. Require Import Bedrock.Word Crypto.Util.WordUtil. Require Import Coq.Lists.List Crypto.Util.ListUtil. Require Import Crypto.Tactics.VerdiTactics. @@ -45,11 +45,15 @@ Local Ltac define_unop_WireToFE f opW blem := abstract bounded_wire_digits_t opW blem. Local Opaque Let_In. -(*Local Arguments interp_radd / _ _. +Local Opaque Z.add Z.sub Z.mul Z.shiftl Z.shiftr Z.land Z.lor Z.eqb NToWord64. +Local Arguments interp_radd / _ _. Local Arguments interp_rsub / _ _. Local Arguments interp_rmul / _ _. Local Arguments interp_ropp / _. Local Arguments interp_rfreeze / _. +Local Arguments interp_rge_modulus / _. +Local Arguments interp_rpack / _. +Local Arguments interp_runpack / _. Definition addW (f g : fe25519W) : fe25519W := Eval simpl in interp_radd f g. Definition subW (f g : fe25519W) : fe25519W := Eval simpl in interp_rsub f g. Definition mulW (f g : fe25519W) : fe25519W := Eval simpl in interp_rmul f g. @@ -57,15 +61,7 @@ Definition oppW (f : fe25519W) : fe25519W := Eval simpl in interp_ropp f. Definition freezeW (f : fe25519W) : fe25519W := Eval simpl in interp_rfreeze f. Definition ge_modulusW (f : fe25519W) : word64 := Eval simpl in interp_rge_modulus f. Definition packW (f : fe25519W) : wire_digitsW := Eval simpl in interp_rpack f. -Definition unpackW (f : wire_digitsW) : fe25519W := Eval simpl in interp_runpack f.*) -Definition addW (f g : fe25519W) : fe25519W := Eval cbv beta delta [carry_add] in carry_add f g. -Definition subW (f g : fe25519W) : fe25519W := Eval cbv beta delta [carry_sub] in carry_sub f g. -Definition mulW (f g : fe25519W) : fe25519W := Eval cbv beta delta [mul] in mul f g. -Definition oppW (f : fe25519W) : fe25519W := Eval cbv beta delta [carry_opp] in carry_opp f. -Definition freezeW (f : fe25519W) : fe25519W := Eval cbv beta delta [freeze] in freeze f. -Definition ge_modulusW (f : fe25519W) : word64 := Eval cbv beta delta [ge_modulus] in ge_modulus f. -Definition packW (f : fe25519W) : wire_digitsW := Eval cbv beta delta [pack] in pack f. -Definition unpackW (f : wire_digitsW) : fe25519W := Eval cbv beta delta [unpack] in unpack f. +Definition unpackW (f : wire_digitsW) : fe25519W := Eval simpl in interp_runpack f. Local Transparent Let_In. Definition powW (f : fe25519W) chain := fold_chain_opt (proj1_fe25519W one) mulW chain [f]. @@ -78,21 +74,21 @@ Local Ltac port_correct_and_bounded pre_rewrite opW interp_rop rop_cb := intros; apply rop_cb; assumption. Lemma addW_correct_and_bounded : ibinop_correct_and_bounded addW carry_add. -Proof. (*port_correct_and_bounded interp_radd_correct addW interp_radd radd_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_radd_correct addW interp_radd radd_correct_and_bounded. Qed. Lemma subW_correct_and_bounded : ibinop_correct_and_bounded subW carry_sub. -Proof. (*port_correct_and_bounded interp_rsub_correct subW interp_rsub rsub_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_rsub_correct subW interp_rsub rsub_correct_and_bounded. Qed. Lemma mulW_correct_and_bounded : ibinop_correct_and_bounded mulW mul. -Proof. (*port_correct_and_bounded interp_rmul_correct mulW interp_rmul rmul_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_rmul_correct mulW interp_rmul rmul_correct_and_bounded. Qed. Lemma oppW_correct_and_bounded : iunop_correct_and_bounded oppW carry_opp. -Proof. (*port_correct_and_bounded interp_ropp_correct oppW interp_ropp ropp_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_ropp_correct oppW interp_ropp ropp_correct_and_bounded. Qed. Lemma freezeW_correct_and_bounded : iunop_correct_and_bounded freezeW freeze. -Proof. (*port_correct_and_bounded interp_rfreeze_correct freezeW interp_rfreeze rfreeze_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_rfreeze_correct freezeW interp_rfreeze rfreeze_correct_and_bounded. Qed. Lemma ge_modulusW_correct : iunop_FEToZ_correct ge_modulusW ge_modulus. -Proof. (*port_correct_and_bounded interp_rge_modulus_correct ge_modulusW interp_rge_modulus rge_modulus_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_rge_modulus_correct ge_modulusW interp_rge_modulus rge_modulus_correct_and_bounded. Qed. Lemma packW_correct_and_bounded : iunop_FEToWire_correct_and_bounded packW pack. -Proof. (*port_correct_and_bounded interp_rpack_correct packW interp_rpack rpack_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_rpack_correct packW interp_rpack rpack_correct_and_bounded. Qed. Lemma unpackW_correct_and_bounded : iunop_WireToFE_correct_and_bounded unpackW unpack. -Proof. (*port_correct_and_bounded interp_runpack_correct unpackW interp_runpack runpack_correct_and_bounded. Qed.*) Admitted. +Proof. port_correct_and_bounded interp_runpack_correct unpackW interp_runpack runpack_correct_and_bounded. Qed. Lemma powW_correct_and_bounded chain : iunop_correct_and_bounded (fun x => powW x chain) (fun x => pow x chain). Proof. @@ -100,8 +96,7 @@ Proof. intro x; intros; apply (fold_chain_opt_gen fe25519WToZ is_bounded [x]). { reflexivity. } { reflexivity. } - { intros; progress rewrite <- ?mul_correct, - <- ?(fun X Y => proj1 (mulW_correct_and_bounded _ _ X Y)) by assumption. + { intros; progress rewrite <- (fun X Y => proj1 (mulW_correct_and_bounded _ _ X Y)) by assumption. apply mulW_correct_and_bounded; assumption. } { intros; rewrite (fun X Y => proj1 (mulW_correct_and_bounded _ _ X Y)) by assumption; reflexivity. } { intros [|?]; autorewrite with simpl_nth_default; @@ -131,11 +126,8 @@ Proof. Defined. Definition fieldwisebW (f g : fe25519W) : bool := - Eval cbv beta iota delta [proj1_sig fieldwisebW_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in - proj1_sig (fieldwisebW_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) - (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + Eval cbv [proj1_sig fieldwisebW_sig appify2 app_fe25519W] in + appify2 (fun f g => proj1_sig (fieldwisebW_sig f g)) f g. Lemma fieldwisebW_correct f g : fieldwisebW f g = GF25519.fieldwiseb (fe25519WToZ f) (fe25519WToZ g). @@ -167,11 +159,8 @@ Proof. Defined. Definition eqbW (f g : fe25519W) : bool := - Eval cbv beta iota delta [proj1_sig eqbW_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in - proj1_sig (eqbW_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) - (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + Eval cbv [proj1_sig eqbW_sig appify2 app_fe25519W] in + appify2 (fun f g => proj1_sig (eqbW_sig f g)) f g. Lemma eqbW_correct f g : is_bounded (fe25519WToZ f) = true @@ -183,28 +172,32 @@ Proof. exact (proj2_sig (eqbW_sig f' g')). Qed. +(* TODO(jgross): use NToWord or such for this constant too *) Definition sqrt_m1W : fe25519W := Eval vm_compute in fe25519ZToW sqrt_m1. +Definition GF25519sqrt x : GF25519.fe25519 := + dlet powx := powW (fe25519ZToW x) (chain GF25519.sqrt_ec) in + GF25519.sqrt (fe25519WToZ powx) (fe25519WToZ (mulW powx powx)) x. + Definition sqrtW_sig - : { sqrtW | iunop_correct_and_bounded sqrtW GF25519.sqrt }. + : { sqrtW | iunop_correct_and_bounded sqrtW GF25519sqrt }. Proof. eexists. - unfold GF25519.sqrt. - intros; set_evars; rewrite <- (fun pf => proj1 (powW_correct_and_bounded _ _ pf)) by assumption; subst_evars. - match goal with - | [ |- context G[dlet x := fe25519WToZ ?v in @?f x] ] - => let G' := context G[dlet x := v in f (fe25519WToZ x)] in - cut G'; cbv beta; - [ cbv [Let_In]; exact (fun x => x) | ] - end. + unfold GF25519sqrt, GF25519.sqrt. + intros. + rewrite !fe25519ZToW_WToZ. split. { etransitivity. Focus 2. { apply Proper_Let_In_nd_changebody_eq; intros. set_evars. + match goal with (* unfold the first dlet ... in, but only if it's binding a var *) + | [ |- ?x = dlet y := fe25519WToZ ?z in ?f ] + => is_var z; change (x = match fe25519WToZ z with y => f end) + end. change sqrt_m1 with (fe25519WToZ sqrt_m1W). - rewrite <- !(fun X Y => proj1 (mulW_correct_and_bounded _ _ X Y)), <- eqbW_correct, (pull_bool_if fe25519WToZ) + rewrite <- (fun X Y => proj1 (mulW_correct_and_bounded sqrt_m1W a X Y)), <- eqbW_correct, (pull_bool_if fe25519WToZ) by repeat match goal with | _ => progress subst | [ |- is_bounded (fe25519WToZ ?op) = true ] @@ -232,11 +225,10 @@ Proof. Defined. Definition sqrtW (f : fe25519W) : fe25519W := - Eval cbv beta iota delta [proj1_sig sqrtW_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - proj1_sig sqrtW_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9). + Eval cbv [proj1_sig sqrtW_sig app_fe25519W] in + app_fe25519W f (proj1_sig sqrtW_sig). -Lemma sqrtW_correct_and_bounded : iunop_correct_and_bounded sqrtW GF25519.sqrt. +Lemma sqrtW_correct_and_bounded : iunop_correct_and_bounded sqrtW GF25519sqrt. Proof. intro f. set (f' := f). @@ -310,7 +302,7 @@ Lemma pow_correct (f : fe25519) chain : proj1_fe25519 (pow f chain) = GF25519.po Proof. op_correct_t pow (powW_correct_and_bounded chain). Qed. Lemma inv_correct (f : fe25519) : proj1_fe25519 (inv f) = GF25519.inv (proj1_fe25519 f). Proof. op_correct_t inv invW_correct_and_bounded. Qed. -Lemma sqrt_correct (f : fe25519) : proj1_fe25519 (sqrt f) = GF25519.sqrt (proj1_fe25519 f). +Lemma sqrt_correct (f : fe25519) : proj1_fe25519 (sqrt f) = GF25519sqrt (proj1_fe25519 f). Proof. op_correct_t sqrt sqrtW_correct_and_bounded. Qed. Import Morphisms. diff --git a/src/Specific/GF25519BoundedCommon.v b/src/Specific/GF25519BoundedCommon.v index 511370c63..f9ee444dc 100644 --- a/src/Specific/GF25519BoundedCommon.v +++ b/src/Specific/GF25519BoundedCommon.v @@ -15,41 +15,224 @@ Require Import Crypto.Util.Tactics. Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Notations. Require Import Crypto.Util.Decidable. +Require Import Crypto.Util.HList. +Require Import Crypto.Util.Tuple. Require Import Crypto.Algebra. Import ListNotations. Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. Local Open Scope Z. -(* BEGIN aliases for word extraction *) -Definition word64 := Z. -Coercion word64ToZ (x : word64) : Z - := x. -Coercion ZToWord64 (x : Z) : word64 := x. -Definition w64eqb (x y : word64) := Z.eqb x y. - -Lemma word64eqb_Zeqb x y : (word64ToZ x =? word64ToZ y)%Z = w64eqb x y. -Proof. reflexivity. Qed. +(* BEGIN common curve-specific definitions *) +Definition bit_width : nat := 64%nat. +Local Notation b_of exp := (0, 2^exp + 2^(exp-3))%Z (only parsing). (* max is [(0, 2^(exp+2) + 2^exp + 2^(exp-1) + 2^(exp-3) + 2^(exp-4) + 2^(exp-5) + 2^(exp-6) + 2^(exp-10) + 2^(exp-12) + 2^(exp-13) + 2^(exp-14) + 2^(exp-15) + 2^(exp-17) + 2^(exp-23) + 2^(exp-24))%Z] *) +Definition bounds_exp : tuple Z length_fe25519 + := Eval compute in + Tuple.from_list length_fe25519 limb_widths eq_refl. +Definition bounds : tuple (Z * Z) length_fe25519 + := Eval compute in + Tuple.map (fun e => b_of e) bounds_exp. +Definition wire_digit_bounds_exp : tuple Z (length wire_widths) + := Eval compute in Tuple.from_list _ wire_widths eq_refl. +Definition wire_digit_bounds : tuple (Z * Z) (length wire_widths) + := Eval compute in Tuple.map (fun e => (0,2^e-1)%Z) wire_digit_bounds_exp. +(* END common curve-specific definitions *) +(* BEGIN aliases for word extraction *) +Definition word64 := Word.word bit_width. +Coercion word64ToZ (x : word64) : Z := Z.of_N (wordToN x). +Coercion ZToWord64 (x : Z) : word64 := NToWord _ (Z.to_N x). +Definition NToWord64 : N -> word64 := NToWord _. +Definition word64ize (x : word64) : word64 + := Eval cbv [wordToN N.succ_double N.double] in NToWord64 (wordToN x). +Definition w64eqb (x y : word64) := weqb x y. + +Global Arguments NToWord64 : simpl never. Arguments word64 : simpl never. +Arguments bit_width : simpl never. Global Opaque word64. +Global Opaque bit_width. (* END aliases for word extraction *) +(* BEGIN basic types *) +Module Type WordIsBounded. + Parameter is_boundedT : forall (lower upper : Z), word64 -> bool. + Parameter Build_is_boundedT : forall {lower upper} {proj_word : word64}, + andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z = true -> is_boundedT lower upper proj_word = true. + Parameter project_is_boundedT : forall {lower upper} {proj_word : word64}, + is_boundedT lower upper proj_word = true -> andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z = true. +End WordIsBounded. + +Module Import WordIsBoundedDefault : WordIsBounded. + Definition is_boundedT : forall (lower upper : Z), word64 -> bool + := fun lower upper proj_word => andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z. + Definition Build_is_boundedT {lower upper} {proj_word : word64} + : andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z = true -> is_boundedT lower upper proj_word = true + := fun x => x. + Definition project_is_boundedT {lower upper} {proj_word : word64} + : is_boundedT lower upper proj_word = true -> andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z = true + := fun x => x. +End WordIsBoundedDefault. + +Definition bounded_word (lower upper : Z) + := { proj_word : word64 | is_boundedT lower upper proj_word = true }. +Local Notation word_of exp := (bounded_word (fst (b_of exp)) (snd (b_of exp))). +Local Notation unbounded_word sz := (bounded_word 0 (2^sz-1)%Z). + +Local Opaque word64. +Definition fe25519W := Eval cbv (*-[word64]*) in (tuple word64 length_fe25519). +Definition wire_digitsW := Eval cbv (*-[word64]*) in (tuple word64 (length wire_widths)). +Definition fe25519 := + Eval cbv -[bounded_word Z.pow Z.sub Z.add] in + hlist (fun e => word_of e) bounds_exp. +Definition wire_digits := + Eval cbv -[bounded_word Z.pow Z.sub Z.add] in + hlist (fun e => unbounded_word e) wire_digit_bounds_exp. + +Definition is_bounded_gen {n} (x : tuple Z n) (bounds : tuple (Z * Z) n) : bool + := let res := Tuple.map2 + (fun bounds v => + let '(lower, upper) := bounds in + (lower <=? v) && (v <=? upper))%bool%Z + bounds x in + List.fold_right andb true (Tuple.to_list _ res). + +Definition is_bounded (x : Specific.GF25519.fe25519) : bool + := is_bounded_gen (n:=length_fe25519) x bounds. + +Definition wire_digits_is_bounded (x : Specific.GF25519.wire_digits) : bool + := is_bounded_gen (n:=length wire_widths) x wire_digit_bounds. + +(* END basic types *) + +Section generic_destructuring. + Fixpoint app_on' A n : forall T (f : tuple' A n) (P : forall x : tuple' A n, T x), T f + := match n return forall T (f : tuple' A n) (P : forall x : tuple' A n, T x), T f with + | O => fun T v P => P v + | S n' => fun T v P => let '(v, x) := v in app_on' A n' _ v (fun v => P (v, x)) + end. + Definition app_on {A n} : forall {T} (f : tuple A n) (P : forall x : tuple A n, T x), T f + := match n return forall T (f : tuple A n) (P : forall x : tuple A n, T x), T f with + | O => fun T v P => P v + | S n' => @app_on' A n' + end. + Lemma app_on'_correct {A n T} f (P : forall x : tuple' A n, T x) : app_on' A n T f P = P f. + Proof. + induction n; simpl in *; destruct_head' prod; [ reflexivity | exact (IHn _ _ (fun t => P (t, _))) ]. + Qed. + Lemma app_on_correct {A n T} f (P : forall x : tuple A n, T x) : app_on f P = P f. + Proof. destruct n; [ reflexivity | apply app_on'_correct ]. Qed. + + Fixpoint app_on_h' A F n : forall ts T (f : @hlist' A n F ts) (P : forall x : @hlist' A n F ts, T x), T f + := match n return forall ts T (f : @hlist' A n F ts) (P : forall x : @hlist' A n F ts, T x), T f with + | O => fun ts T v P => P v + | S n' => fun ts T v P => let '(v, x) := v in app_on_h' A F n' _ _ v (fun v => P (v, x)) + end. + Definition app_on_h {A F n} : forall ts T (f : @hlist A n F ts) (P : forall x : @hlist A n F ts, T x), T f + := match n return forall ts T (f : @hlist A n F ts) (P : forall x : @hlist A n F ts, T x), T f with + | O => fun ts T v P => P v + | S n' => @app_on_h' A F n' + end. + Lemma app_on_h'_correct {A F n ts T} f P : @app_on_h' A F n ts T f P = P f. + Proof. + induction n; simpl in *; destruct_head' prod; [ reflexivity | exact (IHn _ _ _ (fun h => P (h, f))) ]. + Qed. + Lemma app_on_h_correct {A} F {n} ts {T} f P : @app_on_h A F n ts T f P = P f. + Proof. destruct n; [ reflexivity | apply app_on_h'_correct ]. Qed. + + Definition app_wire_digitsW_dep {A T} (P : forall x : tuple A (length wire_widths), T x) + : forall (f : tuple A (length wire_widths)), T f + := Eval compute in fun f => @app_on A (length wire_widths) T f P. + Definition app_wire_digitsW {A T} (f : tuple A (length wire_widths)) (P : tuple A (length wire_widths) -> T) + := Eval compute in @app_wire_digitsW_dep A (fun _ => T) P f. + Definition app_fe25519W_dep {A T} (P : forall x : tuple A length_fe25519, T x) + : forall (f : tuple A length_fe25519), T f + := Eval compute in fun f => @app_on A length_fe25519 T f P. + Definition app_fe25519W {A T} (f : tuple A length_fe25519) (P : tuple A length_fe25519 -> T) + := Eval compute in @app_fe25519W_dep A (fun _ => T) P f. + Definition app_fe25519_dep {T} (P : forall x : fe25519, T x) + : forall f : fe25519, T f + := Eval compute in fun f => @app_on_h _ (fun e => word_of e) length_fe25519 bounds_exp T f P. + Definition app_fe25519 {T} (f : fe25519) (P : hlist (fun e => word_of e) bounds_exp -> T) + := Eval compute in @app_fe25519_dep (fun _ => T) P f. + Definition app_wire_digits_dep {T} (P : forall x : wire_digits, T x) + : forall f : wire_digits, T f + := Eval compute in fun f => @app_on_h _ (fun e => unbounded_word e) (length wire_widths) wire_digit_bounds_exp T f P. + Definition app_wire_digits {T} (f : wire_digits) (P : hlist (fun e => unbounded_word e) wire_digit_bounds_exp -> T) + := Eval compute in @app_wire_digits_dep (fun _ => T) P f. + + Definition app_wire_digitsW_dep_correct {A T} f P : @app_wire_digitsW_dep A T P f = P f + := app_on_correct f P. + Definition app_wire_digitsW_correct {A T} f P : @app_wire_digitsW A T f P = P f + := @app_wire_digitsW_dep_correct A (fun _ => T) f P. + Definition app_fe25519W_dep_correct {A T} f P : @app_fe25519W_dep A T P f = P f + := app_on_correct f P. + Definition app_fe25519W_correct {A T} f P : @app_fe25519W A T f P = P f + := @app_fe25519W_dep_correct A (fun _ => T) f P. + Definition app_fe25519_dep_correct {T} f P : @app_fe25519_dep T P f = P f + := app_on_h_correct (fun e => word_of e) bounds_exp f P. + Definition app_fe25519_correct {T} f P : @app_fe25519 T f P = P f + := @app_fe25519_dep_correct (fun _ => T) f P. + Definition app_wire_digits_dep_correct {T} f P : @app_wire_digits_dep T P f = P f + := app_on_h_correct (fun e => unbounded_word e) wire_digit_bounds_exp f P. + Definition app_wire_digits_correct {T} f P : @app_wire_digits T f P = P f + := @app_wire_digits_dep_correct (fun _ => T) f P. + + Definition appify2 {T} (op : fe25519W -> fe25519W -> T) (f g : fe25519W) := + app_fe25519W f (fun f0 => (app_fe25519W g (fun g0 => op f0 g0))). + + Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. + Proof. + intros. cbv [appify2]. + etransitivity; apply app_fe25519W_correct. + Qed. +End generic_destructuring. + +Definition eta_fe25519W_sig (x : fe25519W) : { v : fe25519W | v = x }. +Proof. + eexists; symmetry. + repeat (etransitivity; [ apply surjective_pairing | apply f_equal2 ]); reflexivity. +Defined. +Definition eta_fe25519W (x : fe25519W) : fe25519W + := Eval cbv [proj1_sig eta_fe25519W_sig] in proj1_sig (eta_fe25519W_sig x). +Definition eta_wire_digitsW_sig (x : wire_digitsW) : { v : wire_digitsW | v = x }. +Proof. + eexists; symmetry. + repeat (etransitivity; [ apply surjective_pairing | apply f_equal2 ]); reflexivity. +Defined. +Definition eta_wire_digitsW (x : wire_digitsW) : wire_digitsW + := Eval cbv [proj1_sig eta_wire_digitsW_sig] in proj1_sig (eta_wire_digitsW_sig x). + +Local Transparent word64. +Lemma word64ize_id x : word64ize x = x. +Proof. apply NToWord_wordToN. Qed. +Local Opaque word64. + +Lemma word64eqb_Zeqb x y : (word64ToZ x =? word64ToZ y)%Z = w64eqb x y. +Proof. apply wordeqb_Zeqb. Qed. + Local Arguments Z.pow_pos !_ !_ / . -Lemma ZToWord64ToZ x : 0 <= x < 2^64 -> word64ToZ (ZToWord64 x) = x. +Lemma word64ToZ_ZToWord64 x : 0 <= x < 2^Z.of_nat bit_width -> word64ToZ (ZToWord64 x) = x. Proof. intros; unfold word64ToZ, ZToWord64. - rewrite ?wordToN_NToWord_idempotent, ?N2Z.id, ?Z2N.id - by (omega || apply N2Z.inj_lt; rewrite ?N2Z.id, ?Z2N.id by omega; simpl in *; omega). + rewrite wordToN_NToWord_idempotent, Z2N.id + by (omega || apply N2Z.inj_lt; rewrite <- ?(N_nat_Z (Npow2 _)), ?Npow2_nat, ?Zpow_pow2, ?N2Z.id, ?Z2N.id, ?Z2Nat.id by omega; omega). reflexivity. Qed. +Lemma ZToWord64_word64ToZ x : ZToWord64 (word64ToZ x) = x. +Proof. + intros; unfold word64ToZ, ZToWord64. + rewrite N2Z.id, NToWord_wordToN; reflexivity. +Qed. (* BEGIN precomputation. *) -Local Notation b_of exp := (0, 2^exp + 2^(exp-3))%Z (only parsing). (* max is [(0, 2^(exp+2) + 2^exp + 2^(exp-1) + 2^(exp-3) + 2^(exp-4) + 2^(exp-5) + 2^(exp-6) + 2^(exp-10) + 2^(exp-12) + 2^(exp-13) + 2^(exp-14) + 2^(exp-15) + 2^(exp-17) + 2^(exp-23) + 2^(exp-24))%Z] *) -Record bounded_word (lower upper : Z) := - Build_bounded_word' - { proj_word :> word64; - word_bounded : andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z = true }. + +Definition proj_word {lower upper} (v : bounded_word lower upper) := Eval cbv [proj1_sig] in proj1_sig v. +Definition word_bounded {lower upper} (v : bounded_word lower upper) + : andb (lower <=? proj_word v)%Z (proj_word v <=? upper)%Z = true + := project_is_boundedT (proj2_sig v). +Definition Build_bounded_word' {lower upper} proj_word word_bounded : bounded_word lower upper + := exist _ proj_word (Build_is_boundedT word_bounded). Arguments proj_word {_ _} _. Arguments word_bounded {_ _} _. Arguments Build_bounded_word' {_ _} _ _. @@ -61,22 +244,20 @@ Definition Build_bounded_word {lower upper} (proj_word : word64) (word_bounded : | true => fun _ => eq_refl | false => fun x => x end word_bounded). -Local Notation word_of exp := (bounded_word (fst (b_of exp)) (snd (b_of exp))). -Local Notation unbounded_word sz := (bounded_word 0 (2^sz-1)%Z). -Lemma word_to_unbounded_helper {x e : nat} : (x < pow2 e)%nat -> (Z.of_nat e <= 64)%Z -> ((0 <=? word64ToZ (ZToWord64 (Z.of_nat x))) && (word64ToZ (ZToWord64 (Z.of_nat x)) <=? 2 ^ (Z.of_nat e) - 1))%bool = true. +Lemma word_to_unbounded_helper {x e : nat} : (x < pow2 e)%nat -> (Z.of_nat e <= Z.of_nat bit_width)%Z -> ((0 <=? word64ToZ (ZToWord64 (Z.of_nat x))) && (word64ToZ (ZToWord64 (Z.of_nat x)) <=? 2 ^ (Z.of_nat e) - 1))%bool = true. Proof. rewrite pow2_id; intro H; apply Nat2Z.inj_lt in H; revert H. rewrite Z.pow_Zpow; simpl Z.of_nat. intros H H'. - assert (2^Z.of_nat e <= 2^64) by auto with zarith. - rewrite !ZToWord64ToZ by omega. + assert (2^Z.of_nat e <= 2^Z.of_nat bit_width) by auto with zarith. + rewrite ?word64ToZ_ZToWord64 by omega. match goal with | [ |- context[andb ?x ?y] ] => destruct x eqn:?, y eqn:?; try reflexivity; Z.ltb_to_lt end; intros; omega. Qed. -Definition word_to_unbounded_word {sz} (x : word sz) : (Z.of_nat sz <=? 64)%Z = true -> unbounded_word (Z.of_nat sz). +Definition word_to_unbounded_word {sz} (x : word sz) : (Z.of_nat sz <=? Z.of_nat bit_width)%Z = true -> unbounded_word (Z.of_nat sz). Proof. refine (fun pf => Build_bounded_word (Z.of_N (wordToN x)) _). abstract (rewrite wordToN_nat, nat_N_Z; Z.ltb_to_lt; apply (word_to_unbounded_helper (wordToNat_bound x)); simpl; omega). @@ -85,248 +266,310 @@ Definition word32_to_unbounded_word (x : word 32) : unbounded_word 32. Proof. apply (word_to_unbounded_word x); reflexivity. Defined. Definition word31_to_unbounded_word (x : word 31) : unbounded_word 31. Proof. apply (word_to_unbounded_word x); reflexivity. Defined. -Definition bounds : list (Z * Z) - := Eval compute in - [b_of 25; b_of 26; b_of 25; b_of 26; b_of 25; b_of 26; b_of 25; b_of 26; b_of 25; b_of 26]. -Definition wire_digit_bounds : list (Z * Z) - := Eval compute in - List.repeat (0, 2^32-1)%Z 7 ++ ((0,2^31-1)%Z :: nil). Local Opaque word64. -Definition fe25519W := Eval cbv (*-[word64]*) in (tuple word64 (length limb_widths)). -Definition wire_digitsW := Eval cbv (*-[word64]*) in (tuple word64 8). +Declare Reduction app_tuple_map := cbv [app_wire_digitsW app_fe25519W app_fe25519 HList.mapt HList.mapt' Tuple.map on_tuple List.map List.app length_fe25519 List.length wire_widths Tuple.from_list Tuple.from_list' Tuple.to_list Tuple.to_list' fst snd]. Definition fe25519WToZ (x : fe25519W) : Specific.GF25519.fe25519 - := let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - (x0 : Z, x1 : Z, x2 : Z, x3 : Z, x4 : Z, x5 : Z, x6 : Z, x7 : Z, x8 : Z, x9 : Z). + := Eval app_tuple_map in + app_fe25519W x (Tuple.map (fun v : word64 => v : Z)). Definition fe25519ZToW (x : Specific.GF25519.fe25519) : fe25519W - := let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - (x0 : word64, x1 : word64, x2 : word64, x3 : word64, x4 : word64, x5 : word64, x6 : word64, x7 : word64, x8 : word64, x9 : word64). + := Eval app_tuple_map in + app_fe25519W x (Tuple.map (fun v : Z => v : word64)). Definition wire_digitsWToZ (x : wire_digitsW) : Specific.GF25519.wire_digits - := let '(x0, x1, x2, x3, x4, x5, x6, x7) := x in - (x0 : Z, x1 : Z, x2 : Z, x3 : Z, x4 : Z, x5 : Z, x6 : Z, x7 : Z). + := Eval app_tuple_map in + app_wire_digitsW x (Tuple.map (fun v : word64 => v : Z)). Definition wire_digitsZToW (x : Specific.GF25519.wire_digits) : wire_digitsW - := let '(x0, x1, x2, x3, x4, x5, x6, x7) := x in - (x0 : word64, x1 : word64, x2 : word64, x3 : word64, x4 : word64, x5 : word64, x6 : word64, x7 : word64). -Definition fe25519 := - Eval cbv [fst snd] in - let sanity := eq_refl : length bounds = length limb_widths in - (word_of 25 * word_of 26 * word_of 25 * word_of 26 * word_of 25 * word_of 26 * word_of 25 * word_of 26 * word_of 25 * word_of 26)%type. -Definition wire_digits := - Eval cbv [fst snd Tuple.tuple Tuple.tuple'] in - (unbounded_word 32 * unbounded_word 32 * unbounded_word 32 * unbounded_word 32 - * unbounded_word 32 * unbounded_word 32 * unbounded_word 32 * unbounded_word 31)%type. + := Eval app_tuple_map in + app_wire_digitsW x (Tuple.map (fun v : Z => v : word64)). +Definition fe25519W_word64ize (x : fe25519W) : fe25519W + := Eval app_tuple_map in + app_fe25519W x (Tuple.map word64ize). +Definition wire_digitsW_word64ize (x : wire_digitsW) : wire_digitsW + := Eval app_tuple_map in + app_wire_digitsW x (Tuple.map word64ize). + +(** TODO: Turn this into a lemma to speed up proofs *) +Ltac unfold_is_bounded_in H := + unfold is_bounded, wire_digits_is_bounded, is_bounded_gen, fe25519WToZ, wire_digitsWToZ in H; + cbv [to_list length bounds wire_digit_bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map fold_right List.rev List.app length_fe25519 List.length wire_widths] in H; + rewrite ?Bool.andb_true_iff in H. + +Ltac unfold_is_bounded := + unfold is_bounded, wire_digits_is_bounded, is_bounded_gen, fe25519WToZ, wire_digitsWToZ; + cbv [to_list length bounds wire_digit_bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map fold_right List.rev List.app length_fe25519 List.length wire_widths]; + rewrite ?Bool.andb_true_iff. + +Local Transparent bit_width. +Definition Pow2_64 := Eval compute in 2^Z.of_nat bit_width. +Definition unfold_Pow2_64 : 2^Z.of_nat bit_width = Pow2_64 := eq_refl. +Local Opaque bit_width. + +Local Ltac prove_lt_bit_width := + rewrite unfold_Pow2_64; cbv [Pow2_64]; omega. + +Lemma fe25519ZToW_WToZ (x : fe25519W) : fe25519ZToW (fe25519WToZ x) = x. +Proof. + hnf in x; destruct_head' prod; cbv [fe25519WToZ fe25519ZToW]. + rewrite !ZToWord64_word64ToZ; reflexivity. +Qed. + +Lemma fe25519WToZ_ZToW x : is_bounded x = true -> fe25519WToZ (fe25519ZToW x) = x. +Proof. + hnf in x; destruct_head' prod; cbv [fe25519WToZ fe25519ZToW]. + intro H. + unfold_is_bounded_in H; destruct_head' and. + Z.ltb_to_lt. + rewrite !word64ToZ_ZToWord64 by prove_lt_bit_width. + reflexivity. +Qed. + +Lemma fe25519W_word64ize_id x : fe25519W_word64ize x = x. +Proof. + hnf in x; destruct_head' prod. + cbv [fe25519W_word64ize]; + repeat apply f_equal2; apply word64ize_id. +Qed. +Lemma wire_digitsW_word64ize_id x : wire_digitsW_word64ize x = x. +Proof. + hnf in x; destruct_head' prod. + cbv [wire_digitsW_word64ize]; + repeat apply f_equal2; apply word64ize_id. +Qed. + +Definition uncurry_unop_fe25519W {T} (op : fe25519W -> T) + := Eval cbv (*-[word64]*) in Tuple.uncurry (n:=length_fe25519) op. +Definition curry_unop_fe25519W {T} op : fe25519W -> T + := Eval cbv (*-[word64]*) in fun f => app_fe25519W f (Tuple.curry (n:=length_fe25519) op). +Definition uncurry_binop_fe25519W {T} (op : fe25519W -> fe25519W -> T) + := Eval cbv (*-[word64]*) in uncurry_unop_fe25519W (fun f => uncurry_unop_fe25519W (op f)). +Definition curry_binop_fe25519W {T} op : fe25519W -> fe25519W -> T + := Eval cbv (*-[word64]*) in appify2 (fun f => curry_unop_fe25519W (curry_unop_fe25519W op f)). + +Definition uncurry_unop_wire_digitsW {T} (op : wire_digitsW -> T) + := Eval cbv (*-[word64]*) in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digitsW {T} op : wire_digitsW -> T + := Eval cbv (*-[word64]*) in fun f => app_wire_digitsW f (Tuple.curry (n:=length wire_widths) op). + + Definition proj1_fe25519W (x : fe25519) : fe25519W - := let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - (proj_word x0, proj_word x1, proj_word x2, proj_word x3, proj_word x4, - proj_word x5, proj_word x6, proj_word x7, proj_word x8, proj_word x9). + := Eval app_tuple_map in + app_fe25519 x (HList.mapt (fun _ => (@proj_word _ _))). Coercion proj1_fe25519 (x : fe25519) : Specific.GF25519.fe25519 := fe25519WToZ (proj1_fe25519W x). -Definition is_bounded (x : Specific.GF25519.fe25519) : bool - := let res := Tuple.map2 - (fun bounds v => - let '(lower, upper) := bounds in - (lower <=? v) && (v <=? upper))%bool%Z - (Tuple.from_list _ (List.rev bounds) eq_refl) x in - List.fold_right andb true (Tuple.to_list _ res). Lemma is_bounded_proj1_fe25519 (x : fe25519) : is_bounded (proj1_fe25519 x) = true. Proof. - refine (let '(Build_bounded_word' x0 p0, Build_bounded_word' x1 p1, Build_bounded_word' x2 p2, Build_bounded_word' x3 p3, Build_bounded_word' x4 p4, - Build_bounded_word' x5 p5, Build_bounded_word' x6 p6, Build_bounded_word' x7 p7, Build_bounded_word' x8 p8, Build_bounded_word' x9 p9) - as x := x return is_bounded (proj1_fe25519 x) = true in - _). - cbv [is_bounded proj1_fe25519 proj1_fe25519W fe25519WToZ to_list length bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word]. + revert x; refine (app_fe25519_dep _); intro x. + hnf in x; destruct_head' prod; destruct_head' bounded_word. + cbv [is_bounded proj1_fe25519 proj1_fe25519W fe25519WToZ to_list length bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word length_fe25519 is_bounded_gen]. apply fold_right_andb_true_iff_fold_right_and_True. cbv [fold_right List.map]. cbv beta in *. - repeat split; assumption. + repeat split; auto using project_is_boundedT. Qed. Definition proj1_wire_digitsW (x : wire_digits) : wire_digitsW - := let '(x0, x1, x2, x3, x4, x5, x6, x7) := x in - (proj_word x0, proj_word x1, proj_word x2, proj_word x3, proj_word x4, - proj_word x5, proj_word x6, proj_word x7). + := app_wire_digits x (HList.mapt (fun _ => proj_word)). Coercion proj1_wire_digits (x : wire_digits) : Specific.GF25519.wire_digits := wire_digitsWToZ (proj1_wire_digitsW x). -Definition wire_digits_is_bounded (x : Specific.GF25519.wire_digits) : bool - := let res := Tuple.map2 - (fun bounds v => - let '(lower, upper) := bounds in - (lower <=? v) && (v <=? upper))%bool%Z - (Tuple.from_list _ (List.rev wire_digit_bounds) eq_refl) x in - List.fold_right andb true (Tuple.to_list _ res). Lemma is_bounded_proj1_wire_digits (x : wire_digits) : wire_digits_is_bounded (proj1_wire_digits x) = true. Proof. - refine (let '(Build_bounded_word' x0 p0, Build_bounded_word' x1 p1, Build_bounded_word' x2 p2, Build_bounded_word' x3 p3, Build_bounded_word' x4 p4, - Build_bounded_word' x5 p5, Build_bounded_word' x6 p6, Build_bounded_word' x7 p7) - as x := x return wire_digits_is_bounded (proj1_wire_digits x) = true in - _). - cbv [wire_digits_is_bounded proj1_wire_digits proj1_wire_digitsW wire_digitsWToZ to_list length wire_digit_bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word]. + revert x; refine (app_wire_digits_dep _); intro x. + hnf in x; destruct_head' prod; destruct_head' bounded_word. + cbv [wire_digits_is_bounded proj1_wire_digits proj1_wire_digitsW wire_digitsWToZ to_list length wire_digit_bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word is_bounded_gen wire_widths HList.mapt HList.mapt' app_wire_digits fst snd]. apply fold_right_andb_true_iff_fold_right_and_True. cbv [fold_right List.map]. cbv beta in *. - repeat split; assumption. + repeat split; auto using project_is_boundedT. Qed. -(** TODO: Turn this into a lemma to speed up proofs *) -Ltac unfold_is_bounded_in H := - unfold is_bounded, wire_digits_is_bounded, fe25519WToZ, wire_digitsWToZ in H; - cbv [to_list length bounds wire_digit_bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map fold_right List.rev List.app] in H; - rewrite !Bool.andb_true_iff in H. - -Definition Pow2_64 := Eval compute in 2^64. -Definition unfold_Pow2_64 : 2^64 = Pow2_64 := eq_refl. - -Definition exist_fe25519W (x : fe25519W) : is_bounded (fe25519WToZ x) = true -> fe25519. -Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return is_bounded (fe25519WToZ x) = true -> fe25519 in - fun H => (fun H' => (Build_bounded_word x0 _, Build_bounded_word x1 _, Build_bounded_word x2 _, Build_bounded_word x3 _, Build_bounded_word x4 _, - Build_bounded_word x5 _, Build_bounded_word x6 _, Build_bounded_word x7 _, Build_bounded_word x8 _, Build_bounded_word x9 _)) - (let H' := proj1 (@fold_right_andb_true_iff_fold_right_and_True _) H in - _)); - [ - | | | | | | | | | - | clearbody H'; clear H x; - unfold_is_bounded_in H'; - exact H' ]; - destruct_head and; auto; - rewrite_hyp !*; reflexivity. -Defined. - -Definition exist_fe25519' (x : Specific.GF25519.fe25519) : is_bounded x = true -> fe25519. -Proof. - intro H; apply (exist_fe25519W (fe25519ZToW x)). +Local Ltac make_exist_W' x app_W_dep := + let H := fresh in + revert x; refine (@app_W_dep _ _ _); intros x H; + let x' := fresh in + set (x' := x); + cbv [tuple tuple' length_fe25519 List.length wire_widths] in x; + destruct_head' prod; + let rec do_refine v H := + first [ let v' := (eval cbv [snd fst] in (snd v)) in + refine (_, Build_bounded_word v' _); + [ do_refine (fst v) (proj2 H) | subst x'; abstract exact (proj1 H) ] + | let v' := (eval cbv [snd fst] in v) in + refine (Build_bounded_word v' _); subst x'; abstract exact (proj1 H) ] in + let H' := constr:(proj1 (@fold_right_andb_true_iff_fold_right_and_True _) H) in + let T := type of H' in + let T := (eval cbv [id + List.fold_right List.map List.length List.app ListUtil.map2 List.rev + Tuple.to_list Tuple.to_list' Tuple.from_list Tuple.from_list' Tuple.map2 Tuple.on_tuple2 + fe25519 bounds fe25519WToZ length_fe25519 + wire_digits wire_digit_bounds wire_digitsWToZ wire_widths] in T) in + let H' := constr:(H' : T) in + let v := (eval unfold x' in x') in + do_refine v H'. +Local Ltac make_exist'' x exist_W ZToW := + let H := fresh in + intro H; apply (exist_W (ZToW x)); abstract ( - hnf in x; destruct_head prod; + hnf in x; destruct_head' prod; + let H' := fresh in pose proof H as H'; unfold_is_bounded_in H; - destruct_head and; + destruct_head' and; simpl in *; Z.ltb_to_lt; - rewrite ?ZToWord64ToZ by (simpl; omega); + rewrite ?word64ToZ_ZToWord64 by prove_lt_bit_width; assumption ). -Defined. - -Definition exist_fe25519 (x : Specific.GF25519.fe25519) : is_bounded x = true -> fe25519. -Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return is_bounded x = true -> fe25519 in - fun H => _). - let v := constr:(exist_fe25519' (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) H) in +Local Ltac make_exist' x app_W_dep exist'' exist_W ZToW := + let H := fresh in + revert x; refine (@app_W_dep _ _ _); intros x H; + let x' := fresh in + set (x' := x) in *; + cbv [tuple tuple' length_fe25519 List.length wire_widths] in x; + destruct_head' prod; let rec do_refine v := - first [ let v' := (eval cbv [exist_fe25519W fe25519ZToW exist_fe25519' proj_word Build_bounded_word snd fst] in (proj_word v)) in - refine (Build_bounded_word v' _); abstract exact (word_bounded v) - | let v' := (eval cbv [exist_fe25519W fe25519ZToW exist_fe25519' proj_word Build_bounded_word snd fst] in (proj_word (snd v))) in + first [ let v' := (eval cbv [exist_W ZToW exist'' proj_word Build_bounded_word Build_bounded_word' snd fst] in (proj_word v)) in + refine (Build_bounded_word v' _); subst x'; abstract exact (word_bounded v) + | let v' := (eval cbv [exist_W ZToW exist'' proj_word Build_bounded_word Build_bounded_word' snd fst] in (proj_word (snd v))) in refine (_, Build_bounded_word v' _); - [ do_refine (fst v) | abstract exact (word_bounded (snd v)) ] ] in + [ do_refine (fst v) | subst x'; abstract exact (word_bounded (snd v)) ] ] in + let v := (eval unfold x' in (exist'' x' H)) in do_refine v. -Defined. + +Definition exist_fe25519W' (x : fe25519W) : is_bounded (fe25519WToZ x) = true -> fe25519. +Proof. make_exist_W' x (@app_fe25519W_dep). Defined. +Definition exist_fe25519W (x : fe25519W) : is_bounded (fe25519WToZ x) = true -> fe25519 + := Eval cbv [app_fe25519W_dep exist_fe25519W' fe25519ZToW] in exist_fe25519W' x. +Definition exist_fe25519'' (x : Specific.GF25519.fe25519) : is_bounded x = true -> fe25519. +Proof. make_exist'' x exist_fe25519W fe25519ZToW. Defined. +Definition exist_fe25519' (x : Specific.GF25519.fe25519) : is_bounded x = true -> fe25519. +Proof. make_exist' x (@app_fe25519W_dep) exist_fe25519'' exist_fe25519W fe25519ZToW. Defined. +Definition exist_fe25519 (x : Specific.GF25519.fe25519) : is_bounded x = true -> fe25519 + := Eval cbv [exist_fe25519' exist_fe25519W exist_fe25519' app_fe25519 app_fe25519W_dep] in + exist_fe25519' x. Lemma proj1_fe25519_exist_fe25519W x pf : proj1_fe25519 (exist_fe25519W x pf) = fe25519WToZ x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return forall pf : is_bounded (fe25519WToZ x) = true, proj1_fe25519 (exist_fe25519W x pf) = fe25519WToZ x in - fun pf => _). - reflexivity. -Qed. +Proof. now hnf in x; destruct_head' prod. Qed. Lemma proj1_fe25519W_exist_fe25519 x pf : proj1_fe25519W (exist_fe25519 x pf) = fe25519ZToW x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return forall pf : is_bounded x = true, proj1_fe25519W (exist_fe25519 x pf) = fe25519ZToW x in - fun pf => _). - reflexivity. -Qed. +Proof. now hnf in x; destruct_head' prod. Qed. Lemma proj1_fe25519_exist_fe25519 x pf : proj1_fe25519 (exist_fe25519 x pf) = x. Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return forall pf : is_bounded x = true, proj1_fe25519 (exist_fe25519 x pf) = x in - fun pf => _). - cbv [proj1_fe25519 exist_fe25519 proj1_fe25519W fe25519WToZ proj_word Build_bounded_word]. + hnf in x; destruct_head' prod. + cbv [proj1_fe25519 exist_fe25519 proj1_fe25519W fe25519WToZ proj_word Build_bounded_word Build_bounded_word']. unfold_is_bounded_in pf. - destruct_head and. + destruct_head' and. Z.ltb_to_lt. - rewrite ?ZToWord64ToZ by (rewrite unfold_Pow2_64; cbv [Pow2_64]; omega). + rewrite ?word64ToZ_ZToWord64 by prove_lt_bit_width. reflexivity. Qed. -Definition exist_wire_digitsW (x : wire_digitsW) : wire_digits_is_bounded (wire_digitsWToZ x) = true -> wire_digits. +Definition exist_wire_digitsW' (x : wire_digitsW) + : wire_digits_is_bounded (wire_digitsWToZ x) = true -> wire_digits. +Proof. make_exist_W' x (@app_wire_digitsW_dep). Defined. +Definition exist_wire_digitsW (x : wire_digitsW) + : wire_digits_is_bounded (wire_digitsWToZ x) = true -> wire_digits + := Eval cbv [app_wire_digitsW_dep exist_wire_digitsW' wire_digitsZToW] in exist_wire_digitsW' x. +Definition exist_wire_digits'' (x : Specific.GF25519.wire_digits) + : wire_digits_is_bounded x = true -> wire_digits. +Proof. make_exist'' x exist_wire_digitsW wire_digitsZToW. Defined. +Definition exist_wire_digits' (x : Specific.GF25519.wire_digits) + : wire_digits_is_bounded x = true -> wire_digits. +Proof. make_exist' x (@app_wire_digitsW_dep) exist_wire_digits'' exist_wire_digitsW wire_digitsZToW. Defined. +Definition exist_wire_digits (x : Specific.GF25519.wire_digits) + : wire_digits_is_bounded x = true -> wire_digits + := Eval cbv [exist_wire_digits' exist_wire_digitsW exist_wire_digits' app_wire_digits app_wire_digitsW_dep] in + exist_wire_digits' x. + +Lemma proj1_wire_digits_exist_wire_digitsW x pf : proj1_wire_digits (exist_wire_digitsW x pf) = wire_digitsWToZ x. +Proof. now hnf in x; destruct_head' prod. Qed. +Lemma proj1_wire_digitsW_exist_wire_digits x pf : proj1_wire_digitsW (exist_wire_digits x pf) = wire_digitsZToW x. +Proof. now hnf in x; destruct_head' prod. Qed. +Lemma proj1_wire_digits_exist_wire_digits x pf : proj1_wire_digits (exist_wire_digits x pf) = x. Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return wire_digits_is_bounded (wire_digitsWToZ x) = true -> wire_digits in - fun H => (fun H' => (Build_bounded_word x0 _, Build_bounded_word x1 _, Build_bounded_word x2 _, Build_bounded_word x3 _, Build_bounded_word x4 _, - Build_bounded_word x5 _, Build_bounded_word x6 _, Build_bounded_word x7 _)) - (let H' := proj1 (@fold_right_andb_true_iff_fold_right_and_True _) H in - _)); - [ - | | | | | | | - | clearbody H'; clear H x; - unfold_is_bounded_in H'; - exact H' ]; - destruct_head and; auto; - rewrite_hyp !*; reflexivity. -Defined. + hnf in x; destruct_head' prod. + cbv [proj1_wire_digits exist_wire_digits proj1_wire_digitsW wire_digitsWToZ proj_word Build_bounded_word Build_bounded_word' app_wire_digits HList.mapt HList.mapt' length wire_widths fst snd]. + unfold_is_bounded_in pf. + destruct_head' and. + Z.ltb_to_lt. + rewrite ?word64ToZ_ZToWord64 by prove_lt_bit_width. + reflexivity. +Qed. -Definition exist_wire_digits' (x : Specific.GF25519.wire_digits) : wire_digits_is_bounded x = true -> wire_digits. +Module opt. + Definition word64ToZ := Eval vm_compute in word64ToZ. + Definition word64ToN := Eval vm_compute in @wordToN bit_width. + Definition NToWord64 := Eval vm_compute in NToWord64. + Definition bit_width := Eval vm_compute in bit_width. + Definition Zleb := Eval cbv [Z.leb] in Z.leb. + Definition andb := Eval vm_compute in andb. + Definition word64ize := Eval vm_compute in word64ize. +End opt. + +Local Transparent bit_width. +Local Ltac do_change lem := + match lem with + | context L[andb (?x <=? ?y)%Z (?y <=? ?z)] + => let x' := (eval vm_compute in x) in + let z' := (eval vm_compute in z) in + lazymatch y with + | word64ToZ (word64ize ?v) + => let y' := constr:(opt.word64ToZ (opt.word64ize v)) in + let L' := context L[andb (opt.Zleb x' y') (opt.Zleb y' z')] in + do_change L' + end + | _ => lem + end. +Definition fe25519_word64ize (x : fe25519) : fe25519. Proof. - intro H; apply (exist_wire_digitsW (wire_digitsZToW x)). - abstract ( - hnf in x; destruct_head prod; - pose proof H as H'; - unfold_is_bounded_in H; - destruct_head and; - Z.ltb_to_lt; - rewrite ?ZToWord64ToZ by (simpl; omega); - assumption - ). + set (x' := x). + hnf in x; destruct_head' prod. + let lem := constr:(exist_fe25519W (fe25519W_word64ize (proj1_fe25519W x'))) in + let lem := (eval cbv [proj1_fe25519W x' fe25519W_word64ize proj_word exist_fe25519W Build_bounded_word' Build_bounded_word] in lem) in + let lem := do_change lem in + refine (lem _); + change (is_bounded (fe25519WToZ (fe25519W_word64ize (proj1_fe25519W x'))) = true); + abstract (rewrite fe25519W_word64ize_id; apply is_bounded_proj1_fe25519). Defined. - -Definition exist_wire_digits (x : Specific.GF25519.wire_digits) : wire_digits_is_bounded x = true -> wire_digits. +Definition wire_digits_word64ize (x : wire_digits) : wire_digits. Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return wire_digits_is_bounded x = true -> wire_digits in - fun H => _). - let v := constr:(exist_wire_digits' (x0, x1, x2, x3, x4, x5, x6, x7) H) in - let rec do_refine v := - first [ let v' := (eval cbv [exist_wire_digitsW wire_digitsZToW exist_wire_digits' proj_word Build_bounded_word snd fst] in (proj_word v)) in - refine (Build_bounded_word v' _); abstract exact (word_bounded v) - | let v' := (eval cbv [exist_wire_digitsW wire_digitsZToW exist_wire_digits' proj_word Build_bounded_word snd fst] in (proj_word (snd v))) in - refine (_, Build_bounded_word v' _); - [ do_refine (fst v) | abstract exact (word_bounded (snd v)) ] ] in - do_refine v. + set (x' := x). + hnf in x; destruct_head' prod. + let lem := constr:(exist_wire_digitsW (wire_digitsW_word64ize (proj1_wire_digitsW x'))) in + let lem := (eval cbv [proj1_wire_digitsW x' wire_digitsW_word64ize proj_word exist_wire_digitsW Build_bounded_word Build_bounded_word'] in lem) in + let lem := do_change lem in + let lem := (eval cbv [word64ize opt.word64ize andb Z.leb Z.compare CompOpp Pos.compare] in lem) in + refine (lem _); + change (wire_digits_is_bounded (wire_digitsWToZ (wire_digitsW_word64ize (proj1_wire_digitsW x'))) = true); + abstract (rewrite wire_digitsW_word64ize_id; apply is_bounded_proj1_wire_digits). Defined. -Lemma proj1_wire_digits_exist_wire_digitsW x pf : proj1_wire_digits (exist_wire_digitsW x pf) = wire_digitsWToZ x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return forall pf : wire_digits_is_bounded (wire_digitsWToZ x) = true, proj1_wire_digits (exist_wire_digitsW x pf) = wire_digitsWToZ x in - fun pf => _). - reflexivity. -Qed. -Lemma proj1_wire_digitsW_exist_wire_digits x pf : proj1_wire_digitsW (exist_wire_digits x pf) = wire_digitsZToW x. +Lemma is_bounded_to_nth_default x (H : is_bounded x = true) + : forall n : nat, + (n < length limb_widths)%nat + -> (0 <= nth_default 0 (Tuple.to_list length_fe25519 x) n <= + snd (b_of (nth_default (-1) limb_widths n)))%Z. Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return forall pf : wire_digits_is_bounded x = true, proj1_wire_digitsW (exist_wire_digits x pf) = wire_digitsZToW x in - fun pf => _). - reflexivity. -Qed. -Lemma proj1_wire_digits_exist_wire_digits x pf : proj1_wire_digits (exist_wire_digits x pf) = x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return forall pf : wire_digits_is_bounded x = true, proj1_wire_digits (exist_wire_digits x pf) = x in - fun pf => _). - cbv [proj1_wire_digits exist_wire_digits proj1_wire_digitsW wire_digitsWToZ proj_word Build_bounded_word]. - unfold_is_bounded_in pf. - destruct_head and. + hnf in x; destruct_head' prod. + unfold_is_bounded_in H; destruct_head' and. Z.ltb_to_lt. - rewrite ?ZToWord64ToZ by (rewrite unfold_Pow2_64; cbv [Pow2_64]; omega). - reflexivity. + unfold nth_default; simpl. + intros. + repeat match goal with + | [ |- context[nth_error _ ?x] ] + => is_var x; destruct x; simpl + end; + omega. Qed. (* END precomputation *) (* Precompute constants *) -Definition one := Eval vm_compute in exist_fe25519 Specific.GF25519.one_ eq_refl. +Definition one' := Eval vm_compute in exist_fe25519 Specific.GF25519.one_ eq_refl. +Definition one := Eval cbv [one' fe25519_word64ize word64ize andb opt.word64ToZ opt.word64ize opt.Zleb Z.compare CompOpp Pos.compare Pos.compare_cont] in fe25519_word64ize one'. -Definition zero := Eval vm_compute in exist_fe25519 Specific.GF25519.zero_ eq_refl. +Definition zero' := Eval vm_compute in exist_fe25519 Specific.GF25519.zero_ eq_refl. +Definition zero := Eval cbv [zero' fe25519_word64ize word64ize andb opt.word64ToZ opt.word64ize opt.Zleb Z.compare CompOpp Pos.compare Pos.compare_cont] in fe25519_word64ize zero'. Lemma fold_chain_opt_gen {A B} (F : A -> B) is_bounded ls id' op' id op chain (Hid_bounded : is_bounded (F id') = true) @@ -365,17 +608,21 @@ Proof. pose proof (bounded_encode x). generalize dependent (encode x). intro t; compute in t; intros. - destruct_head prod. + destruct_head' prod. unfold Pow2Base.bounded in H. - pose proof (H 0%nat); pose proof (H 1%nat); pose proof (H 2%nat); - pose proof (H 3%nat); pose proof (H 4%nat); pose proof (H 5%nat); - pose proof (H 6%nat); pose proof (H 7%nat); pose proof (H 8%nat); - pose proof (H 9%nat); clear H. + cbv [nth_default Tuple.to_list Tuple.to_list' List.length limb_widths params25519] in H. + repeat match type of H with + | context[nth_error (cons _ _) _] + => let H' := fresh in + pose proof (H O) as H'; specialize (fun i => H (S i)); simpl @nth_error in H, H'; + cbv beta iota in H' + end. + clear H. simpl in *. cbv [Z.pow_pos Z.mul Pos.mul Pos.iter nth_default nth_error value] in *. unfold is_bounded. apply fold_right_andb_true_iff_fold_right_and_True. - cbv [is_bounded proj1_fe25519 to_list length bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word fold_right]. + cbv [is_bounded proj1_fe25519 to_list length bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word fold_right length_fe25519]. repeat split; rewrite !Bool.andb_true_iff, !Z.leb_le; omega. Qed. @@ -387,12 +634,28 @@ Definition decode (x : fe25519) : F modulus Lemma proj1_fe25519_encode x : proj1_fe25519 (encode x) = ModularBaseSystem.encode x. -Proof. reflexivity. Qed. +Proof. + cbv [encode]. + generalize (encode_bounded x); generalize (ModularBaseSystem.encode x). + intros y pf; intros; hnf in y; destruct_head_hnf' prod. + cbv [proj1_fe25519 exist_fe25519 proj1_fe25519W Build_bounded_word Build_bounded_word' fe25519WToZ proj_word]. + unfold_is_bounded_in pf. + destruct_head' and. + Z.ltb_to_lt. + rewrite ?word64ToZ_ZToWord64 by prove_lt_bit_width. + reflexivity. +Qed. Lemma decode_exist_fe25519 x pf : decode (exist_fe25519 x pf) = ModularBaseSystem.decode x. Proof. - hnf in x; destruct_head' prod; reflexivity. + hnf in x; destruct_head' prod. + cbv [decode proj1_fe25519 exist_fe25519 proj1_fe25519W Build_bounded_word Build_bounded_word' fe25519WToZ proj_word]. + unfold_is_bounded_in pf. + destruct_head' and. + Z.ltb_to_lt. + rewrite ?word64ToZ_ZToWord64 by prove_lt_bit_width. + reflexivity. Qed. Definition div (f g : fe25519) : fe25519 diff --git a/src/Specific/GF25519BoundedCommonWord.v b/src/Specific/GF25519BoundedCommonWord.v deleted file mode 100644 index 9328d4527..000000000 --- a/src/Specific/GF25519BoundedCommonWord.v +++ /dev/null @@ -1,414 +0,0 @@ -Require Import Crypto.BaseSystem. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystem. -Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. -Require Import Crypto.Specific.GF25519. -Require Import Bedrock.Word Crypto.Util.WordUtil. -Require Import Coq.Lists.List Crypto.Util.ListUtil. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.Tactics. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Algebra. -Import ListNotations. -Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Local Open Scope Z. - -(* BEGIN aliases for word extraction *) -Definition word64 := Word.word 64. -Coercion word64ToZ (x : word64) : Z - := Z.of_N (wordToN x). -Coercion ZToWord64 (x : Z) : word64 := NToWord _ (Z.to_N x). -Definition w64eqb (x y : word64) := weqb x y. - -Lemma word64eqb_Zeqb x y : (word64ToZ x =? word64ToZ y)%Z = w64eqb x y. -Proof. apply wordeqb_Zeqb. Qed. - -(* END aliases for word extraction *) - -Local Arguments Z.pow_pos !_ !_ / . -Lemma ZToWord64ToZ x : 0 <= x < 2^64 -> word64ToZ (ZToWord64 x) = x. -Proof. - intros; unfold word64ToZ, ZToWord64. - rewrite ?wordToN_NToWord_idempotent, ?N2Z.id, ?Z2N.id - by (omega || apply N2Z.inj_lt; rewrite ?N2Z.id, ?Z2N.id by omega; simpl in *; omega). - reflexivity. -Qed. - -(* BEGIN precomputation. *) -Local Notation b_of exp := (0, 2^exp + 2^(exp-3))%Z (only parsing). (* max is [(0, 2^(exp+2) + 2^exp + 2^(exp-1) + 2^(exp-3) + 2^(exp-4) + 2^(exp-5) + 2^(exp-6) + 2^(exp-10) + 2^(exp-12) + 2^(exp-13) + 2^(exp-14) + 2^(exp-15) + 2^(exp-17) + 2^(exp-23) + 2^(exp-24))%Z] *) -Record bounded_word (lower upper : Z) := - Build_bounded_word' - { proj_word :> word64; - word_bounded : andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z = true }. -Arguments proj_word {_ _} _. -Arguments word_bounded {_ _} _. -Arguments Build_bounded_word' {_ _} _ _. -Definition Build_bounded_word {lower upper} (proj_word : word64) (word_bounded : andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z = true) - : bounded_word lower upper - := Build_bounded_word' - proj_word - (match andb (lower <=? proj_word)%Z (proj_word <=? upper)%Z as b return b = true -> b = true with - | true => fun _ => eq_refl - | false => fun x => x - end word_bounded). -Local Notation word_of exp := (bounded_word (fst (b_of exp)) (snd (b_of exp))). -Local Notation unbounded_word sz := (bounded_word 0 (2^sz-1)%Z). -Lemma word_to_unbounded_helper {x e : nat} : (x < pow2 e)%nat -> (Z.of_nat e <= 64)%Z -> ((0 <=? word64ToZ (ZToWord64 (Z.of_nat x))) && (word64ToZ (ZToWord64 (Z.of_nat x)) <=? 2 ^ (Z.of_nat e) - 1))%bool = true. -Proof. - rewrite pow2_id; intro H; apply Nat2Z.inj_lt in H; revert H. - rewrite Z.pow_Zpow; simpl Z.of_nat. - intros H H'. - assert (2^Z.of_nat e <= 2^64) by auto with zarith. - rewrite !ZToWord64ToZ by omega. - match goal with - | [ |- context[andb ?x ?y] ] - => destruct x eqn:?, y eqn:?; try reflexivity; Z.ltb_to_lt - end; - intros; omega. -Qed. -Definition word_to_unbounded_word {sz} (x : word sz) : (Z.of_nat sz <=? 64)%Z = true -> unbounded_word (Z.of_nat sz). -Proof. - refine (fun pf => Build_bounded_word (Z.of_N (wordToN x)) _). - abstract (rewrite wordToN_nat, nat_N_Z; Z.ltb_to_lt; apply (word_to_unbounded_helper (wordToNat_bound x)); simpl; omega). -Defined. -Definition word32_to_unbounded_word (x : word 32) : unbounded_word 32. -Proof. apply (word_to_unbounded_word x); reflexivity. Defined. -Definition word31_to_unbounded_word (x : word 31) : unbounded_word 31. -Proof. apply (word_to_unbounded_word x); reflexivity. Defined. -Definition bounds : list (Z * Z) - := Eval compute in - [b_of 25; b_of 26; b_of 25; b_of 26; b_of 25; b_of 26; b_of 25; b_of 26; b_of 25; b_of 26]. -Definition wire_digit_bounds : list (Z * Z) - := Eval compute in - List.repeat (0, 2^32-1)%Z 7 ++ ((0,2^31-1)%Z :: nil). - -Definition fe25519W := Eval cbv -[word64] in (tuple word64 (length limb_widths)). -Definition wire_digitsW := Eval cbv -[word64] in (tuple word64 8). -Definition fe25519WToZ (x : fe25519W) : Specific.GF25519.fe25519 - := let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - (x0 : Z, x1 : Z, x2 : Z, x3 : Z, x4 : Z, x5 : Z, x6 : Z, x7 : Z, x8 : Z, x9 : Z). -Definition fe25519ZToW (x : Specific.GF25519.fe25519) : fe25519W - := let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - (x0 : word64, x1 : word64, x2 : word64, x3 : word64, x4 : word64, x5 : word64, x6 : word64, x7 : word64, x8 : word64, x9 : word64). -Definition wire_digitsWToZ (x : wire_digitsW) : Specific.GF25519.wire_digits - := let '(x0, x1, x2, x3, x4, x5, x6, x7) := x in - (x0 : Z, x1 : Z, x2 : Z, x3 : Z, x4 : Z, x5 : Z, x6 : Z, x7 : Z). -Definition wire_digitsZToW (x : Specific.GF25519.wire_digits) : wire_digitsW - := let '(x0, x1, x2, x3, x4, x5, x6, x7) := x in - (x0 : word64, x1 : word64, x2 : word64, x3 : word64, x4 : word64, x5 : word64, x6 : word64, x7 : word64). -Definition fe25519 := - Eval cbv [fst snd] in - let sanity := eq_refl : length bounds = length limb_widths in - (word_of 25 * word_of 26 * word_of 25 * word_of 26 * word_of 25 * word_of 26 * word_of 25 * word_of 26 * word_of 25 * word_of 26)%type. -Definition wire_digits := - Eval cbv [fst snd Tuple.tuple Tuple.tuple'] in - (unbounded_word 32 * unbounded_word 32 * unbounded_word 32 * unbounded_word 32 - * unbounded_word 32 * unbounded_word 32 * unbounded_word 32 * unbounded_word 31)%type. -Definition proj1_fe25519W (x : fe25519) : fe25519W - := let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) := x in - (proj_word x0, proj_word x1, proj_word x2, proj_word x3, proj_word x4, - proj_word x5, proj_word x6, proj_word x7, proj_word x8, proj_word x9). -Coercion proj1_fe25519 (x : fe25519) : Specific.GF25519.fe25519 - := fe25519WToZ (proj1_fe25519W x). -Definition is_bounded (x : Specific.GF25519.fe25519) : bool - := let res := Tuple.map2 - (fun bounds v => - let '(lower, upper) := bounds in - (lower <=? v) && (v <=? upper))%bool%Z - (Tuple.from_list _ (List.rev bounds) eq_refl) x in - List.fold_right andb true (Tuple.to_list _ res). - -Lemma is_bounded_proj1_fe25519 (x : fe25519) : is_bounded (proj1_fe25519 x) = true. -Proof. - refine (let '(Build_bounded_word' x0 p0, Build_bounded_word' x1 p1, Build_bounded_word' x2 p2, Build_bounded_word' x3 p3, Build_bounded_word' x4 p4, - Build_bounded_word' x5 p5, Build_bounded_word' x6 p6, Build_bounded_word' x7 p7, Build_bounded_word' x8 p8, Build_bounded_word' x9 p9) - as x := x return is_bounded (proj1_fe25519 x) = true in - _). - cbv [is_bounded proj1_fe25519 proj1_fe25519W fe25519WToZ to_list length bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word]. - apply fold_right_andb_true_iff_fold_right_and_True. - cbv [fold_right List.map]. - cbv beta in *. - repeat split; assumption. -Qed. - -Definition proj1_wire_digitsW (x : wire_digits) : wire_digitsW - := let '(x0, x1, x2, x3, x4, x5, x6, x7) := x in - (proj_word x0, proj_word x1, proj_word x2, proj_word x3, proj_word x4, - proj_word x5, proj_word x6, proj_word x7). -Coercion proj1_wire_digits (x : wire_digits) : Specific.GF25519.wire_digits - := wire_digitsWToZ (proj1_wire_digitsW x). -Definition wire_digits_is_bounded (x : Specific.GF25519.wire_digits) : bool - := let res := Tuple.map2 - (fun bounds v => - let '(lower, upper) := bounds in - (lower <=? v) && (v <=? upper))%bool%Z - (Tuple.from_list _ (List.rev wire_digit_bounds) eq_refl) x in - List.fold_right andb true (Tuple.to_list _ res). - -Lemma is_bounded_proj1_wire_digits (x : wire_digits) : wire_digits_is_bounded (proj1_wire_digits x) = true. -Proof. - refine (let '(Build_bounded_word' x0 p0, Build_bounded_word' x1 p1, Build_bounded_word' x2 p2, Build_bounded_word' x3 p3, Build_bounded_word' x4 p4, - Build_bounded_word' x5 p5, Build_bounded_word' x6 p6, Build_bounded_word' x7 p7) - as x := x return wire_digits_is_bounded (proj1_wire_digits x) = true in - _). - cbv [wire_digits_is_bounded proj1_wire_digits proj1_wire_digitsW wire_digitsWToZ to_list length wire_digit_bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word]. - apply fold_right_andb_true_iff_fold_right_and_True. - cbv [fold_right List.map]. - cbv beta in *. - repeat split; assumption. -Qed. - -(** TODO: Turn this into a lemma to speed up proofs *) -Ltac unfold_is_bounded_in H := - unfold is_bounded, wire_digits_is_bounded, fe25519WToZ, wire_digitsWToZ in H; - cbv [to_list length bounds wire_digit_bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map fold_right List.rev List.app] in H; - rewrite !Bool.andb_true_iff in H. - -Definition Pow2_64 := Eval compute in 2^64. -Definition unfold_Pow2_64 : 2^64 = Pow2_64 := eq_refl. - -Definition exist_fe25519W (x : fe25519W) : is_bounded (fe25519WToZ x) = true -> fe25519. -Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return is_bounded (fe25519WToZ x) = true -> fe25519 in - fun H => (fun H' => (Build_bounded_word x0 _, Build_bounded_word x1 _, Build_bounded_word x2 _, Build_bounded_word x3 _, Build_bounded_word x4 _, - Build_bounded_word x5 _, Build_bounded_word x6 _, Build_bounded_word x7 _, Build_bounded_word x8 _, Build_bounded_word x9 _)) - (let H' := proj1 (@fold_right_andb_true_iff_fold_right_and_True _) H in - _)); - [ - | | | | | | | | | - | clearbody H'; clear H x; - unfold_is_bounded_in H'; - exact H' ]; - destruct_head and; auto; - rewrite_hyp !*; reflexivity. -Defined. - -Definition exist_fe25519' (x : Specific.GF25519.fe25519) : is_bounded x = true -> fe25519. -Proof. - intro H; apply (exist_fe25519W (fe25519ZToW x)). - abstract ( - hnf in x; destruct_head prod; - pose proof H as H'; - unfold_is_bounded_in H; - destruct_head and; - Z.ltb_to_lt; - rewrite !ZToWord64ToZ by (simpl; omega); - assumption - ). -Defined. - -Definition exist_fe25519 (x : Specific.GF25519.fe25519) : is_bounded x = true -> fe25519. -Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return is_bounded x = true -> fe25519 in - fun H => _). - let v := constr:(exist_fe25519' (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) H) in - let rec do_refine v := - first [ let v' := (eval cbv [exist_fe25519W fe25519ZToW exist_fe25519' proj_word Build_bounded_word snd fst] in (proj_word v)) in - refine (Build_bounded_word v' _); abstract exact (word_bounded v) - | let v' := (eval cbv [exist_fe25519W fe25519ZToW exist_fe25519' proj_word Build_bounded_word snd fst] in (proj_word (snd v))) in - refine (_, Build_bounded_word v' _); - [ do_refine (fst v) | abstract exact (word_bounded (snd v)) ] ] in - do_refine v. -Defined. - -Lemma proj1_fe25519_exist_fe25519W x pf : proj1_fe25519 (exist_fe25519W x pf) = fe25519WToZ x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return forall pf : is_bounded (fe25519WToZ x) = true, proj1_fe25519 (exist_fe25519W x pf) = fe25519WToZ x in - fun pf => _). - reflexivity. -Qed. -Lemma proj1_fe25519W_exist_fe25519 x pf : proj1_fe25519W (exist_fe25519 x pf) = fe25519ZToW x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return forall pf : is_bounded x = true, proj1_fe25519W (exist_fe25519 x pf) = fe25519ZToW x in - fun pf => _). - reflexivity. -Qed. -Lemma proj1_fe25519_exist_fe25519 x pf : proj1_fe25519 (exist_fe25519 x pf) = x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) as x := x return forall pf : is_bounded x = true, proj1_fe25519 (exist_fe25519 x pf) = x in - fun pf => _). - cbv [proj1_fe25519 exist_fe25519 proj1_fe25519W fe25519WToZ proj_word Build_bounded_word]. - unfold_is_bounded_in pf. - destruct_head and. - Z.ltb_to_lt. - rewrite !ZToWord64ToZ by (rewrite unfold_Pow2_64; cbv [Pow2_64]; omega). - reflexivity. -Qed. - -Definition exist_wire_digitsW (x : wire_digitsW) : wire_digits_is_bounded (wire_digitsWToZ x) = true -> wire_digits. -Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return wire_digits_is_bounded (wire_digitsWToZ x) = true -> wire_digits in - fun H => (fun H' => (Build_bounded_word x0 _, Build_bounded_word x1 _, Build_bounded_word x2 _, Build_bounded_word x3 _, Build_bounded_word x4 _, - Build_bounded_word x5 _, Build_bounded_word x6 _, Build_bounded_word x7 _)) - (let H' := proj1 (@fold_right_andb_true_iff_fold_right_and_True _) H in - _)); - [ - | | | | | | | - | clearbody H'; clear H x; - unfold_is_bounded_in H'; - exact H' ]; - destruct_head and; auto; - rewrite_hyp !*; reflexivity. -Defined. - -Definition exist_wire_digits' (x : Specific.GF25519.wire_digits) : wire_digits_is_bounded x = true -> wire_digits. -Proof. - intro H; apply (exist_wire_digitsW (wire_digitsZToW x)). - abstract ( - hnf in x; destruct_head prod; - pose proof H as H'; - unfold_is_bounded_in H; - destruct_head and; - Z.ltb_to_lt; - rewrite !ZToWord64ToZ by (simpl; omega); - assumption - ). -Defined. - -Definition exist_wire_digits (x : Specific.GF25519.wire_digits) : wire_digits_is_bounded x = true -> wire_digits. -Proof. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return wire_digits_is_bounded x = true -> wire_digits in - fun H => _). - let v := constr:(exist_wire_digits' (x0, x1, x2, x3, x4, x5, x6, x7) H) in - let rec do_refine v := - first [ let v' := (eval cbv [exist_wire_digitsW wire_digitsZToW exist_wire_digits' proj_word Build_bounded_word snd fst] in (proj_word v)) in - refine (Build_bounded_word v' _); abstract exact (word_bounded v) - | let v' := (eval cbv [exist_wire_digitsW wire_digitsZToW exist_wire_digits' proj_word Build_bounded_word snd fst] in (proj_word (snd v))) in - refine (_, Build_bounded_word v' _); - [ do_refine (fst v) | abstract exact (word_bounded (snd v)) ] ] in - do_refine v. -Defined. - -Lemma proj1_wire_digits_exist_wire_digitsW x pf : proj1_wire_digits (exist_wire_digitsW x pf) = wire_digitsWToZ x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return forall pf : wire_digits_is_bounded (wire_digitsWToZ x) = true, proj1_wire_digits (exist_wire_digitsW x pf) = wire_digitsWToZ x in - fun pf => _). - reflexivity. -Qed. -Lemma proj1_wire_digitsW_exist_wire_digits x pf : proj1_wire_digitsW (exist_wire_digits x pf) = wire_digitsZToW x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return forall pf : wire_digits_is_bounded x = true, proj1_wire_digitsW (exist_wire_digits x pf) = wire_digitsZToW x in - fun pf => _). - reflexivity. -Qed. -Lemma proj1_wire_digits_exist_wire_digits x pf : proj1_wire_digits (exist_wire_digits x pf) = x. -Proof. - revert pf. - refine (let '(x0, x1, x2, x3, x4, x5, x6, x7) as x := x return forall pf : wire_digits_is_bounded x = true, proj1_wire_digits (exist_wire_digits x pf) = x in - fun pf => _). - cbv [proj1_wire_digits exist_wire_digits proj1_wire_digitsW wire_digitsWToZ proj_word Build_bounded_word]. - unfold_is_bounded_in pf. - destruct_head and. - Z.ltb_to_lt. - rewrite !ZToWord64ToZ by (rewrite unfold_Pow2_64; cbv [Pow2_64]; omega). - reflexivity. -Qed. - -(* END precomputation *) - -(* Precompute constants *) - -Definition one := Eval vm_compute in exist_fe25519 Specific.GF25519.one_ eq_refl. - -Definition zero := Eval vm_compute in exist_fe25519 Specific.GF25519.zero_ eq_refl. - -Lemma fold_chain_opt_gen {A B} (F : A -> B) is_bounded ls id' op' id op chain - (Hid_bounded : is_bounded (F id') = true) - (Hid : id = F id') - (Hop_bounded : forall x y, is_bounded (F x) = true - -> is_bounded (F y) = true - -> is_bounded (op (F x) (F y)) = true) - (Hop : forall x y, is_bounded (F x) = true - -> is_bounded (F y) = true - -> op (F x) (F y) = F (op' x y)) - (Hls_bounded : forall n, is_bounded (F (nth_default id' ls n)) = true) - : F (fold_chain_opt id' op' chain ls) - = fold_chain_opt id op chain (List.map F ls) - /\ is_bounded (F (fold_chain_opt id' op' chain ls)) = true. -Proof. - rewrite !fold_chain_opt_correct. - revert dependent ls; induction chain as [|x xs IHxs]; intros. - { pose proof (Hls_bounded 0%nat). - destruct ls; simpl; split; trivial; congruence. } - { destruct x; simpl; unfold Let_In; simpl. - rewrite (fun ls pf => proj1 (IHxs ls pf)) at 1; simpl. - { do 2 f_equal. - rewrite <- Hop, Hid by auto. - rewrite !map_nth_default_always. - split; try reflexivity. - apply (IHxs (_::_)). - intros [|?]; autorewrite with simpl_nth_default; auto. - rewrite <- Hop; auto. } - { intros [|?]; simpl; - autorewrite with simpl_nth_default; auto. - rewrite <- Hop; auto. } } -Qed. - -Lemma encode_bounded x : is_bounded (encode x) = true. -Proof. - pose proof (bounded_encode x). - generalize dependent (encode x). - intro t; compute in t; intros. - destruct_head prod. - unfold Pow2Base.bounded in H. - pose proof (H 0%nat); pose proof (H 1%nat); pose proof (H 2%nat); - pose proof (H 3%nat); pose proof (H 4%nat); pose proof (H 5%nat); - pose proof (H 6%nat); pose proof (H 7%nat); pose proof (H 8%nat); - pose proof (H 9%nat); clear H. - simpl in *. - cbv [Z.pow_pos Z.mul Pos.mul Pos.iter nth_default nth_error value] in *. - unfold is_bounded. - apply fold_right_andb_true_iff_fold_right_and_True. - cbv [is_bounded proj1_fe25519 to_list length bounds from_list from_list' map2 on_tuple2 to_list' ListUtil.map2 List.map List.rev List.app proj_word fold_right]. - repeat split; rewrite !Bool.andb_true_iff, !Z.leb_le; omega. -Qed. - -Definition encode (x : F modulus) : fe25519 - := exist_fe25519 (encode x) (encode_bounded x). - -Definition decode (x : fe25519) : F modulus - := ModularBaseSystem.decode (proj1_fe25519 x). - -Definition div (f g : fe25519) : fe25519 - := exist_fe25519 (div (proj1_fe25519 f) (proj1_fe25519 g)) (encode_bounded _). - -Definition eq (f g : fe25519) : Prop := eq (proj1_fe25519 f) (proj1_fe25519 g). - - -Notation ibinop_correct_and_bounded irop op - := (forall x y, - is_bounded (fe25519WToZ x) = true - -> is_bounded (fe25519WToZ y) = true - -> fe25519WToZ (irop x y) = op (fe25519WToZ x) (fe25519WToZ y) - /\ is_bounded (fe25519WToZ (irop x y)) = true) (only parsing). -Notation iunop_correct_and_bounded irop op - := (forall x, - is_bounded (fe25519WToZ x) = true - -> fe25519WToZ (irop x) = op (fe25519WToZ x) - /\ is_bounded (fe25519WToZ (irop x)) = true) (only parsing). -Notation iunop_FEToZ_correct irop op - := (forall x, - is_bounded (fe25519WToZ x) = true - -> word64ToZ (irop x) = op (fe25519WToZ x)) (only parsing). -Notation iunop_FEToWire_correct_and_bounded irop op - := (forall x, - is_bounded (fe25519WToZ x) = true - -> wire_digitsWToZ (irop x) = op (fe25519WToZ x) - /\ wire_digits_is_bounded (wire_digitsWToZ (irop x)) = true) (only parsing). -Notation iunop_WireToFE_correct_and_bounded irop op - := (forall x, - wire_digits_is_bounded (wire_digitsWToZ x) = true - -> fe25519WToZ (irop x) = op (wire_digitsWToZ x) - /\ is_bounded (fe25519WToZ (irop x)) = true) (only parsing). diff --git a/src/Specific/GF25519Reflective.v b/src/Specific/GF25519Reflective.v new file mode 100644 index 000000000..4405eefa9 --- /dev/null +++ b/src/Specific/GF25519Reflective.v @@ -0,0 +1,119 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Export Crypto.Specific.GF25519. +Require Import Crypto.Specific.GF25519BoundedCommon. +Require Import Crypto.Reflection.Reify. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.MapInterp. +Require Import Crypto.Reflection.Z.Interpretations. +Require Crypto.Reflection.Z.Interpretations.Relations. +Require Import Crypto.Reflection.Z.Interpretations.RelationsCombinations. +Require Import Crypto.Reflection.Z.Reify. +Require Import Crypto.Reflection.Z.Syntax. +Require Import Crypto.Specific.GF25519Reflective.Common. +Require Import Crypto.Specific.GF25519Reflective.Reified. +Require Import Bedrock.Word Crypto.Util.WordUtil. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +Definition radd : ExprBinOp := Eval vm_compute in rcarry_addW. +Definition rsub : ExprBinOp := Eval vm_compute in rcarry_subW. +Definition rmul : ExprBinOp := Eval vm_compute in rmulW. +Definition ropp : ExprUnOp := Eval vm_compute in rcarry_oppW. +Definition rfreeze : ExprUnOp := Eval vm_compute in rfreezeW. +Definition rge_modulus : ExprUnOpFEToZ := Eval vm_compute in rge_modulusW. +Definition rpack : ExprUnOpFEToWire := Eval vm_compute in rpackW. +Definition runpack : ExprUnOpWireToFE := Eval vm_compute in runpackW. + +Definition rword64ize {t} (x : Expr t) : Expr t + := MapInterp (fun t => match t with TZ => word64ize end) x. + +Declare Reduction asm_interp + := cbv beta iota delta + [id + interp_bexpr interp_uexpr interp_uexpr_FEToWire interp_uexpr_FEToZ interp_uexpr_WireToFE + radd rsub rmul ropp rfreeze rge_modulus rpack runpack + curry_binop_fe25519W curry_unop_fe25519W curry_unop_wire_digitsW + Word64.interp_op Word64.interp_base_type + Z.interp_op Z.interp_base_type + Z.Syntax.interp_op Z.Syntax.interp_base_type + mapf_interp_flat_type map_interp Word64.interp_base_type MapInterp mapf_interp word64ize rword64ize + Interp interp interp_flat_type interpf interp_flat_type fst snd]. +Ltac asm_interp + := cbv beta iota delta + [id + interp_bexpr interp_uexpr interp_uexpr_FEToWire interp_uexpr_FEToZ interp_uexpr_WireToFE + radd rsub rmul ropp rfreeze rge_modulus rpack runpack + curry_binop_fe25519W curry_unop_fe25519W curry_unop_wire_digitsW + Word64.interp_op Word64.interp_base_type + Z.interp_op Z.interp_base_type + Z.Syntax.interp_op Z.Syntax.interp_base_type + mapf_interp_flat_type map_interp Word64.interp_base_type MapInterp mapf_interp word64ize rword64ize + Interp interp interp_flat_type interpf interp_flat_type fst snd]. + + +Definition interp_radd : Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W + := Eval asm_interp in interp_bexpr (rword64ize radd). +(*Print interp_radd.*) +Definition interp_radd_correct : interp_radd = interp_bexpr radd := eq_refl. +Definition interp_rsub : Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W + := Eval asm_interp in interp_bexpr (rword64ize rsub). +(*Print interp_rsub.*) +Definition interp_rsub_correct : interp_rsub = interp_bexpr rsub := eq_refl. +Definition interp_rmul : Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W + := Eval asm_interp in interp_bexpr (rword64ize rmul). +(*Print interp_rmul.*) +Definition interp_rmul_correct : interp_rmul = interp_bexpr rmul := eq_refl. +Definition interp_ropp : Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W + := Eval asm_interp in interp_uexpr (rword64ize ropp). +(*Print interp_ropp.*) +Definition interp_ropp_correct : interp_ropp = interp_uexpr ropp := eq_refl. +Definition interp_rfreeze : Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W + := Eval asm_interp in interp_uexpr (rword64ize rfreeze). +(*Print interp_rfreeze.*) +Definition interp_rfreeze_correct : interp_rfreeze = interp_uexpr rfreeze := eq_refl. + +Definition interp_rge_modulus : Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.word64 + := Eval asm_interp in interp_uexpr_FEToZ (rword64ize rge_modulus). +Definition interp_rge_modulus_correct : interp_rge_modulus = interp_uexpr_FEToZ rge_modulus := eq_refl. + +Definition interp_rpack : Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.wire_digitsW + := Eval asm_interp in interp_uexpr_FEToWire (rword64ize rpack). +Definition interp_rpack_correct : interp_rpack = interp_uexpr_FEToWire rpack := eq_refl. + +Definition interp_runpack : Specific.GF25519BoundedCommon.wire_digitsW -> Specific.GF25519BoundedCommon.fe25519W + := Eval asm_interp in interp_uexpr_WireToFE (rword64ize runpack). +Definition interp_runpack_correct : interp_runpack = interp_uexpr_WireToFE runpack := eq_refl. + +Lemma radd_correct_and_bounded : binop_correct_and_bounded radd carry_add. +Proof. exact rcarry_addW_correct_and_bounded. Qed. +Lemma rsub_correct_and_bounded : binop_correct_and_bounded rsub carry_sub. +Proof. exact rcarry_subW_correct_and_bounded. Qed. +Lemma rmul_correct_and_bounded : binop_correct_and_bounded rmul mul. +Proof. exact rmulW_correct_and_bounded. Qed. +Lemma ropp_correct_and_bounded : unop_correct_and_bounded ropp carry_opp. +Proof. exact rcarry_oppW_correct_and_bounded. Qed. +Lemma rfreeze_correct_and_bounded : unop_correct_and_bounded rfreeze freeze. +Proof. exact rfreezeW_correct_and_bounded. Qed. +Lemma rge_modulus_correct_and_bounded : unop_FEToZ_correct rge_modulus ge_modulus. +Proof. exact rge_modulusW_correct_and_bounded. Qed. +Lemma rpack_correct_and_bounded : unop_FEToWire_correct_and_bounded rpack pack. +Proof. exact rpackW_correct_and_bounded. Qed. +Lemma runpack_correct_and_bounded : unop_WireToFE_correct_and_bounded runpack unpack. +Proof. exact runpackW_correct_and_bounded. Qed. diff --git a/src/Specific/GF25519Reflective/Common.v b/src/Specific/GF25519Reflective/Common.v new file mode 100644 index 000000000..80932d4df --- /dev/null +++ b/src/Specific/GF25519Reflective/Common.v @@ -0,0 +1,548 @@ +Require Export Coq.ZArith.ZArith. +Require Export Coq.Strings.String. +Require Export Crypto.Specific.GF25519. +Require Import Crypto.Specific.GF25519BoundedCommon. +Require Import Crypto.Reflection.Reify. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.Z.Interpretations. +Require Crypto.Reflection.Z.Interpretations.Relations. +Require Import Crypto.Reflection.Z.Interpretations.RelationsCombinations. +Require Import Crypto.Reflection.Z.Reify. +Require Export Crypto.Reflection.Z.Syntax. +Require Import Crypto.Reflection.InterpWfRel. +Require Import Crypto.Reflection.Application. +Require Import Crypto.Reflection.MapInterp. +Require Import Crypto.Reflection.MapInterpWf. +Require Import Crypto.Reflection.WfReflective. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.Notations. + +Notation Expr := (Expr base_type Word64.interp_base_type op). + +Local Ltac make_type_from uncurried_op := + let T := (type of uncurried_op) in + let T := (eval compute in T) in + let rT := reify_type T in + exact rT. + +Definition ExprBinOpT : type base_type. +Proof. make_type_from (uncurry_binop_fe25519 carry_add). Defined. +Definition ExprUnOpT : type base_type. +Proof. make_type_from (uncurry_unop_fe25519 carry_opp). Defined. +Definition ExprUnOpFEToZT : type base_type. +Proof. make_type_from (uncurry_unop_fe25519 ge_modulus). Defined. +Definition ExprUnOpWireToFET : type base_type. +Proof. make_type_from (uncurry_unop_wire_digits unpack). Defined. +Definition ExprUnOpFEToWireT : type base_type. +Proof. make_type_from (uncurry_unop_fe25519 pack). Defined. +Definition ExprBinOp : Type := Expr ExprBinOpT. +Definition ExprUnOp : Type := Expr ExprUnOpT. +Definition ExprUnOpFEToZ : Type := Expr ExprUnOpFEToZT. +Definition ExprUnOpWireToFE : Type := Expr ExprUnOpWireToFET. +Definition ExprUnOpFEToWire : Type := Expr ExprUnOpFEToWireT. + +Local Ltac bounds_from_list ls := + lazymatch (eval hnf in ls) with + | (?x :: nil)%list => constr:(Some {| ZBounds.lower := fst x ; ZBounds.upper := snd x |}) + | (?x :: ?xs)%list => let bs := bounds_from_list xs in + constr:((Some {| ZBounds.lower := fst x ; ZBounds.upper := snd x |}, bs)) + end. + +Local Ltac make_bounds ls := + compute; + let v := bounds_from_list (List.rev ls) in + let v := (eval compute in v) in + exact v. + +Definition ExprBinOp_bounds : interp_all_binders_for ExprBinOpT ZBounds.interp_base_type. +Proof. make_bounds (Tuple.to_list _ bounds ++ Tuple.to_list _ bounds)%list. Defined. +Definition ExprUnOp_bounds : interp_all_binders_for ExprUnOpT ZBounds.interp_base_type. +Proof. make_bounds (Tuple.to_list _ bounds). Defined. +Definition ExprUnOpFEToZ_bounds : interp_all_binders_for ExprUnOpFEToZT ZBounds.interp_base_type. +Proof. make_bounds (Tuple.to_list _ bounds). Defined. +Definition ExprUnOpFEToWire_bounds : interp_all_binders_for ExprUnOpFEToWireT ZBounds.interp_base_type. +Proof. make_bounds (Tuple.to_list _ bounds). Defined. +Definition ExprUnOpWireToFE_bounds : interp_all_binders_for ExprUnOpWireToFET ZBounds.interp_base_type. +Proof. make_bounds (Tuple.to_list _ wire_digit_bounds). Defined. + +Definition interp_bexpr : ExprBinOp -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W + := fun e => curry_binop_fe25519W (Interp (@Word64.interp_op) e). +Definition interp_uexpr : ExprUnOp -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.fe25519W + := fun e => curry_unop_fe25519W (Interp (@Word64.interp_op) e). +Definition interp_uexpr_FEToZ : ExprUnOpFEToZ -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.word64 + := fun e => curry_unop_fe25519W (Interp (@Word64.interp_op) e). +Definition interp_uexpr_FEToWire : ExprUnOpFEToWire -> Specific.GF25519BoundedCommon.fe25519W -> Specific.GF25519BoundedCommon.wire_digitsW + := fun e => curry_unop_fe25519W (Interp (@Word64.interp_op) e). +Definition interp_uexpr_WireToFE : ExprUnOpWireToFE -> Specific.GF25519BoundedCommon.wire_digitsW -> Specific.GF25519BoundedCommon.fe25519W + := fun e => curry_unop_wire_digitsW (Interp (@Word64.interp_op) e). + +Notation binop_correct_and_bounded rop op + := (ibinop_correct_and_bounded (interp_bexpr rop) op) (only parsing). +Notation unop_correct_and_bounded rop op + := (iunop_correct_and_bounded (interp_uexpr rop) op) (only parsing). +Notation unop_FEToZ_correct rop op + := (iunop_FEToZ_correct (interp_uexpr_FEToZ rop) op) (only parsing). +Notation unop_FEToWire_correct_and_bounded rop op + := (iunop_FEToWire_correct_and_bounded (interp_uexpr_FEToWire rop) op) (only parsing). +Notation unop_WireToFE_correct_and_bounded rop op + := (iunop_WireToFE_correct_and_bounded (interp_uexpr_WireToFE rop) op) (only parsing). + +Ltac rexpr_cbv := + lazymatch goal with + | [ |- { rexpr | interp_type_gen_rel_pointwise _ (Interp _ (t:=?T) rexpr) (?uncurry ?oper) } ] + => let operf := head oper in + let uncurryf := head uncurry in + try cbv delta [T]; try cbv delta [oper]; + try cbv beta iota delta [uncurryf] + end; + cbv beta iota delta [interp_flat_type Z.interp_base_type interp_base_type zero_]. + +Ltac reify_sig := + rexpr_cbv; eexists; Reify_rhs; reflexivity. + +Local Notation rexpr_sig T uncurried_op := + { rexprZ + | interp_type_gen_rel_pointwise (fun _ => Logic.eq) (Interp interp_op (t:=T) rexprZ) uncurried_op } + (only parsing). + +Notation rexpr_binop_sig op := (rexpr_sig ExprBinOpT (uncurry_binop_fe25519 op)) (only parsing). +Notation rexpr_unop_sig op := (rexpr_sig ExprUnOpT (uncurry_unop_fe25519 op)) (only parsing). +Notation rexpr_unop_FEToZ_sig op := (rexpr_sig ExprUnOpFEToZT (uncurry_unop_fe25519 op)) (only parsing). +Notation rexpr_unop_FEToWire_sig op := (rexpr_sig ExprUnOpFEToWireT (uncurry_unop_fe25519 op)) (only parsing). +Notation rexpr_unop_WireToFE_sig op := (rexpr_sig ExprUnOpWireToFET (uncurry_unop_wire_digits op)) (only parsing). + +Notation correct_and_bounded_genT ropW'v ropZ_sigv + := (let ropW' := ropW'v in + let ropZ_sig := ropZ_sigv in + let ropW := MapInterp (fun _ x => x) ropW' in + let ropZ := MapInterp Word64.to_Z ropW' in + let ropBounds := MapInterp ZBounds.of_word64 ropW' in + let ropBoundedWord64 := MapInterp BoundedWord64.of_word64 ropW' in + ropZ = proj1_sig ropZ_sig + /\ interp_type_rel_pointwise2 Relations.related_Z (Interp (@BoundedWord64.interp_op) ropBoundedWord64) (Interp (@Z.interp_op) ropZ) + /\ interp_type_rel_pointwise2 Relations.related_bounds (Interp (@BoundedWord64.interp_op) ropBoundedWord64) (Interp (@ZBounds.interp_op) ropBounds) + /\ interp_type_rel_pointwise2 Relations.related_word64 (Interp (@BoundedWord64.interp_op) ropBoundedWord64) (Interp (@Word64.interp_op) ropW)) + (only parsing). + +Local Ltac args_to_bounded_helper v := + lazymatch v with + | (?x, ?xs) + => args_to_bounded_helper x; [ .. | args_to_bounded_helper xs ] + | ?w + => try refine (_, _); [ refine {| BoundedWord64.value := w |} | .. ] + end. + +Local Ltac make_args x := + let x' := fresh "x'" in + pose (x : id _) as x'; + cbv [fe25519W wire_digitsW] in x; destruct_head' prod; + cbv [fst snd] in *; + simpl @fe25519WToZ in *; + simpl @wire_digitsWToZ in *; + let T := fresh in + evar (T : Type); + cut T; subst T; + [ let H := fresh in + intro H; + let xv := (eval hnf in x') in + args_to_bounded_helper xv; + [ instantiate; + destruct_head' and; + match goal with + | [ H : ?T |- _ ] + => is_evar T; + refine (let c := proj1 H in _); (* work around broken evars in Coq 8.4 *) + lazymatch goal with H := proj1 _ |- _ => refine H end + end.. ] + | instantiate; + repeat match goal with H : is_bounded _ = true |- _ => unfold_is_bounded_in H end; + repeat match goal with H : wire_digits_is_bounded _ = true |- _ => unfold_is_bounded_in H end; + destruct_head' and; + Z.ltb_to_lt; + repeat first [ eexact I + | apply conj; + [ repeat apply conj; [ | eassumption | eassumption | ]; + instantiate; vm_compute; [ refine (fun x => match x with eq_refl => I end) | reflexivity ] + | ] ] ]. + +Local Ltac app_tuples x y := + let tx := type of x in + lazymatch (eval hnf in tx) with + | prod _ _ => let xs := app_tuples (snd x) y in + constr:((fst x, xs)) + | _ => constr:((x, y)) + end. + +Class is_evar {T} (x : T) := make_is_evar : True. +Hint Extern 0 (is_evar ?e) => is_evar e; exact I : typeclass_instances. + +Definition unop_args_to_bounded (x : fe25519W) (H : is_bounded (fe25519WToZ x) = true) + : interp_flat_type (fun _ => BoundedWord64.BoundedWord) (all_binders_for ExprUnOpT). +Proof. make_args x. Defined. +Definition unopWireToFE_args_to_bounded (x : wire_digitsW) (H : wire_digits_is_bounded (wire_digitsWToZ x) = true) + : interp_flat_type (fun _ => BoundedWord64.BoundedWord) (all_binders_for ExprUnOpWireToFET). +Proof. make_args x. Defined. +Definition binop_args_to_bounded (x : fe25519W * fe25519W) + (H : is_bounded (fe25519WToZ (fst x)) = true) + (H' : is_bounded (fe25519WToZ (snd x)) = true) + : interp_flat_type (fun _ => BoundedWord64.BoundedWord) (all_binders_for ExprBinOpT). +Proof. + let v := app_tuples (unop_args_to_bounded (fst x) H) (unop_args_to_bounded (snd x) H') in + exact v. +Defined. + +Ltac assoc_right_tuple x so_far := + let t := type of x in + lazymatch (eval hnf in t) with + | prod _ _ => let so_far := assoc_right_tuple (snd x) so_far in + assoc_right_tuple (fst x) so_far + | _ => lazymatch so_far with + | @None => x + | _ => constr:((x, so_far)) + end + end. + +Local Ltac make_bounds_prop bounds orig_bounds := + let bounds' := fresh "bounds'" in + let bounds_bad := fresh "bounds_bad" in + rename bounds into bounds_bad; + let boundsv := assoc_right_tuple bounds_bad (@None) in + pose boundsv as bounds; + pose orig_bounds as bounds'; + repeat (refine (match fst bounds' with + | Some bounds' => let (l, u) := fst bounds in + let (l', u') := bounds' in + ((l' <=? l) && (u <=? u'))%Z%bool + | None => false + end && _)%bool; + destruct bounds' as [_ bounds'], bounds as [_ bounds]); + try exact (match bounds' with + | Some bounds' => let (l, u) := bounds in + let (l', u') := bounds' in + ((l' <=? l) && (u <=? u'))%Z%bool + | None => false + end). + + +Definition unop_bounds_good (bounds : interp_flat_type (fun _ => ZBounds.bounds) (remove_all_binders ExprUnOpT)) : bool. +Proof. make_bounds_prop bounds ExprUnOp_bounds. Defined. +Definition binop_bounds_good (bounds : interp_flat_type (fun _ => ZBounds.bounds) (remove_all_binders ExprBinOpT)) : bool. +Proof. make_bounds_prop bounds ExprUnOp_bounds. Defined. +Definition unopFEToWire_bounds_good (bounds : interp_flat_type (fun _ => ZBounds.bounds) (remove_all_binders ExprUnOpFEToWireT)) : bool. +Proof. make_bounds_prop bounds ExprUnOpWireToFE_bounds. Defined. +Definition unopWireToFE_bounds_good (bounds : interp_flat_type (fun _ => ZBounds.bounds) (remove_all_binders ExprUnOpWireToFET)) : bool. +Proof. make_bounds_prop bounds ExprUnOp_bounds. Defined. +(* TODO FIXME(jgross?, andreser?): Is every function returning a single Z a boolean function? *) +Definition unopFEToZ_bounds_good (bounds : interp_flat_type (fun _ => ZBounds.bounds) (remove_all_binders ExprUnOpFEToZT)) : bool. +Proof. + refine (let (l, u) := bounds in ((0 <=? l) && (u <=? 1))%Z%bool). +Defined. + +(* FIXME TODO(jgross): This is a horrible tactic. We should unify the + various kinds of correct and boundedness, and abstract in Gallina + rather than Ltac *) + +Local Ltac t_correct_and_bounded ropZ_sig Hbounds H0 H1 args := + let Heq := fresh "Heq" in + let Hbounds0 := fresh "Hbounds0" in + let Hbounds1 := fresh "Hbounds1" in + let Hbounds2 := fresh "Hbounds2" in + pose proof (proj2_sig ropZ_sig) as Heq; + cbv [interp_bexpr interp_uexpr interp_uexpr_FEToWire interp_uexpr_FEToZ interp_uexpr_WireToFE + curry_binop_fe25519W curry_unop_fe25519W curry_unop_wire_digitsW + curry_binop_fe25519 curry_unop_fe25519 curry_unop_wire_digits + uncurry_binop_fe25519W uncurry_unop_fe25519W uncurry_unop_wire_digitsW + uncurry_binop_fe25519 uncurry_unop_fe25519 uncurry_unop_wire_digits + ExprBinOpT ExprUnOpFEToWireT ExprUnOpT ExprUnOpFEToZT ExprUnOpWireToFET + interp_type_gen_rel_pointwise interp_type_gen_rel_pointwise] in *; + cbv zeta in *; + simpl @fe25519WToZ; simpl @wire_digitsWToZ; + rewrite <- Heq; clear Heq; + destruct Hbounds as [Heq Hbounds]; + change interp_op with (@Z.interp_op) in *; + change interp_base_type with (@Z.interp_base_type) in *; + rewrite <- Heq; clear Heq; + destruct Hbounds as [ Hbounds0 [Hbounds1 Hbounds2] ]; + pose proof (fun pf => Relations.uncurry_interp_type_rel_pointwise2_proj_from_option2 Word64.to_Z pf Hbounds2 Hbounds0) as Hbounds_left; + pose proof (fun pf => Relations.uncurry_interp_type_rel_pointwise2_proj1_from_option2 Relations.related_word64_boundsi' pf Hbounds1 Hbounds2) as Hbounds_right; + specialize_by repeat first [ progress intros + | reflexivity + | assumption + | progress destruct_head' base_type + | progress destruct_head' BoundedWord64.BoundedWord + | progress destruct_head' and + | progress repeat apply conj ]; + specialize (Hbounds_left args H0); + specialize (Hbounds_right args H0); + cbv beta in *; + lazymatch type of Hbounds_right with + | match ?e with _ => _ end + => lazymatch type of H1 with + | match ?e' with _ => _ end + => change e' with e in H1; destruct e eqn:?; [ | exfalso; assumption ] + end + end; + repeat match goal with x := _ |- _ => subst x end; + cbv [id + binop_args_to_bounded unop_args_to_bounded unopWireToFE_args_to_bounded + Relations.proj_eq_rel interp_flat_type_rel_pointwise2 SmartVarfMap interp_flat_type smart_interp_flat_map Application.all_binders_for fst snd BoundedWord64.to_word64' BoundedWord64.boundedWordToWord64 BoundedWord64.value Application.ApplyInterpedAll Application.fst_binder Application.snd_binder interp_flat_type_rel_pointwise2_gen_Prop Relations.related_word64_boundsi' Relations.related'_word64_bounds ZBounds.upper ZBounds.lower Application.remove_all_binders Word64.to_Z] in Hbounds_left, Hbounds_right; + match goal with + | [ |- fe25519WToZ ?x = _ /\ _ ] + => destruct x; destruct_head_hnf' prod + | [ |- wire_digitsWToZ ?x = _ /\ _ ] + => destruct x; destruct_head_hnf' prod + | [ |- _ = _ ] + => exact Hbounds_left + end; + change word64ToZ with Word64.word64ToZ in *; + (split; [ exact Hbounds_left | ]); + cbv [interp_flat_type] in *; + cbv [fst snd + binop_bounds_good unop_bounds_good unopFEToWire_bounds_good unopWireToFE_bounds_good unopFEToZ_bounds_good + ExprUnOp_bounds ExprBinOp_bounds ExprUnOpFEToWire_bounds ExprUnOpFEToZ_bounds ExprUnOpWireToFE_bounds] in H1; + destruct_head' ZBounds.bounds; + unfold_is_bounded_in H1; + simpl @fe25519WToZ; simpl @wire_digitsWToZ; + unfold_is_bounded; + destruct_head' and; + Z.ltb_to_lt; + change Word64.word64ToZ with word64ToZ in *; + repeat apply conj; Z.ltb_to_lt; try omega; try reflexivity. + +Local Opaque Interp. +Lemma ExprBinOp_correct_and_bounded + ropW op (ropZ_sig : rexpr_binop_sig op) + (Hbounds : correct_and_bounded_genT ropW ropZ_sig) + (H0 : forall xy + (xy := (eta_fe25519W (fst xy), eta_fe25519W (snd xy))) + (Hxy : is_bounded (fe25519WToZ (fst xy)) = true + /\ is_bounded (fe25519WToZ (snd xy)) = true), + let Hx := let (Hx, Hy) := Hxy in Hx in + let Hy := let (Hx, Hy) := Hxy in Hy in + let args := binop_args_to_bounded xy Hx Hy in + match LiftOption.of' + (ApplyInterpedAll (Interp (@BoundedWord64.interp_op) (MapInterp BoundedWord64.of_word64 ropW)) + (LiftOption.to' (Some args))) + with + | Some _ => True + | None => False + end) + (H1 : forall xy + (xy := (eta_fe25519W (fst xy), eta_fe25519W (snd xy))) + (Hxy : is_bounded (fe25519WToZ (fst xy)) = true + /\ is_bounded (fe25519WToZ (snd xy)) = true), + let Hx := let (Hx, Hy) := Hxy in Hx in + let Hy := let (Hx, Hy) := Hxy in Hy in + let args := binop_args_to_bounded (fst xy, snd xy) Hx Hy in + let x' := SmartVarfMap (fun _ : base_type => BoundedWord64.BoundedWordToBounds) args in + match LiftOption.of' + (ApplyInterpedAll (Interp (@ZBounds.interp_op) (MapInterp ZBounds.of_word64 ropW)) (LiftOption.to' (Some x'))) + with + | Some bounds => binop_bounds_good bounds = true + | None => False + end) + : binop_correct_and_bounded (MapInterp (fun _ x => x) ropW) op. +Proof. + intros x y Hx Hy. + pose x as x'; pose y as y'. + hnf in x, y; destruct_head' prod. + specialize (H0 (x', y') (conj Hx Hy)). + specialize (H1 (x', y') (conj Hx Hy)). + let args := constr:(binop_args_to_bounded (x', y') Hx Hy) in + t_correct_and_bounded ropZ_sig Hbounds H0 H1 args. +Qed. + +Lemma ExprUnOp_correct_and_bounded + ropW op (ropZ_sig : rexpr_unop_sig op) + (Hbounds : correct_and_bounded_genT ropW ropZ_sig) + (H0 : forall x + (x := eta_fe25519W x) + (Hx : is_bounded (fe25519WToZ x) = true), + let args := unop_args_to_bounded x Hx in + match LiftOption.of' + (ApplyInterpedAll (Interp (@BoundedWord64.interp_op) (MapInterp BoundedWord64.of_word64 ropW)) + (LiftOption.to' (Some args))) + with + | Some _ => True + | None => False + end) + (H1 : forall x + (x := eta_fe25519W x) + (Hx : is_bounded (fe25519WToZ x) = true), + let args := unop_args_to_bounded x Hx in + let x' := SmartVarfMap (fun _ : base_type => BoundedWord64.BoundedWordToBounds) args in + match LiftOption.of' + (ApplyInterpedAll (Interp (@ZBounds.interp_op) (MapInterp ZBounds.of_word64 ropW)) (LiftOption.to' (Some x'))) + with + | Some bounds => unop_bounds_good bounds = true + | None => False + end) + : unop_correct_and_bounded (MapInterp (fun _ x => x) ropW) op. +Proof. + intros x Hx. + pose x as x'. + hnf in x; destruct_head' prod. + specialize (H0 x' Hx). + specialize (H1 x' Hx). + let args := constr:(unop_args_to_bounded x' Hx) in + t_correct_and_bounded ropZ_sig Hbounds H0 H1 args. +Qed. + +Lemma ExprUnOpFEToWire_correct_and_bounded + ropW op (ropZ_sig : rexpr_unop_FEToWire_sig op) + (Hbounds : correct_and_bounded_genT ropW ropZ_sig) + (H0 : forall x + (x := eta_fe25519W x) + (Hx : is_bounded (fe25519WToZ x) = true), + let args := unop_args_to_bounded x Hx in + match LiftOption.of' + (ApplyInterpedAll (Interp (@BoundedWord64.interp_op) (MapInterp BoundedWord64.of_word64 ropW)) + (LiftOption.to' (Some args))) + with + | Some _ => True + | None => False + end) + (H1 : forall x + (x := eta_fe25519W x) + (Hx : is_bounded (fe25519WToZ x) = true), + let args := unop_args_to_bounded x Hx in + let x' := SmartVarfMap (fun _ : base_type => BoundedWord64.BoundedWordToBounds) args in + match LiftOption.of' + (ApplyInterpedAll (Interp (@ZBounds.interp_op) (MapInterp ZBounds.of_word64 ropW)) (LiftOption.to' (Some x'))) + with + | Some bounds => unopFEToWire_bounds_good bounds = true + | None => False + end) + : unop_FEToWire_correct_and_bounded (MapInterp (fun _ x => x) ropW) op. +Proof. + intros x Hx. + pose x as x'. + hnf in x; destruct_head' prod. + specialize (H0 x' Hx). + specialize (H1 x' Hx). + let args := constr:(unop_args_to_bounded x' Hx) in + t_correct_and_bounded ropZ_sig Hbounds H0 H1 args. +Qed. + +Lemma ExprUnOpWireToFE_correct_and_bounded + ropW op (ropZ_sig : rexpr_unop_WireToFE_sig op) + (Hbounds : correct_and_bounded_genT ropW ropZ_sig) + (H0 : forall x + (x := eta_wire_digitsW x) + (Hx : wire_digits_is_bounded (wire_digitsWToZ x) = true), + let args := unopWireToFE_args_to_bounded x Hx in + match LiftOption.of' + (ApplyInterpedAll (Interp (@BoundedWord64.interp_op) (MapInterp BoundedWord64.of_word64 ropW)) + (LiftOption.to' (Some args))) + with + | Some _ => True + | None => False + end) + (H1 : forall x + (x := eta_wire_digitsW x) + (Hx : wire_digits_is_bounded (wire_digitsWToZ x) = true), + let args := unopWireToFE_args_to_bounded x Hx in + let x' := SmartVarfMap (fun _ : base_type => BoundedWord64.BoundedWordToBounds) args in + match LiftOption.of' + (ApplyInterpedAll (Interp (@ZBounds.interp_op) (MapInterp ZBounds.of_word64 ropW)) (LiftOption.to' (Some x'))) + with + | Some bounds => unopWireToFE_bounds_good bounds = true + | None => False + end) + : unop_WireToFE_correct_and_bounded (MapInterp (fun _ x => x) ropW) op. +Proof. + intros x Hx. + pose x as x'. + hnf in x; destruct_head' prod. + specialize (H0 x' Hx). + specialize (H1 x' Hx). + let args := constr:(unopWireToFE_args_to_bounded x' Hx) in + t_correct_and_bounded ropZ_sig Hbounds H0 H1 args. +Qed. + +Lemma ExprUnOpFEToZ_correct_and_bounded + ropW op (ropZ_sig : rexpr_unop_FEToZ_sig op) + (Hbounds : correct_and_bounded_genT ropW ropZ_sig) + (H0 : forall x + (x := eta_fe25519W x) + (Hx : is_bounded (fe25519WToZ x) = true), + let args := unop_args_to_bounded x Hx in + match LiftOption.of' + (ApplyInterpedAll (Interp (@BoundedWord64.interp_op) (MapInterp BoundedWord64.of_word64 ropW)) + (LiftOption.to' (Some args))) + with + | Some _ => True + | None => False + end) + (H1 : forall x + (x := eta_fe25519W x) + (Hx : is_bounded (fe25519WToZ x) = true), + let args := unop_args_to_bounded x Hx in + let x' := SmartVarfMap (fun _ : base_type => BoundedWord64.BoundedWordToBounds) args in + match LiftOption.of' + (ApplyInterpedAll (Interp (@ZBounds.interp_op) (MapInterp ZBounds.of_word64 ropW)) (LiftOption.to' (Some x'))) + with + | Some bounds => unopFEToZ_bounds_good bounds = true + | None => False + end) + : unop_FEToZ_correct (MapInterp (fun _ x => x) ropW) op. +Proof. + intros x Hx. + pose x as x'. + hnf in x; destruct_head' prod. + specialize (H0 x' Hx). + specialize (H1 x' Hx). + let args := constr:(unop_args_to_bounded x' Hx) in + t_correct_and_bounded ropZ_sig Hbounds H0 H1 args. +Qed. + +Ltac rexpr_correct := + let ropW' := fresh in + let ropZ_sig := fresh in + intros ropW' ropZ_sig; + let wf_ropW := fresh "wf_ropW" in + assert (wf_ropW : Wf ropW') by (subst ropW' ropZ_sig; reflect_Wf base_type_eq_semidec_is_dec op_beq_bl); + cbv zeta; repeat apply conj; + [ vm_compute; reflexivity + | apply @InterpRelWf; + [ | apply @RelWfMapInterp, wf_ropW ].. ]; + auto with interp_related. + +Notation rword_of_Z rexprZ_sig := (MapInterp Word64.of_Z (proj1_sig rexprZ_sig)) (only parsing). + +Notation compute_bounds opW bounds + := (ApplyInterpedAll (Interp (@ZBounds.interp_op) (MapInterp (@ZBounds.of_word64) opW)) bounds) + (only parsing). + + +Module Export PrettyPrinting. + Inductive bounds_on := overflow | in_range (lower upper : Z). + + Definition ZBounds_to_bounds_on + := fun t : base_type + => match t return ZBounds.interp_base_type t -> match t with TZ => bounds_on end with + | TZ => fun x => match x with + | Some {| ZBounds.lower := l ; ZBounds.upper := u |} + => in_range l u + | None + => overflow + end + end. + + Fixpoint no_overflow {t} : interp_flat_type (fun t => match t with TZ => bounds_on end) t -> bool + := match t return interp_flat_type (fun t => match t with TZ => bounds_on end) t -> bool with + | Tbase TZ => fun v => match v with + | overflow => false + | in_range _ _ => true + end + | Prod x y => fun v => andb (@no_overflow _ (fst v)) (@no_overflow _ (snd v)) + end. + + (** This gives a slightly easier to read version of the bounds *) + Notation compute_bounds_for_display opW bounds + := (SmartVarfMap ZBounds_to_bounds_on (compute_bounds opW bounds)) (only parsing). + Notation sanity_check opW bounds + := (eq_refl true <: no_overflow (SmartVarfMap ZBounds_to_bounds_on (compute_bounds opW bounds)) = true) (only parsing). +End PrettyPrinting. diff --git a/src/Specific/GF25519Reflective/Reified.v b/src/Specific/GF25519Reflective/Reified.v new file mode 100644 index 000000000..98edc1282 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified.v @@ -0,0 +1,13 @@ +(** We split the reification up into separate files, one operation per + file, so that it can run in parallel. *) +Require Export Crypto.Specific.GF25519Reflective.Reified.Add. +Require Export Crypto.Specific.GF25519Reflective.Reified.CarryAdd. +Require Export Crypto.Specific.GF25519Reflective.Reified.Sub. +Require Export Crypto.Specific.GF25519Reflective.Reified.CarrySub. +Require Export Crypto.Specific.GF25519Reflective.Reified.Mul. +Require Export Crypto.Specific.GF25519Reflective.Reified.Opp. +Require Export Crypto.Specific.GF25519Reflective.Reified.CarryOpp. +Require Export Crypto.Specific.GF25519Reflective.Reified.Freeze. +Require Export Crypto.Specific.GF25519Reflective.Reified.GeModulus. +Require Export Crypto.Specific.GF25519Reflective.Reified.Pack. +Require Export Crypto.Specific.GF25519Reflective.Reified.Unpack. diff --git a/src/Specific/GF25519Reflective/Reified/Add.v b/src/Specific/GF25519Reflective/Reified/Add.v new file mode 100644 index 000000000..36357fcb7 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/Add.v @@ -0,0 +1,11 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition raddZ_sig : rexpr_binop_sig add. Proof. reify_sig. Defined. +Definition raddW := Eval vm_compute in rword_of_Z raddZ_sig. +Lemma raddW_correct_and_bounded_gen : correct_and_bounded_genT raddW raddZ_sig. +Proof. rexpr_correct. Qed. +Definition radd_output_bounds := Eval vm_compute in compute_bounds raddW ExprBinOp_bounds. + +Local Open Scope string_scope. +Compute ("Add", compute_bounds_for_display raddW ExprBinOp_bounds). +(*Compute ("Add overflows? ", sanity_check raddW ExprBinOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/CarryAdd.v b/src/Specific/GF25519Reflective/Reified/CarryAdd.v new file mode 100644 index 000000000..0ff563a8c --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/CarryAdd.v @@ -0,0 +1,16 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rcarry_addZ_sig : rexpr_binop_sig carry_add. Proof. reify_sig. Defined. +Definition rcarry_addW := Eval vm_compute in rword_of_Z rcarry_addZ_sig. +Lemma rcarry_addW_correct_and_bounded_gen : correct_and_bounded_genT rcarry_addW rcarry_addZ_sig. +Proof. rexpr_correct. Qed. +Definition rcarry_add_output_bounds := Eval vm_compute in compute_bounds rcarry_addW ExprBinOp_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition rcarry_addW_correct_and_bounded + := ExprBinOp_correct_and_bounded + rcarry_addW carry_add rcarry_addZ_sig rcarry_addW_correct_and_bounded_gen + _ _. + +Local Open Scope string_scope. +Compute ("Carry_Add", compute_bounds_for_display rcarry_addW ExprBinOp_bounds). +(*Compute ("Carry_Add overflows? ", sanity_check rcarry_addW ExprBinOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/CarryOpp.v b/src/Specific/GF25519Reflective/Reified/CarryOpp.v new file mode 100644 index 000000000..4c21fbeb8 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/CarryOpp.v @@ -0,0 +1,16 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rcarry_oppZ_sig : rexpr_unop_sig carry_opp. Proof. reify_sig. Defined. +Definition rcarry_oppW := Eval vm_compute in rword_of_Z rcarry_oppZ_sig. +Lemma rcarry_oppW_correct_and_bounded_gen : correct_and_bounded_genT rcarry_oppW rcarry_oppZ_sig. +Proof. rexpr_correct. Qed. +Definition rcarry_opp_output_bounds := Eval vm_compute in compute_bounds rcarry_oppW ExprUnOp_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition rcarry_oppW_correct_and_bounded + := ExprUnOp_correct_and_bounded + rcarry_oppW carry_opp rcarry_oppZ_sig rcarry_oppW_correct_and_bounded_gen + _ _. + +Local Open Scope string_scope. +Compute ("Carry_Opp", compute_bounds_for_display rcarry_oppW ExprUnOp_bounds). +(*Compute ("Carry_Opp overflows? ", sanity_check rcarry_oppW ExprUnOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/CarrySub.v b/src/Specific/GF25519Reflective/Reified/CarrySub.v new file mode 100644 index 000000000..3acfb1f45 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/CarrySub.v @@ -0,0 +1,16 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rcarry_subZ_sig : rexpr_binop_sig carry_sub. Proof. reify_sig. Defined. +Definition rcarry_subW := Eval vm_compute in rword_of_Z rcarry_subZ_sig. +Lemma rcarry_subW_correct_and_bounded_gen : correct_and_bounded_genT rcarry_subW rcarry_subZ_sig. +Proof. rexpr_correct. Qed. +Definition rcarry_sub_output_bounds := Eval vm_compute in compute_bounds rcarry_subW ExprBinOp_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition rcarry_subW_correct_and_bounded + := ExprBinOp_correct_and_bounded + rcarry_subW carry_sub rcarry_subZ_sig rcarry_subW_correct_and_bounded_gen + _ _. + +Local Open Scope string_scope. +Compute ("Carry_Sub", compute_bounds_for_display rcarry_subW ExprBinOp_bounds). +(*Compute ("Carry_Sub overflows? ", sanity_check rcarry_subW ExprBinOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/Freeze.v b/src/Specific/GF25519Reflective/Reified/Freeze.v new file mode 100644 index 000000000..e3ecc62c8 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/Freeze.v @@ -0,0 +1,18 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rfreezeZ_sig : rexpr_unop_sig freeze. Proof. reify_sig. Defined. +Definition rfreezeW := Eval vm_compute in rword_of_Z rfreezeZ_sig. +Lemma rfreezeW_correct_and_bounded_gen : correct_and_bounded_genT rfreezeW rfreezeZ_sig. +Proof. rexpr_correct. Qed. +Definition rfreeze_output_bounds := Eval vm_compute in compute_bounds rfreezeW ExprUnOp_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Axiom proof_admitted : False. +(** XXX TODO: Fix bounds analysis on freeze *) +Definition rfreezeW_correct_and_bounded + := ExprUnOp_correct_and_bounded + rfreezeW freeze rfreezeZ_sig rfreezeW_correct_and_bounded_gen + match proof_admitted with end match proof_admitted with end. + +Local Open Scope string_scope. +Compute ("Freeze", compute_bounds_for_display rfreezeW ExprUnOp_bounds). +(*Compute ("Freeze overflows? ", sanity_check rfreezeW ExprUnOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/GeModulus.v b/src/Specific/GF25519Reflective/Reified/GeModulus.v new file mode 100644 index 000000000..73ee6904a --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/GeModulus.v @@ -0,0 +1,16 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rge_modulusZ_sig : rexpr_unop_FEToZ_sig ge_modulus. Proof. reify_sig. Defined. +Definition rge_modulusW := Eval vm_compute in rword_of_Z rge_modulusZ_sig. +Lemma rge_modulusW_correct_and_bounded_gen : correct_and_bounded_genT rge_modulusW rge_modulusZ_sig. +Proof. rexpr_correct. Qed. +Definition rge_modulus_output_bounds := Eval vm_compute in compute_bounds rge_modulusW ExprUnOpFEToZ_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition rge_modulusW_correct_and_bounded + := ExprUnOpFEToZ_correct_and_bounded + rge_modulusW ge_modulus rge_modulusZ_sig rge_modulusW_correct_and_bounded_gen + _ _. + +Local Open Scope string_scope. +Compute ("Ge_Modulus", compute_bounds_for_display rge_modulusW ExprUnOpFEToZ_bounds). +(*Compute ("Ge_Modulus overflows? ", sanity_check rge_modulusW ExprUnOpFEToZ_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/Mul.v b/src/Specific/GF25519Reflective/Reified/Mul.v new file mode 100644 index 000000000..a206f02a1 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/Mul.v @@ -0,0 +1,16 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rmulZ_sig : rexpr_binop_sig mul. Proof. reify_sig. Defined. +Definition rmulW := Eval vm_compute in rword_of_Z rmulZ_sig. +Lemma rmulW_correct_and_bounded_gen : correct_and_bounded_genT rmulW rmulZ_sig. +Proof. rexpr_correct. Qed. +Definition rmul_output_bounds := Eval vm_compute in compute_bounds rmulW ExprBinOp_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition rmulW_correct_and_bounded + := ExprBinOp_correct_and_bounded + rmulW mul rmulZ_sig rmulW_correct_and_bounded_gen + _ _. + +Local Open Scope string_scope. +Compute ("Mul", compute_bounds_for_display rmulW ExprBinOp_bounds). +(*Compute ("Mul overflows? ", sanity_check rmulW ExprBinOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/Opp.v b/src/Specific/GF25519Reflective/Reified/Opp.v new file mode 100644 index 000000000..907771b14 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/Opp.v @@ -0,0 +1,11 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition roppZ_sig : rexpr_unop_sig opp. Proof. reify_sig. Defined. +Definition roppW := Eval vm_compute in rword_of_Z roppZ_sig. +Lemma roppW_correct_and_bounded_gen : correct_and_bounded_genT roppW roppZ_sig. +Proof. rexpr_correct. Qed. +Definition ropp_output_bounds := Eval vm_compute in compute_bounds roppW ExprUnOp_bounds. + +Local Open Scope string_scope. +Compute ("Opp", compute_bounds_for_display roppW ExprUnOp_bounds). +(*Compute ("Opp overflows? ", sanity_check roppW ExprUnOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/Pack.v b/src/Specific/GF25519Reflective/Reified/Pack.v new file mode 100644 index 000000000..a7cf4fc13 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/Pack.v @@ -0,0 +1,16 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rpackZ_sig : rexpr_unop_FEToWire_sig pack. Proof. reify_sig. Defined. +Definition rpackW := Eval vm_compute in rword_of_Z rpackZ_sig. +Lemma rpackW_correct_and_bounded_gen : correct_and_bounded_genT rpackW rpackZ_sig. +Proof. rexpr_correct. Qed. +Definition rpack_output_bounds := Eval vm_compute in compute_bounds rpackW ExprUnOpFEToWire_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition rpackW_correct_and_bounded + := ExprUnOpFEToWire_correct_and_bounded + rpackW pack rpackZ_sig rpackW_correct_and_bounded_gen + _ _. + +Local Open Scope string_scope. +Compute ("Pack", compute_bounds_for_display rpackW ExprUnOpFEToWire_bounds). +(*Compute ("Pack overflows? ", sanity_check rpackW ExprUnOpFEToWire_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/Sub.v b/src/Specific/GF25519Reflective/Reified/Sub.v new file mode 100644 index 000000000..9b684248d --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/Sub.v @@ -0,0 +1,11 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition rsubZ_sig : rexpr_binop_sig sub. Proof. reify_sig. Defined. +Definition rsubW := Eval vm_compute in rword_of_Z rsubZ_sig. +Lemma rsubW_correct_and_bounded_gen : correct_and_bounded_genT rsubW rsubZ_sig. +Proof. rexpr_correct. Qed. +Definition rsub_output_bounds := Eval vm_compute in compute_bounds rsubW ExprBinOp_bounds. + +Local Open Scope string_scope. +Compute ("Sub", compute_bounds_for_display rsubW ExprBinOp_bounds). +(*Compute ("Sub overflows? ", sanity_check rsubW ExprBinOp_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/Unpack.v b/src/Specific/GF25519Reflective/Reified/Unpack.v new file mode 100644 index 000000000..027eedf39 --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/Unpack.v @@ -0,0 +1,16 @@ +Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition runpackZ_sig : rexpr_unop_WireToFE_sig unpack. Proof. reify_sig. Defined. +Definition runpackW := Eval vm_compute in rword_of_Z runpackZ_sig. +Lemma runpackW_correct_and_bounded_gen : correct_and_bounded_genT runpackW runpackZ_sig. +Proof. rexpr_correct. Qed. +Definition runpack_output_bounds := Eval vm_compute in compute_bounds runpackW ExprUnOpWireToFE_bounds. +Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition runpackW_correct_and_bounded + := ExprUnOpWireToFE_correct_and_bounded + runpackW unpack runpackZ_sig runpackW_correct_and_bounded_gen + _ _. + +Local Open Scope string_scope. +Compute ("Unpack", compute_bounds_for_display runpackW ExprUnOpWireToFE_bounds). +(*Compute ("Unpack overflows? ", sanity_check runpackW ExprUnOpWireToFE_bounds).*) diff --git a/src/Specific/GF25519Reflective/Reified/rebuild-reified.py b/src/Specific/GF25519Reflective/Reified/rebuild-reified.py new file mode 100755 index 000000000..76ac2c91b --- /dev/null +++ b/src/Specific/GF25519Reflective/Reified/rebuild-reified.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python2.7 +from __future__ import with_statement + +for name, opkind in ([(name, 'BinOp') for name in ('Add', 'Carry_Add', 'Sub', 'Carry_Sub', 'Mul')] + + [(name, 'UnOp') for name in ('Opp', 'Carry_Opp', 'Freeze')] + + [('Ge_Modulus', 'UnOp_FEToZ'), ('Pack', 'UnOp_FEToWire'), ('Unpack', 'UnOp_WireToFE')]): + lname = name.lower() + lopkind = opkind.replace('UnOp', 'unop').replace('BinOp', 'binop') + uopkind = opkind.replace('_', '') + extra = '' + if name in ('Carry_Add', 'Carry_Sub', 'Mul', 'Carry_Opp', 'Pack', 'Unpack', 'Ge_Modulus'): + extra = r"""Local Obligation Tactic := intros; vm_compute; constructor. +Program Definition r%(lname)sW_correct_and_bounded + := Expr%(uopkind)s_correct_and_bounded + r%(lname)sW %(lname)s r%(lname)sZ_sig r%(lname)sW_correct_and_bounded_gen + _ _. +""" % locals() + elif name == 'Freeze': + extra = r"""Local Obligation Tactic := intros; vm_compute; constructor. +Axiom proof_admitted : False. +(** XXX TODO: Fix bounds analysis on freeze *) +Definition r%(lname)sW_correct_and_bounded + := Expr%(uopkind)s_correct_and_bounded + r%(lname)sW %(lname)s r%(lname)sZ_sig r%(lname)sW_correct_and_bounded_gen + match proof_admitted with end match proof_admitted with end. +""" % locals() + with open(name.replace('_', '') + '.v', 'w') as f: + f.write(r"""Require Import Crypto.Specific.GF25519Reflective.Common. + +Definition r%(lname)sZ_sig : rexpr_%(lopkind)s_sig %(lname)s. Proof. reify_sig. Defined. +Definition r%(lname)sW := Eval vm_compute in rword_of_Z r%(lname)sZ_sig. +Lemma r%(lname)sW_correct_and_bounded_gen : correct_and_bounded_genT r%(lname)sW r%(lname)sZ_sig. +Proof. rexpr_correct. Qed. +Definition r%(lname)s_output_bounds := Eval vm_compute in compute_bounds r%(lname)sW Expr%(uopkind)s_bounds. +%(extra)s +Local Open Scope string_scope. +Compute ("%(name)s", compute_bounds_for_display r%(lname)sW Expr%(uopkind)s_bounds). +(*Compute ("%(name)s overflows? ", sanity_check r%(lname)sW Expr%(uopkind)s_bounds).*) +""" % locals()) diff --git a/src/Test/Curve25519SpecTestVectors.v b/src/Test/Curve25519SpecTestVectors.v index 511998d48..15ca468c1 100644 --- a/src/Test/Curve25519SpecTestVectors.v +++ b/src/Test/Curve25519SpecTestVectors.v @@ -6,7 +6,7 @@ Definition F := F (2^255 - 19). Definition a : F := F.of_Z _ 486662. Definition a24 : F := ((a - F.of_Z _ 2) / F.of_Z _ 4)%F. Definition cswap {T} (swap:bool) (a b:T) := if swap then (b, a) else (a, b). -Definition monty : BinNums.N -> F -> F := @MxDH.montladder F F.zero F.one F.add F.sub F.mul F.div a24 BinNums.N BinNat.N.testbit_nat cswap 255. +Definition monty s : F -> F := @MxDH.montladder F F.zero F.one F.add F.sub F.mul F.inv a24 cswap 255 (BinNat.N.testbit_nat s). Example one_basepoint : F.to_Z (monty 1 (F.of_Z _ 9)) = 9%Z. Proof. vm_decide_no_check. Qed. diff --git a/src/Util/Bool.v b/src/Util/Bool.v index 7b94c503e..a59cf53f0 100644 --- a/src/Util/Bool.v +++ b/src/Util/Bool.v @@ -51,3 +51,9 @@ Definition pull_bool_if_dep {A B} (f : forall b : bool, A b -> B b) (b : bool) ( Definition pull_bool_if {A B} (f : A -> B) (b : bool) (x : A) (y : A) : (if b then f x else f y) = f (if b then x else y) := @pull_bool_if_dep (fun _ => A) (fun _ => B) (fun _ => f) b x y. + +Definition reflect_iff_gen {P b} : reflect P b -> forall b' : bool, (if b' then P else ~P) <-> b = b'. +Proof. + intros H; apply reflect_iff in H; intro b'; destruct b, b'; + intuition congruence. +Qed. diff --git a/src/Util/Decidable.v b/src/Util/Decidable.v index a6954663b..b01fe3627 100644 --- a/src/Util/Decidable.v +++ b/src/Util/Decidable.v @@ -111,6 +111,11 @@ Global Instance dec_le_Z : DecidableRel BinInt.Z.le := ZArith_dec.Z_le_dec. Global Instance dec_gt_Z : DecidableRel BinInt.Z.gt := ZArith_dec.Z_gt_dec. Global Instance dec_ge_Z : DecidableRel BinInt.Z.ge := ZArith_dec.Z_ge_dec. +Global Instance dec_match_pair {A B} {P : A -> B -> Prop} {x : A * B} + {HD : Decidable (P (fst x) (snd x))} + : Decidable (let '(a, b) := x in P a b) | 1. +Proof. destruct x; assumption. Defined. + Lemma not_not P {d:Decidable P} : not (not P) <-> P. Proof. destruct (dec P); intuition. Qed. diff --git a/src/Util/HList.v b/src/Util/HList.v new file mode 100644 index 000000000..aacefe8f3 --- /dev/null +++ b/src/Util/HList.v @@ -0,0 +1,96 @@ +Require Import Coq.Classes.Morphisms. +Require Import Coq.Relations.Relation_Definitions. +Require Import Coq.Lists.List. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.Tuple. +Require Export Crypto.Util.FixCoqMistakes. + +Fixpoint hlist' T n (f : T -> Type) : tuple' T n -> Type := + match n return tuple' _ n -> Type with + | 0 => fun T => f T + | S n' => fun Ts => (hlist' T n' f (fst Ts) * f (snd Ts))%type + end. +Global Arguments hlist' {T n} f _. + +Definition hlist {T n} (f : T -> Type) : forall (Ts : tuple T n), Type := + match n return tuple _ n -> Type with + | 0 => fun _ => unit + | S n' => @hlist' T n' f + end. + +Fixpoint const' {T n F xs} (v : forall x, F x) : @hlist' T n F xs + := match n return forall xs, @hlist' T n F xs with + | 0 => fun _ => v _ + | S n' => fun _ => (@const' T n' F _ v, v _) + end xs. +Definition const {T n F xs} (v : forall x, F x) : @hlist T n F xs + := match n return forall xs, @hlist T n F xs with + | 0 => fun _ => tt + | S n' => fun xs => @const' T n' F xs v + end xs. + +(* tuple map *) +Fixpoint mapt' {n A F B} (f : forall x : A, F x -> B) : forall {ts : tuple' A n}, hlist' F ts -> tuple' B n + := match n return forall ts : tuple' A n, hlist' F ts -> tuple' B n with + | 0 => fun ts v => f _ v + | S n' => fun ts v => (@mapt' n' A F B f _ (fst v), f _ (snd v)) + end. +Definition mapt {n A F B} (f : forall x : A, F x -> B) + : forall {ts : tuple A n}, hlist F ts -> tuple B n + := match n return forall ts : tuple A n, hlist F ts -> tuple B n with + | 0 => fun ts v => tt + | S n' => @mapt' n' A F B f + end. + +Lemma map'_mapt' {n A F B C} (g : B -> C) (f : forall x : A, F x -> B) + {ts : tuple' A n} (ls : hlist' F ts) + : Tuple.map (n:=S n) g (mapt' f ls) = mapt' (fun x v => g (f x v)) ls. +Proof. + induction n as [|n IHn]; [ reflexivity | ]. + { simpl @mapt' in *. + rewrite <- IHn. + rewrite Tuple.map_S; reflexivity. } +Qed. + +Lemma map_mapt {n A F B C} (g : B -> C) (f : forall x : A, F x -> B) + {ts : tuple A n} (ls : hlist F ts) + : Tuple.map g (mapt f ls) = mapt (fun x v => g (f x v)) ls. +Proof. + destruct n as [|n]; [ reflexivity | ]. + apply map'_mapt'. +Qed. + +Lemma map_is_mapt {n A F B} (f : A -> B) {ts : tuple A n} (ls : hlist F ts) + : Tuple.map f ts = mapt (fun x _ => f x) ls. +Proof. + destruct n as [|n]; [ reflexivity | ]. + induction n as [|n IHn]; [ reflexivity | ]. + { unfold mapt in *; simpl @mapt' in *. + rewrite <- IHn; clear IHn. + rewrite <- (@Tuple.map_S n _ _ f); destruct ts; reflexivity. } +Qed. + +Lemma map_is_mapt' {n A F B} (f : A -> B) {ts : tuple A (S n)} (ls : hlist' F ts) + : Tuple.map f ts = mapt' (fun x _ => f x) ls. +Proof. apply (@map_is_mapt (S n)). Qed. + + +Lemma hlist'_impl {n A F G} (xs:tuple' A n) + : (hlist' (fun x => F x -> G x) xs) -> (hlist' F xs -> hlist' G xs). +Proof. + induction n; simpl in *; intuition. +Defined. + +Lemma hlist_impl {n A F G} (xs:tuple A n) + : (hlist (fun x => F x -> G x) xs) -> (hlist F xs -> hlist G xs). +Proof. + destruct n; [ constructor | apply hlist'_impl ]. +Defined. + +Module Tuple. + Lemma map_id_ext {n A} (f : A -> A) (xs:tuple A n) + : hlist (fun x => f x = x) xs -> Tuple.map f xs = xs. + Proof. + Admitted. +End Tuple. diff --git a/src/Util/IffT.v b/src/Util/IffT.v new file mode 100644 index 000000000..f4eaa9d53 --- /dev/null +++ b/src/Util/IffT.v @@ -0,0 +1,10 @@ +Require Import Coq.Classes.RelationClasses. +Notation iffT A B := (((A -> B) * (B -> A)))%type. +Notation iffTp := (fun A B => inhabited (iffT A B)). + +Global Instance iffTp_Reflexive : Reflexive iffTp | 1. +Proof. repeat constructor; intro; assumption. Defined. +Global Instance iffTp_Symmetric : Symmetric iffTp | 1. +Proof. repeat (intros [?] || intro); constructor; tauto. Defined. +Global Instance iffTp_Transitive : Transitive iffTp | 1. +Proof. repeat (intros [?] || intro); constructor; tauto. Defined. diff --git a/src/Util/IterAssocOp.v b/src/Util/IterAssocOp.v index 773fea8fd..d630698e7 100644 --- a/src/Util/IterAssocOp.v +++ b/src/Util/IterAssocOp.v @@ -243,7 +243,7 @@ Proof. | _ => solve [ reflexivity | congruence | eauto 99 ] | _ => progress eapply (Proper_funexp (R:=(fun nt NT => Logic.eq (fst nt) (fst NT) /\ R (snd nt) (snd NT)))) | _ => progress eapply Proper_test_and_op - | _ => progress eapply conj + | _ => progress split | _ => progress (cbv [fst snd pointwise_relation respectful] in * ) | _ => intro end. diff --git a/src/Util/PartiallyReifiedProp.v b/src/Util/PartiallyReifiedProp.v new file mode 100644 index 000000000..ef1567bd8 --- /dev/null +++ b/src/Util/PartiallyReifiedProp.v @@ -0,0 +1,165 @@ +(** * Propositions with a distinguished representation of [True], [False], [and], [or], and [impl] *) +(** This allows for something between [bool] and [Prop], where we can + computationally reduce things like [True /\ True], but can still + express equality of types. *) +Require Import Coq.Setoids.Setoid. +Require Import Coq.Program.Tactics. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Tactics. + +Delimit Scope reified_prop_scope with reified_prop. +Inductive reified_Prop := rTrue | rFalse | rAnd (x y : reified_Prop) | rOr (x y : reified_Prop) | rImpl (x y : reified_Prop) | rForall {T} (f : T -> reified_Prop) | rEq {T} (x y : T) | inject (_ : Prop). +Bind Scope reified_prop_scope with reified_Prop. + +Fixpoint to_prop (x : reified_Prop) : Prop + := match x with + | rTrue => True + | rFalse => False + | rAnd x y => to_prop x /\ to_prop y + | rOr x y => to_prop x \/ to_prop y + | rImpl x y => to_prop x -> to_prop y + | @rForall _ f => forall x, to_prop (f x) + | @rEq _ x y => x = y + | inject x => x + end. + +Coercion reified_Prop_of_bool (x : bool) : reified_Prop + := if x then rTrue else rFalse. + +Definition and_reified_Prop (x y : reified_Prop) : reified_Prop + := match x, y with + | rTrue, y => y + | x, rTrue => x + | rFalse, y => rFalse + | x, rFalse => rFalse + | rEq T a b, rEq T' a' b' => rEq (a, a') (b, b') + | x', y' => rAnd x' y' + end. +Definition or_reified_Prop (x y : reified_Prop) : reified_Prop + := match x, y with + | rTrue, y => rTrue + | x, rTrue => rTrue + | rFalse, y => y + | x, rFalse => x + | x', y' => rOr x' y' + end. +Definition impl_reified_Prop (x y : reified_Prop) : reified_Prop + := match x, y with + | rTrue, y => y + | x, rTrue => rTrue + | rFalse, y => rTrue + | rImpl x rFalse, rFalse => x + | x', y' => rImpl x' y' + end. + +Infix "/\" := and_reified_Prop : reified_prop_scope. +Infix "\/" := or_reified_Prop : reified_prop_scope. +Infix "->" := impl_reified_Prop : reified_prop_scope. +Infix "=" := rEq : reified_prop_scope. +Notation "~ P" := (P -> rFalse)%reified_prop : reified_prop_scope. +Notation "∀ x .. y , P" := (rForall (fun x => .. (rForall (fun y => P%reified_prop)) .. )) + (at level 200, x binder, y binder, right associativity) : reified_prop_scope. + +Definition reified_Prop_eq (x y : reified_Prop) + := match x, y with + | rTrue, _ => y = rTrue + | rFalse, _ => y = rFalse + | rAnd x0 x1, rAnd y0 y1 + => x0 = y0 /\ x1 = y1 + | rAnd _ _, _ => False + | rOr x0 x1, rOr y0 y1 + => x0 = y0 /\ x1 = y1 + | rOr _ _, _ => False + | rImpl x0 x1, rImpl y0 y1 + => x0 = y0 /\ x1 = y1 + | rImpl _ _, _ => False + | @rForall Tx fx, @rForall Ty fy + => exists pf : Tx = Ty, + forall x, fx x = fy (eq_rect _ (fun t => t) x _ pf) + | rForall _ _, _ => False + | @rEq Tx x0 x1, @rEq Ty y0 y1 + => exists pf : Tx = Ty, + eq_rect _ (fun t => t) x0 _ pf = y0 + /\ eq_rect _ (fun t => t) x1 _ pf = y1 + | rEq _ _ _, _ => False + | inject x, inject y => x = y + | inject _, _ => False + end. + +Section rel. + Local Ltac t := + cbv; + repeat (match goal with |- forall x, _ => intro end (* work around broken Ltac [match] in 8.4 that diverges on things under binders *) + || break_match + || break_match_hyps + || intro + || (simpl in * ) + || intuition try congruence + || (exists eq_refl) + || eauto + || (subst * ) + || apply conj + || destruct_head' ex + || solve [ apply reflexivity + | apply symmetry; eassumption + | eapply transitivity; eassumption ] ). + + Global Instance Reflexive_reified_Prop_eq : Reflexive reified_Prop_eq. + Proof. t. Qed. + Global Instance Symmetric_reified_Prop_eq : Symmetric reified_Prop_eq. + Proof. t. Qed. + Global Instance Transitive_reified_Prop_eq : Transitive reified_Prop_eq. + Proof. t. Qed. + Global Instance Equivalence_reified_Prop_eq : Equivalence reified_Prop_eq. + Proof. split; exact _. Qed. +End rel. + +Definition reified_Prop_leq_to_eq (x y : reified_Prop) : x = y -> reified_Prop_eq x y. +Proof. intro; subst; simpl; reflexivity. Qed. + +Ltac inversion_reified_Prop_step := + let do_on H := apply reified_Prop_leq_to_eq in H; unfold reified_Prop_eq in H in + match goal with + | [ H : False |- _ ] => solve [ destruct H ] + | [ H : (_ = _ :> reified_Prop) /\ (_ = _ :> reified_Prop) |- _ ] => destruct H + | [ H : ?x = ?x :> reified_Prop |- _ ] => clear H + | [ H : exists pf : _ = _ :> Type, forall x, _ = _ :> reified_Prop |- _ ] + => destruct H as [? H]; subst; simpl @eq_rect in H + | [ H : ?x = _ :> reified_Prop |- _ ] => is_var x; subst x + | [ H : _ = ?y :> reified_Prop |- _ ] => is_var y; subst y + | [ H : rTrue = rFalse |- _ ] => solve [ inversion H ] + | [ H : rFalse = rTrue |- _ ] => solve [ inversion H ] + | [ H : rTrue = _ |- _ ] => do_on H; progress subst + | [ H : rFalse = _ |- _ ] => do_on H; progress subst + | [ H : rAnd _ _ = _ |- _ ] => do_on H + | [ H : rOr _ _ = _ |- _ ] => do_on H + | [ H : rImpl _ _ = _ |- _ ] => do_on H + | [ H : rForall _ = _ |- _ ] => do_on H + | [ H : rEq _ _ = _ |- _ ] => do_on H + | [ H : inject _ = _ |- _ ] => do_on H + end. +Ltac inversion_reified_Prop := repeat inversion_reified_Prop_step. + +Lemma to_prop_and_reified_Prop x y : to_prop (x /\ y) <-> (to_prop x /\ to_prop y). +Proof. + destruct x, y; simpl; try tauto. + { split; intro H; inversion H; subst; repeat split. } +Qed. + +(** Remove all possibly false terms in a reified prop *) +Fixpoint trueify (p : reified_Prop) : reified_Prop + := match p with + | rTrue => rTrue + | rFalse => rTrue + | rAnd x y => rAnd (trueify x) (trueify y) + | rOr x y => rOr (trueify x) (trueify y) + | rImpl x y => rImpl x (trueify y) + | rForall T f => rForall (fun x => trueify (f x)) + | rEq T x y => rEq x x + | inject x => inject True + end. + +Lemma trueify_true : forall p, to_prop (trueify p). +Proof. + induction p; simpl; auto. +Qed. diff --git a/src/Util/Prod.v b/src/Util/Prod.v index b83aea68f..bcd9404a6 100644 --- a/src/Util/Prod.v +++ b/src/Util/Prod.v @@ -5,6 +5,8 @@ between two such pairs, or when we want such an equality, we have a systematic way of reducing such equalities to equalities at simpler types. *) +Require Import Coq.Classes.Morphisms. +Require Import Crypto.Util.IffT. Require Import Crypto.Util.Equality. Require Import Crypto.Util.GlobalSettings. @@ -68,6 +70,31 @@ Section prod. Definition path_prod_ind {A B u v} (P : u = v :> @prod A B -> Prop) := path_prod_rec P. End prod. +Lemma prod_iff_and (A B : Prop) : (A /\ B) <-> (A * B). +Proof. repeat (intros [? ?] || intro || split); assumption. Defined. + +Global Instance iff_prod_Proper + : Proper (iff ==> iff ==> iff) (fun A B => prod A B). +Proof. repeat intro; tauto. Defined. +Global Instance iff_iffTp_prod_Proper + : Proper (iff ==> iffTp ==> iffTp) (fun A B => prod A B) | 1. +Proof. + intros ?? [?] ?? [?]; constructor; tauto. +Defined. +Global Instance iffTp_iff_prod_Proper + : Proper (iffTp ==> iff ==> iffTp) (fun A B => prod A B) | 1. +Proof. + intros ?? [?] ?? [?]; constructor; tauto. +Defined. +Global Instance iffTp_iffTp_prod_Proper + : Proper (iffTp ==> iffTp ==> iffTp) (fun A B => prod A B) | 1. +Proof. + intros ?? [?] ?? [?]; constructor; tauto. +Defined. +Hint Extern 2 (Proper _ prod) => apply iffTp_iffTp_prod_Proper : typeclass_instances. +Hint Extern 2 (Proper _ (fun A => prod A)) => refine iff_iffTp_prod_Proper : typeclass_instances. +Hint Extern 2 (Proper _ (fun A B => prod A B)) => refine iff_prod_Proper : typeclass_instances. + (** ** Useful Tactics *) (** *** [inversion_prod] *) Ltac simpl_proj_pair_in H := diff --git a/src/Util/Tactics.v b/src/Util/Tactics.v index 2f9f6c59f..128fdcfd0 100644 --- a/src/Util/Tactics.v +++ b/src/Util/Tactics.v @@ -109,6 +109,12 @@ Ltac break_match_hyps_when_head_step T := constr_eq T T'). Ltac break_match_when_head T := repeat break_match_when_head_step T. Ltac break_match_hyps_when_head T := repeat break_match_hyps_when_head_step T. +Ltac break_innermost_match_step := + break_match_step ltac:(fun v => lazymatch v with + | appcontext[match _ with _ => _ end] => fail + | _ => idtac + end). +Ltac break_innermost_match := repeat break_innermost_match_step. Ltac free_in x y := idtac; @@ -481,3 +487,35 @@ Ltac idtac_context := idtac_goal; lazymatch goal with |- ?G => idtac "Context:" G end; fail). + +(** Destruct the convoy pattern ([match e as x return x = e -> _ with _ => _ end eq_refl] *) +Ltac convoy_destruct_gen T change_in := + let e' := fresh in + let H' := fresh in + match T with + | context G[?f eq_refl] + => match f with + | match ?e with _ => _ end + => pose e as e'; + match f with + | context F[e] + => let F' := context F[e'] in + first [ pose (eq_refl : e = e') as H'; + let G' := context G[F' H'] in + change_in G'; + clearbody H' e' + | pose (eq_refl : e' = e) as H'; + let G' := context G[F' H'] in + change_in G'; + clearbody H' e' ] + end + end; + destruct e' + end. + +Ltac convoy_destruct_in H := + let T := type of H in + convoy_destruct_gen T ltac:(fun T' => change T' in H). +Ltac convoy_destruct := + let T := get_goal in + convoy_destruct_gen T ltac:(fun T' => change T'). diff --git a/src/Util/Tuple.v b/src/Util/Tuple.v index 2e9f7b0ad..4d97c7857 100644 --- a/src/Util/Tuple.v +++ b/src/Util/Tuple.v @@ -1,6 +1,9 @@ Require Import Coq.Classes.Morphisms. Require Import Coq.Relations.Relation_Definitions. Require Import Coq.Lists.List. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Tactics. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.ListUtil. Require Export Crypto.Util.FixCoqMistakes. @@ -17,6 +20,19 @@ Definition tuple T n : Type := | S n' => tuple' T n' end. +Definition tl' {T n} : tuple' T (S n) -> tuple' T n := @fst _ _. +Definition tl {T n} : tuple T (S n) -> tuple T n := + match n with + | O => fun _ => tt + | S n' => @tl' T n' + end. +Definition hd' {T n} : tuple' T n -> T := + match n with + | O => fun x => x + | S n' => @snd _ _ + end. +Definition hd {T n} : tuple T (S n) -> T := @hd' _ _. + Fixpoint to_list' {T} (n:nat) {struct n} : tuple' T n -> list T := match n with | 0 => fun x => (x::nil)%list @@ -136,6 +152,13 @@ Definition on_tuple {A B} (f:list A -> list B) Definition map {n A B} (f:A -> B) (xs:tuple A n) : tuple B n := on_tuple (List.map f) (fun _ => eq_trans (map_length _ _)) xs. +Lemma map_S {n A B} (f:A -> B) (xs:tuple' A n) (x:A) + : map (n:=S (S n)) f (xs, x) = (map (n:=S n) f xs, f x). +Proof. + unfold map, on_tuple. + simpl @List.map. +Admitted. + Definition on_tuple2 {A B C} (f : list A -> list B -> list C) {a b c : nat} (Hlength : forall la lb, length la = a -> length lb = b -> length (f la lb) = c) (ta:tuple A a) (tb:tuple B b) : tuple C c @@ -145,6 +168,103 @@ Definition on_tuple2 {A B C} (f : list A -> list B -> list C) {a b c : nat} Definition map2 {n A B C} (f:A -> B -> C) (xs:tuple A n) (ys:tuple B n) : tuple C n := on_tuple2 (map2 f) (fun la lb pfa pfb => eq_trans (@map2_length _ _ _ _ la lb) (eq_trans (f_equal2 _ pfa pfb) (Min.min_idempotent _))) xs ys. +Lemma map_map2 {n A B C D} (f:A -> B -> C) (g:C -> D) (xs:tuple A n) (ys:tuple B n) + : map g (map2 f xs ys) = map2 (fun a b => g (f a b)) xs ys. +Proof. +Admitted. + +Lemma map2_fst {n A B C} (f:A -> C) (xs:tuple A n) (ys:tuple B n) + : map2 (fun a b => f a) xs ys = map f xs. +Proof. +Admitted. + +Lemma map2_snd {n A B C} (f:B -> C) (xs:tuple A n) (ys:tuple B n) + : map2 (fun a b => f b) xs ys = map f ys. +Proof. +Admitted. + +Lemma map_id {n A} (xs:tuple A n) + : map (fun x => x) xs = xs. +Proof. +Admitted. + +Lemma map_id_ext {n A} (f : A -> A) (xs:tuple A n) + : (forall x, f x = x) -> map f xs = xs. +Proof. +Admitted. + +Lemma map_map {n A B C} (g : B -> C) (f : A -> B) (xs:tuple A n) + : map g (map f xs) = map (fun x => g (f x)) xs. +Proof. +Admitted. + +Section monad. + Context (M : Type -> Type) (bind : forall X Y, M X -> (X -> M Y) -> M Y) (ret : forall X, X -> M X). + Fixpoint lift_monad' {n A} {struct n} + : tuple' (M A) n -> M (tuple' A n) + := match n return tuple' (M A) n -> M (tuple' A n) with + | 0 => fun t => t + | S n' => fun xy => bind _ _ (@lift_monad' n' _ (fst xy)) (fun x' => bind _ _ (snd xy) (fun y' => ret _ (x', y'))) + end. + Fixpoint push_monad' {n A} {struct n} + : M (tuple' A n) -> tuple' (M A) n + := match n return M (tuple' A n) -> tuple' (M A) n with + | 0 => fun t => t + | S n' => fun xy => (@push_monad' n' _ (bind _ _ xy (fun xy' => ret _ (fst xy'))), + bind _ _ xy (fun xy' => ret _ (snd xy'))) + end. + Definition lift_monad {n A} + : tuple (M A) n -> M (tuple A n) + := match n return tuple (M A) n -> M (tuple A n) with + | 0 => ret _ + | S n' => @lift_monad' n' A + end. + Definition push_monad {n A} + : M (tuple A n) -> tuple (M A) n + := match n return M (tuple A n) -> tuple (M A) n with + | 0 => fun _ => tt + | S n' => @push_monad' n' A + end. +End monad. +Local Notation option_bind + := (fun A B (x : option A) f => match x with + | Some x' => f x' + | None => None + end). +Definition lift_option {n A} (xs : tuple (option A) n) : option (tuple A n) + := lift_monad option option_bind (@Some) xs. +Definition push_option {n A} (xs : option (tuple A n)) : tuple (option A) n + := push_monad option option_bind (@Some) xs. + +Lemma lift_push_option {n A} (xs : option (tuple A (S n))) : lift_option (push_option xs) = xs. +Proof. + simpl in *. + induction n; [ reflexivity | ]. + simpl in *; rewrite IHn; clear IHn. + destruct xs as [ [? ?] | ]; reflexivity. +Qed. + +Lemma push_lift_option {n A} {xs : tuple (option A) (S n)} {v} + : lift_option xs = Some v <-> xs = push_option (Some v). +Proof. + simpl in *. + induction n; [ reflexivity | ]. + specialize (IHn (fst xs) (fst v)). + repeat first [ progress destruct_head_hnf' prod + | progress destruct_head_hnf' and + | progress destruct_head_hnf' iff + | progress destruct_head_hnf' option + | progress inversion_option + | progress inversion_prod + | progress subst + | progress break_match + | progress simpl in * + | progress specialize_by exact eq_refl + | reflexivity + | split + | intro ]. +Qed. + Fixpoint fieldwise' {A B} (n:nat) (R:A->B->Prop) (a:tuple' A n) (b:tuple' B n) {struct n} : Prop. destruct n; simpl @tuple' in *. { exact (R a b). } diff --git a/src/Util/WordUtil.v b/src/Util/WordUtil.v index 36fd21d28..24160d83e 100644 --- a/src/Util/WordUtil.v +++ b/src/Util/WordUtil.v @@ -34,6 +34,15 @@ Proof. auto. Qed. +Lemma Z_land_le : forall x y, (0 <= x)%Z -> (Z.land x y <= x)%Z. +Proof. + intros; apply Z.ldiff_le; [assumption|]. + rewrite Z.ldiff_land, Z.land_comm, Z.land_assoc. + rewrite <- Z.land_0_l with (a := y); f_equal. + rewrite Z.land_comm, Z.land_lnot_diag. + reflexivity. +Qed. + Lemma wordToN_NToWord_idempotent : forall sz n, (n < Npow2 sz)%N -> wordToN (NToWord sz n) = n. Proof. @@ -364,3 +373,289 @@ Proof. Qed. Hint Rewrite @wordToN_wor using word_util_arith : push_wordToN. Hint Rewrite <- @wordToN_wor using word_util_arith : pull_wordToN. + +Local Notation bound n lower value upper := ( + (0 <= lower)%Z + /\ (lower <= Z.of_N (@wordToN n value))%Z + /\ (Z.of_N (@wordToN n value) <= upper)%Z). + +Definition valid_update n lowerF valueF upperF : Prop := + forall lower0 value0 upper0 + lower1 value1 upper1, + + bound n lower0 value0 upper0 + -> bound n lower1 value1 upper1 + -> (0 <= lowerF lower0 upper0 lower1 upper1)%Z + -> (Z.log2 (upperF lower0 upper0 lower1 upper1) < Z.of_nat n)%Z + -> bound n (lowerF lower0 upper0 lower1 upper1) + (valueF value0 value1) + (upperF lower0 upper0 lower1 upper1). + +Local Ltac add_mono := + etransitivity; [| apply Z.add_le_mono_r; eassumption]; omega. + +Lemma add_valid_update: forall n, + valid_update n + (fun l0 u0 l1 u1 => l0 + l1)%Z + (@wplus n) + (fun l0 u0 l1 u1 => u0 + u1)%Z. +Proof. + unfold valid_update; intros until upper1; intros B0 B1. + destruct B0 as [? B0], B1 as [? B1], B0, B1. + repeat split; [add_mono| |]; ( + rewrite wordToN_wplus; [add_mono|add_mono|]; + eapply Z.le_lt_trans; [| eassumption]; + apply Z.log2_le_mono; add_mono). +Qed. + +Local Ltac sub_mono := + etransitivity; + [| apply Z.sub_le_mono_r]; eauto; + first [ reflexivity + | apply Z.sub_le_mono_l; assumption + | apply Z.le_add_le_sub_l; etransitivity; [|eassumption]; + repeat rewrite Z.add_0_r; assumption]. + +Lemma sub_valid_update: forall n, + valid_update n + (fun l0 u0 l1 u1 => l0 - u1)%Z + (@wminus n) + (fun l0 u0 l1 u1 => u0 - l1)%Z. +Proof. + unfold valid_update; intros until upper1; intros B0 B1. + destruct B0 as [? B0], B1 as [? B1], B0, B1. + repeat split; [sub_mono| |]; ( + rewrite wordToN_wminus; [sub_mono|omega|]; + eapply Z.le_lt_trans; [apply Z.log2_le_mono|eassumption]; sub_mono). +Qed. + +Local Ltac mul_mono := + etransitivity; [|apply Z.mul_le_mono_nonneg_r]; + repeat first + [ eassumption + | reflexivity + | apply Z.mul_le_mono_nonneg_l + | rewrite Z.mul_0_l + | omega]. + +Lemma mul_valid_update: forall n, + valid_update n + (fun l0 u0 l1 u1 => l0 * l1)%Z + (@wmult n) + (fun l0 u0 l1 u1 => u0 * u1)%Z. +Proof. + unfold valid_update; intros until upper1; intros B0 B1. + destruct B0 as [? B0], B1 as [? B1], B0, B1. + repeat split; [mul_mono| |]; ( + rewrite wordToN_wmult; [mul_mono|mul_mono|]; + eapply Z.le_lt_trans; [| eassumption]; + apply Z.log2_le_mono; mul_mono). +Qed. + +Local Ltac solve_land_ge0 := + apply Z.land_nonneg; left; etransitivity; [|eassumption]; assumption. + +Local Ltac land_mono := + first [assumption | etransitivity; [|eassumption]; assumption]. + +Lemma land_valid_update: forall n, + valid_update n + (fun l0 u0 l1 u1 => 0)%Z + (@wand n) + (fun l0 u0 l1 u1 => Z.min u0 u1)%Z. +Proof. + unfold valid_update; intros until upper1; intros B0 B1. + destruct B0 as [? B0], B1 as [? B1], B0, B1. + repeat split; [reflexivity| |]. + + - rewrite wordToN_wand; [solve_land_ge0|solve_land_ge0|]. + eapply Z.le_lt_trans; [apply Z.log2_land; land_mono|]; + eapply Z.le_lt_trans; [| eassumption]; + repeat match goal with + | [|- context[Z.min ?a ?b]] => + destruct (Z.min_dec a b) as [g|g]; rewrite g; clear g + end; apply Z.log2_le_mono; try assumption. + + admit. admit. + + - rewrite wordToN_wand; [|solve_land_ge0|]. + eapply Z.le_lt_trans; [apply Z.log2_land; land_mono|]; + match goal with + | [|- (Z.min ?a ?b < _)%Z] => + destruct (Z.min_dec a b) as [g|g]; rewrite g; clear g + end. + + . + (* +[apply N2Z.is_nonneg|]; + unfold Word64.word64ToZ; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; + rewrite wordize_and. + + destruct (Z_ge_dec upper1 upper0) as [g|g]. + + - rewrite Z.min_r; [|abstract (apply Z.log2_le_mono; omega)]. + abstract ( + rewrite (land_intro_ones (wordToN value0)); + rewrite N.land_assoc; + etransitivity; [apply N2Z.inj_le; apply N.lt_le_incl; apply land_lt_Npow2|]; + rewrite N2Z.inj_pow; + apply Z.pow_le_mono; [abstract (split; cbn; [omega|reflexivity])|]; + unfold getBits; rewrite N2Z.inj_succ; + apply -> Z.succ_le_mono; + rewrite <- (N2Z.id (wordToN value0)), <- log2_conv; + apply Z.log2_le_mono; + etransitivity; [eassumption|reflexivity]). + + - rewrite Z.min_l; [|abstract (apply Z.log2_le_mono; omega)]. + abstract ( + rewrite (land_intro_ones (wordToN value1)); + rewrite <- N.land_comm, N.land_assoc; + etransitivity; [apply N2Z.inj_le; apply N.lt_le_incl; apply land_lt_Npow2|]; + rewrite N2Z.inj_pow; + apply Z.pow_le_mono; [abstract (split; cbn; [omega|reflexivity])|]; + unfold getBits; rewrite N2Z.inj_succ; + apply -> Z.succ_le_mono; + rewrite <- (N2Z.id (wordToN value1)), <- log2_conv; + apply Z.log2_le_mono; + etransitivity; [eassumption|reflexivity]). + +*) +Admitted. + +Lemma lor_valid_update: forall n, + valid_update n + (fun l0 u0 l1 u1 => Z.max l0 l1)%Z + (@wor n) + (fun l0 u0 l1 u1 => 2^(Z.max (Z.log2 (u0+1)) (Z.log2 (u1+1))) - 1)%Z. +Proof. +(* unfold Word64.word64ToZ in *; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; + rewrite wordize_or. + + - transitivity (Z.max (Z.of_N (wordToN value1)) (Z.of_N (wordToN value0))); + [ abstract (destruct + (Z_ge_dec lower1 lower0) as [l|l], + (Z_ge_dec (Z.of_N (& value1)%w) (Z.of_N (& value0)%w)) as [v|v]; + [ rewrite Z.max_l, Z.max_l | rewrite Z.max_l, Z.max_r + | rewrite Z.max_r, Z.max_l | rewrite Z.max_r, Z.max_r ]; + + try (omega || assumption)) + | ]. + + rewrite <- N2Z.inj_max. + apply Z2N.inj_le; [apply N2Z.is_nonneg|apply N2Z.is_nonneg|]. + repeat rewrite N2Z.id. + + abstract ( + destruct (N.max_dec (wordToN value1) (wordToN value0)) as [v|v]; + rewrite v; + apply N.ldiff_le, N.bits_inj_iff; intros k; + rewrite N.ldiff_spec, N.lor_spec; + induction (N.testbit (wordToN value1)), (N.testbit (wordToN value0)); simpl; + reflexivity). + + - apply Z.lt_le_incl, Z.log2_lt_cancel. + rewrite Z.log2_pow2; [| abstract ( + destruct (Z.max_dec (Z.log2 upper1) (Z.log2 upper0)) as [g|g]; + rewrite g; apply Z.le_le_succ_r, Z.log2_nonneg)]. + + eapply (Z.le_lt_trans _ (Z.log2 (Z.lor _ _)) _). + + + apply Z.log2_le_mono, Z.eq_le_incl. + apply Z.bits_inj_iff'; intros k Hpos. + rewrite Z2N.inj_testbit, Z.lor_spec, N.lor_spec; [|assumption]. + repeat (rewrite <- Z2N.inj_testbit; [|assumption]). + reflexivity. + + + abstract ( + rewrite Z.log2_lor; [|trans'|trans']; + destruct + (Z_ge_dec (Z.of_N (wordToN value1)) (Z.of_N (wordToN value0))) as [g0|g0], + (Z_ge_dec upper1 upper0) as [g1|g1]; + [ rewrite Z.max_l, Z.max_l + | rewrite Z.max_l, Z.max_r + | rewrite Z.max_r, Z.max_l + | rewrite Z.max_r, Z.max_r]; + try apply Z.log2_le_mono; try omega; + apply Z.le_succ_l; + apply -> Z.succ_le_mono; + apply Z.log2_le_mono; + assumption || (etransitivity; [eassumption|]; omega)). +*) +Admitted. + +Lemma shr_valid_update: forall n, + valid_update n + (fun l0 u0 l1 u1 => Z.shiftr l0 u1)%Z + (@wordBin N.shiftr n) + (fun l0 u0 l1 u1 => Z.shiftr u0 l1)%Z. +Proof. + (* + Ltac shr_mono := etransitivity; + [apply Z.div_le_compat_l | apply Z.div_le_mono]. + + assert (forall x, (0 <= x)%Z -> (0 < 2^x)%Z) as gt0. { + intros; rewrite <- (Z2Nat.id x); [|assumption]. + induction (Z.to_nat x) as [|n]; [cbv; auto|]. + eapply Z.lt_le_trans; [eassumption|rewrite Nat2Z.inj_succ]. + apply Z.pow_le_mono_r; [cbv; auto|omega]. + } + + build_binop Word64.w64shr ZBounds.shr; t_start; abstract ( + unfold Word64.word64ToZ; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; + rewrite Z.shiftr_div_pow2 in *; + repeat match goal with + | [|- _ /\ _ ] => split + | [|- (0 <= 2 ^ _)%Z ] => apply Z.pow_nonneg + | [|- (0 < 2 ^ ?X)%Z ] => apply gt0 + | [|- (0 <= _ / _)%Z ] => apply Z.div_le_lower_bound; [|rewrite Z.mul_0_r] + | [|- (2 ^ _ <= 2 ^ _)%Z ] => apply Z.pow_le_mono_r + | [|- context[(?a >> ?b)%Z]] => rewrite Z.shiftr_div_pow2 in * + | [|- (_ < Npow2 _)%N] => + apply N2Z.inj_lt, Z.log2_lt_cancel; simpl; + eapply Z.le_lt_trans; [|eassumption]; apply Z.log2_le_mono; rewrite Z2N.id + + | _ => progress shr_mono + | _ => progress trans' + | _ => progress omega + end). + +*) +Admitted. + +Lemma shl_valid_update: forall n, + valid_update n + (fun l0 u0 l1 u1 => Z.shiftl l0 l1)%Z + (@wordBin N.shiftl n) + (fun l0 u0 l1 u1 => Z.shiftl u0 u1)%Z. +Proof. + (* + Ltac shl_mono := etransitivity; + [apply Z.mul_le_mono_nonneg_l | apply Z.mul_le_mono_nonneg_r]. + + build_binop Word64.w64shl ZBounds.shl; t_start; abstract ( + unfold Word64.word64ToZ; repeat rewrite wordToN_NToWord; repeat rewrite Z2N.id; + rewrite Z.shiftl_mul_pow2 in *; + repeat match goal with + | [|- (0 <= 2 ^ _)%Z ] => apply Z.pow_nonneg + | [|- (0 <= _ * _)%Z ] => apply Z.mul_nonneg_nonneg + | [|- (2 ^ _ <= 2 ^ _)%Z ] => apply Z.pow_le_mono_r + | [|- context[(?a << ?b)%Z]] => rewrite Z.shiftl_mul_pow2 + | [|- (_ < Npow2 _)%N] => + apply N2Z.inj_lt, Z.log2_lt_cancel; simpl; + eapply Z.le_lt_trans; [|eassumption]; apply Z.log2_le_mono; rewrite Z2N.id + + | _ => progress shl_mono + | _ => progress trans' + | _ => progress omega + end). + +*) +Admitted. + + +Axiom wlast : forall sz, word (sz+1) -> bool. Arguments wlast {_} _. +Axiom winit : forall sz, word (sz+1) -> word sz. Arguments winit {_} _. +Axiom combine_winit_wlast : forall {sz} a b (c:word (sz+1)), + @combine sz a 1 b = c <-> a = winit c /\ b = (WS (wlast c) WO). +Axiom winit_combine : forall sz a b, @winit sz (combine a b) = a. +Axiom wlast_combine : forall sz a b, @wlast sz (combine a (WS b WO)) = b.
\ No newline at end of file diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v index 5417c3407..d6ffa4a53 100644 --- a/src/Util/ZUtil.v +++ b/src/Util/ZUtil.v @@ -4,6 +4,7 @@ Require Import Coq.Structures.Equalities. Require Import Coq.omega.Omega Coq.micromega.Psatz Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. Require Import Crypto.Util.NatUtil. Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.Bool. Require Import Crypto.Util.Notations. Require Import Coq.Lists.List. Require Export Crypto.Util.FixCoqMistakes. @@ -21,6 +22,8 @@ Hint Extern 1 => nia : nia. Hint Extern 1 => omega : omega. Hint Resolve Z.log2_nonneg Z.div_small Z.mod_small Z.pow_neg_r Z.pow_0_l Z.pow_pos_nonneg Z.lt_le_incl Z.pow_nonzero Z.div_le_upper_bound Z_div_exact_full_2 Z.div_same Z.div_lt_upper_bound Z.div_le_lower_bound Zplus_minus Zplus_gt_compat_l Zplus_gt_compat_r Zmult_gt_compat_l Zmult_gt_compat_r Z.pow_lt_mono_r Z.pow_lt_mono_l Z.pow_lt_mono Z.mul_lt_mono_nonneg Z.div_lt_upper_bound Z.div_pos Zmult_lt_compat_r Z.pow_le_mono_r Z.pow_le_mono_l Z.div_lt : zarith. Hint Resolve (fun a b H => proj1 (Z.mod_pos_bound a b H)) (fun a b H => proj2 (Z.mod_pos_bound a b H)) (fun a b pf => proj1 (Z.pow_gt_1 a b pf)) : zarith. +Hint Resolve (fun n m => proj1 (Z.pred_le_mono n m)) : zarith. +Hint Resolve (fun a b => proj2 (Z.lor_nonneg a b)) : zarith. Ltac zutil_arith := solve [ omega | lia | auto with nocore ]. Ltac zutil_arith_more_inequalities := solve [ zutil_arith | auto with zarith ]. @@ -1083,6 +1086,21 @@ Module Z. inversion H; trivial. Qed. + Lemma ones_le x y : x <= y -> Z.ones x <= Z.ones y. + Proof. + rewrite !Z.ones_equiv; auto with zarith. + Qed. + Hint Resolve ones_le : zarith. + + Lemma geb_spec0 : forall x y : Z, Bool.reflect (x >= y) (x >=? y). + Proof. + intros x y; pose proof (Zge_cases x y) as H; destruct (Z.geb x y); constructor; omega. + Qed. + Lemma gtb_spec0 : forall x y : Z, Bool.reflect (x > y) (x >? y). + Proof. + intros x y; pose proof (Zgt_cases x y) as H; destruct (Z.gtb x y); constructor; omega. + Qed. + Ltac ltb_to_lt_with_hyp H lem := let H' := fresh in rename H into H'; @@ -1090,6 +1108,10 @@ Module Z. rewrite H' in H; clear H'. + Ltac ltb_to_lt_in_goal b' lem := + refine (proj1 (@reflect_iff_gen _ _ lem b') _); + cbv beta iota. + Ltac ltb_to_lt := repeat match goal with | [ H : (?x <? ?y) = ?b |- _ ] @@ -1102,6 +1124,16 @@ Module Z. => ltb_to_lt_with_hyp H (Zge_cases x y) | [ H : (?x =? ?y) = ?b |- _ ] => ltb_to_lt_with_hyp H (eqb_cases x y) + | [ |- (?x <? ?y) = ?b ] + => ltb_to_lt_in_goal b (Z.ltb_spec0 x y) + | [ |- (?x <=? ?y) = ?b ] + => ltb_to_lt_in_goal b (Z.leb_spec0 x y) + | [ |- (?x >? ?y) = ?b ] + => ltb_to_lt_in_goal b (Z.gtb_spec0 x y) + | [ |- (?x >=? ?y) = ?b ] + => ltb_to_lt_in_goal b (Z.geb_spec0 x y) + | [ |- (?x =? ?y) = ?b ] + => ltb_to_lt_in_goal b (Z.eqb_spec x y) end. Ltac compare_to_sgn := @@ -2054,6 +2086,57 @@ Module Z. Qed. Hint Resolve shiftr_nonneg_le : zarith. + Lemma log2_pred_pow2_full a : Z.log2 (Z.pred (2^a)) = Z.max 0 (Z.pred a). + Proof. + destruct (Z_dec 0 a) as [ [?|?] | ?]. + { rewrite Z.log2_pred_pow2 by assumption. + apply Z.max_case_strong; omega. } + { autorewrite with zsimplify; simpl. + apply Z.max_case_strong; omega. } + { subst; compute; reflexivity. } + Qed. + Hint Rewrite log2_pred_pow2_full : zsimplify. + + Lemma ones_lt_pow2 x y : 0 <= x <= y -> Z.ones x < 2^y. + Proof. + rewrite Z.ones_equiv, Z.lt_pred_le. + auto with zarith. + Qed. + Hint Resolve ones_lt_pow2 : zarith. + + Lemma log2_ones_full x : Z.log2 (Z.ones x) = Z.max 0 (Z.pred x). + Proof. + rewrite Z.ones_equiv, log2_pred_pow2_full; reflexivity. + Qed. + Hint Rewrite log2_ones_full : zsimplify. + + Lemma log2_ones_lt x y : 0 < x <= y -> Z.log2 (Z.ones x) < y. + Proof. + rewrite log2_ones_full; apply Z.max_case_strong; omega. + Qed. + Hint Resolve log2_ones_lt : zarith. + + Lemma log2_ones_le x y : 0 <= x <= y -> Z.log2 (Z.ones x) <= y. + Proof. + rewrite log2_ones_full; apply Z.max_case_strong; omega. + Qed. + Hint Resolve log2_ones_le : zarith. + + Lemma log2_ones_lt_nonneg x y : 0 < y -> x <= y -> Z.log2 (Z.ones x) < y. + Proof. + rewrite log2_ones_full; apply Z.max_case_strong; omega. + Qed. + Hint Resolve log2_ones_lt_nonneg : zarith. + + Lemma log2_lt_pow2_alt a b : 0 < b -> a < 2^b <-> Z.log2 a < b. + Proof. + destruct (Z_lt_le_dec 0 a); auto using Z.log2_lt_pow2; []. + rewrite Z.log2_nonpos by omega. + split; auto with zarith; []. + intro; eapply le_lt_trans; [ eassumption | ]. + auto with zarith. + Qed. + Lemma simplify_twice_sub_sub x y : 2 * x - (x - y) = x + y. Proof. lia. Qed. Hint Rewrite simplify_twice_sub_sub : zsimplify. @@ -2828,6 +2911,44 @@ for name in names: Module RemoveEquivModuloInstances (dummy : Nop). Global Remove Hints equiv_modulo_Reflexive equiv_modulo_Symmetric equiv_modulo_Transitive mul_mod_Proper add_mod_Proper sub_mod_Proper opp_mod_Proper modulo_equiv_modulo_Proper eq_to_ProperProxy : typeclass_instances. End RemoveEquivModuloInstances. + + Module N2Z. + Require Import Coq.NArith.NArith. + + Lemma inj_shiftl: forall x y, Z.of_N (N.shiftl x y) = Z.shiftl (Z.of_N x) (Z.of_N y). + Proof. + intros. + apply Z.bits_inj_iff'; intros k Hpos. + rewrite Z2N.inj_testbit; [|assumption]. + rewrite Z.shiftl_spec; [|assumption]. + + assert ((Z.to_N k) >= y \/ (Z.to_N k) < y)%N as g by ( + unfold N.ge, N.lt; induction (N.compare (Z.to_N k) y); [left|auto|left]; + intro H; inversion H). + + destruct g as [g|g]; + [ rewrite N.shiftl_spec_high; [|apply N2Z.inj_le; rewrite Z2N.id|apply N.ge_le] + | rewrite N.shiftl_spec_low]; try assumption. + + - rewrite <- N2Z.inj_testbit; f_equal. + rewrite N2Z.inj_sub, Z2N.id; [reflexivity|assumption|apply N.ge_le; assumption]. + + - apply N2Z.inj_lt in g. + rewrite Z2N.id in g; [symmetry|assumption]. + apply Z.testbit_neg_r; omega. + Qed. + + Lemma inj_shiftr: forall x y, Z.of_N (N.shiftr x y) = Z.shiftr (Z.of_N x) (Z.of_N y). + Proof. + intros. + apply Z.bits_inj_iff'; intros k Hpos. + rewrite Z2N.inj_testbit; [|assumption]. + rewrite Z.shiftr_spec, N.shiftr_spec; [|apply N2Z.inj_le; rewrite Z2N.id|]; try assumption. + rewrite <- N2Z.inj_testbit; f_equal. + rewrite N2Z.inj_add; f_equal. + apply Z2N.id; assumption. + Qed. + End N2Z. End Z. Module Export BoundsTactics. |