From a5ef5f1ab5517742721e394a40e3289f20a809d6 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Mon, 22 Jan 2018 13:19:04 -0500 Subject: Use Derive plugin, do two passes of partial reduction --- src/Experiments/SimplyTypedArithmetic.v | 284 ++++++++++++++++++++------------ 1 file changed, 175 insertions(+), 109 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index f63eb9c22..7010069dd 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -1,11 +1,13 @@ (* Following http://adam.chlipala.net/theses/andreser.pdf chapter 3 *) Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. +Require Import Coq.derive.Derive. Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. Require Import Crypto.Util.Tactics.RunTacticAsConstr. +Require Import Crypto.Util.Tactics.Head. Require Import Crypto.Util.Notations. Import ListNotations. Local Open Scope Z_scope. @@ -2291,102 +2293,129 @@ Open Scope RT_expr_scope. Require Import AdmitAxiom. -(*Definition w (i:nat) : Z := 2^Qceiling((25+1/2)*i).*) -Definition w (i:nat) : Z := 2^Qceiling(51*i). -Example base_51_carry_mul (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 : Z) - (f:=(f0 :: f1 :: f2 :: f3 :: f4 :: f5 :: f6 :: f7 :: f8 :: f9 :: nil)%list) - (g:=(f0 :: f1 :: f2 :: f3 :: f4 :: f5 :: f6 :: f7 :: f8 :: f9 :: nil)%list)*) (fg : list Z * list Z) - (f := fst fg) (g := snd fg) - (n:=5%nat) - (Hf : length f = n) (Hg : length g = n) - : { fg : list Z | (eval w n fg) mod (2^255-19) - = (eval w n f * eval w n g) mod (2^255-19) }. - (* manually assign names to limbs for pretty-printing *) - eexists ?[fg]. - erewrite <-eval_mulmod with (s:=2^255) (c:=[(1,19)]) - by (try assumption; try eapply pow_ceil_mul_nat_nonzero; vm_decide). -(* eval w ?fg mod (2 ^ 255 - 19) = *) -(* eval w *) -(* (mulmod w (2^255) [(1, 19)] (f9,f8,f7,f6,f5,f4,f3,f2,f1,f0) *) -(* (g9,g8,g7,g6,g5,g4,g3,g2,g1,g0)) mod (2^255 - 19) *) - etransitivity; (* work around [rewrite] being stupid about evars *) - [ - | rewrite <- eval_chained_carries with (s:=2^255) (c:=[(1,19)]) (idxs:=(seq 0 n ++ [0; 1])%list%nat) (modulo:=fun x y => Z.modulo x y) (div:=fun x y => Z.div x y) - by (try assumption; auto using Z.div_mod; try (intros; eapply pow_ceil_mul_nat_divide_nonzero); try eapply pow_ceil_mul_nat_nonzero; try vm_decide); - reflexivity ]. + +Example test1 : True. +Proof. + let v := Reify ((fun x => 2^x) 255)%Z in + pose v as E. + vm_compute in E. + pose (PartialReduce (canonicalize_list_recursion E)) as E'. + vm_compute in E'. + lazymatch (eval cbv delta [E'] in E') with + | (fun var => AppIdent (ident.primitive ?v) TT) => idtac + end. + constructor. +Qed. +Example test2 : True. +Proof. + let v := Reify (fun y : Z + => (fun k : Z * Z -> Z * Z + => dlet_nd x := (y * y)%RT in + dlet_nd z := (x * x)%RT in + k (z, z)) + (fun v => v)) in + pose v as E. + vm_compute in E. + pose (PartialReduce (canonicalize_list_recursion E)) as E'. + vm_compute in E'. + lazymatch (eval cbv delta [E'] in E') with + | (fun var : type -> Type => + (λ x : var (type.type_primitive type.Z), + expr_let x0 := (Var x * Var x)%RT_expr in + expr_let x1 := (Var x0 * Var x0)%RT_expr in + (Var x1, Var x1))%expr) => idtac + end. + constructor. +Qed. +Example test3 : True. +Proof. + let v := Reify (fun y : Z + => dlet_nd x := dlet_nd x := (y * y)%RT in + (x * x)%RT in + dlet_nd z := dlet_nd z := (x * x)%RT in + (z * z)%RT in + (z * z)%RT) in + pose v as E. + vm_compute in E. + pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'. + vm_compute in E'. + pose (PartialReduce E') as E''. + lazy in E''. + lazymatch (eval cbv delta [E''] in E'') with + | (fun var : type -> Type => + (λ x : var (type.type_primitive type.Z), + expr_let x0 := Var x * Var x in + expr_let x1 := Var x0 * Var x0 in + expr_let x2 := Var x1 * Var x1 in + expr_let x3 := Var x2 * Var x2 in + Var x3 * Var x3)%RT_expr%expr) + => idtac + end. + constructor. +Qed. + +Axiom admit : forall {T}, T. + +Derive carry_mul_gen + SuchThat (forall (w : nat -> Z) + (fg : list Z * list Z) + (f := fst fg) (g := snd fg) + (n : nat) + (Hf : length f = n) + (Hg : length g = n) + (s : Z) + (c : list (Z * Z)) + (len_c : nat) + (Hc : length c = len_c) + (idxs : list nat) + (len_idxs : nat) + (Hidxs : length idxs = len_idxs) + (Hw0_1 : w 0%nat = 1) + (Hw_nz : forall i : nat, w i <> 0) + (Hw_div_nz : forall i : nat, w (S i) / w i <> 0) + (Hsc_nz : s - Associational.eval c <> 0) + (Hs_nz : s <> 0) + (Hn_nz : n <> 0%nat), + let fg' := carry_mul_gen w fg n s c len_c idxs len_idxs in + (eval w n fg') mod (s - Associational.eval c) + = (eval w n f * eval w n g) mod (s - Associational.eval c)) + As carry_mul_gen_correct. +Proof. + intros; subst carry_mul_gen. + erewrite <-eval_mulmod with (s:=s) (c:=c) + by (try assumption; try reflexivity). + (* eval w n (fg' w fg n s c len_c) mod (s - Associational.eval c) = + eval w n (mulmod w s c n f g) mod (s - Associational.eval c) *) + etransitivity; + [ | rewrite <- eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) (modulo:=fun x y => Z.modulo x y) (div:=fun x y => Z.div x y) + by (try assumption; auto using Z.div_mod); reflexivity ]. eapply f_equal2; [|trivial]. eapply f_equal. -(* ?fg = *) -(* mulmod w (2 ^ 255) [(1, 19)] (f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) *) -(* (g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) *) - (*cbv [f g].*) - cbv [w Qceiling Qfloor Qopp Qnum Qdiv Qplus inject_Z Qmult Qinv Qden Pos.mul]. - let ev := match goal with |- ?ev = _ => ev end in - set (e := ev). - rewrite <- (expand_list_correct n (-1)%Z f), <- (expand_list_correct n (-1)%Z g) by assumption; subst e. + erewrite <- (expand_list_correct _ (-1)%Z f), + <- (expand_list_correct _ (-1)%Z g), + <- (expand_list_correct _ 0%nat idxs), + <- (expand_list_correct _ (-1,-1)%Z c) + by eassumption. + pose (idxs, len_idxs, n, s, c, len_c, w, fg) as args. + subst f g. + change fg with (snd args). + change w with (snd (fst args)). + change len_c with (snd (fst (fst args))). + change c with (snd (fst (fst (fst args)))). + change s with (snd (fst (fst (fst (fst args))))). + change n with (snd (fst (fst (fst (fst (fst args)))))). + change len_idxs with (snd (fst (fst (fst (fst (fst (fst args))))))). + change idxs with (fst (fst (fst (fst (fst (fst (fst args))))))). + remember args as args' eqn:Hargs. etransitivity. Focus 2. - { subst f g. - repeat match goal with H : _ |- _ => clear H end; revert fg. + { subst fg'. + repeat match goal with H : _ |- _ => clear H end; revert args'. lazymatch goal with - | [ |- forall fg, ?ev = @?RHS fg ] - => refine (fun fg => f_equal (fun F => F fg) (_ : _ = RHS)) + | [ |- forall args, ?ev = @?RHS args ] + => refine (fun args => f_equal (fun F => F args) (_ : _ = RHS)) end. - cbv [n expand_list expand_list_helper]. - cbv delta [chained_carries carry carry_reduce Associational.carry carryterm mulmod w to_associational mul to_associational reduce from_associational add_to_nth zeros place split]. - Locate Ltac Reify. - assert True. - { let v := Reify ((fun x => 2^x) 255)%Z in - pose v as E. - vm_compute in E. - pose (PartialReduce (canonicalize_list_recursion E)) as E'. - vm_compute in E'. - lazymatch (eval cbv delta [E'] in E') with - | (fun var => AppIdent (ident.primitive ?v) TT) => idtac - end. - constructor. } - assert True. - { let v := Reify (fun y : Z - => (fun k : Z * Z -> Z * Z - => dlet_nd x := (y * y)%RT in - dlet_nd z := (x * x)%RT in - k (z, z)) - (fun v => v)) in - pose v as E. - vm_compute in E. - pose (PartialReduce (canonicalize_list_recursion E)) as E'. - vm_compute in E'. - lazymatch (eval cbv delta [E'] in E') with - | (fun var : type -> Type => - (λ x : var (type.type_primitive type.Z), - expr_let x0 := (Var x * Var x)%RT_expr in - expr_let x1 := (Var x0 * Var x0)%RT_expr in - (Var x1, Var x1))%expr) => idtac - end. - constructor. } - assert True. - { let v := Reify (fun y : Z - => dlet_nd x := dlet_nd x := (y * y)%RT in - (x * x)%RT in - dlet_nd z := dlet_nd z := (x * x)%RT in - (z * z)%RT in - (z * z)%RT) in - pose v as E. - vm_compute in E. - pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'. - vm_compute in E'. - pose (PartialReduce E') as E''. - lazy in E''. - lazymatch (eval cbv delta [E''] in E'') with - | (fun var : type -> Type => - (λ x : var (type.type_primitive type.Z), - expr_let x0 := Var x * Var x in - expr_let x1 := Var x0 * Var x0 in - expr_let x2 := Var x1 * Var x1 in - expr_let x3 := Var x2 * Var x2 in - Var x3 * Var x3)%RT_expr%expr) - => idtac - end. - constructor. } + cbv [expand_list expand_list_helper]. + cbv delta [chained_carries carry carry_reduce Associational.carry carryterm mulmod to_associational mul to_associational reduce from_associational add_to_nth zeros place split]. Reify_rhs (). reflexivity. } Unfocus. @@ -2396,26 +2425,63 @@ Example base_51_carry_mul (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 Time let E' := constr:(PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in let E' := (eval lazy in E') in pose E' as E''. - transitivity (Interp E'' fg); [ clear E | admit ]. + transitivity (Interp E'' (fst (fst args'), (fun '((i, k) : nat * (Z -> list Z)) => k (w i)), snd args')); [ clear E | exact admit ]. + subst args' args; cbn [fst snd]. + subst fg'. reflexivity. - (*cbv -[runtime_mul runtime_add]; cbv [runtime_mul runtime_add]. - ring_simplify_subterms.*) -(* ?fg = - (f0*g9+ f1*g8+ f2*g7+ f3*g6+ f4*g5+ f5*g4+ f6*g3+ f7*g2+ f8*g1+ f9*g0, - f0*g8+ 2*f1*g7+ f2*g6+ 2*f3*g5+ f4*g4+ 2*f5*g3+ f6*g2+ 2*f7*g1+ f8*g0+ 38*f9*g9, - f0*g7+ f1*g6+ f2*g5+ f3*g4+ f4*g3+ f5*g2+ f6*g1+ f7*g0+ 19*f8*g9+ 19*f9*g8, - f0*g6+ 2*f1*g5+ f2*g4+ 2*f3*g3+ f4*g2+ 2*f5*g1+ f6*g0+ 38*f7*g9+ 19*f8*g8+ 38*f9*g7, - f0*g5+ f1*g4+ f2*g3+ f3*g2+ f4*g1+ f5*g0+ 19*f6*g9+ 19*f7*g8+ 19*f8*g7+ 19*f9*g6, - f0*g4+ 2*f1*g3+ f2*g2+ 2*f3*g1+ f4*g0+ 38*f5*g9+ 19*f6*g8+ 38*f7*g7+ 19*f8*g6+ 38*f9*g5, - f0*g3+ f1*g2+ f2*g1+ f3*g0+ 19*f4*g9+ 19*f5*g8+ 19*f6*g7+ 19*f7*g6+ 19*f8*g5+ 19*f9*g4, - f0*g2+ 2*f1*g1+ f2*g0+ 38*f3*g9+ 19*f4*g8+ 38*f5*g7+ 19*f6*g6+ 38*f7*g5+ 19*f8*g4+ 38*f9*g3, - f0*g1+ f1*g0+ 19*f2*g9+ 19*f3*g8+ 19*f4*g7+ 19*f5*g6+ 19*f6*g5+ 19*f7*g4+ 19*f8*g3+ 19*f9*g2, - f0*g0+ 38*f1*g9+ 19*f2*g8+ 38*f3*g7+ 19*f4*g6+ 38*f5*g5+ 19*f6*g4+ 38*f7*g3+ 19*f8*g2+ 38*f9*g1) *) - (*trivial.*) -Defined. +Qed. + +(*Definition w (i:nat) : Z := 2^Qceiling((25+1/2)*i).*) +Definition w (i:nat) : Z := 2^Qceiling(51*i). +Derive base_51_carry_mul + SuchThat (forall + (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 : Z) + (f:=(f0 :: f1 :: f2 :: f3 :: f4 :: f5 :: f6 :: f7 :: f8 :: f9 :: nil)%list) + (g:=(f0 :: f1 :: f2 :: f3 :: f4 :: f5 :: f6 :: f7 :: f8 :: f9 :: nil)%list)*) + (fg : list Z * list Z) + (f := fst fg) (g := snd fg) + (n:=5%nat) + (Hf : length f = n) (Hg : length g = n) + (fg' := base_51_carry_mul fg), + (eval w n fg') mod (2^255-19) + = (eval w n f * eval w n g) mod (2^255-19)) + As base_51_carry_mul_correct. +Proof. + intros; subst f g fg'. + erewrite <- carry_mul_gen_correct with (s:=2^255) (c:=[(1, 19)]) (idxs:=(seq 0 n ++ [0; 1])%list%nat) + by (cbn [length seq n List.app]; try reflexivity; try assumption; + try (intros; eapply pow_ceil_mul_nat_divide_nonzero); + try eapply pow_ceil_mul_nat_nonzero; + (apply dec_bool; vm_compute; reflexivity)). + cbn [length seq n List.app]. + cbv [w Qceiling Qfloor Qopp Qnum Qdiv Qplus inject_Z Qmult Qinv Qden Pos.mul]. + cbv [Associational.eval fold_right map fst snd]. + apply f_equal2; [ | reflexivity ]; apply f_equal. + cbv [carry_mul_gen n]. + lazymatch goal with + | [ |- ?ev = expr.Interp (@ident.interp) ?e (?args, fg) ] + => let rargs := Reify args in + let rargs := constr:(canonicalize_list_recursion rargs) in + transitivity (expr.Interp + (@ident.interp) + (fun var + => λ (FG : var (type.list type.Z * type.list type.Z)%ctype), + (e var @ (rargs var, Var FG)))%expr fg) + end. + 2:cbv [expr.interp expr.Interp ident.interp]; exact admit. + let e := match goal with |- _ = expr.Interp _ ?e _ => e end in + set (E := e). + cbv [canonicalize_list_recursion canonicalize_list_recursion.expr.transfer canonicalize_list_recursion.ident.transfer] in E. + Time let E' := constr:(PartialReduce E) in + let E' := (eval lazy in E') in + pose E' as E''. + transitivity (Interp E'' fg); [ clear E | exact admit ]. + subst base_51_carry_mul. + reflexivity. +Qed. Import ident. -Eval cbv [proj1_sig base_51_carry_mul] in (fun fg Hf Hg => proj1_sig (base_51_carry_mul fg Hf Hg)). +Print base_51_carry_mul. (* = fun (fg : list Z * list Z) (_ : length (Datatypes.fst fg) = 5%nat) (_ : length (Datatypes.snd fg) = 5%nat) => expr.Interp (@interp) @@ -2630,4 +2696,4 @@ Eval cbv [proj1_sig base_51_carry_mul] in (fun fg Hf Hg => proj1_sig (base_51_ca expr_let y82 := Z.div @@ (y81, 2251799813685248) in expr_let y83 := Z.modulo @@ (y81, 2251799813685248) in y74 :: y83 :: y82 + y36 :: y46 :: y55 :: [])%expr) fg -*) \ No newline at end of file +*) -- cgit v1.2.3