diff options
author | Jason Gross <jgross@mit.edu> | 2018-03-21 17:29:16 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-04-04 15:39:34 -0400 |
commit | 8e17c3d75ce9cb9d2c0c3921514e9318776a28de (patch) | |
tree | 78221eefcbc3bab08c4070c5d4e3aba559091ef0 /src | |
parent | c900290d3297ade2cc2e73fe6b322abe52d1715a (diff) |
Stick an uncurry pass in the pipeline
This allows us to (a) consolidate the constant and non-constant
pipelines and (b) vastly simplify the call-with-id-continuation logic.
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 729 |
1 files changed, 155 insertions, 574 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 202e21098..a8a98c37e 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -501,11 +501,10 @@ Section mod_ops. Qed. Derive carry_mulmod - SuchThat (forall (fg : list Z * list Z) - (f := fst fg) (g := snd fg) + SuchThat (forall (f g : list Z) (Hf : length f = n) (Hg : length g = n), - (eval weight n (carry_mulmod fg)) mod (s - Associational.eval c) + (eval weight n (carry_mulmod f g)) mod (s - Associational.eval c) = (eval weight n f * eval weight n g) mod (s - Associational.eval c)) As eval_carry_mulmod. Proof. @@ -516,7 +515,7 @@ Section mod_ops. by auto; reflexivity ]. eapply f_equal2; [|trivial]. eapply f_equal. expand_lists (). - subst f g carry_mulmod; reflexivity. + subst carry_mulmod; reflexivity. Qed. Derive carrymod @@ -536,11 +535,10 @@ Section mod_ops. Qed. Derive addmod - SuchThat (forall (fg: list Z * list Z) - (f := fst fg) (g := snd fg) + SuchThat (forall (f g : list Z) (Hf : length f = n) (Hg : length g = n), - (eval weight n (addmod fg)) mod (s - Associational.eval c) + (eval weight n (addmod f g)) mod (s - Associational.eval c) = (eval weight n f + eval weight n g) mod (s - Associational.eval c)) As eval_addmod. Proof. @@ -548,16 +546,15 @@ Section mod_ops. rewrite <-eval_add by auto. eapply f_equal2; [|trivial]. eapply f_equal. expand_lists (). - subst f g addmod; reflexivity. + subst addmod; reflexivity. Qed. Derive submod SuchThat (forall (coef:Z) - (fg: list Z * list Z) - (f := fst fg) (g := snd fg) + (f g : list Z) (Hf : length f = n) (Hg : length g = n), - (eval weight n (submod coef fg)) mod (s - Associational.eval c) + (eval weight n (submod coef f g)) mod (s - Associational.eval c) = (eval weight n f - eval weight n g) mod (s - Associational.eval c)) As eval_submod. Proof. @@ -565,7 +562,7 @@ Section mod_ops. rewrite <-eval_sub with (coef:=coef) by auto. eapply f_equal2; [|trivial]. eapply f_equal. expand_lists (). - subst f g submod; reflexivity. + subst submod; reflexivity. Qed. Derive oppmod @@ -1634,17 +1631,17 @@ Module Ring. is_bounded_by tight_bounds arg = true -> is_bounded_by loose_bounds (Interp_rrelaxv arg) = true /\ Interp_rrelaxv arg = expanding_id n arg) - (carry_mulmod : list Z * list Z -> list Z) + (carry_mulmod : list Z -> list Z -> list Z) (Hcarry_mulmod - : forall fg, - length (fst fg) = n -> length (snd fg) = n -> - (eval (carry_mulmod fg)) mod (s - Associational.eval c) - = (eval (fst fg) * eval (snd fg)) mod (s - Associational.eval c)) + : forall f g, + length f = n -> length g = n -> + (eval (carry_mulmod f g)) mod (s - Associational.eval c) + = (eval f * eval g) mod (s - Associational.eval c)) (Interp_rcarry_mulv : list Z * list Z -> list Z) (HInterp_rcarry_mulv : forall arg, is_bounded_by2 loose_bounds arg = true -> is_bounded_by tight_bounds (Interp_rcarry_mulv arg) = true - /\ Interp_rcarry_mulv arg = carry_mulmod arg) + /\ Interp_rcarry_mulv arg = carry_mulmod (fst arg) (snd arg)) (carrymod : list Z -> list Z) (Hcarrymod : forall f, @@ -1656,28 +1653,28 @@ Module Ring. is_bounded_by loose_bounds arg = true -> is_bounded_by tight_bounds (Interp_rcarryv arg) = true /\ Interp_rcarryv arg = carrymod arg) - (addmod : list Z * list Z -> list Z) + (addmod : list Z -> list Z -> list Z) (Haddmod - : forall fg, - length (fst fg) = n -> length (snd fg) = n -> - (eval (addmod fg)) mod (s - Associational.eval c) - = (eval (fst fg) + eval (snd fg)) mod (s - Associational.eval c)) + : forall f g, + length f = n -> length g = n -> + (eval (addmod f g)) mod (s - Associational.eval c) + = (eval f + eval g) mod (s - Associational.eval c)) (Interp_raddv : list Z * list Z -> list Z) (HInterp_raddv : forall arg, is_bounded_by2 tight_bounds arg = true -> is_bounded_by loose_bounds (Interp_raddv arg) = true - /\ Interp_raddv arg = addmod arg) - (submod : list Z * list Z -> list Z) + /\ Interp_raddv arg = addmod (fst arg) (snd arg)) + (submod : list Z -> list Z -> list Z) (Hsubmod - : forall fg, - length (fst fg) = n -> length (snd fg) = n -> - (eval (submod fg)) mod (s - Associational.eval c) - = (eval (fst fg) - eval (snd fg)) mod (s - Associational.eval c)) + : forall f g, + length f = n -> length g = n -> + (eval (submod f g)) mod (s - Associational.eval c) + = (eval f - eval g) mod (s - Associational.eval c)) (Interp_rsubv : list Z * list Z -> list Z) (HInterp_rsubv : forall arg, is_bounded_by2 tight_bounds arg = true -> is_bounded_by loose_bounds (Interp_rsubv arg) = true - /\ Interp_rsubv arg = submod arg) + /\ Interp_rsubv arg = submod (fst arg) (snd arg)) (oppmod : list Z -> list Z) (Hoppmod : forall f, @@ -1773,10 +1770,12 @@ Module Ring. | [ |- _ = _ :> Z ] => first [ reflexivity | rewrite <- m_eq; reflexivity ] | [ H : context[?x] |- Fdecode ?x = _ ] => rewrite H | [ H : context[?x _] |- Fdecode (?x _) = _ ] => rewrite H + | [ H : context[?x _ _] |- Fdecode (?x _ _) = _ ] => rewrite H | _ => progress cbv [Fdecode] | [ |- _ = _ :> F _ ] => apply F.eq_to_Z_iff | _ => progress autorewrite with push_FtoZ | _ => rewrite m_eq + | [ H : context[?x _ _] |- context[eval (?x _ _)] ] => rewrite H | [ H : context[?x _] |- context[eval (?x _)] ] => rewrite H | [ H : context[?x] |- context[eval ?x] ] => rewrite H | [ |- context[List.length ?x] ] @@ -4187,347 +4186,69 @@ Module Compilers. : @Compilers.Uncurried.expr.default.Expr R := expr.CallWithContinuation (@ident.untranslate _) (@ident.fst) (@ident.snd) e k. - (** It's not clear how to "plug in the identity continuation" - for the CPS'd form of an expression of type [((A -> B) -> C) - -> D]. So we must describe types of at most second order - functions, so that we can write a uniform "plug in the - identity continuation" transformation. *) - Module second_order. - Import Compilers.type. - Module Import type. - Inductive flat_type := - | type_primitive (_ : primitive) - | prod (_ : flat_type) (_ : flat_type) - | list (_ : flat_type). - Inductive argtype := - | flat_arg (_ : flat_type) - | arrow_arg (s : flat_type) (d : argtype) - | prod_arg (_ _ : argtype). - Inductive type := - | flat (_ : flat_type) - | arrow (s : argtype) (d : type). - End type. - - Module Export Coercions. - Coercion type_primitive : primitive >-> flat_type. - Coercion flat_arg : flat_type >-> argtype. - Coercion flat : flat_type >-> type. - End Coercions. - Notation flat_type := flat_type. - Notation argtype := argtype. - Notation type := type. - - Fixpoint flat_to_type (t : flat_type) : Compilers.type.type - := match t with - | type_primitive x => x - | prod A B => Compilers.type.prod (flat_to_type A) (flat_to_type B) - | list A => Compilers.type.list (flat_to_type A) - end. - - Fixpoint arg_to_type (t : argtype) : Compilers.type.type - := match t with - | flat_arg t => flat_to_type t - | arrow_arg s d => Compilers.type.arrow (flat_to_type s) (arg_to_type d) - | prod_arg A B => Compilers.type.prod (arg_to_type A) (arg_to_type B) - end. - - Fixpoint to_type (t : type) : Compilers.type.type - := match t with - | flat t => flat_to_type t - | arrow s d => Compilers.type.arrow (arg_to_type s) (to_type d) - end. - - Fixpoint flat_of_type (t : Compilers.type.type) : option flat_type - := match t with - | Compilers.type.type_primitive x => @Some flat_type x - | Compilers.type.prod A B - => match flat_of_type A, flat_of_type B with - | Some A, Some B => @Some flat_type (prod A B) - | _, _ => None - end - | Compilers.type.arrow s d => None - | Compilers.type.list A - => option_map list (flat_of_type A) - end. - - Fixpoint arg_of_type (t : Compilers.type.type) : option argtype - := match t with - | Compilers.type.prod A B as t - => match flat_of_type t, arg_of_type A, arg_of_type B with - | Some t, _, _ - => @Some argtype t - | None, Some A, Some B - => @Some argtype (prod_arg A B) - | _, _, _ => None - end - | Compilers.type.arrow s d - => match flat_of_type s, arg_of_type d with - | Some s, Some d => Some (arrow_arg s d) - | _, _ => None - end - | Compilers.type.type_primitive _ as t - | Compilers.type.list _ as t - => option_map flat_arg (flat_of_type t) - end. - - Fixpoint of_type (t : Compilers.type.type) : option type - := match t with - | Compilers.type.arrow s d - => match arg_of_type s, of_type d with - | Some s, Some d => Some (arrow s d) - | _, _ => None - end - | Compilers.type.prod _ _ as t - | Compilers.type.type_primitive _ as t - | Compilers.type.list _ as t - => option_map flat (flat_of_type t) - end. - - Fixpoint try_transport_flat_of_type P (t : Compilers.type.type) - : P t -> option { t' : _ & P (flat_to_type t') } - := match t with - | Compilers.type.type_primitive x - => fun v => Some (existT _ (x : flat_type) v) - | Compilers.type.prod A B - => fun v - => match try_transport_flat_of_type (fun a => P (a * _)%ctype) A v with - | Some (existT A v) - => match try_transport_flat_of_type (fun b => P (_ * b)%ctype) B v with - | Some (existT B v) - => Some (existT _ (prod A B) v) - | None => None - end - | None => None - end - | Compilers.type.arrow s d => fun _ => None - | Compilers.type.list A - => fun v - => option_map - (fun '(existT A v) => existT (fun t => P (flat_to_type t)) (list A) v) - (try_transport_flat_of_type (fun a => P (Compilers.type.list a)) A v) - end. - - Fixpoint try_transport_arg_of_type P (t : Compilers.type.type) - : P t -> option { t' : _ & P (arg_to_type t') } - := match t with - | Compilers.type.prod A B as t - => fun v - => match try_transport_flat_of_type P t v with - | Some (existT t v) => Some (existT (fun t' => P (arg_to_type t')) t v) - | None - => match try_transport_arg_of_type (fun a => P (a * _)%ctype) A v with - | Some (existT A v) - => match try_transport_arg_of_type (fun b => P (_ * b)%ctype) B v with - | Some (existT B v) - => Some (existT _ (prod_arg A B) v) - | None => None - end - | None => None - end - end - | Compilers.type.arrow s d - => fun v - => match try_transport_flat_of_type (fun s => P (s -> _)%ctype) s v with - | Some (existT s v) - => match try_transport_flat_of_type (fun d => P (_ -> d)%ctype) d v with - | Some (existT d v) - => Some (existT (fun t' => P (arg_to_type t')) (arrow_arg s d) v) - | None => None - end - | None => None - end - | Compilers.type.type_primitive _ as t - | Compilers.type.list _ as t - => fun v - => option_map - (fun '(existT t v) => existT (fun t => P (arg_to_type t)) (flat_arg t) v) - (try_transport_flat_of_type P t v) - end. - - Fixpoint try_transport_of_type P (t : Compilers.type.type) - : P t -> option { t' : _ & P (to_type t') } - := match t with - | Compilers.type.arrow s d - => fun v - => match try_transport_arg_of_type (fun s => P (s -> _)%ctype) s v with - | Some (existT s v) - => match try_transport_of_type (fun d => P (_ -> d)%ctype) d v with - | Some (existT d v) - => Some (existT (fun t' => P (to_type t')) (arrow s d) v) - | None => None - end - | None => None - end - | Compilers.type.prod _ _ as t - | Compilers.type.type_primitive _ as t - | Compilers.type.list _ as t - => fun v - => option_map - (fun '(existT t v) => existT (fun t => P (to_type t)) (flat t) v) - (try_transport_flat_of_type P t v) - end. - End second_order. - Import second_order.Coercions. - - Fixpoint untranslate_translate_flat - (P : Compilers.type.type -> Type) - {R} - {t : second_order.flat_type} - (e : P (second_order.to_type t)) - {struct t} - : P (type.untranslate R (type.translate (second_order.to_type t))) - := match t return P (second_order.to_type t) - -> P (type.untranslate R (type.translate (second_order.to_type t))) - with - | second_order.type.type_primitive x => id - | second_order.type.prod A B - => fun ab : P (second_order.flat_to_type A * second_order.flat_to_type B)%ctype - => @untranslate_translate_flat - (fun a => P (a * _)%ctype) - R A - (@untranslate_translate_flat - (fun b => P (_ * b)%ctype) - R B - ab) - | second_order.type.list A - => @untranslate_translate_flat - (fun t => P (Compilers.type.list t)) - R A - end e. - - Fixpoint untranslate_translate_flat' - (P : Compilers.type.type -> Type) - {R} - {t : second_order.flat_type} - (e : P (type.untranslate R (type.translate (second_order.to_type t)))) - {struct t} - : P (second_order.to_type t) - := match t return P (type.untranslate R (type.translate (second_order.to_type t))) - -> P (second_order.to_type t) - with - | second_order.type.type_primitive x => id - | second_order.type.prod A B - => fun ab : - (* ignore this line *) P (type.untranslate R (type.translate (second_order.flat_to_type A)) * type.untranslate R (type.translate (second_order.flat_to_type B)))%ctype - => @untranslate_translate_flat' - (fun a => P (a * _)%ctype) - R A - (@untranslate_translate_flat' - (fun b => P (_ * b)%ctype) - R B - ab) - | second_order.type.list A - => @untranslate_translate_flat' - (fun t => P (Compilers.type.list t)) - R A - end e. - - Definition transport_final_codomain_flat P {t} - : P (second_order.flat_to_type t) - -> P (type.final_codomain (second_order.flat_to_type t)) - := match t with - | second_order.type.type_primitive x => id - | second_order.type.prod x x0 => id - | second_order.type.list x => id - end. - - Definition transport_final_codomain_flat' P {t} - : P (type.final_codomain (second_order.flat_to_type t)) - -> P (second_order.flat_to_type t) - := match t with - | second_order.type.type_primitive x => id - | second_order.type.prod x x0 => id - | second_order.type.list x => id + Local Notation iffT A B := ((A -> B) * (B -> A))%type. + (** We can only "plug in the identity continuation" for flat + (arrow-free) types. (Actually, we know how to do it in a + very ad-hoc way for types of at-most second-order functions; + see git history. This is much simpler.) *) + Fixpoint try_untranslate_translate {R} {t} + : option (forall (P : Compilers.type.type -> Type), + iffT (P (type.untranslate R (type.translate t))) (P t)) + := match t return option (forall (P : Compilers.type.type -> Type), + iffT (P (type.untranslate R (type.translate t))) (P t)) with + | Compilers.type.type_primitive x + => Some (fun P => ((fun v => v), (fun v => v))) + | type.arrow s d => None + | Compilers.type.prod A B + => (fA <- (@try_untranslate_translate _ A); + fB <- (@try_untranslate_translate _ B); + Some + (fun P + => let fA := fA (fun A => P (Compilers.type.prod A (type.untranslate R (type.translate B)))) in + let fB := fB (fun B => P (Compilers.type.prod A B)) in + ((fun v => fst fB (fst fA v)), + (fun v => snd fA (snd fB v)))))%option + | Compilers.type.list A + => (fA <- (@try_untranslate_translate R A); + Some (fun P => fA (fun A => P (Compilers.type.list A))))%option end. - Fixpoint untranslate_translate_arg - {var} - {R} - {t : second_order.argtype} - (e : @Compilers.Uncurried.expr.default.expr var (second_order.arg_to_type t)) - {struct t} - : @Compilers.Uncurried.expr.default.expr var (type.untranslate R (type.translate (second_order.arg_to_type t))) - := match t return Compilers.Uncurried.expr.default.expr (second_order.arg_to_type t) - -> Compilers.Uncurried.expr.default.expr (type.untranslate R (type.translate (second_order.arg_to_type t))) - with - | second_order.type.flat_arg t - => untranslate_translate_flat _ - | second_order.type.arrow_arg s d - => fun e' - => Abs (fun v : - (* ignore this line *) var (type.untranslate R (type.translate (second_order.flat_to_type s)) * (type.untranslate R (type.translate (second_order.arg_to_type d)) -> R))%ctype - => (ident.snd @@ Var v) - @ (@untranslate_translate_arg - var R d - (e' @ (untranslate_translate_flat' _ (ident.fst @@ Var v)))))%expr - | second_order.type.prod_arg A B - => fun e' : expr.default.expr (second_order.arg_to_type A * second_order.arg_to_type B) - => ((Abs (fun a => Abs (fun b => (Var a, Var b)))) - @ (@untranslate_translate_arg var R A (ident.fst @@ e')) - @ (@untranslate_translate_arg var R B (ident.snd @@ e')))%expr - end e. - Local Notation "x <-- e1 ; e2" := (expr.splice e1 (fun x => e2%cpsexpr)) : cpsexpr_scope. - Fixpoint call_fun_with_id_continuation' - {var} - {t : second_order.type} - (R := type.final_codomain (second_order.to_type t)) - (e : @expr (fun t0 => - @Uncurried.expr.expr default.ident.ident var (type.untranslate R t0)) - (type.translate (second_order.to_type t))) - {struct t} - : @Compilers.Uncurried.expr.default.expr var (second_order.to_type t) - := match t - return (@expr (fun t0 => - @Uncurried.expr.expr default.ident.ident var (type.untranslate (type.final_codomain (second_order.to_type t)) t0)) - (type.translate (second_order.to_type t))) - -> @Compilers.Uncurried.expr.default.expr var (second_order.to_type t) - with - | second_order.type.flat t - => fun e' - => transport_final_codomain_flat' - _ - (@call_with_continuation - var _ _ e' - (fun e'' => transport_final_codomain_flat _ (untranslate_translate_flat' _ e''))) - | second_order.type.arrow s d - => fun e' : - (* ignore this line *) expr (type.translate (second_order.arg_to_type s) * (type.translate (second_order.to_type d) --->) --->) - => Abs (s:=second_order.arg_to_type s) (d:=second_order.to_type d) - (fun v - => @call_fun_with_id_continuation' - var d - (f <-- e'; + Definition call_fun_with_id_continuation' + {s d} + : option (forall var + (e : @expr _ (type.translate (s -> d))), + @Compilers.Uncurried.expr.default.expr var (s -> d)) + := (fs <- (@try_untranslate_translate _ s); + fd <- (@try_untranslate_translate _ d); + Some + (fun var e + => let P := @Compilers.Uncurried.expr.default.expr var in + Abs + (fun v : var s + => call_with_continuation + ((f <-- e; k <- (λ r, expr.Halt r); - p <- (untranslate_translate_arg (Var v), k); + p <- (snd (fs P) (Var v), k); f @ p)%cpsexpr) - end e. - Definition CallFunWithIdContinuation' - {t : second_order.type} - (e : Expr (type.translate (second_order.to_type t))) - : @Compilers.Uncurried.expr.default.Expr (second_order.to_type t) - := fun var => @call_fun_with_id_continuation' _ t (e _). + (fst (fd P)))))%option. - Definition CallFunWithIdContinuation - {t} - (e : Expr (type.translate t)) - := match second_order.try_transport_of_type (fun t => Expr (type.translate t)) _ - e - as o return match o with None => _ | _ => _ end - with - | Some v => CallFunWithIdContinuation' (projT2 v) - | None => I - end. + Definition call_fun_with_id_continuation + {var} + {s d} (e : @expr _ (type.translate (s -> d))) + : option (@Compilers.Uncurried.expr.default.expr var (s -> d)) + := option_map + (fun f => f _ e) + (@call_fun_with_id_continuation' s d). - Definition CallFunWithIdContinuation_opt - {t} - (e : Expr (type.translate t)) - : option (@Compilers.Uncurried.expr.default.Expr t) - := (e' <- (second_order.try_transport_of_type - (fun t => Expr (type.translate t)) _ - e); - type.try_transport _ _ _ (CallFunWithIdContinuation' (projT2 e')))%option. + Definition CallFunWithIdContinuation + {s d} + (e : Expr (type.translate (s -> d))) + : option (@Compilers.Uncurried.expr.default.Expr (s -> d)) + := option_map + (fun f var => f _ (e _)) + (@call_fun_with_id_continuation' s d). End default. Include default. End CPS. @@ -6404,19 +6125,20 @@ Module test3. (z * z)) in pose v as E. vm_compute in E. - pose (PartialEvaluate false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. + pose (option_map (PartialEvaluate false) (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. vm_compute in E'. lazymatch (eval cbv delta [E'] in E') with - | (fun var : type -> Type => - (λ x : var (type.type_primitive type.Z), - expr_let x0 := Var x * Var x in - expr_let x1 := Var x0 * Var x0 in - expr_let x2 := Var x1 * Var x1 in - expr_let x3 := Var x2 * Var x2 in - Var x3 * Var x3)%expr) + | (Some + (fun var : type -> Type => + (λ x : var (type.type_primitive type.Z), + expr_let x0 := Var x * Var x in + expr_let x1 := Var x0 * Var x0 in + expr_let x2 := Var x1 * Var x1 in + expr_let x3 := Var x2 * Var x2 in + Var x3 * Var x3)%expr)) => idtac end. - pose (PartialEvaluateWithBounds1 E' (Some r[0~>10]%zrange)) as E'''. + pose (PartialEvaluateWithBounds1 (invert_Some E') (Some r[0~>10]%zrange)) as E'''. lazy in E'''. lazymatch (eval cbv delta [E'''] in E''') with | (fun var : type -> Type => @@ -6441,10 +6163,10 @@ Module test4. (xz :: xz :: nil)) in pose v as E. vm_compute in E. - pose (PartialEvaluate false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. + pose (option_map (PartialEvaluate false) (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. lazy in E'. clear E. - pose (PartialEvaluateWithBounds1 E' (Some [Some r[0~>10]%zrange],Some [Some r[0~>10]%zrange])) as E''. + pose (PartialEvaluateWithBounds1 (invert_Some E') (Some [Some r[0~>10]%zrange],Some [Some r[0~>10]%zrange])) as E''. lazy in E''. lazymatch (eval cbv delta [E''] in E'') with | (fun var : type -> Type => @@ -6468,7 +6190,7 @@ Module test5. x) in pose v as E. vm_compute in E. - pose (ReassociateSmallConstants.Reassociate (2^8) (PartialEvaluate false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))))) as E'. + pose (ReassociateSmallConstants.Reassociate (2^8) (PartialEvaluate false (invert_Some (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))))) as E'. lazy in E'. clear E. lazymatch (eval cbv delta [E'] in E') with @@ -6495,7 +6217,7 @@ Module test6. pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'. lazy in E'. clear E. - pose (PartialEvaluate false E') as E''. + pose (PartialEvaluate false (invert_Some E')) as E''. lazy in E''. lazymatch eval cbv delta [E''] in E'' with | fun var : type -> Type => (λ x : var (type.type_primitive type.Z), Var x)%expr @@ -6650,7 +6372,7 @@ Create HintDb reify_gen_cache. Derive carry_mul_gen SuchThat (forall (limbwidth_num limbwidth_den : Z) - (fg : list Z * list Z) + (f g : list Z) (n : nat) (s : Z) (c : list (Z * Z)) @@ -6658,11 +6380,11 @@ Derive carry_mul_gen (idxs : list nat) (len_idxs : nat), Interp (t:=type.reify_type_of carry_mulmod) - carry_mul_gen limbwidth_num limbwidth_den s c n len_c idxs len_idxs fg - = carry_mulmod limbwidth_num limbwidth_den s c n len_c idxs len_idxs fg) + carry_mul_gen limbwidth_num limbwidth_den s c n len_c idxs len_idxs f g + = carry_mulmod limbwidth_num limbwidth_den s c n len_c idxs len_idxs f g) As carry_mul_gen_correct. Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. -Hint Extern 1 (_ = carry_mulmod _ _ _ _ _ _ _ _ _) => simple apply carry_mul_gen_correct : reify_gen_cache. +Hint Extern 1 (_ = carry_mulmod _ _ _ _ _ _ _ _ _ _) => simple apply carry_mul_gen_correct : reify_gen_cache. Derive carry_gen SuchThat (forall (limbwidth_num limbwidth_den : Z) @@ -6696,14 +6418,14 @@ Hint Extern 1 (_ = encodemod _ _ _ _ _ _ _) => simple apply encode_gen_correct : Derive add_gen SuchThat (forall (limbwidth_num limbwidth_den : Z) - (fg : list Z * list Z) + (f g : list Z) (n : nat), Interp (t:=type.reify_type_of addmod) - add_gen limbwidth_num limbwidth_den n fg - = addmod limbwidth_num limbwidth_den n fg) + add_gen limbwidth_num limbwidth_den n f g + = addmod limbwidth_num limbwidth_den n f g) As add_gen_correct. Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = addmod _ _ _ _) => simple apply add_gen_correct : reify_gen_cache. +Hint Extern 1 (_ = addmod _ _ _ _ _) => simple apply add_gen_correct : reify_gen_cache. Derive sub_gen SuchThat (forall (limbwidth_num limbwidth_den : Z) (n : nat) @@ -6711,13 +6433,13 @@ Derive sub_gen (c : list (Z * Z)) (len_c : nat) (coef : Z) - (fg : list Z * list Z), + (f g : list Z), Interp (t:=type.reify_type_of submod) - sub_gen limbwidth_num limbwidth_den s c n len_c coef fg - = submod limbwidth_num limbwidth_den s c n len_c coef fg) + sub_gen limbwidth_num limbwidth_den s c n len_c coef f g + = submod limbwidth_num limbwidth_den s c n len_c coef f g) As sub_gen_correct. Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = submod _ _ _ _ _ _ _ _) => simple apply sub_gen_correct : reify_gen_cache. +Hint Extern 1 (_ = submod _ _ _ _ _ _ _ _ _) => simple apply sub_gen_correct : reify_gen_cache. Derive opp_gen SuchThat (forall (limbwidth_num limbwidth_den : Z) @@ -6773,6 +6495,7 @@ Derive id_gen Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. Hint Extern 1 (_ = expanding_id _ _) => simple apply id_gen_correct : reify_gen_cache. +Import Uncurry. Module Pipeline. Import GeneralizeVar. Inductive ErrorMessage := @@ -6808,16 +6531,17 @@ Module Pipeline. expr.Interp (@for_reification.ident.interp) E. Admitted. - Definition BoundsPipeline_0_or_1 + Definition BoundsPipeline (with_dead_code_elimination : bool := true) (with_subst01 : bool) + relax_zrange {t} (E : Expr t) - {t'} - (CheckedPartialEvaluateWithBounds : _ -> _ -> _ + partial.data t') + arg_bounds out_bounds - : ErrorT (Expr t) - := let E := CPS.CallFunWithIdContinuation_opt (CPS.Translate E) in + : ErrorT (Expr (type.uncurry t)) + := let E := expr.Uncurry E in + let E := CPS.CallFunWithIdContinuation (CPS.Translate E) in match E with | Some E => (let E := PartialEvaluate false E in @@ -6832,7 +6556,7 @@ Module Pipeline. let E := FromFlat e in let E := if with_subst01 then Subst01.Subst01 E else E in let E := ReassociateSmallConstants.Reassociate (2^8) E in - let E := CheckedPartialEvaluateWithBounds E out_bounds in + let E := CheckedPartialEvaluateWithBounds1 relax_zrange E arg_bounds out_bounds in match E with | inl E => Success E | inr b @@ -6841,45 +6565,14 @@ Module Pipeline. | None => Error (Type_too_complicated_for_cps t) end. - Definition BoundsPipeline - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - {s d} - (E : Expr (s -> d)) - arg_bounds - out_bounds - : ErrorT (Expr (s -> d)) - := BoundsPipeline_0_or_1 - (*with_dead_code_elimination*) - with_subst01 - E - (fun E => CheckedPartialEvaluateWithBounds1 relax_zrange E arg_bounds) - out_bounds. - - Definition BoundsPipelineConst - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - {t} - (E : Expr t) - out_bounds - : ErrorT (Expr t) - := BoundsPipeline_0_or_1 - (*with_dead_code_elimination*) - with_subst01 - E - (fun E => CheckedPartialEvaluateWithBounds0 relax_zrange E) - out_bounds. - Lemma BoundsPipeline_correct (with_dead_code_elimination : bool := true) (with_subst01 : bool) relax_zrange (Hrelax : forall r r' z : zrange, (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {s d} - (e : Expr (s -> d)) + {t} + (e : Expr t) arg_bounds out_bounds rv @@ -6887,9 +6580,9 @@ Module Pipeline. : forall arg (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true - /\ Interp rv arg = Interp e arg. + /\ Interp rv arg = app_curried (Interp e) arg. Proof. - cbv [BoundsPipeline BoundsPipeline_0_or_1 Let_In] in *; + cbv [BoundsPipeline Let_In] in *; repeat match goal with | [ H : match ?x with _ => _ end = Success _ |- _ ] => destruct x eqn:?; cbv beta iota in H; [ | congruence ]; @@ -6913,15 +6606,15 @@ Module Pipeline. Qed. Definition BoundsPipeline_correct_transT - {s d} + {t} arg_bounds out_bounds - (InterpE : type.interp s -> type.interp d) - (rv : Expr (s -> d)) + (InterpE : type.interp t) + (rv : Expr (type.uncurry t)) := forall arg (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true - /\ Interp rv arg = InterpE arg. + /\ Interp rv arg = app_curried InterpE arg. Lemma BoundsPipeline_correct_trans (with_dead_code_elimination : bool := true) @@ -6930,14 +6623,14 @@ Module Pipeline. (Hrelax : forall r r' z : zrange, (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {s d} - (e : Expr (s -> d)) + {t} + (e : Expr t) arg_bounds out_bounds - (InterpE : type.interp s -> type.interp d) + (InterpE : type.interp t) (InterpE_correct : forall arg (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), - Interp e arg = InterpE arg) + app_curried (Interp e) arg = app_curried InterpE arg) rv (Hrv : BoundsPipeline (*with_dead_code_elimination*) with_subst01 relax_zrange e arg_bounds out_bounds = Success rv) : BoundsPipeline_correct_transT arg_bounds out_bounds InterpE rv. @@ -6950,17 +6643,17 @@ Module Pipeline. (with_dead_code_elimination : bool := true) (with_subst01 : bool) relax_zrange - {s d} - (E : for_reification.Expr (s -> d)) + {t} + (E : for_reification.Expr t) arg_bounds out_bounds - : ErrorT (Expr (s -> d)) + : ErrorT (Expr (type.uncurry t)) := let E := PrePipeline E in @BoundsPipeline (*with_dead_code_elimination*) with_subst01 relax_zrange - s d E arg_bounds out_bounds. + t E arg_bounds out_bounds. Lemma BoundsPipeline_full_correct (with_dead_code_elimination : bool := true) @@ -6968,8 +6661,8 @@ Module Pipeline. relax_zrange (Hrelax : forall r r' z : zrange, (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {s d} - (E : for_reification.Expr (s -> d)) + {t} + (E : for_reification.Expr t) arg_bounds out_bounds rv @@ -6977,111 +6670,12 @@ Module Pipeline. : forall arg (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true - /\ Interp rv arg = for_reification.Interp E arg. + /\ Interp rv arg = app_curried (for_reification.Interp E) arg. Proof. cbv [BoundsPipeline_full] in *. eapply BoundsPipeline_correct_trans; [ eassumption | | eassumption.. ]. intros; erewrite PrePipeline_correct; reflexivity. Qed. - - Lemma BoundsPipelineConst_correct - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - (Hrelax : forall r r' z : zrange, - (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {d} - (e : Expr d) - bounds - rv - (Hrv : BoundsPipelineConst (*with_dead_code_elimination*) with_subst01 relax_zrange e bounds = Success rv) - : ZRange.type.option.is_bounded_by bounds (Interp rv) = true - /\ Interp rv = Interp e. - Proof. - cbv [BoundsPipelineConst BoundsPipeline_0_or_1 Let_In] in *; - repeat match goal with - | [ H : match ?x with _ => _ end = Success _ |- _ ] - => destruct x eqn:?; cbv beta iota in H; [ | congruence ]; - let H' := fresh in - inversion H as [H']; clear H; rename H' into H - end. - { intros; - match goal with - | [ H : _ = _ |- _ ] - => eapply CheckedPartialEvaluateWithBounds0_Correct in H; - [ destruct H as [H0 H1] | .. ] - end; - [ - | eassumption || (try reflexivity).. ]. - refine (let H' := admit (* interp correctness *) in - conj _ (eq_trans H' _)); - clearbody H'. - { rewrite H'; eassumption. } - { rewrite H0. - exact admit. (* interp correctness *) } } - Qed. - - Definition BoundsPipelineConst_correct_transT - {t} - out_bounds - (InterpE : type.interp t) - (rv : Expr t) - := ZRange.type.option.is_bounded_by out_bounds (Interp rv) = true - /\ Interp rv = InterpE. - - Lemma BoundsPipelineConst_correct_trans - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - (Hrelax - : forall r r' z : zrange, - (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {t} - (e : Expr t) - out_bounds - (InterpE : type.interp t) - (InterpE_correct : Interp e = InterpE) - rv - (Hrv : BoundsPipelineConst (*with_dead_code_elimination*) with_subst01 relax_zrange e out_bounds = Success rv) - : BoundsPipelineConst_correct_transT out_bounds InterpE rv. - Proof. - rewrite <- InterpE_correct. - eapply @BoundsPipelineConst_correct; eassumption. - Qed. - - Definition BoundsPipelineConst_full - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - {t} - (E : for_reification.Expr t) - out_bounds - : ErrorT (Expr t) - := let E := PrePipeline E in - @BoundsPipelineConst - (*with_dead_code_elimination*) - with_subst01 - relax_zrange - t E out_bounds. - - Lemma BoundsPipelineConst_full_correct - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - (Hrelax : forall r r' z : zrange, - (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {t} - (E : for_reification.Expr t) - out_bounds - rv - (Hrv : BoundsPipelineConst_full (*with_dead_code_elimination*) with_subst01 relax_zrange E out_bounds = Success rv) - : ZRange.type.option.is_bounded_by out_bounds (Interp rv) = true - /\ Interp rv = for_reification.Interp E. - Proof. - cbv [BoundsPipelineConst_full] in *. - eapply BoundsPipelineConst_correct_trans; [ eassumption | | eassumption.. ]. - intros; erewrite PrePipeline_correct; reflexivity. - Qed. End Pipeline. Definition round_up_bitwidth_gen (possible_values : list Z) (bitwidth : Z) : option Z @@ -7176,19 +6770,13 @@ Section rcarry_mul. relax_zrange rop%Expr in_bounds out_bounds). - Notation BoundsPipelineConst rop out_bounds - := (Pipeline.BoundsPipelineConst - (*false*) true - relax_zrange - rop%Expr out_bounds). - Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (type.reify_type_of op%function)) Hrop + := (fun rv (rop : Expr (type.reify_type_of op)) Hrop => @Pipeline.BoundsPipeline_correct_trans (*false*) true relax_zrange (relax_zrange_gen_good _) - _ _ + _ rop in_bounds out_bounds @@ -7196,19 +6784,6 @@ Section rcarry_mul. Hrop rv) (only parsing). - Notation BoundsPipelineConst_correct out_bounds op - := (fun rv (rop : Expr (type.reify_type_of op)) Hrop - => @Pipeline.BoundsPipelineConst_correct_trans - (*false*) true - relax_zrange - (relax_zrange_gen_good _) - _ - rop%Expr - out_bounds - op - Hrop rv) - (only parsing). - (* N.B. We only need [rcarry_mul] if we want to extract the Pipeline; otherwise we can just use [rcarry_mul_correct] *) Definition rcarry_mul := BoundsPipeline @@ -7260,12 +6835,14 @@ Section rcarry_mul. (encodemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). Definition rzero_correct - := BoundsPipelineConst_correct + := BoundsPipeline_correct + tt (Some tight_bounds) (zeromod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). Definition rone_correct - := BoundsPipelineConst_correct + := BoundsPipeline_correct + tt (Some tight_bounds) (onemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). @@ -7363,8 +6940,8 @@ Section rcarry_mul. (Interp raddv) (Interp rsubv) (Interp roppv) - (Interp rzerov) - (Interp ronev) + (Interp rzerov tt) + (Interp ronev tt) (Interp rencodev). Theorem Good : GoodT. @@ -7382,9 +6959,12 @@ Section rcarry_mul. | apply conj | progress intros | progress cbv [onemod zeromod] + | eapply Hrzerov (* to handle diff with whether or not correctness asks for boundedness of tt *) + | eapply Hronev (* to handle diff with whether or not correctness asks for boundedness of tt *) | match goal with | [ |- ?x = ?x ] => reflexivity | [ |- ?x = ?ev ] => is_evar ev; reflexivity + | [ |- ZRange.type.option.is_bounded_by tt tt = true ] => reflexivity end ]. Qed. End make_ring. @@ -7400,6 +6980,9 @@ Proof. cbv [pointwise_relation]; intros; subst; trivial. Qed. Ltac peel_interp_app _ := lazymatch goal with + | [ |- ?R' (?InterpE ?arg) (?f ?arg) ] + => apply fg_equal_rel; [ | reflexivity ]; + try peel_interp_app () | [ |- ?R' (Interp ?ev) (?f ?x) ] => let sv := type of x in let fx := constr:(f x) in @@ -7420,10 +7003,9 @@ Ltac peel_interp_app _ := end ] ] end. Ltac pre_cache_reify _ := + cbv [app_curried]; let arg := fresh "arg" in - (tryif intros arg _ - then apply fg_equal_rel; [ | reflexivity ] - else hnf); + intros arg _; peel_interp_app (); [ lazymatch goal with | [ |- ?R (Interp ?ev) _ ] @@ -7733,7 +7315,6 @@ Module X25519_64. base_51_encode_correct. Print Assumptions base_51_good. - Import PrintingNotations. Print base_51_carry_mul. (*base_51_carry_mul = @@ -8304,12 +7885,12 @@ Module MontgomeryReduction. else res. Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (type.reify_type_of op%function)) Hrop + := (fun rv (rop : Expr (type.reify_type_of op)) Hrop => @Pipeline.BoundsPipeline_correct_trans false (* subst01 *) relax_zrange (relax_zrange_gen_good _) - _ _ + _ rop in_bounds out_bounds |